diff --git a/cheetah/accelerator.py b/cheetah/accelerator.py index 6e62c887..582063bb 100644 --- a/cheetah/accelerator.py +++ b/cheetah/accelerator.py @@ -40,6 +40,8 @@ class Element: Base class for elements of particle accelerators. :param name: Unique identifier of the element. + :param is_active: Whether the element is active or not. This can have different + meanings depending on the element type. :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. """ @@ -50,7 +52,9 @@ class Element: device: str = "auto" length: float = 0 - def __init__(self, name: Optional[str] = None, device: str = "auto") -> None: + def __init__( + self, name: Optional[str] = None, is_active: bool = False, device: str = "auto" + ) -> None: global ELEMENT_COUNT if name is not None: self.name = name @@ -931,6 +935,7 @@ class Aperture(Element): :param y_max: half size vertical offset in [m] :param shape: Shape of the aperture. Can be "rectangular" or "elliptical". :param name: Unique identifier of the element. + :param is_active: If the aperture actually blocks particles. """ x_max: float = np.inf @@ -943,6 +948,7 @@ def __init__( y_max: float = np.inf, shape: Literal["rectangular", "elliptical"] = "rectangular", name: Optional[str] = None, + is_active: bool = True, **kwargs, ) -> None: super().__init__(name, **kwargs) @@ -950,6 +956,7 @@ def __init__( self.x_max = x_max self.y_max = y_max self.shape = shape + self.is_active = is_active self.lost_particles = None diff --git a/test/test_compare_ocelot.py b/test/test_compare_ocelot.py index f105e414..c9074d72 100644 --- a/test/test_compare_ocelot.py +++ b/test/test_compare_ocelot.py @@ -70,12 +70,15 @@ def test_aperture(): cheetah_segment = cheetah.Segment( [ cheetah.Aperture( - x_max=2e-4, y_max=2e-4, shape="rectangular", name="aperture" - ), # TODO: is_active on init + x_max=2e-4, + y_max=2e-4, + shape="rectangular", + name="aperture", + is_active=True, + ), cheetah.Drift(length=0.1), ] ) - cheetah_segment.aperture.is_active = True outgoing_beam = cheetah_segment.track(incoming_beam) # Ocelot @@ -103,12 +106,15 @@ def test_aperture_elliptical(): cheetah_segment = cheetah.Segment( [ cheetah.Aperture( - x_max=2e-4, y_max=2e-4, shape="elliptical", name="aperture" - ), # TODO: is_active on init + x_max=2e-4, + y_max=2e-4, + shape="elliptical", + name="aperture", + is_active=True, + ), cheetah.Drift(length=0.1), ] ) - cheetah_segment.aperture.is_active = True outgoing_beam = cheetah_segment.track(incoming_beam) # Ocelot