Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Approaching multi trajectory adaptive stepping #481

Open
lockwo opened this issue Aug 9, 2024 · 1 comment
Open

Approaching multi trajectory adaptive stepping #481

lockwo opened this issue Aug 9, 2024 · 1 comment
Labels
feature New feature

Comments

@lockwo
Copy link
Contributor

lockwo commented Aug 9, 2024

In line with some of the weak solvers we are working on to get a PR for in diffrax, there are a variety of adaptive methods that we are implementing. One of the schemes rely on estimating errors by looking at multiple trajectories (https://onlinelibrary.wiley.com/doi/abs/10.1002/pamm.200410005), like you estimate some quantity from simultaneous trajectories.

I wanted to think how to best integrate this into diffrax philosophically, since this code works as a wrapper on top of it, but isn't as trivial to implement in the framework itself. Since integrate.py conceptually works over a single trajectory, to get multiple the solution is usually just to vmap, so I was thinking of playing around inside that and making a unvmap version of the computations that we needed (but that seemed very hacky to define custom unvmaps). I was curious if you had thought about this more and had opinions on multi trajectory reliant adaptive schemes?

@patrick-kidger patrick-kidger added the feature New feature label Aug 10, 2024
@patrick-kidger
Copy link
Owner

Hmm. You've got a couple of options I think.
First of all would be to bundle multiple trajectories together into one gigantic vector field (with each piece independent of the others). Diffrax just sees a single integration like normal. This would mean that a batch of solves would get fairly gigantic (each batch element has its own 'inner batch' of trajectories). It would preserve batch independence, however.

The alternative would be to reach across the batch and explicitly create a cross-batch dependence. JAX provides tools to do this in the form of jax.lax.p{sum, ...}. Take a look at eqx.nn.BatchNorm for an example. Typically you name a particular vmap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants