Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecursiveCheckpointAdjoint not working for two-level minimisation #465

Open
ddrous opened this issue Jul 20, 2024 · 2 comments
Open

RecursiveCheckpointAdjoint not working for two-level minimisation #465

ddrous opened this issue Jul 20, 2024 · 2 comments
Labels
question User queries

Comments

@ddrous
Copy link

ddrous commented Jul 20, 2024

Hi all, I've edited the introductory Neural ODE example to highlight a problem I'm facing with two-level optimisation: first (outer) level wrt the model, and second (inner) level wrt a parameter alpha. JAX throws a JaxStackTraceBeforeTransformation error if I use RecursiveCheckpointAdjoint, but everything runs if I use DirectAdjoint instead. In line with the recommendations in the documentation, I'd love to use the former adjoint rule. Please help, Thanks.

import equinox as eqx
import diffrax
import jax
import jax.numpy as jnp

data_size=2

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self):
        self.mlp = eqx.nn.MLP(
            in_size=data_size+1,
            out_size=data_size,
            width_size=4,
            depth=2,
            activation=jax.nn.softplus,
            key=jax.random.PRNGKey(0),
        )

    def __call__(self, t, y, args):
        alpha = args[0]
        y = jnp.concatenate([y, alpha])
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func

    def __init__(self):
        self.func = Func()

    def __call__(self, ts, y0, alpha):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            args=(alpha,),
            # adjoint=diffrax.DirectAdjoint(),               ## works fine ! 🎉
            adjoint=diffrax.RecursiveCheckpointAdjoint(),    ## throws a JaxStackTraceBeforeTransformation 😢
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

def loss_fn(model, alpha):
    ts = jnp.linspace(0, 1, 100)
    y0 = jnp.zeros(data_size)
    return jnp.mean(model(ts, y0, alpha) ** 2)

def inner_step(model, alpha):
    alpha_grad = eqx.filter_grad(lambda alpha, model: loss_fn(model, alpha))(alpha, model)
    return jnp.mean(alpha_grad)

def outer_step(model, alpha):
    model_grad = eqx.filter_grad(inner_step)(model, alpha)
    return model_grad


model = NeuralODE()
alpha = jnp.array([1.])

## Run the outer step
outer_step(model, alpha)
@ddrous ddrous changed the title RecursiveCheckpointAdjoint not working for two-level minimisation step RecursiveCheckpointAdjoint not working for two-level minimisation Jul 20, 2024
@patrick-kidger
Copy link
Owner

You're actually bumping into something that I think is a bit of an open reseach problem. :) Namely, how to do second-order autodifferentiation whilst using checkpointing! In particular what you're seeing here is that the backward pass for RecursiveCheckpointAdjoint is not itself reverse-mode autodifferentiable.

I do note that alpha appears to be a scalar. I've not thought through every detail, but for such cases it usually more efficient to use jax.jvp to perform forward-mode autodifferentiation instead. Typically the goal is to frame the computation as a jvp-of-grad-of-loss. (Such 'forward over reverse' is usually most efficient overall.) This may allow you to sidestep this problem.

Failing that, then ysing DirectAdjoint is probably the best option available here.

@patrick-kidger patrick-kidger added the question User queries label Jul 21, 2024
@ddrous
Copy link
Author

ddrous commented Jul 21, 2024

Thank you @patrick-kidger It helps to know what the real problem is. Looking forward to any research/development on this in the future.

Using JVPs is not really an option for me since my parameters are themselves neural nets (I turned alpha into a scalar just for the purpose of a MWE). So looks like I'm gonna have to use Directdjoint() even-though I can barely handle its memory requirements (this after tweaking max_steps).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants