Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating initial guess when using nonlinear solver inside ODE term #199

Open
slishak opened this issue Dec 2, 2022 · 3 comments
Open

Updating initial guess when using nonlinear solver inside ODE term #199

slishak opened this issue Dec 2, 2022 · 3 comments
Labels
feature New feature question User queries

Comments

@slishak
Copy link
Contributor

slishak commented Dec 2, 2022

Hi,

I think this is partially related to #60 as it involves storing some information after an accepted step, but the difference is that I actually need to access the last known information from inside the ODE function.

I have an ODE function that requires the use of a nonlinear solver to compute the derivatives. At the moment, I'm using a fixed initial guess for NewtonNonlinearSolver, but this is inefficient. What I'd like to do is, after an accepted step, store the found root and use it as the initial guess during the next integration step. I was doing this in torchdiffeq successfully, but I can't see an equivalent way in Diffrax.

As a (contrived) example: the code below performs some sort of nonlinear solve, but each time with a poor initial guess (meaning it takes 10 iterations to converge at each call to the ODE function). If I set init_x = 0.9, which is a much better guess in this case, it takes two or three iterations, so the potential benefit is clear (especially for more expensive nonlinear functions). In this case, I wouldn't expect to run into weird issues with gradients, because backpropagating through NewtonNonlinearSolver shouldn't depend on the initial guess.

Thanks!

from diffrax import diffeqsolve, ODETerm, Dopri5, NewtonNonlinearSolver
import jax.debug
import jax.numpy as jnp


init_x = 0.1
nl_solver = NewtonNonlinearSolver(rtol=1e-3, atol=1e-6)


def f_nonlinear(x, y):
    return jnp.cos(y * x) - x**3


def f(t, y, args):
    sol = nl_solver(f_nonlinear, init_x, y)

    jax.debug.print(
        "t=t{t}, {n} iterations, x={x}",
        t=t,
        n=sol.num_steps,
        x=sol.root,
    )

    return -sol.root


term = ODETerm(f)
solver = Dopri5()
y0 = 1.0
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
@slishak
Copy link
Contributor Author

slishak commented Dec 2, 2022

Aside: this made-up example is actually quite severe, because printing sol.result shows that it usually fails to converge with init_x = 0.1, but always converges with init_x = 0.9, therefore would particularly benefit from having the ability to update the initial guess.

Second aside: I misunderstood the documentation at https://docs.kidger.site/diffrax/api/nonlinear_solver/, I was expecting the default option tolerate_nonconvergence=False to set root to something like jnp.nan in the case of nonconvergence, so was lazily not checking the return code. Additionally, the documentation could be clearer about what the return codes actually mean (e.g. referring to diffrax.RESULTS). Also, the descriptions in RESULTS specifically refer to "implicit methods" which may be confusing if they come from the Newton solver outside the context of implicit solvers. I can contribute a PR to try and clear up some of this if you like.

@patrick-kidger
Copy link
Owner

So you'd like to pass data between vector field evaluations. It's worth noting that this isn't really a clearly-defined notion, mathematically speaking: a diffeq solver may evaluate the vector field nonmonotonically (not just forward in time), in particular when using an adaptive solver that may reject steps.

That said, I agree that it can be very useful to be able to do this!

Supporting this kind of side-effect hasn't been a priority for JAX so far, as side-effects are quite a complicated thing to make happen in a functional framework. Nonetheless, I think this should be possible using an upcoming JAX API, that provides for stateful operations.

I've not yet tried it myself -- and it's not documented yet -- but it might suffice for this task. The operations are available here, and you can see an example of them being using in for_loop.


As for the nonlinear solvers, I actually have an overhaul of these planned myself. These are going to be dramatically improved soon.

@patrick-kidger patrick-kidger added feature New feature question User queries labels Dec 7, 2022
@slishak
Copy link
Contributor Author

slishak commented Dec 14, 2022

Thank you for the response - I'll keep an eye out for when that API is documented as it should be useful, but I don't currently have the time to try and figure it out for myself!

Regarding variable step solvers, just for clarity, I was envisaging only updating the state after an accepted step, rather than any call to the ODE function. Along similar lines, in the case of a nonlinear solve inside the ODE function failing, it would be nice to have a mechanism to force an adaptive solver to reject that step (which is what I was doing in rtqichen/torchdiffeq#210).
(Edit: I've now seen there are already ways to do this in Diffrax: #200 (comment), #194 (comment))

Good to hear about the upcoming nonlinear solver changes too!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature question User queries
Projects
None yet
Development

No branches or pull requests

2 participants