Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making subsaveat consider previous saves and not just the current one #472

Open
etienney opened this issue Jul 31, 2024 · 4 comments
Open
Labels
question User queries

Comments

@etienney
Copy link

etienney commented Jul 31, 2024

When using Saveat we have the option to call a "fn" defined in the doc as

fn: A function fn(t, y, args) which specifies what to save into sol.ys when using t0, t1, ts or steps. Defaults to fn(t, y, args) -> y, so that the evolving solution is saved. For example this can be useful to save only statistics of your solution, so as to reduce memory usage.

but why not changing it to be something like fn(save_state, t, y, args) in _integrate.py where it is called line 218 under

def _save(
    t: FloatScalarLike,
    y: PyTree[Array],
    args: PyTree,
    fn: Callable,
    save_state: SaveState,
) -> SaveState:
    ts = save_state.ts
    ys = save_state.ys
    save_index = save_state.save_index

    ts = ts.at[save_index].set(t)
    ys = jtu.tree_map(lambda ys_, y_: ys_.at[save_index].set(y_), ys, fn(t, y, args))
    save_index = save_index + 1

    return eqx.tree_at(
        lambda s: [s.ts, s.ys, s.save_index], save_state, [ts, ys, save_index]
    )

we could then save a function that depends not only on the last state but also on the previous ones... (which can be useful in my case and I believe is more general than current version "for free")
we ought to also modify the default function for fn in _saveat.py

def save_y(t, y, args):
    return y

to

def save_y(save_state, t, y, args):
    return y

and I think it should be okay ?
Of course I could do it myself and work with such a diffrax but i'm working on a library which is dependent on a library which itself is dependent on diffrax, so I'm interested for it to be in the "real" diffrax, so that my library keeps up with the version of the library above me haha

@patrick-kidger
Copy link
Owner

So save_state is an internal implementation detail that I don't think should be exposed to users -- it may change from release to release.

More generally, this kind of dependency might introduce dependencies that I'm not confident will play well with autodifferentiation. Diffrax does some fairly complicated things here to be able to fill in buffers in an efficient manner during the iteration, whilst also remaining autodiff-friendly.

What's your use case?

@etienney
Copy link
Author

Okay save_state may be too much to be given to users but maybe save_index, save_state.ts and save_state.ys would be nice ?

The usecase is to compute functions with entries like f(t_n, y_n, args_n, t_{n-1}, y_{n-1}, args_{n-1}, n) (or with any (n-i)) with n the nth iteration (t_n would be ts[save_index] with current formalism), along the simulation.

@patrick-kidger
Copy link
Owner

I think for this case it'd be best to just save output as normal, and then do an additional scan over the saved values (after the diffeqsolve) to compute your desired output.

@etienney
Copy link
Author

etienney commented Jul 31, 2024

Then you would not be able to use your saved output in some event though.
Of course this can be done at the end, but the idea is to do it along for such a reason.

@patrick-kidger patrick-kidger added the question User queries label Aug 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants