Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose attention weights' head dimension #830

Open
neel04 opened this issue Sep 5, 2024 · 5 comments
Open

Expose attention weights' head dimension #830

neel04 opened this issue Sep 5, 2024 · 5 comments

Comments

@neel04
Copy link

neel04 commented Sep 5, 2024

Currently, equinox handles attention heads opaquely - it reshapes QKV through the _project method to add the heads dimension.

However, sharding via the heads dimension is commonly used when parallelizing the model.

I feel that the {query | key | value}_proj should be splitted to expose the head dimension.

WDYT?

@patrick-kidger
Copy link
Owner

Sorry, it's not totally clear to me what change you're suggesting. Can you expand?

@neel04 neel04 changed the title Expose Q, K, V projections' heads Expose attention weights' head dimension Sep 6, 2024
@neel04
Copy link
Author

neel04 commented Sep 6, 2024

The W_{q | k | v} projections are 2D - of the general shape (query_size, num_heads * qk_size).

I feel that the head dimension should be explicit here - so the shape would be 3D of (query_size, num_heads, qk_size).

This might be a bit tricky to incorporate I suppose - but it's definitely quite helpful, for example sharding along the head dimension or weight sharing.

@patrick-kidger
Copy link
Owner

Ah, I see what you're saying!
So I think much like QKV fusion, this would unfortunately be a backward-incompatible change.

For specifically the purposes of sharding, then I think whatever we should do depends on whatever you and dlwh come up with in #825.

@neel04
Copy link
Author

neel04 commented Sep 7, 2024

🤷 Even with a specific sharding API, ideally one should only need to deal with the model PyTree. If someone wants to shard on the heads dimension, then you would still have to insert explicit reshapes during MHA computation to convert the 3D array created during sharding back to 2D for the _project method.

I suppose one could add a check to fold the heads dimension if the array is 3D in MHA... but that seems janky

@neel04
Copy link
Author

neel04 commented Sep 7, 2024

I wonder if at this point, it might be better to add another optimized version of eqx.nn.MultiHeadAttention - that would internally use jax's SDPA for more performance, expose heads for sharding and customization, fuse QKV and add @Artur-Galstyan's cache as well.

Users who want to explicitly adopt the newer features could switch over.

Or I suppose one needs a seperate lib of equinox utilities with such feature-complete modules.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants