Skip to content

Commit

Permalink
Fix simple example Notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Sep 14, 2023
1 parent 988dd67 commit edcdb36
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 30 deletions.
26 changes: 16 additions & 10 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,15 +984,20 @@ class BPM(Element):
"""
Beam Position Monitor (BPM) in a particle accelerator.
:param is_active: If `True` the BPM is active and will record the beam's position.
If `False` the BPM is inactive and will not record the beam's position.
:param name: Unique identifier of the element.
:param device: Device to move the beam's particle array to. If set to `"auto"` a
CUDA GPU is selected if available. The CPU is used otherwise.
"""

def __init__(self, name: Optional[str] = None, device: str = "auto") -> None:
def __init__(
self, is_active: bool = False, name: Optional[str] = None, device: str = "auto"
) -> None:
super().__init__(name=name, device=device)

self.reading = (None, None)
self.is_active = is_active
self.reading = None

@property
def is_skippable(self) -> bool:
Expand All @@ -1003,17 +1008,16 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:

def track(self, incoming: Beam) -> Beam:
if incoming is Beam.empty:
self.reading = (None, None)
return Beam.empty
self.reading = None
elif isinstance(incoming, ParameterBeam):
self.reading = (incoming.mu_x, incoming.mu_y)
return ParameterBeam(incoming._mu, incoming._cov, incoming.energy)
self.reading = torch.stack([incoming._mu_x, incoming._mu_y])
elif isinstance(incoming, ParticleBeam):
self.reading = (incoming.mu_x, incoming.mu_y)
return ParticleBeam(incoming.particles, incoming.energy, device=self.device)
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
else:
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")

return deepcopy(incoming)

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]

Expand Down Expand Up @@ -1731,9 +1735,11 @@ def plot_reference_particle_traces(
in the plot.
"""
reference_segment = deepcopy(self)
splits = reference_segment.split(resolution)
splits = reference_segment.split(resolution=torch.tensor(resolution))

split_lengths = [split.length for split in splits]
split_lengths = [
split.length if hasattr(split, "length") else 0.0 for split in splits
]
ss = [0] + [sum(split_lengths[: i + 1]) for i, _ in enumerate(split_lengths)]

references = []
Expand Down
Loading

0 comments on commit edcdb36

Please sign in to comment.