-
-
Notifications
You must be signed in to change notification settings - Fork 136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
InitVar and deepcopy break PyTreeDef equality #857
Comments
Why do you need a copy of the parameters? All jittable functions should be without side-effects, and jax Arrays are immutable, so you can just store the original array and compare the new array to the old one instead of to some deep copy. If you really do need to copy arrays, you could use https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.copy.html I guess? from dataclasses import InitVar
import jax
from jax import numpy as jnp
import equinox as eqx
from jaxtyping import Key
class MLP2(eqx.Module):
key: InitVar[Key] = eqx.field(kw_only=True)
layers: list = eqx.field(init=False)
def __post_init__(self, key):
self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]
def __call__(self, t):
for layer in self.layers:
t = layer(t)
return t
key = jax.random.PRNGKey(0)
mlp = MLP2(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)
params_copy = jax.tree.map(lambda x: jnp.copy(x) if isinstance(x, jax.Array) else x, params)
print(jax.tree.flatten(params)[1] == jax.tree.flatten(params_copy)[1]) # return True!
print(jax.tree.flatten(params)[1], jax.tree.flatten(params_copy)[1]) |
Hmm. This is really weird! I've poked at this a little bit and you're right, it's specifically the interaction of |
Hi,
In some optimization process, I want to compare some new parameter values to old ones that I stored using a deepcopy. I get an error in the jitting of the optimization function because of the tree structure of my parameters being modified. See below a MWE, where we lose the tree structure equality, which is the root of the trouble in my complete program.
Note that I found out the bug disappears when not using
InitVar
(probably less elegant so):Is the problem really due to
InitVar
? Should I use something else rather thandeepcopy
?Thanks!
The text was updated successfully, but these errors were encountered: