-
-
Notifications
You must be signed in to change notification settings - Fork 136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Native Model Parallellism in equinox #825
Comments
So this is something which Haliax, a downstream library of Equinox, explicitly adds. I'd be open to open suggestions on how to accomplish something similar within Equinox, although I'm also concious of not wanting to step on Haliax's toes. |
The problem with depending on downstream libs is that often the maintainers move on with life and are unable to contribute as much for reasons - so having crucial features mostly tied to the core framework would make (Not to mention that porting an I'm not sure what a good API here is, tbh. Personally, I'm strongly considering spinning off my own lib for parallelization w/ equinox, but I'd rather prefer to see some sort of interface for autoparallelization within I wonder - for a quick and dirty hack that is compatible with current codebases, could we just somehow integrate a |
(Disclaimer: Haliax is mine.) I really think you have to do something like what Haliax (or really even Flax) does in the general case, which is associate semantic names with at least a subset of axes and then map those to the physical mesh axes. In Haliax I decided to go "all in" on names (which is not proving to be super popular) but there is a middle ground like flax where you could use a jmp-style object (or global state like in flax) to hold the semantic-to-physical mapping and then (like flax) make a (Totally get where you're coming from w.r.t. people moving on. I will point out that I still merge PRs into my last "big" library Breeze and it's been going for almost 15 years, so I have a decent track record of not totally abandoning things!) |
Thinking more about it, I had been in the process of refactoring how Haliax worked to be a bit more jmp like anyway (making a "mesh env" that had a mesh and an axis mapping.) With a bit more work I could break it out into a library, if you wanted to collaborate on that. |
If the question is directed at me -- happy to help out with anything needed from the Equinox side / if you think it's best to upstream anything into Equinox itself. If you're planning a totally separate library then I think I have enough commitments right now, and will politely decline 😅 |
No, I meant @neel04 but would be happy to discuss upstreaming into Equinox. 😄 IMHO the right thing to do is to make it as a separate lib, and iterate until we get it right. Then, if you decide you like it, you can bring it into equinox. |
Haha! In that case, SGTM! |
Haha I know - this wasn't in reference to you 🙂 Rather, I just feel that approach of leveraging 3rd party support almost always leads to fragmentation. I have strong opinions about torch's ecosystem precisely because of this - 3rd party libs often don't operate on shared abstractions which makes things tricky to then operate with other libs, leading to unforeseen edge cases and an awful dev experience - where it feels like you're perpetually in 'integration hell'. I feel like equinox might be slipping into a similar mistake here. Parallelization is the lifeblood of JAX, so it should have first-class support in any JAX based framework. Ideally, as strong as Looking at hmmm.. I wonder - is there any iterative way we can inject I'll have to think more about, but I'm imagining it like a wrapper hook on the forward effectively, wherein we insert |
I've been playing around with a couple approaches in this Colab. I've done a minimal implementation of approach For consecutive layers, we need to alternate the sharding to minimize communication, as taken from the Megatron paper. So, we need a way to maintain state and track the index of the
Maybe one can develop some custom
Here's how a sample spec would look: model_sharding_rule = TreePathShardingRule(
('embedding', P('fsdp', 'tp')),
('lm_head/kernel', P('tp', 'fsdp')),
# Megatron style feedfoward sharding
('mlp/(up|gate)_proj/kernel', P('fsdp', 'tp')),
('mlp/down_proj/kernel', P('tp', 'fsdp')),
# Attention should be sharded by heads
('self_attn/(k|q|v)_proj/kernel', P('fsdp', 'tp')),
('self_attn/o_proj/kernel', P('tp', 'fsdp')),
('norm', P()),
) None of these approaches feel right to me. Especially for the approach I can use some combination for my personal codebase, but I think we're just hitting the limitations of the toolset What do you guys think? |
Oftentimes, one wants to do a more general
n
-way data parallelism,m
-way model parallelism as helpfully explained in the official JAX docs.Here, the common convention is to alternatively shard the layers, as laid out in the Megatron paper. The linked JAX example also uses this:
Which reduces the communication required.
We don't really have an API for that in equinox. While an external library might be a better fit, ultimately I this is such a common usecase that is should be a core feature IMO.
Scalax has a rule based system where it tries to guess the "correct" axis, like the
FSDPShardingRule
and for a while I was thinking of simply porting it to work witheqx
, but it was proving rather hairy - its doing everything explicitly, because the primary usecase is for operating withflax
, and thus requires quite a bit of fiddling, plumbing and replacing things with filtered versions that I ultimately gave up.I think
eqx
should adopt a simpler, flexible API system wherein one can configure sharding through a configurable system and be able to apply it to arbitary PyTrees, without needing to explicitly provide sharding for every leaf.What do you think?
The text was updated successfully, but these errors were encountered: