Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Understand diffrax.diffeqsolve #467

Open
zhengqigao opened this issue Jul 24, 2024 · 1 comment
Open

Understand diffrax.diffeqsolve #467

zhengqigao opened this issue Jul 24, 2024 · 1 comment

Comments

@zhengqigao
Copy link

zhengqigao commented Jul 24, 2024

Hi,

Thanks for the nice package. I have a question regarding the implementation of diffrax.diffeqsolve. Specifically, I want to know a bit more details about the ODE solver, especially on what are the key factors driving diffeqsolve to be faster. In a nutshell, can you list several bulletpoints on the code implementation optimizations that have been done to accelerate ODE solving in diffrax?

@patrick-kidger
Copy link
Owner

There's been a lot of JAX-specific tricks that have gone into this: for example knowing when the compiler will want to make a copy of a buffer, and avoiding those cases. Or knowing that vmap-of-cond becomes jnp.where, and likewise knowing to avoid that.

I think one interesting trick here is the way we make sure to be vmap-friendly. That is, if we do vmap(diffeqsolve), we'd like the result to be fast. The key trick here is that we don't do a loop-over-output-times (integrating in between), and instead do a loop-over-steps (outputting times as we go along). This avoids having a double-while loop, with the inner loop sitting and waiting at every output time until every batch element is done.

Another trick that comes to mind is the use of Stumm-Walther-Wang-Moin checkpointing, to make autodiff fast: https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/checkpointed.py

There's probably a bunch of other things too, but this is what comes to mind!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants