Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InitVar and deepcopy break PyTreeDef equality #857

Open
HGangloff opened this issue Sep 18, 2024 · 2 comments
Open

InitVar and deepcopy break PyTreeDef equality #857

HGangloff opened this issue Sep 18, 2024 · 2 comments

Comments

@HGangloff
Copy link

Hi,

In some optimization process, I want to compare some new parameter values to old ones that I stored using a deepcopy. I get an error in the jitting of the optimization function because of the tree structure of my parameters being modified. See below a MWE, where we lose the tree structure equality, which is the root of the trouble in my complete program.

from dataclasses import InitVar
from copy import deepcopy

import jax
import equinox as eqx
from jaxtyping import Key

class MLP2(eqx.Module):

    key: InitVar[Key] = eqx.field(kw_only=True)
    layers: list = eqx.field(init=False)

    def __post_init__(self, key):
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t


key = jax.random.PRNGKey(0)
mlp = MLP2(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)
print(jax.tree.flatten(params)[1] == jax.tree.flatten(deepcopy(params))[1]) # return False!
print(jax.tree.flatten(params)[1], jax.tree.flatten(deepcopy(params))[1])

Note that I found out the bug disappears when not using InitVar (probably less elegant so):

class MLP1(eqx.Module):

    layers: list = eqx.field(init=False)

    def __init__(self, key):        
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t

key = jax.random.PRNGKey(0)
mlp = MLP1(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)
print(jax.tree.flatten(params)[1] == jax.tree.flatten(deepcopy(params))[1]) # returns True!
print(jax.tree.flatten(params)[1], jax.tree.flatten(deepcopy(params))[1])

Is the problem really due to InitVar? Should I use something else rather than deepcopy?

Thanks!

@HGangloff HGangloff changed the title InitVar and tree structure equality InitVar and deepcopy break PyTreeDef equality Sep 18, 2024
@SimonKoop
Copy link

Why do you need a copy of the parameters? All jittable functions should be without side-effects, and jax Arrays are immutable, so you can just store the original array and compare the new array to the old one instead of to some deep copy.

If you really do need to copy arrays, you could use https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.copy.html I guess?
So then you could do something like:

from dataclasses import InitVar
import jax
from jax import numpy as jnp
import equinox as eqx
from jaxtyping import Key

class MLP2(eqx.Module):

    key: InitVar[Key] = eqx.field(kw_only=True)
    layers: list = eqx.field(init=False)

    def __post_init__(self, key):
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t


key = jax.random.PRNGKey(0)
mlp = MLP2(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)

params_copy = jax.tree.map(lambda x: jnp.copy(x) if isinstance(x, jax.Array) else x, params)

print(jax.tree.flatten(params)[1] == jax.tree.flatten(params_copy)[1]) # return True!
print(jax.tree.flatten(params)[1], jax.tree.flatten(params_copy)[1])

@patrick-kidger
Copy link
Owner

Hmm. This is really weird! I've poked at this a little bit and you're right, it's specifically the interaction of InitVar[...] and deepcopy.
I have no idea why this should be the case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants