Skip to content

Commit

Permalink
Add properties via pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Sep 12, 2023
1 parent a0588e6 commit 7204199
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
53 changes: 52 additions & 1 deletion cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pydantic
import torch
from matplotlib.patches import Rectangle
from scipy import constants
Expand All @@ -28,7 +29,7 @@
)


class Element(ABC):
class Element(ABC, pydantic.BaseModel):
"""
Base class for elements of particle accelerators.
Expand All @@ -37,6 +38,9 @@ class Element(ABC):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

name: str
device: torch.device

def __init__(self, name: Optional[str] = None, device: str = "auto") -> None:
global ELEMENT_COUNT
if name is not None:
Expand Down Expand Up @@ -174,6 +178,8 @@ class Drift(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

length: torch.Tensor

def __init__(
self, length: torch.Tensor, name: Optional[str] = None, device: str = "auto"
) -> None:
Expand Down Expand Up @@ -240,6 +246,11 @@ class Quadrupole(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

length: torch.Tensor
k1: torch.Tensor = torch.tensor(0.0)
misalignment: torch.Tensor = torch.tensor([0.0, 0.0])
tilt: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
Expand Down Expand Up @@ -336,6 +347,15 @@ class Dipole(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

length: torch.Tensor
angle: torch.Tensor = torch.tensor(0.0)
e1: torch.Tensor = torch.tensor(0.0)
e2: torch.Tensor = torch.tensor(0.0)
tilt: torch.Tensor = torch.tensor(0.0)
fringe_integral: torch.Tensor = torch.tensor(0.0)
fringe_integral_exit: Optional[torch.Tensor] = None
gap: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
Expand Down Expand Up @@ -566,6 +586,9 @@ class HorizontalCorrector(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

length: torch.Tensor
angle: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
Expand Down Expand Up @@ -646,6 +669,9 @@ class VerticalCorrector(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

length: torch.Tensor
angle: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
Expand Down Expand Up @@ -728,6 +754,11 @@ class Cavity(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

length: torch.Tensor
voltage: torch.Tensor = torch.tensor(0.0)
phase: torch.Tensor = torch.tensor(0.0)
frequency: torch.Tensor = torch.tensor(0.0)

def __init__(
self,
length: torch.Tensor,
Expand Down Expand Up @@ -1131,6 +1162,12 @@ class Screen(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

resolution: torch.Tensor = torch.tensor((1024, 1024))
pixel_size: torch.Tensor = torch.tensor((1e-3, 1e-3))
binning: torch.Tensor = torch.tensor(1)
misalignment: torch.Tensor = torch.tensor((0.0, 0.0))
is_active: bool = False

def __init__(
self,
resolution: torch.Tensor = torch.tensor((1024, 1024)),
Expand Down Expand Up @@ -1311,6 +1348,11 @@ class Aperture(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

x_max: torch.Tensor = torch.inf
y_max: torch.Tensor = torch.inf
shape: Literal["rectangular", "elliptical"] = "rectangular"
is_active: bool = True

def __init__(
self,
x_max: torch.Tensor = torch.inf,
Expand Down Expand Up @@ -1415,6 +1457,9 @@ class Undulator(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

length: torch.Tensor
is_active: bool = False

def __init__(
self,
length: torch.Tensor,
Expand Down Expand Up @@ -1491,6 +1536,10 @@ class Solenoid(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

length: torch.Tensor
k: torch.Tensor = torch.tensor(0.0)
misalignment: torch.Tensor = torch.tensor((0.0, 0.0))

def __init__(
self,
length: torch.Tensor = torch.tensor(0.0),
Expand Down Expand Up @@ -1582,6 +1631,8 @@ class Segment(Element):
CUDA GPU is selected if available. The CPU is used otherwise.
"""

elements: list[Element]

def __init__(
self, cell: list[Element], name: Optional[str] = None, device: str = "auto"
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
long_description_content_type="text/markdown",
packages=["cheetah"],
python_requires=">=3.9",
install_requires=["torch", "matplotlib", "numpy", "scipy"],
install_requires=["matplotlib", "numpy", "pydantic", "scipy", "torch"],
)

0 comments on commit 7204199

Please sign in to comment.