From 198e9bb96d772910c9ce12aee7fc223624c8e9c0 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 12 Sep 2023 19:40:23 +0200 Subject: [PATCH] Replace Element __init__s with pydantic adapted versions --- cheetah/accelerator.py | 215 +++++------------------------------------ 1 file changed, 22 insertions(+), 193 deletions(-) diff --git a/cheetah/accelerator.py b/cheetah/accelerator.py index a435011a..959f5a64 100644 --- a/cheetah/accelerator.py +++ b/cheetah/accelerator.py @@ -40,12 +40,11 @@ class Element(ABC, pydantic.BaseModel): name: str = "unnamed" device: torch.device - def __init__(self, name: str = "unnamed", device: str = "auto") -> None: - self.name = name + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device + if self.device == "auto": + self.device = "cuda" if torch.cuda.is_available() else "cpu" def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: """ @@ -174,13 +173,6 @@ class Drift(Element): length: torch.Tensor - def __init__( - self, length: torch.Tensor, name: Optional[str] = None, device: str = "auto" - ) -> None: - super().__init__(name=name, device=device) - - self.length = length - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: gamma = energy / REST_ENERGY igamma2 = 1 / gamma**2 if gamma != 0 else torch.tensor(0.0) @@ -245,22 +237,6 @@ class Quadrupole(Element): misalignment: torch.Tensor = torch.tensor([0.0, 0.0]) tilt: torch.Tensor = torch.tensor(0.0) - def __init__( - self, - length: torch.Tensor, - k1: torch.Tensor = torch.tensor(0.0), - misalignment: torch.Tensor = torch.tensor([0.0, 0.0]), - tilt: torch.Tensor = torch.tensor(0.0), - name: Optional[str] = None, - device: str = "auto", - ) -> None: - super().__init__(name=name, device=device) - - self.length = length - self.k1 = k1 - self.misalignment = misalignment - self.tilt = tilt - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: R = base_rmatrix( length=self.length, @@ -350,38 +326,13 @@ class Dipole(Element): fringe_integral_exit: Optional[torch.Tensor] = None gap: torch.Tensor = torch.tensor(0.0) - def __init__( - self, - length: torch.Tensor, - angle: torch.Tensor = torch.tensor(0.0), - e1: torch.Tensor = torch.tensor(0.0), - e2: torch.Tensor = torch.tensor(0.0), - tilt: torch.Tensor = torch.tensor(0.0), - fringe_integral: torch.Tensor = torch.tensor(0.0), - fringe_integral_exit: Optional[torch.Tensor] = None, - gap: torch.Tensor = torch.tensor(0.0), - name: Optional[str] = None, - device: str = "auto", - ): - super().__init__(name=name, device=device) - - self.length = length - self.angle = angle - self.gap = gap - self.tilt = tilt - self.name = name - self.fringe_integral = fringe_integral - self.fringe_integral_exit = ( - fringe_integral if fringe_integral_exit is None else fringe_integral_exit - ) - # Rectangular bend - self.e1 = e1 - self.e2 = e2 + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - if self.length == 0.0: - self.hx = torch.tensor(0.0) - else: - self.hx = self.angle / self.length + if self.fringe_integral_exit is None: + self.fringe_integral_exit = self.fringe_integral + + self.hx = self.angle / self.length if self.length != 0 else torch.tensor(0.0) @property def is_skippable(self) -> bool: @@ -540,33 +491,16 @@ class RBend(Dipole): """ def __init__( - self, - length: torch.Tensor, + *args, angle: torch.Tensor = torch.tensor(0.0), e1: torch.Tensor = torch.tensor(0.0), e2: torch.Tensor = torch.tensor(0.0), - tilt: torch.Tensor = torch.tensor(0.0), - fringe_integral: torch.Tensor = torch.tensor(0.0), - fringe_integral_exit: Optional[torch.Tensor] = None, - gap: torch.Tensor = torch.tensor(0.0), - name: Optional[str] = None, - device: str = "auto", - ): + **kwargs, + ) -> None: e1 = e1 + angle / 2 e2 = e2 + angle / 2 - super().__init__( - length=length, - angle=angle, - e1=e1, - e2=e2, - tilt=tilt, - fringe_integral=fringe_integral, - fringe_integral_exit=fringe_integral_exit, - gap=gap, - name=name, - device=device, - ) + super().__init__(*args, angle=angle, e1=e1, e2=e2, **kwargs) class HorizontalCorrector(Element): @@ -583,18 +517,6 @@ class HorizontalCorrector(Element): length: torch.Tensor angle: torch.Tensor = torch.tensor(0.0) - def __init__( - self, - length: torch.Tensor, - angle: torch.Tensor = torch.tensor(0.0), - name: Optional[str] = None, - device: str = "auto", - ) -> None: - super().__init__(name=name, device=device) - - self.length = length - self.angle = angle - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return torch.tensor( [ @@ -666,18 +588,6 @@ class VerticalCorrector(Element): length: torch.Tensor angle: torch.Tensor = torch.tensor(0.0) - def __init__( - self, - length: torch.Tensor, - angle: torch.Tensor = torch.tensor(0.0), - name: Optional[str] = None, - device: str = "auto", - ) -> None: - super().__init__(name=name, device=device) - - self.length = length - self.angle = angle - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return torch.tensor( [ @@ -753,22 +663,6 @@ class Cavity(Element): phase: torch.Tensor = torch.tensor(0.0) frequency: torch.Tensor = torch.tensor(0.0) - def __init__( - self, - length: torch.Tensor, - voltage: torch.Tensor = torch.tensor(0.0), - phase: torch.Tensor = torch.tensor(0.0), - frequency: torch.Tensor = torch.tensor(0.0), - name: Optional[str] = None, - device: str = "auto", - ) -> None: - super().__init__(name=name, device=device) - - self.length = length - self.voltage = voltage - self.phase = phase - self.frequency = frequency - @property def is_active(self) -> bool: return self.voltage != 0 @@ -1057,8 +951,8 @@ class BPM(Element): CUDA GPU is selected if available. The CPU is used otherwise. """ - def __init__(self, name: Optional[str] = None, device: str = "auto") -> None: - super().__init__(name=name, device=device) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.reading = (None, None) @@ -1109,9 +1003,6 @@ class Marker(Element): CUDA GPU is selected if available. The CPU is used otherwise. """ - def __init__(self, name: Optional[str] = None, device: str = "auto") -> None: - super().__init__(name=name, device=device) - def transfer_map(self, energy): return torch.eye(7, device=self.device) @@ -1162,23 +1053,8 @@ class Screen(Element): misalignment: torch.Tensor = torch.tensor((0.0, 0.0)) is_active: bool = False - def __init__( - self, - resolution: torch.Tensor = torch.tensor((1024, 1024)), - pixel_size: torch.Tensor = torch.tensor((1e-3, 1e-3)), - binning: torch.Tensor = torch.tensor(1), - misalignment: torch.Tensor = torch.tensor((0.0, 0.0)), - is_active: bool = False, - name: Optional[str] = None, - device: str = "auto", - ) -> None: - super().__init__(name=name, device=device) - - self.resolution = tuple(resolution) - self.pixel_size = tuple(pixel_size) - self.binning = binning - self.misalignment = tuple(misalignment) - self.is_active = is_active + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.read_beam = None self.cached_reading = None @@ -1347,21 +1223,8 @@ class Aperture(Element): shape: Literal["rectangular", "elliptical"] = "rectangular" is_active: bool = True - def __init__( - self, - x_max: torch.Tensor = torch.inf, - y_max: torch.Tensor = torch.inf, - shape: Literal["rectangular", "elliptical"] = "rectangular", - is_active: bool = True, - name: Optional[str] = None, - device: str = "auto", - ) -> None: - super().__init__(name=name, device=device) - - self.x_max = x_max - self.y_max = y_max - self.shape = shape - self.is_active = is_active + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.lost_particles = None @@ -1454,18 +1317,6 @@ class Undulator(Element): length: torch.Tensor is_active: bool = False - def __init__( - self, - length: torch.Tensor, - is_active: bool = False, - name: Optional[str] = None, - device: str = "auto", - ) -> None: - super().__init__(name=name, device=device) - - self.length = length - self.is_active = is_active - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: gamma = energy / REST_ENERGY igamma2 = 1 / gamma**2 if gamma != 0 else torch.tensor(0.0) @@ -1534,20 +1385,6 @@ class Solenoid(Element): k: torch.Tensor = torch.tensor(0.0) misalignment: torch.Tensor = torch.tensor((0.0, 0.0)) - def __init__( - self, - length: torch.Tensor = torch.tensor(0.0), - k: torch.Tensor = torch.tensor(0.0), - misalignment: torch.Tensor = torch.tensor((0.0, 0.0)), - name: Optional[str] = None, - device: str = "auto", - ) -> None: - super().__init__(name=name, device=device) - - self.length = length - self.k = k - self.misalignment = tuple(misalignment) - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: gamma = energy / REST_ENERGY c = torch.cos(self.length * self.k) @@ -1627,16 +1464,8 @@ class Segment(Element): elements: list[Element] - def __init__( - self, cell: list[Element], name: str = "unnamed", device: str = "auto" - ) -> None: - self.name = name - - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - - self.elements = cell + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) for element in self.elements: element.device = self.device