Skip to content

Commit

Permalink
Fix plotting in 0-dimensional case
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Oct 2, 2024
1 parent 06eda4e commit e7b12fa
Show file tree
Hide file tree
Showing 11 changed files with 18 additions and 31 deletions.
2 changes: 1 addition & 1 deletion cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.4

patch = Rectangle(
(s, 0), self.length[0], height, color="gold", alpha=alpha, zorder=2
(s, 0), self.length, height, color="gold", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,5 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.4

patch = Rectangle((s, 0), self.length[0], height, color="tab:olive", zorder=2)
patch = Rectangle((s, 0), self.length, height, color="tab:olive", zorder=2)
ax.add_patch(patch)
4 changes: 2 additions & 2 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,9 @@ def defining_features(self) -> list[str]:

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)
height = 0.8 * (np.sign(self.angle) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:green", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:green", alpha=alpha, zorder=2
)
ax.add_patch(patch)
4 changes: 2 additions & 2 deletions cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)
height = 0.8 * (np.sign(self.angle) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:blue", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:blue", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
4 changes: 2 additions & 2 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.k1[0]) if self.is_active else 1)
height = 0.8 * (np.sign(self.k1) if self.is_active else 1)
patch = Rectangle(
(s, 0), self.length[0], height, color="tab:red", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:red", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
8 changes: 4 additions & 4 deletions cheetah/accelerator/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
]

def plot(self, ax: plt.Axes, s: float) -> None:
element_lengths = [element.length[0] for element in self.elements]
element_lengths = [element.length for element in self.elements]
element_ss = [0] + [
sum(element_lengths[: i + 1]) for i, _ in enumerate(element_lengths)
]
Expand Down Expand Up @@ -423,7 +423,7 @@ def plot_reference_particle_traces(
reference_segment = deepcopy(self)
splits = reference_segment.split(resolution=torch.tensor(resolution))

split_lengths = [split.length[0] for split in splits]
split_lengths = [split.length for split in splits]
ss = [0] + [sum(split_lengths[: i + 1]) for i, _ in enumerate(split_lengths)]

references = []
Expand Down Expand Up @@ -464,7 +464,7 @@ def plot_reference_particle_traces(

for particle_index in range(num_particles):
xs = [
float(reference_beam.x[0, particle_index].cpu())
reference_beam.x[particle_index]
for reference_beam in references
if reference_beam is not Beam.empty
]
Expand All @@ -475,7 +475,7 @@ def plot_reference_particle_traces(

for particle_index in range(num_particles):
ys = [
float(reference_beam.ys[0, particle_index].cpu())
reference_beam.y[particle_index]
for reference_beam in references
if reference_beam is not Beam.empty
]
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/solenoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.8

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:orange", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:orange", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.4

patch = Rectangle(
(s, 0), self.length[0], height, color="olive", alpha=alpha, zorder=2
(s, 0), self.length, height, color="olive", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/undulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.4

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:purple", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:purple", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
4 changes: 2 additions & 2 deletions cheetah/accelerator/vertical_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)
height = 0.8 * (np.sign(self.angle) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:cyan", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:cyan", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
15 changes: 1 addition & 14 deletions docs/examples/simple.ipynb

Large diffs are not rendered by default.

0 comments on commit e7b12fa

Please sign in to comment.