Skip to content

andyElking/diffrax_STLA

 
 

Repository files navigation

Diffrax with new SDE-solving capabilities

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax is a JAX-based library providing numerical differential equation solvers. This fork adds additional methods for solving stochastic differential equations (SDEs), including a third order method for the Underdamped Langevin Diffusion (ULD) process, and a method for generating Brownian paths in a way that enables the use of high order SDE solvers with adaptive time-stepping. These improvements are based on the paper

@misc{jelinčič2024singleseed,
    title={Single-seed generation of Brownian paths and integrals
    for adaptive and high order SDE solvers},
    author={Andraž Jelinčič and James Foster and Patrick Kidger},
    year={2024},
    eprint={2405.06464},
    archivePrefix={arXiv},
    primaryClass={math.NA}
}

The original Diffrax library is described below.

Features include:

  • ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
  • lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit solvers);
  • vmappable everything (including the region of integration);
  • using a PyTree as the state;
  • dense solutions;
  • multiple adjoint methods for backpropagation;
  • support for neural differential equations.

From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.

Installation

pip install diffrax

Requires Python 3.9+, JAX 0.4.13+, and Equinox 0.10.11+.

Documentation

Available at https://docs.kidger.site/diffrax.

Quick example

from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

Here, Dopri5 refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems.

SDE example

For guidance on how to simulate SDEs in Diffrax see notebooks2/sde_exmaple.ipynb.

Citation

If you found this library useful in academic research, please cite: (arXiv link)

@phdthesis{kidger2021on,
    title={{O}n {N}eural {D}ifferential {E}quations},
    author={Patrick Kidger},
    year={2021},
    school={University of Oxford},
}

(Also consider starring the project on GitHub.)

See also: other libraries in the JAX ecosystem

Always useful
Equinox: neural networks and everything not already in core JAX!
jaxtyping: type annotations for shape/dtype of arrays.

Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).

Scientific computing
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)

Awesome JAX
Awesome JAX: a longer list of other JAX projects.

About

Space-time Lévy area extension for Diffrax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 78.8%
  • Python 21.2%