diff --git a/elastica/experimental/timestepper/explicit_steppers.py b/elastica/experimental/timestepper/explicit_steppers.py index 6ebda1da..3705ccd2 100644 --- a/elastica/experimental/timestepper/explicit_steppers.py +++ b/elastica/experimental/timestepper/explicit_steppers.py @@ -12,8 +12,11 @@ SteppersOperatorsType, StateType, ) -from elastica.systems.protocol import ExplicitSystemProtocol -from elastica.timestepper.protocol import ExplicitStepperProtocol, MemoryProtocol +from elastica.experimental.timestepper.protocol import ( + ExplicitSystemProtocol, + ExplicitStepperProtocol, + MemoryProtocol, +) """ diff --git a/elastica/experimental/timestepper/memory.py b/elastica/experimental/timestepper/memory.py index 0948a5ee..b63931aa 100644 --- a/elastica/experimental/timestepper/memory.py +++ b/elastica/experimental/timestepper/memory.py @@ -1,11 +1,10 @@ from typing import Iterator, TypeVar, Generic, Type -from elastica.timestepper.protocol import ExplicitStepperProtocol from elastica.typing import SystemCollectionType from elastica.experimental.timestepper.explicit_steppers import ( RungeKutta4, EulerForward, ) - +from elastica.experimental.timestepper.protocol import ExplicitStepperProtocol from copy import copy diff --git a/elastica/experimental/timestepper/protocol.py b/elastica/experimental/timestepper/protocol.py new file mode 100644 index 00000000..b8d489fa --- /dev/null +++ b/elastica/experimental/timestepper/protocol.py @@ -0,0 +1,86 @@ +from typing import Protocol + +from elastica.typing import StepType, StateType +from elastica.systems.protocol import SystemProtocol, SlenderBodyGeometryProtocol +from elastica.timestepper.protocol import StepperProtocol + +import numpy as np + + +class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol): + # TODO: Temporarily made to handle explicit stepper. + # Need to be refactored as the explicit stepper is further developed. + def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ... + @property + def state(self) -> StateType: ... + @state.setter + def state(self, state: StateType) -> None: ... + @property + def n_elems(self) -> int: ... + + +class MemoryProtocol(Protocol): + @property + def initial_state(self) -> bool: ... + + +class ExplicitStepperProtocol(StepperProtocol, Protocol): + """symplectic stepper protocol.""" + + def get_stages(self) -> list[StepType]: ... + + def get_updates(self) -> list[StepType]: ... + + +# class _LinearExponentialIntegratorMixin: +# """ +# Linear Exponential integrator mixin wrapper. +# """ +# +# def __init__(self): +# pass +# +# def _do_stage(self, System, Memory, time, dt): +# # TODO : Make more general, system should not be calculating what the state +# # transition matrix directly is, but rather it should just give +# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt) +# +# def _do_update(self, System, Memory, time, dt): +# # FIXME What's the right formula when doing update? +# # System.linearly_evolving_state = _batch_matmul( +# # System.linearly_evolving_state, +# # Memory.linear_operator +# # ) +# System.linearly_evolving_state = np.einsum( +# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator +# ) +# return time + dt +# +# def _first_prefactor(self, dt): +# """Prefactor call to satisfy interface of SymplecticStepper. Should never +# be used in actual code. +# +# Parameters +# ---------- +# dt : the time step of simulation +# +# Raises +# ------ +# RuntimeError +# """ +# raise RuntimeError( +# "Symplectic prefactor of LinearExponentialIntegrator should not be called!" +# ) +# +# # Code repeat! +# # Easy to avoid, but keep for performance. +# def _do_one_step(self, System, time, prefac): +# System.linearly_evolving_state = np.einsum( +# "ijk,ljk->ilk", +# System.linearly_evolving_state, +# System.get_linear_state_transition_operator(time, prefac), +# ) +# return ( +# time # TODO fix hack that treats time separately here. Shuold be time + dt +# ) +# # return time + dt diff --git a/elastica/systems/protocol.py b/elastica/systems/protocol.py index 254cdcaf..89d52d92 100644 --- a/elastica/systems/protocol.py +++ b/elastica/systems/protocol.py @@ -68,15 +68,3 @@ def kinematic_rates( def dynamic_rates( self, time: np.float64, prefac: np.float64 ) -> NDArray[np.float64]: ... - - -class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol): - # TODO: Temporarily made to handle explicit stepper. - # Need to be refactored as the explicit stepper is further developed. - def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ... - @property - def state(self) -> StateType: ... - @state.setter - def state(self, state: StateType) -> None: ... - @property - def n_elems(self) -> int: ... diff --git a/elastica/timestepper/protocol.py b/elastica/timestepper/protocol.py index 1a64c725..18a92fc4 100644 --- a/elastica/timestepper/protocol.py +++ b/elastica/timestepper/protocol.py @@ -3,11 +3,11 @@ from typing import Protocol from elastica.typing import ( - SystemType, SteppersOperatorsType, - OperatorType, + StepType, SystemCollectionType, ) +from elastica.systems.protocol import SymplecticSystemProtocol import numpy as np @@ -29,80 +29,13 @@ def step( ) -> np.float64: ... def step_single_instance( - self, SystemCollection: SystemType, time: np.float64, dt: np.float64 + self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64 ) -> np.float64: ... class SymplecticStepperProtocol(StepperProtocol, Protocol): """symplectic stepper protocol.""" - def get_steps(self) -> list[OperatorType]: ... + def get_steps(self) -> list[StepType]: ... - def get_prefactors(self) -> list[OperatorType]: ... - - -class MemoryProtocol(Protocol): - @property - def initial_state(self) -> bool: ... - - -class ExplicitStepperProtocol(StepperProtocol, Protocol): - """symplectic stepper protocol.""" - - def get_stages(self) -> list[OperatorType]: ... - - def get_updates(self) -> list[OperatorType]: ... - - -# class _LinearExponentialIntegratorMixin: -# """ -# Linear Exponential integrator mixin wrapper. -# """ -# -# def __init__(self): -# pass -# -# def _do_stage(self, System, Memory, time, dt): -# # TODO : Make more general, system should not be calculating what the state -# # transition matrix directly is, but rather it should just give -# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt) -# -# def _do_update(self, System, Memory, time, dt): -# # FIXME What's the right formula when doing update? -# # System.linearly_evolving_state = _batch_matmul( -# # System.linearly_evolving_state, -# # Memory.linear_operator -# # ) -# System.linearly_evolving_state = np.einsum( -# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator -# ) -# return time + dt -# -# def _first_prefactor(self, dt): -# """Prefactor call to satisfy interface of SymplecticStepper. Should never -# be used in actual code. -# -# Parameters -# ---------- -# dt : the time step of simulation -# -# Raises -# ------ -# RuntimeError -# """ -# raise RuntimeError( -# "Symplectic prefactor of LinearExponentialIntegrator should not be called!" -# ) -# -# # Code repeat! -# # Easy to avoid, but keep for performance. -# def _do_one_step(self, System, time, prefac): -# System.linearly_evolving_state = np.einsum( -# "ijk,ljk->ilk", -# System.linearly_evolving_state, -# System.get_linear_state_transition_operator(time, prefac), -# ) -# return ( -# time # TODO fix hack that treats time separately here. Shuold be time + dt -# ) -# # return time + dt + def get_prefactors(self) -> list[StepType]: ... diff --git a/elastica/typing.py b/elastica/typing.py index 12d61e92..f77cd015 100644 --- a/elastica/typing.py +++ b/elastica/typing.py @@ -23,12 +23,10 @@ SystemProtocol, StaticSystemProtocol, SymplecticSystemProtocol, - ExplicitSystemProtocol, ) from .timestepper.protocol import ( StepperProtocol, SymplecticStepperProtocol, - MemoryProtocol, ) from .memory_block.protocol import BlockSystemProtocol