You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all, I've edited the introductory Neural ODE example to highlight a problem I'm facing with two-level optimisation: first (outer) level wrt the model, and second (inner) level wrt a parameter alpha. JAX throws a JaxStackTraceBeforeTransformation error if I use RecursiveCheckpointAdjoint, but everything runs if I use DirectAdjoint instead. In line with the recommendations in the documentation, I'd love to use the former adjoint rule. Please help, Thanks.
The text was updated successfully, but these errors were encountered:
ddrous
changed the title
RecursiveCheckpointAdjoint not working for two-level minimisation step
RecursiveCheckpointAdjoint not working for two-level minimisation
Jul 20, 2024
You're actually bumping into something that I think is a bit of an open reseach problem. :) Namely, how to do second-order autodifferentiation whilst using checkpointing! In particular what you're seeing here is that the backward pass for RecursiveCheckpointAdjoint is not itself reverse-mode autodifferentiable.
I do note that alpha appears to be a scalar. I've not thought through every detail, but for such cases it usually more efficient to use jax.jvp to perform forward-mode autodifferentiation instead. Typically the goal is to frame the computation as a jvp-of-grad-of-loss. (Such 'forward over reverse' is usually most efficient overall.) This may allow you to sidestep this problem.
Failing that, then ysing DirectAdjoint is probably the best option available here.
Thank you @patrick-kidger It helps to know what the real problem is. Looking forward to any research/development on this in the future.
Using JVPs is not really an option for me since my parameters are themselves neural nets (I turned alpha into a scalar just for the purpose of a MWE). So looks like I'm gonna have to use Directdjoint() even-though I can barely handle its memory requirements (this after tweaking max_steps).
Hi all, I've edited the introductory Neural ODE example to highlight a problem I'm facing with two-level optimisation: first (outer) level wrt the
model
, and second (inner) level wrt a parameteralpha
. JAX throws aJaxStackTraceBeforeTransformation
error if I useRecursiveCheckpointAdjoint
, but everything runs if I useDirectAdjoint
instead. In line with the recommendations in the documentation, I'd love to use the former adjoint rule. Please help, Thanks.The text was updated successfully, but these errors were encountered: