Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched pytree dataclasses #242

Open
JeyRunner opened this issue Aug 8, 2024 · 5 comments
Open

Batched pytree dataclasses #242

JeyRunner opened this issue Aug 8, 2024 · 5 comments
Labels
feature New feature

Comments

@JeyRunner
Copy link

Hi, I really like this library for jax type annotations.
But for me one important use-case is missing, which is annotating added dimensions to custom pytree dataclasses.

Here is an example use case (adding a batch dimension to a dataclass)

@flax.struct.dataclass
class MyType:
  a: Float[Array, 'dims']

def fn_normal(data: MyType):
  # ...

# I would like to specify something like this (1 added first dimension to MyType)
def fn_batched(data: Shaped[MyType, 'batch_size']):
  return jax.lax.vmap(fn_normal, data)

Sadly the code above does not work. Is it possible to specify this in another way?

@patrick-kidger
Copy link
Owner

Not right now, although this seems like a reasonable request! Thinking this over, I don't think we want Shaped[some_dataclass, 'some dims'] to dynamically create a new class with the appropriate type annotations -- introspecting the type hints to add the extra dimensions is tricky.

Rather, I'm thinking we could arrange for the __instancecheck__ to actually use jax.vmap internally, and do the isinstance check from within the jax.vmap region. This would (a) probably be easiest and (b) would allow us to support any pytree -- Shaped[some_pytree, 'some dims'] -- not just any dataclass.

I'd be happy to consider a PR on this!

(Independently of the above, I'd really recommend using equinox.Module instead of flax.struct.dataclass. The former fixes a lot of footguns that the other doesn't, e.g. on bound method or inheritance.)

@patrick-kidger patrick-kidger added the feature New feature label Aug 8, 2024
@JeyRunner
Copy link
Author

Thanks @patrick-kidger for the fast response :)
I will take a look at __instancecheck__ (do you mean this function

def __instancecheck__(cls, obj: Any) -> bool:
?)

And I don't get why to use vmap, do you maybe man tree_map to check all the leafs of the pytree?
Or is this to make vmap remove the batching dimension (in this case multiple vmaps would be needed for multiple batching dimensions..., and I don't know about the vmap performance overhead)?

(Note that I am not very familiar with the codebase, so maybe you meant the right thing and I am just missing the context 😅 )

@patrick-kidger
Copy link
Owner

Yup, exactly that! And indeed I meant vmap so as to remove the extra batch dimensions. Performance overhead is probably 'not great' but there isn't much we can do about that. (But at least under JIT this shouldn't matter, since it all gets compiled out.)

FWIW I think this is probably a fairly tricky thing to add, which is why I'm not jumping on quickly doing it myself 😄

(Fun side note. jaxtyping started out as supporting just Shaped[jax.Array, ...] only. And since then the list of things that people want to add shape annotations to keeps growing! We now also support torch tensors, tensorflow tensors, numpy arrays, numpy scalars, Python numeric builtins, duck-typed objects providing just shape and dtype, typevars, Any, Union[anything else in this list], and I think still a few more besides! So this request is one which sits in good company.)

@benjamin-macadam
Copy link

benjamin-macadam commented Aug 20, 2024

In a similar vein, if you could have syntax that is a bit like python's generics so you could include the dimensions as parameters in a (probably equinox) dataclass:

@jaxtyped(typechecker=typechecker)
class MyClass['*dims a b'](eqx.Module):
    field_a: Float[Array, '*dims a 2']]
    field_b: Float[Array, '*dims b 5']]

you might be able to get the same effect?

@JeyRunner
Copy link
Author

Thanks @patrick-kidger,
I'll have a look!

@benjamin-macadam this also looks interesting. However, with this approach, I could not add batch dimensions to dataclasses that are not part of my code (e.g. part of a lib that I am using).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

3 participants