Skip to content

Commit

Permalink
Replace Element __init__s with pydantic adapted versions
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Sep 12, 2023
1 parent a19a1ad commit 198e9bb
Showing 1 changed file with 22 additions and 193 deletions.
215 changes: 22 additions & 193 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,11 @@ class Element(ABC, pydantic.BaseModel):
name: str = "unnamed"
device: torch.device

def __init__(self, name: str = "unnamed", device: str = "auto") -> None:
self.name = name
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
if self.device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -174,13 +173,6 @@ class Drift(Element):

length: torch.Tensor

def __init__(
self, length: torch.Tensor, name: Optional[str] = None, device: str = "auto"
) -> None:
super().__init__(name=name, device=device)

self.length = length

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
gamma = energy / REST_ENERGY
igamma2 = 1 / gamma**2 if gamma != 0 else torch.tensor(0.0)
Expand Down Expand Up @@ -245,22 +237,6 @@ class Quadrupole(Element):
misalignment: torch.Tensor = torch.tensor([0.0, 0.0])
tilt: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
k1: torch.Tensor = torch.tensor(0.0),
misalignment: torch.Tensor = torch.tensor([0.0, 0.0]),
tilt: torch.Tensor = torch.tensor(0.0),
name: Optional[str] = None,
device: str = "auto",
) -> None:
super().__init__(name=name, device=device)

self.length = length
self.k1 = k1
self.misalignment = misalignment
self.tilt = tilt

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
R = base_rmatrix(
length=self.length,
Expand Down Expand Up @@ -350,38 +326,13 @@ class Dipole(Element):
fringe_integral_exit: Optional[torch.Tensor] = None
gap: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
angle: torch.Tensor = torch.tensor(0.0),
e1: torch.Tensor = torch.tensor(0.0),
e2: torch.Tensor = torch.tensor(0.0),
tilt: torch.Tensor = torch.tensor(0.0),
fringe_integral: torch.Tensor = torch.tensor(0.0),
fringe_integral_exit: Optional[torch.Tensor] = None,
gap: torch.Tensor = torch.tensor(0.0),
name: Optional[str] = None,
device: str = "auto",
):
super().__init__(name=name, device=device)

self.length = length
self.angle = angle
self.gap = gap
self.tilt = tilt
self.name = name
self.fringe_integral = fringe_integral
self.fringe_integral_exit = (
fringe_integral if fringe_integral_exit is None else fringe_integral_exit
)
# Rectangular bend
self.e1 = e1
self.e2 = e2
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

if self.length == 0.0:
self.hx = torch.tensor(0.0)
else:
self.hx = self.angle / self.length
if self.fringe_integral_exit is None:
self.fringe_integral_exit = self.fringe_integral

self.hx = self.angle / self.length if self.length != 0 else torch.tensor(0.0)

