Skip to content

Commit

Permalink
Refactor aperture code
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Sep 3, 2023
1 parent c1553ec commit 032e6f3
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 46 deletions.
76 changes: 39 additions & 37 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional
from typing import Literal, Optional

import matplotlib
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -914,33 +914,32 @@ class Aperture(Element):
"""
Physical aperture.
:param xmax: half size horizontal offset in [m]
:param ymax: half size vertical offset in [m]
:param type: Aperture shape, "rect" for rectangular and "ellip" for elliptical.
:param x_max: half size horizontal offset in [m]
:param y_max: half size vertical offset in [m]
:param shape: Shape of the aperture. Can be "rectangular" or "elliptical".
:param name: Unique identifier of the element.
"""

xmax: float = np.inf
ymax: float = np.inf
type: str = "rect"
x_max: float = np.inf
y_max: float = np.inf
shape: str = "rect"

def __init__(
self,
xmax: float = np.inf,
ymax: float = np.inf,
type: str = "rect", # TODO: Better strings ellipciatl and rectangular
x_max: float = np.inf,
y_max: float = np.inf,
shape: Literal["rectangular", "elliptical"] = "rectangular",
name: Optional[str] = None,
**kwargs,
) -> None:
assert xmax >= 0 and ymax >= 0
self.xmax = xmax
self.ymax = ymax
if type != "rect" and type != "ellipt":
raise ValueError('Unknown aperture type, use "rect" or "ellipt"')
self.type = type
self.lost_particle = None
super().__init__(name, **kwargs)

self.x_max = x_max
self.y_max = y_max
self.shape = shape

self.lost_particles = None

@property
def is_skippable(self) -> bool: # TODO: Aperatures should always be active
return not self.is_active
Expand All @@ -949,28 +948,31 @@ def transfer_map(self, energy: float) -> torch.Tensor:
return torch.eye(7, device=self.device)

def __call__(self, incoming: Beam) -> Beam:
if self.is_active and isinstance(incoming, ParticleBeam):
x = incoming.particles[:, 0]
y = incoming.particles[:, 2]
if self.type == "rect":
survived_mask = torch.logical_and(
torch.logical_and(x > -self.xmax, x < self.xmax),
torch.logical_and(y > -self.ymax, y < self.ymax),
)
elif self.type == "ellipt":
survived_mask = (
x**2 / self.xmax**2 + y**2 / self.ymax**2
) <= 1.0
outgoing_particles = incoming.particles[survived_mask]

self.lost_particles = incoming.particles[torch.logical_not(survived_mask)]

return ParticleBeam(
outgoing_particles, incoming.energy, device=incoming.device
)
else:
# Only apply aperture to particle beams and if the element is active
if not (isinstance(incoming, ParticleBeam) and self.is_active):
return incoming

assert self.x_max >= 0 and self.y_max >= 0
assert self.shape in [
"rectangular",
"elliptical",
], f"Unknown aperture shape {self.shape}"

if self.shape == "rectangular":
survived_mask = torch.logical_and(
torch.logical_and(incoming.xs > -self.x_max, incoming.xs < self.x_max),
torch.logical_and(incoming.ys > -self.y_max, incoming.ys < self.y_max),
)
elif self.shape == "elliptical":
survived_mask = (
incoming.xs**2 / self.x_max**2 + incoming.ys**2 / self.y_max**2
) <= 1.0
outgoing_particles = incoming.particles[survived_mask]

self.lost_particles = incoming.particles[torch.logical_not(survived_mask)]

return ParticleBeam(outgoing_particles, incoming.energy, device=incoming.device)


@dataclass
class Undulator(Element):
Expand Down
2 changes: 1 addition & 1 deletion cheetah/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def parse_cheetah_element(element: acc.Element):
}
elif isinstance(element, acc.Aperture):
element_class = "Aperture"
params = {"xmax": element.xmax, "ymax": element.ymax, "type": element.type}
params = {"x_max": element.x_max, "y_max": element.y_max, "type": element.shape}
elif isinstance(element, acc.Solenoid):
element_class = "Solenoid"
params = {
Expand Down
10 changes: 2 additions & 8 deletions test/test_compare_ocelot.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ def test_aperture():
cheetah_segment = cheetah.Segment(
[
cheetah.Aperture(
xmax=2e-4,
ymax=2e-4,
type="rect",
name="aperture", # TODO: Don't use type keyword
x_max=2e-4, y_max=2e-4, shape="rectangular", name="aperture"
), # TODO: is_active on init
cheetah.Drift(length=0.1),
]
Expand Down Expand Up @@ -119,10 +116,7 @@ def test_aperture_elliptical():
cheetah_segment = cheetah.Segment(
[
cheetah.Aperture(
xmax=2e-4,
ymax=2e-4,
type="ellipt",
name="aperture", # TODO: Don't use type keyword
x_max=2e-4, y_max=2e-4, shape="elliptical", name="aperture"
), # TODO: is_active on init
cheetah.Drift(length=0.1),
]
Expand Down

0 comments on commit 032e6f3

Please sign in to comment.