Skip to content

Commit

Permalink
Merge branch 'master' into 257-update-latticejson-version-tag
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 authored Oct 1, 2024
2 parents 125819e + 8a38b63 commit cd81614
Show file tree
Hide file tree
Showing 53 changed files with 1,336 additions and 1,472 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cheetah.egg-info
.vscode
dist
.coverage
.idea

*.egg-info

Expand All @@ -14,4 +15,5 @@ build
distributions

docs/_build
dev*

dev*
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

## v0.7.0 [🚧 Work in Progress]

This is a major release with significant upgrades under the hood of Cheetah. Despite extensive testing, you might still encounter a few bugs. Please report them by opening an issue, so we can fix them as soon as possible and improve the experience for everyone.

### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe)
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)

### 🚀 Features

Expand Down
6 changes: 3 additions & 3 deletions cheetah/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import cheetah.converters # noqa: F401
from cheetah.accelerator import ( # noqa: F401
from . import converters # noqa: F401
from .accelerator import ( # noqa: F401
BPM,
Aperture,
Cavity,
Expand All @@ -18,4 +18,4 @@
Undulator,
VerticalCorrector,
)
from cheetah.particles import ParameterBeam, ParticleBeam # noqa: F401
from .particles import ParameterBeam, ParticleBeam # noqa: F401
20 changes: 3 additions & 17 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import Size, nn

from cheetah.particles import Beam, ParticleBeam
from cheetah.utils import UniqueNameGenerator
from torch import nn

from ..particles import Beam, ParticleBeam
from ..utils import UniqueNameGenerator
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -110,19 +109,6 @@ def track(self, incoming: Beam) -> Beam:
else ParticleBeam.empty
)

def broadcast(self, shape: Size) -> Element:
new_aperture = self.__class__(
x_max=self.x_max.repeat(shape),
y_max=self.y_max.repeat(shape),
shape=self.shape,
is_active=self.is_active,
name=self.name,
device=self.x_max.device,
dtype=self.x_max.dtype,
)
new_aperture.length = self.length.repeat(shape)
return new_aperture

def split(self, resolution: torch.Tensor) -> list[Element]:
# TODO: Implement splitting for aperture properly, for now just return self
return [self]
Expand Down
11 changes: 2 additions & 9 deletions cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import Size

from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.utils import UniqueNameGenerator

from ..particles import Beam, ParameterBeam, ParticleBeam
from ..utils import UniqueNameGenerator
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -50,11 +48,6 @@ def track(self, incoming: Beam) -> Beam:

return deepcopy(incoming)

def broadcast(self, shape: Size) -> Element:
new_bpm = self.__class__(is_active=self.is_active, name=self.name)
new_bpm.length = self.length.repeat(shape)
return new_bpm

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]

Expand Down
69 changes: 27 additions & 42 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from matplotlib.patches import Rectangle
from scipy import constants
from scipy.constants import physical_constants
from torch import Size, nn

from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.track_methods import base_rmatrix
from cheetah.utils import UniqueNameGenerator
from torch import nn

from ..particles import Beam, ParameterBeam, ParticleBeam
from ..track_methods import base_rmatrix
from ..utils import UniqueNameGenerator, compute_relativistic_factors
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -110,14 +109,7 @@ def _track_beam(self, incoming: Beam) -> Beam:
Track particles through the cavity. The input can be a `ParameterBeam` or a
`ParticleBeam`.
"""
beta0 = torch.full_like(self.length, 1.0)
igamma2 = torch.full_like(self.length, 0.0)
g0 = torch.full_like(self.length, 1e10)

mask = incoming.energy != 0
g0[mask] = incoming.energy[mask] / electron_mass_eV
igamma2[mask] = 1 / g0[mask] ** 2
beta0[mask] = torch.sqrt(1 - igamma2[mask])
gamma0, igamma2, beta0 = compute_relativistic_factors(incoming.energy)

phi = torch.deg2rad(self.phase)

Expand All @@ -138,8 +130,7 @@ def _track_beam(self, incoming: Beam) -> Beam:
if torch.any(incoming.energy + delta_energy > 0):
k = 2 * torch.pi * self.frequency / constants.speed_of_light
outgoing_energy = incoming.energy + delta_energy
g1 = outgoing_energy / electron_mass_eV
beta1 = torch.sqrt(1 - 1 / g1**2)
gamma1, _, beta1 = compute_relativistic_factors(outgoing_energy)

if isinstance(incoming, ParameterBeam):
outgoing_mu[..., 5] = incoming._mu[..., 5] * incoming.energy * beta0 / (
Expand Down Expand Up @@ -174,18 +165,18 @@ def _track_beam(self, incoming: Beam) -> Beam:
if torch.any(delta_energy > 0):
T566 = (
self.length
* (beta0**3 * g0**3 - beta1**3 * g1**3)
/ (2 * beta0 * beta1**3 * g0 * (g0 - g1) * g1**3)
* (beta0**3 * gamma0**3 - beta1**3 * gamma1**3)
/ (2 * beta0 * beta1**3 * gamma0 * (gamma0 - gamma1) * gamma1**3)
)
T556 = (
beta0
* k
* self.length
* dgamma
* g0
* (beta1**3 * g1**3 + beta0 * (g0 - g1**3))
* gamma0
* (beta1**3 * gamma1**3 + beta0 * (gamma0 - gamma1**3))
* torch.sin(phi)
/ (beta1**3 * g1**3 * (g0 - g1) ** 2)
/ (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 2)
)
T555 = (
beta0**2
Expand All @@ -196,15 +187,15 @@ def _track_beam(self, incoming: Beam) -> Beam:
* (
dgamma
* (
2 * g0 * g1**3 * (beta0 * beta1**3 - 1)
+ g0**2
+ 3 * g1**2
2 * gamma0 * gamma1**3 * (beta0 * beta1**3 - 1)
+ gamma0**2
+ 3 * gamma1**2
- 2
)
/ (beta1**3 * g1**3 * (g0 - g1) ** 3)
/ (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 3)
* torch.sin(phi) ** 2
- (g1 * g0 * (beta1 * beta0 - 1) + 1)
/ (beta1 * g1 * (g0 - g1) ** 2)
- (gamma1 * gamma0 * (beta1 * beta0 - 1) + 1)
/ (beta1 * gamma1 * (gamma0 - gamma1) ** 2)
* torch.cos(phi)
)
)
Expand Down Expand Up @@ -237,9 +228,9 @@ def _track_beam(self, incoming: Beam) -> Beam:

if isinstance(incoming, ParameterBeam):
outgoing = ParameterBeam(
outgoing_mu,
outgoing_cov,
outgoing_energy,
mu=outgoing_mu,
cov=outgoing_cov,
energy=outgoing_energy,
total_charge=incoming.total_charge,
device=outgoing_mu.device,
dtype=outgoing_mu.dtype,
Expand Down Expand Up @@ -302,7 +293,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
beta1 = torch.tensor(1.0)

k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light)
r55_cor = 0.0
r55_cor = torch.tensor(0.0)
if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if?
beta0 = torch.sqrt(1 - 1 / Ei**2)
beta1 = torch.sqrt(1 - 1 / Ef**2)
Expand All @@ -324,7 +315,12 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
r66 = Ei / Ef * beta0 / beta1
r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV)

R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
# Make sure that all matrix elements have the same shape
r11, r12, r21, r22, r55_cor, r56, r65, r66 = torch.broadcast_tensors(
r11, r12, r21, r22, r55_cor, r56, r65, r66
)

R = torch.eye(7, device=device, dtype=dtype).repeat((*r11.shape, 1, 1))
R[..., 0, 0] = r11
R[..., 0, 1] = r12
R[..., 1, 0] = r21
Expand All @@ -340,17 +336,6 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:

return R

def broadcast(self, shape: Size) -> Element:
return self.__class__(
length=self.length.repeat(shape),
voltage=self.voltage.repeat(shape),
phase=self.phase.repeat(shape),
frequency=self.frequency.repeat(shape),
name=self.name,
device=self.length.device,
dtype=self.length.dtype,
)

def split(self, resolution: torch.Tensor) -> list[Element]:
# TODO: Implement splitting for cavity properly, for now just returns the
# element itself
Expand Down
16 changes: 3 additions & 13 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import Size, nn

from cheetah.particles import Beam
from cheetah.utils import UniqueNameGenerator
from torch import nn

from ..particles import Beam
from ..utils import UniqueNameGenerator
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -86,15 +85,6 @@ def from_merging_elements(
def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return self._transfer_map

def broadcast(self, shape: Size) -> Element:
return self.__class__(
self._transfer_map.repeat((*shape, 1, 1)),
length=self.length.repeat(shape),
name=self.name,
device=self._transfer_map.device,
dtype=self._transfer_map.dtype,
)

@property
def is_skippable(self) -> bool:
return True
Expand Down
Loading

0 comments on commit cd81614

Please sign in to comment.