Skip to content

Commit

Permalink
type: type-check operator group implementation against its protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed Jun 30, 2024
1 parent 4bca3a0 commit abd5ba5
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 35 deletions.
11 changes: 8 additions & 3 deletions elastica/modules/operator_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TypeVar, Generic, Iterator, Callable

from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar, Generic, Callable, Any
from collections.abc import Iterable, Iterator

import itertools

Expand Down Expand Up @@ -80,3 +79,9 @@ def add_operators(self, feature: F, operators: list[T]) -> None:
def is_last(self, feature: F) -> bool:
"""Checks if the feature is the last feature in the FIFO."""
return id(feature) == self._operator_ids[-1]


if TYPE_CHECKING:
from elastica.typing import OperatorType

_: Iterable[OperatorType] = OperatorGroupFIFO[OperatorType, Any]()
38 changes: 7 additions & 31 deletions elastica/timestepper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
__doc__ = """Timestepping utilities to be used with Rod and RigidBody classes"""

from typing import Tuple, List, Callable, Type, Any, overload, cast
from elastica.typing import SystemType, SystemCollectionType, SteppersOperatorsType
from typing import Callable
from elastica.typing import SystemCollectionType, SteppersOperatorsType

import numpy as np
from tqdm import tqdm

from elastica.systems import is_system_a_collection

from .symplectic_steppers import PositionVerlet, PEFRL
from .protocol import StepperProtocol, SymplecticStepperProtocol
from .protocol import StepperProtocol


# Deprecated: Remove in the future version
# Many script still uses this method to control timestep. Keep it for backward compatibility
def extend_stepper_interface(
stepper: StepperProtocol, system_collection: SystemCollectionType
) -> Tuple[
) -> tuple[
Callable[
[StepperProtocol, SystemCollectionType, np.float64, np.float64], np.float64
],
Expand All @@ -30,32 +29,10 @@ def extend_stepper_interface(
return do_step_method, stepper_methods


@overload
def integrate(
stepper: StepperProtocol,
systems: SystemType,
final_time: float,
n_steps: int,
restart_time: float,
progress_bar: bool,
) -> float: ...


@overload
def integrate(
stepper: StepperProtocol,
systems: SystemCollectionType,
final_time: float,
n_steps: int,
restart_time: float,
progress_bar: bool,
) -> float: ...


def integrate(
stepper: StepperProtocol,
systems: "SystemType | SystemCollectionType",
final_time: float,
n_steps: int = 1000,
restart_time: float = 0.0,
progress_bar: bool = True,
Expand All @@ -66,7 +43,7 @@ def integrate(
----------
stepper : StepperProtocol
Stepper algorithm to use.
systems : SystemType | SystemCollectionType
systems : SystemCollectionType
The elastica-system to simulate.
final_time : float
Total simulation time. The timestep is determined by final_time / n_steps.
Expand All @@ -84,13 +61,12 @@ def integrate(
time = np.float64(restart_time)

if is_system_a_collection(systems):
systems = cast(SystemCollectionType, systems)
for i in tqdm(range(n_steps), disable=(not progress_bar)):
time = stepper.step(systems, time, dt)
else:
systems = cast(SystemType, systems)
# Typing is ignored since this part only exist for unit-testing
for i in tqdm(range(n_steps), disable=(not progress_bar)):
time = stepper.step_single_instance(systems, time, dt)
time = stepper.step_single_instance(systems, time, dt) # type: ignore[arg-type]

print("Final time of simulation is : ", time)
return float(time)
2 changes: 1 addition & 1 deletion elastica/timestepper/symplectic_steppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

class SymplecticStepperMixin:
def __init__(self: SymplecticStepperProtocol):
self.steps_and_prefactors: Final[SteppersOperatorsType] = self.step_methods()
self.steps_and_prefactors: SteppersOperatorsType = self.step_methods()

def step_methods(self: SymplecticStepperProtocol) -> SteppersOperatorsType:
# Let the total number of steps for the Symplectic method
Expand Down
75 changes: 75 additions & 0 deletions tests/test_modules/test_feature_grouping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from elastica.modules.operator_group import OperatorGroupFIFO
import functools


def test_add_ids():
Expand Down Expand Up @@ -65,3 +66,77 @@ def test_is_last():

assert group.is_last(1) == False
assert group.is_last(2) == True


class TestOperatorGroupingWithCallableModules:
class OperatorTypeA:
def __init__(self):
self.value = 0

def apply(self) -> None:
self.value += 1

class OperatorTypeB:
def __init__(self):
self.value2 = 0

def apply(self) -> None:
self.value2 -= 1

# def test_lambda(self):
# feature_group = OperatorGroupFIFO()

# op_a = self.OperatorTypeA()
# feature_group.append_id(op_a)
# op_b = self.OperatorTypeB()
# feature_group.append_id(op_b)

# for op in [op_a, op_b]:
# func = functools.partial(lambda t: op.apply())
# feature_group.add_operators(op, [func])

# for operator in feature_group:
# operator(t=0)

# assert op_a.value == 1
# assert op_b.value2 == -1

# def test_def(self):
# feature_group = OperatorGroupFIFO()

# op_a = self.OperatorTypeA()
# feature_group.append_id(op_a)
# op_b = self.OperatorTypeB()
# feature_group.append_id(op_b)

# for op in [op_a, op_b]:
# def func(t):
# op.apply()
# feature_group.add_operators(op, [func])

# for operator in feature_group:
# operator(t=0)

# assert op_a.value == 1
# assert op_b.value2 == -1

def test_partial(self):
feature_group = OperatorGroupFIFO()

op_a = self.OperatorTypeA()
feature_group.append_id(op_a)
op_b = self.OperatorTypeB()
feature_group.append_id(op_b)

def _func(t, op):
op.apply()

for op in [op_a, op_b]:
func = functools.partial(_func, op=op)
feature_group.add_operators(op, [func])

for operator in feature_group:
operator(t=0)

assert op_a.value == 1
assert op_b.value2 == -1

0 comments on commit abd5ba5

Please sign in to comment.