Skip to content

Commit

Permalink
move explicit_stepper system, memory, and protocols into experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed Jun 30, 2024
1 parent cddfa60 commit 4bca3a0
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 90 deletions.
7 changes: 5 additions & 2 deletions elastica/experimental/timestepper/explicit_steppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
SteppersOperatorsType,
StateType,
)
from elastica.systems.protocol import ExplicitSystemProtocol
from elastica.timestepper.protocol import ExplicitStepperProtocol, MemoryProtocol
from elastica.experimental.timestepper.protocol import (
ExplicitSystemProtocol,
ExplicitStepperProtocol,
MemoryProtocol,
)


"""
Expand Down
3 changes: 1 addition & 2 deletions elastica/experimental/timestepper/memory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Iterator, TypeVar, Generic, Type
from elastica.timestepper.protocol import ExplicitStepperProtocol
from elastica.typing import SystemCollectionType
from elastica.experimental.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
)

from elastica.experimental.timestepper.protocol import ExplicitStepperProtocol

from copy import copy

Expand Down
86 changes: 86 additions & 0 deletions elastica/experimental/timestepper/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Protocol

from elastica.typing import StepType, StateType
from elastica.systems.protocol import SystemProtocol, SlenderBodyGeometryProtocol
from elastica.timestepper.protocol import StepperProtocol

import numpy as np


class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol):
# TODO: Temporarily made to handle explicit stepper.
# Need to be refactored as the explicit stepper is further developed.
def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ...
@property
def state(self) -> StateType: ...
@state.setter
def state(self, state: StateType) -> None: ...
@property
def n_elems(self) -> int: ...


class MemoryProtocol(Protocol):
@property
def initial_state(self) -> bool: ...


class ExplicitStepperProtocol(StepperProtocol, Protocol):
"""symplectic stepper protocol."""

def get_stages(self) -> list[StepType]: ...

def get_updates(self) -> list[StepType]: ...


# class _LinearExponentialIntegratorMixin:
# """
# Linear Exponential integrator mixin wrapper.
# """
#
# def __init__(self):
# pass
#
# def _do_stage(self, System, Memory, time, dt):
# # TODO : Make more general, system should not be calculating what the state
# # transition matrix directly is, but rather it should just give
# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt)
#
# def _do_update(self, System, Memory, time, dt):
# # FIXME What's the right formula when doing update?
# # System.linearly_evolving_state = _batch_matmul(
# # System.linearly_evolving_state,
# # Memory.linear_operator
# # )
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator
# )
# return time + dt
#
# def _first_prefactor(self, dt):
# """Prefactor call to satisfy interface of SymplecticStepper. Should never
# be used in actual code.
#
# Parameters
# ----------
# dt : the time step of simulation
#
# Raises
# ------
# RuntimeError
# """
# raise RuntimeError(
# "Symplectic prefactor of LinearExponentialIntegrator should not be called!"
# )
#
# # Code repeat!
# # Easy to avoid, but keep for performance.
# def _do_one_step(self, System, time, prefac):
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk",
# System.linearly_evolving_state,
# System.get_linear_state_transition_operator(time, prefac),
# )
# return (
# time # TODO fix hack that treats time separately here. Shuold be time + dt
# )
# # return time + dt
12 changes: 0 additions & 12 deletions elastica/systems/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,3 @@ def kinematic_rates(
def dynamic_rates(
self, time: np.float64, prefac: np.float64
) -> NDArray[np.float64]: ...


class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol):
# TODO: Temporarily made to handle explicit stepper.
# Need to be refactored as the explicit stepper is further developed.
def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ...
@property
def state(self) -> StateType: ...
@state.setter
def state(self, state: StateType) -> None: ...
@property
def n_elems(self) -> int: ...
77 changes: 5 additions & 72 deletions elastica/timestepper/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import Protocol

from elastica.typing import (
SystemType,
SteppersOperatorsType,
OperatorType,
StepType,
SystemCollectionType,
)
from elastica.systems.protocol import SymplecticSystemProtocol

import numpy as np

Expand All @@ -29,80 +29,13 @@ def step(
) -> np.float64: ...

def step_single_instance(
self, SystemCollection: SystemType, time: np.float64, dt: np.float64
self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64
) -> np.float64: ...


class SymplecticStepperProtocol(StepperProtocol, Protocol):
"""symplectic stepper protocol."""

def get_steps(self) -> list[OperatorType]: ...
def get_steps(self) -> list[StepType]: ...

def get_prefactors(self) -> list[OperatorType]: ...


class MemoryProtocol(Protocol):
@property
def initial_state(self) -> bool: ...


class ExplicitStepperProtocol(StepperProtocol, Protocol):
"""symplectic stepper protocol."""

def get_stages(self) -> list[OperatorType]: ...

def get_updates(self) -> list[OperatorType]: ...


# class _LinearExponentialIntegratorMixin:
# """
# Linear Exponential integrator mixin wrapper.
# """
#
# def __init__(self):
# pass
#
# def _do_stage(self, System, Memory, time, dt):
# # TODO : Make more general, system should not be calculating what the state
# # transition matrix directly is, but rather it should just give
# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt)
#
# def _do_update(self, System, Memory, time, dt):
# # FIXME What's the right formula when doing update?
# # System.linearly_evolving_state = _batch_matmul(
# # System.linearly_evolving_state,
# # Memory.linear_operator
# # )
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator
# )
# return time + dt
#
# def _first_prefactor(self, dt):
# """Prefactor call to satisfy interface of SymplecticStepper. Should never
# be used in actual code.
#
# Parameters
# ----------
# dt : the time step of simulation
#
# Raises
# ------
# RuntimeError
# """
# raise RuntimeError(
# "Symplectic prefactor of LinearExponentialIntegrator should not be called!"
# )
#
# # Code repeat!
# # Easy to avoid, but keep for performance.
# def _do_one_step(self, System, time, prefac):
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk",
# System.linearly_evolving_state,
# System.get_linear_state_transition_operator(time, prefac),
# )
# return (
# time # TODO fix hack that treats time separately here. Shuold be time + dt
# )
# # return time + dt
def get_prefactors(self) -> list[StepType]: ...
2 changes: 0 additions & 2 deletions elastica/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
SystemProtocol,
StaticSystemProtocol,
SymplecticSystemProtocol,
ExplicitSystemProtocol,
)
from .timestepper.protocol import (
StepperProtocol,
SymplecticStepperProtocol,
MemoryProtocol,
)
from .memory_block.protocol import BlockSystemProtocol

Expand Down

0 comments on commit 4bca3a0

Please sign in to comment.