-
-
Notifications
You must be signed in to change notification settings - Fork 136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialization of large models on multi-hosts environment #778
Comments
So one way to do this with the current API is to create the model 'skeleton' via eqx.filter_eval_shape(SomeModule, ...) and then fill in the parameters with I do this when translating model weights from PyTorch, for example. I'm not sure what the cleanest way of doing this in general is. Ideally it wouldn't require wrapping the constructors. These can be thought of as functions that return pytrees, so something respecting that in some way would be ideal. |
I see! I tried using |
If this ends up being the best way to do it, then yes! I think having an example would be really useful. Right now I'm not completely convinced it is the best way, though. It's certainly a lot of work. I'm wondering if there's some way we can directly create the arrays in the right way by wrapping the constructor in a filtered/tree-map'd version of the usual JAX operations here. ( |
I agree if there is an ergonomic way to initialize a shared model, that would be great. I am trying to figure out how to best do this, and here are some thoughts: The ideal scenario is we have a function that does the following def init_shard_model(ModelClass, mesh, sharding, ...) -> model:
... I have so far run into two complications:
I think this may be why the devs of I am gonna try a number of things in the coming weeks, any input will be awesome! |
This is indeed one of the main reasons I did names. (There are a few more but I won't belabor them here.) It basically defines this problem away, though it does mean you have to do weird gymnastics when you have square arrays that have semantically identical axes. But it's a price I'm happy to pay. (Obviously happy to have you over in Haliax land, but only if it appeals.) You could instead follow something more like what flax does, if you wanted. See for example, t5x. It's basically the same as Haliax except the names of the axes aren't carried around with the array and so you occasionally have to sprinkle more wsc's. Obviously I'm partial to Haliax but tastes vary. Regardless, either way the basic idea is to do initialization inside jit and use with_sharding_constraint and/or out_shardings to ensure things are correct. IMHO re (2) the right way to do any of this is to always always do as much as you can in a "global" way using jit, and only fall back on make_array_from_callback when you absolutely have to (or for data loading). |
Thank you both! To riff on @dlwh's last paragraph -- doing as much as possible inside JIT -- would something like @jax.jit
def f():
model = SomeModel(**hyperparams)
return eqx.filter_shard(model, shard_tree) work? By describing an appropriate (I can see that this would be a t5x way of doing things rather than a Haliax way of doing things.) |
I'd actually say Haliax and flax/t5x are closer to each other than what you're proposing. (Not saying that's necessarily a bad thing, just trying to clarify.) Haliax looks like this: class SomeModel(Module):
param: hax.NamedArray
# i use static methods but whatever
def __init__(...):
self.param = hax.random.normal( (Embed, Mlp), ...) # [Embed, Mlp] and
flax/t5x is more like: class SomeModel(Module):
@nn.compact
def __call__(...):
param = param_with_axes("param", init, (512, 2048), jnp.float32, axes=("embed", "mlp")) and param_with_axes would have something like this in it (simplifying a bit) physical_axes = map_logical_to_physical(global_axis_mapping, axes) # uses a global context manager
param = with_sharding_constraint(param, PartitionSpec(*physical_axes) so, like Haliax, it injects By comparison, you're proposing keeping sharding an explicit outer step that operates on entire modules trees. Haliax also supports doing it that way, but it's not really the default anymore. (In Haliax, when you do it that way, we can still use the map_logical_to_physical function because the module tree has the names in it, so it's relatively painless from a user's perspective. I think this would be harder in flax, but I could be wrong.) My suspicion is that what you're proposing will prove quite cumbersome compared to keeping sharding at the per-parameter declaration site (as in Flax/t5x or Haliax) for more complex models, but I'm happy to be proven wrong. |
Right! I'm just getting at the Flax-vs-Haliax distinction of whether there's a named array object available at call time. The thing is that I don't see another way to support this kind of per-parameter behaviour in an off-the-shelf library like Equinox. I think if per-parameter behaviour is required then the user usually needs to control the definition of every layer? (Or else use some To the turn the above into something actionable. I think I'd like to include an example on this topic as per @kazewong's suggestion, and I'm trying to figure out what approach to recommend. I'd like to include both (a) a pure-Equinox solution (whether that's my latest thing, or an (To wax a little philosophical, by the way -- this kind of thing comes up in other scenarios too, such as which parameters to apply a transform to: differentiate/vmap/etc. Now JAX's model here is either to specify things at the call site (e.g. |
Yeah, makes sense as a philosophy. In my experience, just as with |
Hi all,
I am wondering what is the preferred way to create a model that is too large to fit in a single device
As a reference starting point, if I use data parallelism, I will first create per-device data arrays, and use make_array from single_device_arrays to put them on the global mesh (This is basically following https://jax.readthedocs.io/en/latest/_autosummary/jax.make_array_from_single_device_arrays.html#jax.make_array_from_single_device_arrays)
Since by default the pre-defined modules in
equinox
will initialize the full set of parameters on every device, I cannot just follow the guide in https://docs.kidger.site/equinox/tricks/#custom-parameter-initialisation and update the parameters to the shared version after I create the model.My current way of bypassing this is to create a wrapper class of the
nn.Modules
I want to use, so I can create the sharded version of the parameters on each device, and then combine them as I would for the data parallelism case.Here is a minimal example for wrapping the
eqx.nn.Linear
classhttps://gist.github.com/kazewong/c976b48c5870d866496740341382acb5
Since the multi-host interface in Jax is still experimental, we probably don't want to put too much of this into the equinox core code. To make creating large models easier now, I think a wrapper class or a decorator is probably the easiest way, but I want to see what people think about this before submitting a PR. @patrick-kidger
The text was updated successfully, but these errors were encountered: