Skip to content

Commit

Permalink
is_active in aperture init
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Sep 3, 2023
1 parent ce82457 commit aae08e4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
9 changes: 8 additions & 1 deletion cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class Element:
Base class for elements of particle accelerators.
:param name: Unique identifier of the element.
:param is_active: Whether the element is active or not. This can have different
meanings depending on the element type.
:param device: Device to move the beam's particle array to. If set to `"auto"` a
CUDA GPU is selected if available. The CPU is used otherwise.
"""
Expand All @@ -50,7 +52,9 @@ class Element:
device: str = "auto"
length: float = 0

def __init__(self, name: Optional[str] = None, device: str = "auto") -> None:
def __init__(
self, name: Optional[str] = None, is_active: bool = False, device: str = "auto"
) -> None:
global ELEMENT_COUNT
if name is not None:
self.name = name
Expand Down Expand Up @@ -931,6 +935,7 @@ class Aperture(Element):
:param y_max: half size vertical offset in [m]
:param shape: Shape of the aperture. Can be "rectangular" or "elliptical".
:param name: Unique identifier of the element.
:param is_active: If the aperture actually blocks particles.
"""

x_max: float = np.inf
Expand All @@ -943,13 +948,15 @@ def __init__(
y_max: float = np.inf,
shape: Literal["rectangular", "elliptical"] = "rectangular",
name: Optional[str] = None,
is_active: bool = True,
**kwargs,
) -> None:
super().__init__(name, **kwargs)

self.x_max = x_max
self.y_max = y_max
self.shape = shape
self.is_active = is_active

self.lost_particles = None

Expand Down
18 changes: 12 additions & 6 deletions test/test_compare_ocelot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ def test_aperture():
cheetah_segment = cheetah.Segment(
[
cheetah.Aperture(
x_max=2e-4, y_max=2e-4, shape="rectangular", name="aperture"
), # TODO: is_active on init
x_max=2e-4,
y_max=2e-4,
shape="rectangular",
name="aperture",
is_active=True,
),
cheetah.Drift(length=0.1),
]
)
cheetah_segment.aperture.is_active = True
outgoing_beam = cheetah_segment.track(incoming_beam)

# Ocelot
Expand Down Expand Up @@ -103,12 +106,15 @@ def test_aperture_elliptical():
cheetah_segment = cheetah.Segment(
[
cheetah.Aperture(
x_max=2e-4, y_max=2e-4, shape="elliptical", name="aperture"
), # TODO: is_active on init
x_max=2e-4,
y_max=2e-4,
shape="elliptical",
name="aperture",
is_active=True,
),
cheetah.Drift(length=0.1),
]
)
cheetah_segment.aperture.is_active = True
outgoing_beam = cheetah_segment.track(incoming_beam)

# Ocelot
Expand Down

0 comments on commit aae08e4

Please sign in to comment.