From 6c36b4a03ffaa54287830575759d81dc537d4b74 Mon Sep 17 00:00:00 2001 From: Theo West Date: Fri, 3 Mar 2023 09:17:17 +0100 Subject: [PATCH] Add additional explicit solvers --- torchdiffeq/_impl/fixed_grid.py | 56 ++++++++++++++++++++++++++++++++- torchdiffeq/_impl/odeint.py | 7 ++++- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/torchdiffeq/_impl/fixed_grid.py b/torchdiffeq/_impl/fixed_grid.py index 7578627c4..f1cf4e469 100644 --- a/torchdiffeq/_impl/fixed_grid.py +++ b/torchdiffeq/_impl/fixed_grid.py @@ -1,5 +1,5 @@ from .solvers import FixedGridODESolver -from .rk_common import rk4_alt_step_func +from .rk_common import rk4_step_func, rk4_alt_step_func from .misc import Perturb @@ -11,6 +11,16 @@ def _step_func(self, func, t0, dt, t1, y0): return dt * f0, f0 +class Heun(FixedGridODESolver): + order = 2 + + def _step_func(self, func, t0, dt, t1, y0): + half_dt = 0.5 * dt + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + f1 = func(t0 + dt, y0 + dt * f0) + return half_dt * (f0 + f1), f0 + + class Midpoint(FixedGridODESolver): order = 2 @@ -21,9 +31,53 @@ def _step_func(self, func, t0, dt, t1, y0): return dt * func(t0 + half_dt, y_mid), f0 +class Ralston(FixedGridODESolver): + order = 2 + + def _step_func(self, func, t0, dt, t1, y0): + fourth_dt = 0.25 * dt + double_dt = 2 * dt + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + f1 = func(t0 + double_dt / 3, y0 + double_dt * f0 / 3) + return fourth_dt * (f0 + 3 * f1), f0 + + +class RK3(FixedGridODESolver): + order = 3 + + def _step_func(self, func, t0, dt, t1, y0): + half_dt = 0.5 * dt + double_dt = 2 * dt + sixth_dt = (1 / 6) * dt + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + f1 = func(t0 + half_dt, y0 + half_dt * f0) + f2 = func(t0 + dt, y0 - dt * f0 + double_dt * f1) + return sixth_dt * (f0 + 4 * f1 + f2), f0 + + +class SSPRK3(FixedGridODESolver): + order = 3 + + def _step_func(self, func, t0, dt, t1, y0): + fourth_dt = 0.25 * dt + sixth_dt = (1 / 6) * dt + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + f1 = func(t0 + dt, y0 + dt * f0) + f2 = func(t0 + fourth_dt, y0 + fourth_dt * (f0 + f1)) + return sixth_dt * (f0 + f1 + 4 * f2), f0 + + class RK4(FixedGridODESolver): order = 4 def _step_func(self, func, t0, dt, t1, y0): f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) return rk4_alt_step_func(func, t0, dt, t1, y0, f0=f0, perturb=self.perturb), f0 + + +class ClassicRK4(FixedGridODESolver): + order = 4 + + def _step_func(self, func, t0, dt, t1, y0): + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + return rk4_step_func(func, t0, dt, t1, y0, f0=f0, perturb=self.perturb), f0 diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index a174219ad..057cb2930 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -4,7 +4,7 @@ from .bosh3 import Bosh3Solver from .adaptive_heun import AdaptiveHeunSolver from .fehlberg2 import Fehlberg2 -from .fixed_grid import Euler, Midpoint, RK4 +from .fixed_grid import Euler, Heun, Midpoint, Ralston, RK3, SSPRK3, ClassicRK4, RK4 from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton from .dopri8 import Dopri8Solver from .scipy_wrapper import ScipyWrapperODESolver @@ -17,7 +17,12 @@ 'fehlberg2': Fehlberg2, 'adaptive_heun': AdaptiveHeunSolver, 'euler': Euler, + 'heun': Heun, 'midpoint': Midpoint, + 'ralston': Ralston, + 'rk3': RK3, + 'ssprk3': SSPRK3, + 'crk4': ClassicRK4, 'rk4': RK4, 'explicit_adams': AdamsBashforth, 'implicit_adams': AdamsBashforthMoulton,