You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey, runtime type-checking seems to fail when providing a Flax dataclass to a vmapped function. I wasn't able to find related resources . Here is a minimal reproduction with the associated error.
E jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of Data.
E The problem arose whilst typechecking argument 'a'.
E Called with arguments: {'self': Data(...), 'a': <object object at 0x7fc7c87e8fc0>}
E Parameter annotations: (self: Any, a: jax.Array).
Here are the versions I'm using:
flax==0.8.1
jax==0.4.21
jaxtyping==0.2.25
I tested, it works with chex.dataclass and equinox.Module, but I don't have the choice of using flax dataclasses in my case. Would love to find a workaround. Thanks!!
The text was updated successfully, but these errors were encountered:
That's odd -- I've just tried running your code (with the same versions of each library) and don't see the same issue. Can you perhaps double-check in a new environment?
Hey, runtime type-checking seems to fail when providing a Flax dataclass to a vmapped function. I wasn't able to find related resources . Here is a minimal reproduction with the associated error.
It raises the following error (with beartyping):
Here are the versions I'm using:
I tested, it works with chex.dataclass and equinox.Module, but I don't have the choice of using flax dataclasses in my case. Would love to find a workaround. Thanks!!
The text was updated successfully, but these errors were encountered: