Skip to content

Commit

Permalink
Merge branch 'master' into 253-default-dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
Hespe committed Oct 7, 2024
2 parents 8d91db7 + 73f53e6 commit 91296db
Show file tree
Hide file tree
Showing 64 changed files with 2,217 additions and 1,743 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cheetah.egg-info
.vscode
dist
.coverage
.idea

*.egg-info

Expand All @@ -14,4 +15,5 @@ build
distributions

docs/_build
dev*

dev*
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

## v0.7.0 [🚧 Work in Progress]

This is a major release with significant upgrades under the hood of Cheetah. Despite extensive testing, you might still encounter a few bugs. Please report them by opening an issue, so we can fix them as soon as possible and improve the experience for everyone.

### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe)
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)

### 🚀 Features

Expand All @@ -20,7 +23,7 @@
- Port Bmad-X tracking methods to Cheetah for `Quadrupole`, `Drift`, and `Dipole` (see #153, #240) (@jp-ga, @jank324)
- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240) (@jp-ga)
- `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe)
- Implement a converter for lattice files imported from Elegant (see #222) (@hespe)
- Implement a converter for lattice files imported from Elegant (see #222, #251) (@hespe)

### 🐛 Bug fixes

Expand All @@ -33,6 +36,7 @@
- Fix an issue where splitting elements would result in splits with a different `dtype` (see #211) (@jank324)
- Fix issue in Bmad import where collimators had no length by interpreting them as `Drift` + `Aperture` (see #249) (@jank324)
- Fix NumPy 2 compatibility issues with PyTorch on Windows (see #220, #242) (@hespe)
- Fix issue with Dipole hgap conversion in Bmad import (see #261) (@cr-xu)

### 🐆 Other

Expand Down
6 changes: 3 additions & 3 deletions cheetah/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import cheetah.converters # noqa: F401
from cheetah.accelerator import ( # noqa: F401
from . import converters # noqa: F401
from .accelerator import ( # noqa: F401
BPM,
Aperture,
Cavity,
Expand All @@ -18,4 +18,4 @@
Undulator,
VerticalCorrector,
)
from cheetah.particles import ParameterBeam, ParticleBeam # noqa: F401
from .particles import ParameterBeam, ParticleBeam # noqa: F401
26 changes: 7 additions & 19 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import Size, nn

from cheetah.particles import Beam, ParticleBeam
from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype
from torch import nn

from ..particles import Beam, ParticleBeam
from ..utils import UniqueNameGenerator, verify_device_and_dtype
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -111,31 +110,20 @@ def track(self, incoming: Beam) -> Beam:
else ParticleBeam.empty
)

def broadcast(self, shape: Size) -> Element:
new_aperture = self.__class__(
x_max=self.x_max.repeat(shape),
y_max=self.y_max.repeat(shape),
shape=self.shape,
is_active=self.is_active,
name=self.name,
device=self.x_max.device,
dtype=self.x_max.dtype,
)
new_aperture.length = self.length.repeat(shape)
return new_aperture

def split(self, resolution: torch.Tensor) -> list[Element]:
# TODO: Implement splitting for aperture properly, for now just return self
return [self]

def plot(self, ax: plt.Axes, s: float) -> None:
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
plot_s = s[vector_idx] if s.dim() > 0 else s

alpha = 1 if self.is_active else 0.2
height = 0.4

dummy_length = 0.0

patch = Rectangle(
(s, 0), dummy_length, height, color="tab:pink", alpha=alpha, zorder=2
(plot_s, 0), dummy_length, height, color="tab:pink", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
17 changes: 6 additions & 11 deletions cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import Size

from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.utils import UniqueNameGenerator

from ..particles import Beam, ParameterBeam, ParticleBeam
from ..utils import UniqueNameGenerator
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -50,18 +48,15 @@ def track(self, incoming: Beam) -> Beam:

return deepcopy(incoming)

def broadcast(self, shape: Size) -> Element:
new_bpm = self.__class__(is_active=self.is_active, name=self.name)
new_bpm.length = self.length.repeat(shape)
return new_bpm

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]

def plot(self, ax: plt.Axes, s: float) -> None:
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
plot_s = s[vector_idx] if s.dim() > 0 else s

alpha = 1 if self.is_active else 0.2
patch = Rectangle(
(s, -0.3), 0, 0.3 * 2, color="darkkhaki", alpha=alpha, zorder=2
(plot_s, -0.3), 0, 0.3 * 2, color="darkkhaki", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
82 changes: 37 additions & 45 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from matplotlib.patches import Rectangle
from scipy import constants
from scipy.constants import physical_constants
from torch import Size, nn

from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.track_methods import base_rmatrix
from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype

from torch import nn

from ..particles import Beam, ParameterBeam, ParticleBeam
from ..track_methods import base_rmatrix
from ..utils import (
UniqueNameGenerator,
compute_relativistic_factors,
verify_device_and_dtype,
)
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -113,14 +116,7 @@ def _track_beam(self, incoming: Beam) -> Beam:
Track particles through the cavity. The input can be a `ParameterBeam` or a
`ParticleBeam`.
"""
beta0 = torch.full_like(self.length, 1.0)
igamma2 = torch.full_like(self.length, 0.0)
g0 = torch.full_like(self.length, 1e10)

mask = incoming.energy != 0
g0[mask] = incoming.energy[mask] / electron_mass_eV
igamma2[mask] = 1 / g0[mask] ** 2
beta0[mask] = torch.sqrt(1 - igamma2[mask])
gamma0, igamma2, beta0 = compute_relativistic_factors(incoming.energy)

phi = torch.deg2rad(self.phase)

Expand All @@ -141,8 +137,7 @@ def _track_beam(self, incoming: Beam) -> Beam:
if torch.any(incoming.energy + delta_energy > 0):
k = 2 * torch.pi * self.frequency / constants.speed_of_light
outgoing_energy = incoming.energy + delta_energy
g1 = outgoing_energy / electron_mass_eV
beta1 = torch.sqrt(1 - 1 / g1**2)
gamma1, _, beta1 = compute_relativistic_factors(outgoing_energy)

if isinstance(incoming, ParameterBeam):
outgoing_mu[..., 5] = incoming._mu[..., 5] * incoming.energy * beta0 / (
Expand Down Expand Up @@ -177,18 +172,18 @@ def _track_beam(self, incoming: Beam) -> Beam:
if torch.any(delta_energy > 0):
T566 = (
self.length
* (beta0**3 * g0**3 - beta1**3 * g1**3)
/ (2 * beta0 * beta1**3 * g0 * (g0 - g1) * g1**3)
* (beta0**3 * gamma0**3 - beta1**3 * gamma1**3)
/ (2 * beta0 * beta1**3 * gamma0 * (gamma0 - gamma1) * gamma1**3)
)
T556 = (
beta0
* k
* self.length
* dgamma
* g0
* (beta1**3 * g1**3 + beta0 * (g0 - g1**3))
* gamma0
* (beta1**3 * gamma1**3 + beta0 * (gamma0 - gamma1**3))
* torch.sin(phi)
/ (beta1**3 * g1**3 * (g0 - g1) ** 2)
/ (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 2)
)
T555 = (
beta0**2
Expand All @@ -199,15 +194,15 @@ def _track_beam(self, incoming: Beam) -> Beam:
* (
dgamma
* (
2 * g0 * g1**3 * (beta0 * beta1**3 - 1)
+ g0**2
+ 3 * g1**2
2 * gamma0 * gamma1**3 * (beta0 * beta1**3 - 1)
+ gamma0**2
+ 3 * gamma1**2
- 2
)
/ (beta1**3 * g1**3 * (g0 - g1) ** 3)
/ (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 3)
* torch.sin(phi) ** 2
- (g1 * g0 * (beta1 * beta0 - 1) + 1)
/ (beta1 * g1 * (g0 - g1) ** 2)
- (gamma1 * gamma0 * (beta1 * beta0 - 1) + 1)
/ (beta1 * gamma1 * (gamma0 - gamma1) ** 2)
* torch.cos(phi)
)
)
Expand Down Expand Up @@ -240,9 +235,9 @@ def _track_beam(self, incoming: Beam) -> Beam:

if isinstance(incoming, ParameterBeam):
outgoing = ParameterBeam(
outgoing_mu,
outgoing_cov,
outgoing_energy,
mu=outgoing_mu,
cov=outgoing_cov,
energy=outgoing_energy,
total_charge=incoming.total_charge,
device=outgoing_mu.device,
dtype=outgoing_mu.dtype,
Expand Down Expand Up @@ -309,7 +304,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
* self.frequency
/ torch.tensor(constants.speed_of_light, **factory_kwargs)
)
r55_cor = 0.0
r55_cor = torch.tensor(0.0, **factory_kwargs)
if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if?
beta0 = torch.sqrt(1 - 1 / Ei**2)
beta1 = torch.sqrt(1 - 1 / Ef**2)
Expand All @@ -331,7 +326,12 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
r66 = Ei / Ef * beta0 / beta1
r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV)

R = torch.eye(7, **factory_kwargs).repeat((*self.length.shape, 1, 1))
# Make sure that all matrix elements have the same shape
r11, r12, r21, r22, r55_cor, r56, r65, r66 = torch.broadcast_tensors(
r11, r12, r21, r22, r55_cor, r56, r65, r66
)

R = torch.eye(7, **factory_kwargs).repeat((*r11.shape, 1, 1))
R[..., 0, 0] = r11
R[..., 0, 1] = r12
R[..., 1, 0] = r21
Expand All @@ -347,28 +347,20 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:

return R

def broadcast(self, shape: Size) -> Element:
return self.__class__(
length=self.length.repeat(shape),
voltage=self.voltage.repeat(shape),
phase=self.phase.repeat(shape),
frequency=self.frequency.repeat(shape),
name=self.name,
device=self.length.device,
dtype=self.length.dtype,
)

def split(self, resolution: torch.Tensor) -> list[Element]:
# TODO: Implement splitting for cavity properly, for now just returns the
# element itself
return [self]

def plot(self, ax: plt.Axes, s: float) -> None:
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
plot_s = s[vector_idx] if s.dim() > 0 else s
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length

alpha = 1 if self.is_active else 0.2
height = 0.4

patch = Rectangle(
(s, 0), self.length[0], height, color="gold", alpha=alpha, zorder=2
(plot_s, 0), plot_length, height, color="gold", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
23 changes: 8 additions & 15 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import Size, nn

from cheetah.particles import Beam
from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype
from torch import nn

from ..particles import Beam
from ..utils import UniqueNameGenerator, verify_device_and_dtype
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -87,15 +86,6 @@ def from_merging_elements(
def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return self._transfer_map

def broadcast(self, shape: Size) -> Element:
return self.__class__(
self._transfer_map.repeat((*shape, 1, 1)),
length=self.length.repeat(shape),
name=self.name,
device=self._transfer_map.device,
dtype=self._transfer_map.dtype,
)

@property
def is_skippable(self) -> bool:
return True
Expand All @@ -113,8 +103,11 @@ def defining_features(self) -> list[str]:
def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]

def plot(self, ax: plt.Axes, s: float) -> None:
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
plot_s = s[vector_idx] if s.dim() > 0 else s
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length

height = 0.4

patch = Rectangle((s, 0), self.length[0], height, color="tab:olive", zorder=2)
patch = Rectangle((plot_s, 0), plot_length, height, color="tab:olive", zorder=2)
ax.add_patch(patch)
Loading

0 comments on commit 91296db

Please sign in to comment.