@property
def is_skippable(self) -> bool:
Expand Down Expand Up @@ -540,33 +491,16 @@ class RBend(Dipole):
"""

def __init__(
self,
length: torch.Tensor,
*args,
angle: torch.Tensor = torch.tensor(0.0),
e1: torch.Tensor = torch.tensor(0.0),
e2: torch.Tensor = torch.tensor(0.0),
tilt: torch.Tensor = torch.tensor(0.0),
fringe_integral: torch.Tensor = torch.tensor(0.0),
fringe_integral_exit: Optional[torch.Tensor] = None,
gap: torch.Tensor = torch.tensor(0.0),
name: Optional[str] = None,
device: str = "auto",
):
**kwargs,
) -> None:
e1 = e1 + angle / 2
e2 = e2 + angle / 2

super().__init__(
length=length,
angle=angle,
e1=e1,
e2=e2,
tilt=tilt,
fringe_integral=fringe_integral,
fringe_integral_exit=fringe_integral_exit,
gap=gap,
name=name,
device=device,
)
super().__init__(*args, angle=angle, e1=e1, e2=e2, **kwargs)


class HorizontalCorrector(Element):
Expand All @@ -583,18 +517,6 @@ class HorizontalCorrector(Element):
length: torch.Tensor
angle: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
angle: torch.Tensor = torch.tensor(0.0),
name: Optional[str] = None,
device: str = "auto",
) -> None:
super().__init__(name=name, device=device)

self.length = length
self.angle = angle

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return torch.tensor(
[
Expand Down Expand Up @@ -666,18 +588,6 @@ class VerticalCorrector(Element):
length: torch.Tensor
angle: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
angle: torch.Tensor = torch.tensor(0.0),
name: Optional[str] = None,
device: str = "auto",
) -> None:
super().__init__(name=name, device=device)

self.length = length
self.angle = angle

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return torch.tensor(
[
Expand Down Expand Up @@ -753,22 +663,6 @@ class Cavity(Element):
phase: torch.Tensor = torch.tensor(0.0)
frequency: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
voltage: torch.Tensor = torch.tensor(0.0),
phase: torch.Tensor = torch.tensor(0.0),
frequency: torch.Tensor = torch.tensor(0.0),
name: Optional[str] = None,
device: str = "auto",
) -> None:
super().__init__(name=name, device=device)

self.length = length
self.voltage = voltage
self.phase = phase
self.frequency = frequency

@property
def is_active(self) -> bool:
return self.voltage != 0
Expand Down Expand Up @@ -1057,8 +951,8 @@ class BPM(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

def __init__(self, name: Optional[str] = None, device: str = "auto") -> None:
super().__init__(name=name, device=device)
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.reading = (None, None)

Expand Down Expand Up @@ -1109,9 +1003,6 @@ class Marker(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

def __init__(self, name: Optional[str] = None, device: str = "auto") -> None:
super().__init__(name=name, device=device)

def transfer_map(self, energy):
return torch.eye(7, device=self.device)

Expand Down Expand Up @@ -1162,23 +1053,8 @@ class Screen(Element):
misalignment: torch.Tensor = torch.tensor((0.0, 0.0))
is_active: bool = False

def __init__(
self,
resolution: torch.Tensor = torch.tensor((1024, 1024)),
pixel_size: torch.Tensor = torch.tensor((1e-3, 1e-3)),
binning: torch.Tensor = torch.tensor(1),
misalignment: torch.Tensor = torch.tensor((0.0, 0.0)),
is_active: bool = False,
name: Optional[str] = None,
device: str = "auto",
) -> None:
super().__init__(name=name, device=device)

self.resolution = tuple(resolution)
self.pixel_size = tuple(pixel_size)
self.binning = binning
self.misalignment = tuple(misalignment)
self.is_active = is_active
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.read_beam = None
self.cached_reading = None
Expand Down Expand Up @@ -1347,21 +1223,8 @@ class Aperture(Element):
shape: Literal["rectangular", "elliptical"] = "rectangular"
is_active: bool = True

def __init__(
self,
x_max: torch.Tensor = torch.inf,
y_max: torch.Tensor = torch.inf,
shape: Literal["rectangular", "elliptical"] = "rectangular",
is_active: bool = True,
name: Optional[str] = None,
device: str = "auto",
) -> None:
super().__init__(name=name, device=device)

self.x_max = x_max
self.y_max = y_max
self.shape = shape
self.is_active = is_active
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.lost_particles = None

Expand Down Expand Up @@ -1454,18 +1317,6 @@ class Undulator(Element):
length: torch.Tensor
is_active: bool = False

def __init__(
self,
length: torch.Tensor,
is_active: bool = False,
name: Optional[str] = None,
device: str = "auto",
) -> None:
super().__init__(name=name, device=device)

self.length = length
self.is_active = is_active

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
gamma = energy / REST_ENERGY
igamma2 = 1 / gamma**2 if gamma != 0 else torch.tensor(0.0)
Expand Down Expand Up @@ -1534,20 +1385,6 @@ class Solenoid(Element):
k: torch.Tensor = torch.tensor(0.0)
misalignment: torch.Tensor = torch.tensor((0.0, 0.0))

def __init__(
self,
length: torch.Tensor = torch.tensor(0.0),
k: torch.Tensor = torch.tensor(0.0),
misalignment: torch.Tensor = torch.tensor((0.0, 0.0)),
name: Optional[str] = None,
device: str = "auto",
) -> None:
super().__init__(name=name, device=device)

self.length = length
self.k = k
self.misalignment = tuple(misalignment)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
gamma = energy / REST_ENERGY
c = torch.cos(self.length * self.k)
Expand Down Expand Up @@ -1627,16 +1464,8 @@ class Segment(Element):

elements: list[Element]

def __init__(
self, cell: list[Element], name: str = "unnamed", device: str = "auto"
) -> None:
self.name = name

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device

self.elements = cell
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

for element in self.elements:
element.device = self.device
Expand Down

0 comments on commit 198e9bb

Please sign in to comment.