diff --git a/.gitignore b/.gitignore index 400b8bb0..a321426f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ cheetah.egg-info .vscode dist .coverage +.idea *.egg-info @@ -14,4 +15,5 @@ build distributions docs/_build -dev* \ No newline at end of file + +dev* diff --git a/CHANGELOG.md b/CHANGELOG.md index 96765227..9dd601e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,10 +2,13 @@ ## v0.7.0 [🚧 Work in Progress] +This is a major release with significant upgrades under the hood of Cheetah. Despite extensive testing, you might still encounter a few bugs. Please report them by opening an issue, so we can fix them as soon as possible and improve the experience for everyone. + ### 🚨 Breaking Changes -- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe) +- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe, @roussel-ryan) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu) +- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) ### 🚀 Features diff --git a/cheetah/__init__.py b/cheetah/__init__.py index fc3005fa..bd3c3990 100644 --- a/cheetah/__init__.py +++ b/cheetah/__init__.py @@ -1,5 +1,5 @@ -import cheetah.converters # noqa: F401 -from cheetah.accelerator import ( # noqa: F401 +from . import converters # noqa: F401 +from .accelerator import ( # noqa: F401 BPM, Aperture, Cavity, @@ -18,4 +18,4 @@ Undulator, VerticalCorrector, ) -from cheetah.particles import ParameterBeam, ParticleBeam # noqa: F401 +from .particles import ParameterBeam, ParticleBeam # noqa: F401 diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index 309d4219..2d1d9ad0 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -3,11 +3,10 @@ import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import Size, nn - -from cheetah.particles import Beam, ParticleBeam -from cheetah.utils import UniqueNameGenerator +from torch import nn +from ..particles import Beam, ParticleBeam +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -110,19 +109,6 @@ def track(self, incoming: Beam) -> Beam: else ParticleBeam.empty ) - def broadcast(self, shape: Size) -> Element: - new_aperture = self.__class__( - x_max=self.x_max.repeat(shape), - y_max=self.y_max.repeat(shape), - shape=self.shape, - is_active=self.is_active, - name=self.name, - device=self.x_max.device, - dtype=self.x_max.dtype, - ) - new_aperture.length = self.length.repeat(shape) - return new_aperture - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for aperture properly, for now just return self return [self] diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index 5aede5ef..945c9d38 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -4,11 +4,9 @@ import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import Size - -from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.utils import UniqueNameGenerator +from ..particles import Beam, ParameterBeam, ParticleBeam +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -50,11 +48,6 @@ def track(self, incoming: Beam) -> Beam: return deepcopy(incoming) - def broadcast(self, shape: Size) -> Element: - new_bpm = self.__class__(is_active=self.is_active, name=self.name) - new_bpm.length = self.length.repeat(shape) - return new_bpm - def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 60385e03..c1b53b17 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -5,12 +5,11 @@ from matplotlib.patches import Rectangle from scipy import constants from scipy.constants import physical_constants -from torch import Size, nn - -from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.track_methods import base_rmatrix -from cheetah.utils import UniqueNameGenerator +from torch import nn +from ..particles import Beam, ParameterBeam, ParticleBeam +from ..track_methods import base_rmatrix +from ..utils import UniqueNameGenerator, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -110,14 +109,7 @@ def _track_beam(self, incoming: Beam) -> Beam: Track particles through the cavity. The input can be a `ParameterBeam` or a `ParticleBeam`. """ - beta0 = torch.full_like(self.length, 1.0) - igamma2 = torch.full_like(self.length, 0.0) - g0 = torch.full_like(self.length, 1e10) - - mask = incoming.energy != 0 - g0[mask] = incoming.energy[mask] / electron_mass_eV - igamma2[mask] = 1 / g0[mask] ** 2 - beta0[mask] = torch.sqrt(1 - igamma2[mask]) + gamma0, igamma2, beta0 = compute_relativistic_factors(incoming.energy) phi = torch.deg2rad(self.phase) @@ -138,8 +130,7 @@ def _track_beam(self, incoming: Beam) -> Beam: if torch.any(incoming.energy + delta_energy > 0): k = 2 * torch.pi * self.frequency / constants.speed_of_light outgoing_energy = incoming.energy + delta_energy - g1 = outgoing_energy / electron_mass_eV - beta1 = torch.sqrt(1 - 1 / g1**2) + gamma1, _, beta1 = compute_relativistic_factors(outgoing_energy) if isinstance(incoming, ParameterBeam): outgoing_mu[..., 5] = incoming._mu[..., 5] * incoming.energy * beta0 / ( @@ -174,18 +165,18 @@ def _track_beam(self, incoming: Beam) -> Beam: if torch.any(delta_energy > 0): T566 = ( self.length - * (beta0**3 * g0**3 - beta1**3 * g1**3) - / (2 * beta0 * beta1**3 * g0 * (g0 - g1) * g1**3) + * (beta0**3 * gamma0**3 - beta1**3 * gamma1**3) + / (2 * beta0 * beta1**3 * gamma0 * (gamma0 - gamma1) * gamma1**3) ) T556 = ( beta0 * k * self.length * dgamma - * g0 - * (beta1**3 * g1**3 + beta0 * (g0 - g1**3)) + * gamma0 + * (beta1**3 * gamma1**3 + beta0 * (gamma0 - gamma1**3)) * torch.sin(phi) - / (beta1**3 * g1**3 * (g0 - g1) ** 2) + / (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 2) ) T555 = ( beta0**2 @@ -196,15 +187,15 @@ def _track_beam(self, incoming: Beam) -> Beam: * ( dgamma * ( - 2 * g0 * g1**3 * (beta0 * beta1**3 - 1) - + g0**2 - + 3 * g1**2 + 2 * gamma0 * gamma1**3 * (beta0 * beta1**3 - 1) + + gamma0**2 + + 3 * gamma1**2 - 2 ) - / (beta1**3 * g1**3 * (g0 - g1) ** 3) + / (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 3) * torch.sin(phi) ** 2 - - (g1 * g0 * (beta1 * beta0 - 1) + 1) - / (beta1 * g1 * (g0 - g1) ** 2) + - (gamma1 * gamma0 * (beta1 * beta0 - 1) + 1) + / (beta1 * gamma1 * (gamma0 - gamma1) ** 2) * torch.cos(phi) ) ) @@ -237,9 +228,9 @@ def _track_beam(self, incoming: Beam) -> Beam: if isinstance(incoming, ParameterBeam): outgoing = ParameterBeam( - outgoing_mu, - outgoing_cov, - outgoing_energy, + mu=outgoing_mu, + cov=outgoing_cov, + energy=outgoing_energy, total_charge=incoming.total_charge, device=outgoing_mu.device, dtype=outgoing_mu.dtype, @@ -302,7 +293,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: beta1 = torch.tensor(1.0) k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light) - r55_cor = 0.0 + r55_cor = torch.tensor(0.0) if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if? beta0 = torch.sqrt(1 - 1 / Ei**2) beta1 = torch.sqrt(1 - 1 / Ef**2) @@ -324,7 +315,12 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: r66 = Ei / Ef * beta0 / beta1 r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV) - R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + # Make sure that all matrix elements have the same shape + r11, r12, r21, r22, r55_cor, r56, r65, r66 = torch.broadcast_tensors( + r11, r12, r21, r22, r55_cor, r56, r65, r66 + ) + + R = torch.eye(7, device=device, dtype=dtype).repeat((*r11.shape, 1, 1)) R[..., 0, 0] = r11 R[..., 0, 1] = r12 R[..., 1, 0] = r21 @@ -340,17 +336,6 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: return R - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - voltage=self.voltage.repeat(shape), - phase=self.phase.repeat(shape), - frequency=self.frequency.repeat(shape), - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for cavity properly, for now just returns the # element itself diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index ad20bec2..2f271af8 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -3,11 +3,10 @@ import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import Size, nn - -from cheetah.particles import Beam -from cheetah.utils import UniqueNameGenerator +from torch import nn +from ..particles import Beam +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -86,15 +85,6 @@ def from_merging_elements( def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return self._transfer_map - def broadcast(self, shape: Size) -> Element: - return self.__class__( - self._transfer_map.repeat((*shape, 1, 1)), - length=self.length.repeat(shape), - name=self.name, - device=self._transfer_map.device, - dtype=self._transfer_map.dtype, - ) - @property def is_skippable(self) -> bool: return True diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 573abf26..d6aab27c 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -5,12 +5,11 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn - -from cheetah.particles import Beam, ParticleBeam -from cheetah.track_methods import base_rmatrix, rotation_matrix -from cheetah.utils import UniqueNameGenerator, bmadx +from torch import nn +from ..particles import Beam, ParticleBeam +from ..track_methods import base_rmatrix, rotation_matrix +from ..utils import UniqueNameGenerator, bmadx from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -147,11 +146,7 @@ def __init__( @property def hx(self) -> torch.Tensor: - value = torch.zeros_like(self.length) - value[self.length != 0] = ( - self.angle[self.length != 0] / self.length[self.length != 0] - ) - return value + return torch.where(self.length == 0.0, 0.0, self.angle / self.length) @property def is_skippable(self) -> bool: @@ -189,6 +184,11 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: `ParticleBeam`. :return: Beam exiting the element. """ + # TODO: The renaming of the compinents of `incoming` to just the component name + # makes things hard to read. The resuse and overwriting of those component names + # throughout the function makes it even hard, is bad practice and should really + # be fixed! + # Compute Bmad coordinates and p0c x = incoming.x px = incoming.px @@ -224,9 +224,14 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: z, pz, p0c, electron_mass_eV ) + # Broadcast to align their shapes so that they can be stacked + x, px, y, py, tau, delta = torch.broadcast_tensors(x, px, y, py, tau, delta) + outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + (x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, @@ -260,48 +265,70 @@ def _bmadx_body( px_norm = torch.sqrt((1 + pz) ** 2 - py**2) # For simplicity phi1 = torch.arcsin(px / px_norm) g = self.angle / self.length - gp = g / px_norm + gp = g.unsqueeze(-1) / px_norm alpha = ( 2 - * (1 + g * x) - * torch.sin(self.angle + phi1) - * self.length - * bmadx.sinc(self.angle) - - gp * ((1 + g * x) * self.length * bmadx.sinc(self.angle)) ** 2 + * (1 + g.unsqueeze(-1) * x) + * torch.sin(self.angle.unsqueeze(-1) + phi1) + * self.length.unsqueeze(-1) + * bmadx.sinc(self.angle).unsqueeze(-1) + - gp + * ( + (1 + g.unsqueeze(-1) * x) + * self.length.unsqueeze(-1) + * bmadx.sinc(self.angle).unsqueeze(-1) + ) + ** 2 ) - x2_t1 = x * torch.cos(self.angle) + self.length**2 * g * bmadx.cosc(self.angle) + x2_t1 = x * torch.cos(self.angle.unsqueeze(-1)) + self.length.unsqueeze( + -1 + ) ** 2 * g.unsqueeze(-1) * bmadx.cosc(self.angle.unsqueeze(-1)) - x2_t2 = torch.sqrt((torch.cos(self.angle + phi1) ** 2) + gp * alpha) - x2_t3 = torch.cos(self.angle + phi1) + x2_t2 = torch.sqrt( + (torch.cos(self.angle.unsqueeze(-1) + phi1) ** 2) + gp * alpha + ) + x2_t3 = torch.cos(self.angle.unsqueeze(-1) + phi1) c1 = x2_t1 + alpha / (x2_t2 + x2_t3) c2 = x2_t1 + (x2_t2 - x2_t3) / gp - temp = torch.abs(self.angle + phi1) + temp = torch.abs(self.angle.unsqueeze(-1) + phi1) x2 = c1 * (temp < torch.pi / 2) + c2 * (temp >= torch.pi / 2) Lcu = ( - x2 - self.length**2 * g * bmadx.cosc(self.angle) - x * torch.cos(self.angle) + x2 + - self.length.unsqueeze(-1) ** 2 + * g.unsqueeze(-1) + * bmadx.cosc(self.angle.unsqueeze(-1)) + - x * torch.cos(self.angle.unsqueeze(-1)) ) - Lcv = -self.length * bmadx.sinc(self.angle) - x * torch.sin(self.angle) + Lcv = -self.length.unsqueeze(-1) * bmadx.sinc( + self.angle.unsqueeze(-1) + ) - x * torch.sin(self.angle.unsqueeze(-1)) - theta_p = 2 * (self.angle + phi1 - torch.pi / 2 - torch.arctan2(Lcv, Lcu)) + theta_p = 2 * ( + self.angle.unsqueeze(-1) + phi1 - torch.pi / 2 - torch.arctan2(Lcv, Lcu) + ) Lc = torch.sqrt(Lcu**2 + Lcv**2) Lp = Lc / bmadx.sinc(theta_p / 2) - P = p0c * (1 + pz) # In eV + P = p0c.unsqueeze(-1) * (1 + pz) # In eV E = torch.sqrt(P**2 + mc2**2) # In eV E0 = torch.sqrt(p0c**2 + mc2**2) # In eV beta = P / E beta0 = p0c / E0 x_f = x2 - px_f = px_norm * torch.sin(self.angle + phi1 - theta_p) + px_f = px_norm * torch.sin(self.angle.unsqueeze(-1) + phi1 - theta_p) y_f = y + py * Lp / px_norm - z_f = z + (beta * self.length / beta0) - ((1 + pz) * Lp / px_norm) + z_f = ( + z + + (beta * self.length.unsqueeze(-1) / beta0.unsqueeze(-1)) + - ((1 + pz) * Lp / px_norm) + ) return x_f, px_f, y_f, py, z_f, pz @@ -336,8 +363,8 @@ def _bmadx_fringe_linear( hy = -g * torch.tan( e - 2 * f_int * h_gap * g * (1 + torch.sin(e) ** 2) / torch.cos(e) ) - px_f = px + x * hx - py_f = py + y * hy + px_f = px + x * hx.unsqueeze(-1) + py_f = py + y * hy.unsqueeze(-1) return px_f, py_f @@ -412,22 +439,6 @@ def _transfer_map_exit(self) -> torch.Tensor: return tm - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - angle=self.angle.repeat(shape), - k1=self.k1.repeat(shape), - e1=self.e1.repeat(shape), - e2=self.e2.repeat(shape), - tilt=self.tilt.repeat(shape), - fringe_integral=self.fringe_integral.repeat(shape), - fringe_integral_exit=self.fringe_integral_exit.repeat(shape), - gap=self.gap.repeat(shape), - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for dipole properly, for now just returns the # element itself diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 86049373..4438c376 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -3,11 +3,10 @@ import matplotlib.pyplot as plt import torch from scipy.constants import physical_constants -from torch import Size, nn - -from cheetah.particles import Beam, ParticleBeam -from cheetah.utils import UniqueNameGenerator, bmadx +from torch import nn +from ..particles import Beam, ParticleBeam +from ..utils import UniqueNameGenerator, bmadx, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -42,19 +41,14 @@ def __init__( self.tracking_method = tracking_method def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - assert ( - energy.shape == self.length.shape - ), f"Beam shape {energy.shape} does not match element shape {self.length.shape}" - device = self.length.device dtype = self.length.dtype - gamma = energy / electron_mass_eV - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - beta = torch.sqrt(1 - igamma2) + _, igamma2, beta = compute_relativistic_factors(energy) + + vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape) - tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length tm[..., 4, 5] = -self.length / beta**2 * igamma2 @@ -112,24 +106,20 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: z, pz, p0c, electron_mass_eV ) + # Broadcast to align their shapes so that they can be stacked + x, px, y, py, tau, delta = torch.broadcast_tensors(x, px, y, py, tau, delta) + outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + [x, px, y, py, tau, delta, torch.ones_like(x)], dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, ) return outgoing_beam - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - tracking_method=self.tracking_method, - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - @property def is_skippable(self) -> bool: return self.tracking_method == "cheetah" diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 7ef4cd68..6cbf433a 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -5,8 +5,8 @@ import torch from torch import nn -from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.utils import UniqueNameGenerator +from ..particles import Beam, ParameterBeam, ParticleBeam +from ..utils import UniqueNameGenerator generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -22,7 +22,7 @@ def __init__(self, name: Optional[str] = None) -> None: super().__init__() self.name = name if name is not None else generate_unique_name() - self.register_buffer("length", torch.zeros((1,))) + self.register_buffer("length", torch.tensor(0.0)) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: r""" @@ -91,10 +91,6 @@ def forward(self, incoming: Beam) -> Beam: """Forward function required by `torch.nn.Module`. Simply calls `track`.""" return self.track(incoming) - def broadcast(self, shape: torch.Size) -> "Element": - """Broadcast the element to higher batch dimensions.""" - raise NotImplementedError - @property @abstractmethod def is_skippable(self) -> bool: diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index cbb2bf14..e17837d6 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -4,17 +4,13 @@ import numpy as np import torch from matplotlib.patches import Rectangle -from scipy.constants import physical_constants -from torch import Size, nn - -from cheetah.utils import UniqueNameGenerator +from torch import nn +from ..utils import UniqueNameGenerator, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") -electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 - class HorizontalCorrector(Element): """ @@ -52,12 +48,11 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - gamma = energy / electron_mass_eV - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - beta = torch.sqrt(1 - igamma2) + _, igamma2, beta = compute_relativistic_factors(energy) + + vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape) - tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 1, 6] = self.angle tm[..., 2, 3] = self.length @@ -65,22 +60,13 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return tm - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - angle=self.angle, - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - @property def is_skippable(self) -> bool: return True @property def is_active(self) -> bool: - return any(self.angle != 0) + return torch.any(self.angle != 0) def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() diff --git a/cheetah/accelerator/marker.py b/cheetah/accelerator/marker.py index 605c81df..643d4f46 100644 --- a/cheetah/accelerator/marker.py +++ b/cheetah/accelerator/marker.py @@ -2,11 +2,9 @@ import matplotlib.pyplot as plt import torch -from torch import Size - -from cheetah.particles import Beam -from cheetah.utils import UniqueNameGenerator +from ..particles import Beam +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -32,11 +30,6 @@ def track(self, incoming: Beam) -> Beam: # Markers would be able to record the beam tracked through them. return incoming - def broadcast(self, shape: Size) -> Element: - new_marker = self.__class__(name=self.name) - new_marker.length = self.length.repeat(shape) - return new_marker - @property def is_skippable(self) -> bool: return True diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index a7c5f4cf..4123121c 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -5,12 +5,11 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn - -from cheetah.particles import Beam, ParticleBeam -from cheetah.track_methods import base_rmatrix, misalignment_matrix -from cheetah.utils import UniqueNameGenerator, bmadx +from torch import nn +from ..particles import Beam, ParticleBeam +from ..track_methods import base_rmatrix, misalignment_matrix +from ..utils import UniqueNameGenerator, bmadx from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -175,6 +174,9 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: x, px, y, py = bmadx.offset_particle_unset( x_offset, y_offset, self.tilt, x, px, y, py ) + + # pz is unaffected by tracking, therefore needs to match vector dimensions + pz = pz * torch.ones_like(x) # End of Bmad-X tracking # Convert back to Cheetah coordinates @@ -191,25 +193,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) return outgoing_beam - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - k1=self.k1.repeat(shape), - misalignment=self.misalignment.repeat((*shape, 1)), - tilt=self.tilt.repeat(shape), - tracking_method=self.tracking_method, - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - @property def is_skippable(self) -> bool: return self.tracking_method == "cheetah" @property def is_active(self) -> bool: - return any(self.k1 != 0) + return torch.any(self.k1 != 0) def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() diff --git a/cheetah/accelerator/rbend.py b/cheetah/accelerator/rbend.py index e7c21586..9bb20b26 100644 --- a/cheetah/accelerator/rbend.py +++ b/cheetah/accelerator/rbend.py @@ -3,8 +3,7 @@ import torch from torch import nn -from cheetah.utils import UniqueNameGenerator - +from ..utils import UniqueNameGenerator from .dipole import Dipole generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 7227cdcc..81a946eb 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -4,12 +4,11 @@ import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import Size, nn +from torch import nn from torch.distributions import MultivariateNormal -from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.utils import UniqueNameGenerator, kde_histogram_2d - +from ..particles import Beam, ParameterBeam, ParticleBeam +from ..utils import UniqueNameGenerator, kde_histogram_2d from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -26,15 +25,20 @@ class Screen(Element): :param binning: Binning used by the camera. :param misalignment: Misalignment of the screen in meters given as a Tensor `(x, y)`. + :param method: Method used to generate the screen's reading. Can be either + "histogram" or "kde", defaults to "histogram". KDE will be slower but allows + backward differentiation. :param kde_bandwidth: Bandwidth used for the kernel density estimation in meters. Controls the smoothness of the distribution. + :param is_blocking: If `True` the screen is blocking and will stop the beam. :param is_active: If `True` the screen is active and will record the beam's distribution. If `False` the screen is inactive and will not record the beam's distribution. - :param method: Method used to generate the screen's reading. Can be either - "histogram" or "kde", defaults to "histogram". KDE will be slower but allows - backward differentiation. :param name: Unique identifier of the element. + + NOTE: `method='histogram'` currently does not support vectorisation. Please use + `method=`kde` instead. Similarly, `ParameterBeam` can also not be vectorised. + Please use `ParticleBeam` instead. """ def __init__( @@ -43,9 +47,10 @@ def __init__( pixel_size: Optional[Union[torch.Tensor, nn.Parameter]] = None, binning: Optional[Union[torch.Tensor, nn.Parameter]] = None, misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, + method: Literal["histogram", "kde"] = "histogram", kde_bandwidth: Optional[Union[torch.Tensor, nn.Parameter]] = None, + is_blocking: bool = False, is_active: bool = False, - method: Literal["histogram", "kde"] = "histogram", name: Optional[str] = None, device=None, dtype=torch.float32, @@ -82,7 +87,7 @@ def __init__( ( torch.as_tensor(misalignment, **factory_kwargs) if misalignment is not None - else torch.tensor([(0.0, 0.0)], **factory_kwargs) + else torch.tensor((0.0, 0.0), **factory_kwargs) ), ) self.register_buffer( @@ -102,6 +107,7 @@ def __init__( else torch.clone(self.pixel_size[0]) ), ) + self.is_blocking = is_blocking self.is_active = is_active self.set_read_beam(None) @@ -163,33 +169,50 @@ def track(self, incoming: Beam) -> Beam: copy_of_incoming = deepcopy(incoming) if isinstance(incoming, ParameterBeam): + copy_of_incoming._mu, _ = torch.broadcast_tensors( + copy_of_incoming._mu, self.misalignment[..., 0] + ) + copy_of_incoming._mu = copy_of_incoming._mu.clone() + copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0] copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1] elif isinstance(incoming, ParticleBeam): - copy_of_incoming.particles[..., :, 0] -= self.misalignment[..., 0] - copy_of_incoming.particles[..., :, 1] -= self.misalignment[..., 1] + copy_of_incoming.particles, _ = torch.broadcast_tensors( + copy_of_incoming.particles, + self.misalignment[..., 0].unsqueeze(-1).unsqueeze(-1), + ) + copy_of_incoming.particles = copy_of_incoming.particles.clone() + + copy_of_incoming.particles[..., 0] -= self.misalignment[ + ..., 0 + ].unsqueeze(-1) + copy_of_incoming.particles[..., 1] -= self.misalignment[ + ..., 1 + ].unsqueeze(-1) self.set_read_beam(copy_of_incoming) - return Beam.empty - else: - return incoming + return Beam.empty if self.is_blocking else incoming @property def reading(self) -> torch.Tensor: + image = None if self.cached_reading is not None: return self.cached_reading read_beam = self.get_read_beam() if read_beam is Beam.empty or read_beam is None: image = torch.zeros( - ( - *self.misalignment.shape[:-1], - int(self.effective_resolution[1]), - int(self.effective_resolution[0]), - ) + (int(self.effective_resolution[1]), int(self.effective_resolution[0])) ) elif isinstance(read_beam, ParameterBeam): + if torch.numel(read_beam._mu[..., 0]) > 1: + raise NotImplementedError( + "`Screen` does not support vectorization of `ParameterBeam`. " + "Please use `ParticleBeam` instead. If this is a feature you would " + "like to see, please open an issue on GitHub." + ) + transverse_mu = torch.stack( [read_beam._mu[..., 0], read_beam._mu[..., 2]], dim=-1 ) @@ -204,14 +227,9 @@ def reading(self) -> torch.Tensor: ], dim=-1, ) - dist = [ - MultivariateNormal( - loc=transverse_mu_sample, covariance_matrix=transverse_cov_sample - ) - for transverse_mu_sample, transverse_cov_sample in zip( - transverse_mu.cpu(), transverse_cov.cpu() - ) - ] + dist = MultivariateNormal( + loc=transverse_mu, covariance_matrix=transverse_cov + ) left = self.extent[0] right = self.extent[1] @@ -225,30 +243,26 @@ def reading(self) -> torch.Tensor: indexing="ij", ) pos = torch.dstack((x, y)) - image = torch.stack( - [dist_sample.log_prob(pos).exp() for dist_sample in dist] - ) + image = dist.log_prob(pos).exp() image = torch.flip(image, dims=[1]) elif isinstance(read_beam, ParticleBeam): - if self.method == "histogram": - image = torch.zeros( - ( - *self.misalignment.shape[:-1], - int(self.effective_resolution[1]), - int(self.effective_resolution[0]), - ) - ) - - for i, (x_sample, y_sample) in enumerate(zip(read_beam.x, read_beam.y)): - image_sample, _ = torch.histogramdd( - torch.stack((x_sample, y_sample)).T.cpu(), - bins=self.pixel_bin_edges, + # Catch vectorisation, which is currently not supported by "histogram" + if ( + len(read_beam.particles.shape) > 2 + or len(read_beam.particle_charges.shape) > 1 + or len(read_beam.energy.shape) > 0 + ): + raise NotImplementedError( + "The `'histogram'` method of `Screen` does not support " + "vectorization. Use `'kde'` instead. If this is a feature you " + "would like to see, please open an issue on GitHub." ) - image_sample = torch.flipud(image_sample.T) - image_sample = image_sample.cpu() - image[i] = image_sample + image, _ = torch.histogramdd( + torch.stack((read_beam.x, read_beam.y)).T, bins=self.pixel_bin_edges + ) + image = torch.flipud(image.T) elif self.method == "kde": image = kde_histogram_2d( x1=read_beam.x, @@ -259,7 +273,7 @@ def reading(self) -> torch.Tensor: ) # Change the x, y positions image = torch.transpose(image, -2, -1) - # Flip up an down, now row 0 corresponds to the top + # Flip up and down, now row 0 corresponds to the top image = torch.flip(image, dims=[-2]) else: raise TypeError(f"Read beam is of invalid type {type(read_beam)}") @@ -271,29 +285,15 @@ def get_read_beam(self) -> Beam: # Using these get and set methods instead of Python's property decorator to # prevent `nn.Module` from intercepting the read beam, which is itself an # `nn.Module`, and registering it as a submodule of the screen. - return self._read_beam[0] if self._read_beam is not None else None + return self._read_beam def set_read_beam(self, value: Beam) -> None: # Using these get and set methods instead of Python's property decorator to # prevent `nn.Module` from intercepting the read beam, which is itself an # `nn.Module`, and registering it as a submodule of the screen. - self._read_beam = [value] + self._read_beam = value self.cached_reading = None - def broadcast(self, shape: Size) -> Element: - new_screen = self.__class__( - resolution=self.resolution, - pixel_size=self.pixel_size, - binning=self.binning, - misalignment=self.misalignment.repeat((*shape, 1)), - is_active=self.is_active, - name=self.name, - device=self.resolution.device, - dtype=self.resolution.dtype, - ) - new_screen.length = self.length.repeat(shape) - return new_screen - def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index d362a4b1..c2aaf65d 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -1,17 +1,17 @@ from copy import deepcopy +from functools import reduce from pathlib import Path from typing import Any, Optional, Union import matplotlib import matplotlib.pyplot as plt import torch -from torch import Size, nn - -from cheetah.converters import bmad, elegant, nxtables -from cheetah.latticejson import load_cheetah_model, save_cheetah_model -from cheetah.particles import Beam, ParticleBeam -from cheetah.utils import UniqueNameGenerator +from torch import nn +from ..converters import bmad, elegant, nxtables +from ..latticejson import load_cheetah_model, save_cheetah_model +from ..particles import Beam, ParticleBeam +from ..utils import UniqueNameGenerator from .custom_transfer_map import CustomTransferMap from .drift import Drift from .element import Element @@ -348,17 +348,12 @@ def is_skippable(self) -> bool: @property def length(self) -> torch.Tensor: - lengths = torch.stack( - [element.length for element in self.elements], - dim=1, - ) - return torch.sum(lengths, dim=1) + lengths = [element.length for element in self.elements] + return reduce(torch.add, lengths) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: if self.is_skippable: - tm = torch.eye(7, device=energy.device, dtype=energy.dtype).repeat( - (*self.length.shape, 1, 1) - ) + tm = torch.eye(7, device=energy.device, dtype=energy.dtype) for element in self.elements: tm = torch.matmul(element.transfer_map(energy), tm) return tm @@ -383,12 +378,6 @@ def track(self, incoming: Beam) -> Beam: return incoming - def broadcast(self, shape: Size) -> Element: - return self.__class__( - elements=[element.broadcast(shape) for element in self.elements], - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: return [ split_element diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 4d2c8770..2d89e208 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -4,11 +4,10 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn - -from cheetah.track_methods import misalignment_matrix -from cheetah.utils import UniqueNameGenerator +from torch import nn +from ..track_methods import misalignment_matrix +from ..utils import UniqueNameGenerator, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -64,19 +63,21 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - gamma = energy / electron_mass_eV + gamma, _, _ = compute_relativistic_factors(energy) c = torch.cos(self.length * self.k) s = torch.sin(self.length * self.k) - s_k = torch.empty_like(self.length) - s_k[self.k == 0] = self.length[self.k == 0] - s_k[self.k != 0] = s[self.k != 0] / self.k[self.k != 0] + s_k = torch.where(self.k == 0.0, self.length, s / self.k) + + vector_shape = torch.broadcast_shapes( + self.length.shape, self.k.shape, energy.shape + ) r56 = torch.where( gamma != 0, self.length / (1 - gamma**2), torch.zeros_like(self.length) ) - R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + R = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) R[..., 0, 0] = c**2 R[..., 0, 1] = c * s_k R[..., 0, 2] = s * c @@ -104,19 +105,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry) return R - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - k=self.k.repeat(shape), - misalignment=self.misalignment.repeat(shape), - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - @property def is_active(self) -> bool: - return any(self.k != 0) + return torch.any(self.k != 0) def is_skippable(self) -> bool: return True diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 2ab6628f..51c97465 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -5,8 +5,7 @@ from scipy.constants import elementary_charge, epsilon_0, speed_of_light from torch import nn -from cheetah.particles import Beam, ParticleBeam - +from ..particles import Beam, ParticleBeam from .element import Element @@ -441,8 +440,8 @@ def _compute_forces( ) -> torch.Tensor: """ Interpolates the space charge force from the grid onto the macroparticles. - Reciprocal function of _deposit_charge_on_grid. - Beam needs to have a flattened batch shape. + Reciprocal function of _deposit_charge_on_grid. `beam` needs to have a flattened + vector shape. """ grad_x, grad_y, grad_z = self._E_plus_vB_field( beam, xp_coordinates, cell_size, grid_dimensions @@ -506,9 +505,9 @@ def _compute_forces( # Keep dimensions, and set F to zero if non-valid force_indices = ( idx_vector, - torch.clamp(idx_x, max=grid_shape[0] - 1), - torch.clamp(idx_y, max=grid_shape[1] - 1), - torch.clamp(idx_tau, max=grid_shape[2] - 1), + torch.clamp(idx_x, min=0, max=grid_shape[0] - 1), + torch.clamp(idx_y, min=0, max=grid_shape[1] - 1), + torch.clamp(idx_tau, min=0, max=grid_shape[2] - 1), ) Fx_values = torch.where(valid_mask, grad_x[force_indices], 0) @@ -553,12 +552,31 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: elif isinstance(incoming, ParticleBeam): # This flattening is a hack to only think about one vector dimension in the # following code. It is reversed at the end of the function. + + # Make sure that the incoming beam has at least one vector dimension + if len(incoming.particles.shape) == 2: + is_incoming_vectorized = False + + vectorized_incoming = ParticleBeam( + particles=incoming.particles.unsqueeze(0), + energy=incoming.energy.unsqueeze(0), + particle_charges=incoming.particle_charges.unsqueeze(0), + device=incoming.particles.device, + dtype=incoming.particles.dtype, + ) + else: + is_incoming_vectorized = True + + vectorized_incoming = incoming + flattened_incoming = ParticleBeam( - particles=incoming.particles.flatten(end_dim=-3), - energy=incoming.energy.flatten(end_dim=-1), - particle_charges=incoming.particle_charges.flatten(end_dim=-2), - device=incoming.particles.device, - dtype=incoming.particles.dtype, + particles=vectorized_incoming.particles.flatten(end_dim=-3), + energy=vectorized_incoming.energy.flatten(end_dim=-1), + particle_charges=vectorized_incoming.particle_charges.flatten( + end_dim=-2 + ), + device=vectorized_incoming.particles.device, + dtype=vectorized_incoming.particles.dtype, ) flattened_length_effect = self.effect_length.flatten(end_dim=-1) @@ -591,40 +609,30 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ..., 2 ] * dt.unsqueeze(-1) - outgoing = ParticleBeam.from_xyz_pxpypz( - xp_coordinates.unflatten(dim=0, sizes=incoming.particles.shape[:-2]), - incoming.energy, - incoming.particle_charges, - incoming.particles.device, - incoming.particles.dtype, - ) - + if not is_incoming_vectorized: + # Reshape to the original non-vectorised shape + outgoing = ParticleBeam.from_xyz_pxpypz( + xp_coordinates.squeeze(0), + vectorized_incoming.energy.squeeze(0), + vectorized_incoming.particle_charges.squeeze(0), + vectorized_incoming.particles.device, + vectorized_incoming.particles.dtype, + ) + else: + # Reverse the flattening of the vector dimensions + outgoing = ParticleBeam.from_xyz_pxpypz( + xp_coordinates.unflatten( + dim=0, sizes=vectorized_incoming.particles.shape[:-2] + ), + vectorized_incoming.energy, + vectorized_incoming.particle_charges, + vectorized_incoming.particles.device, + vectorized_incoming.particles.dtype, + ) return outgoing else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") - def broadcast(self, shape: torch.Size) -> "SpaceChargeKick": - """ - Broadcast the element to higher batch dimensions. - - :param shape: Shape to broadcast the element to. - :returns: Broadcasted element. - """ - new_space_charge_kick = self.__class__( - effect_length=self.effect_length, - num_grid_points_x=self.grid_shape[0], - num_grid_points_y=self.grid_shape[1], - num_grid_points_tau=self.grid_shape[2], - grid_extend_x=self.grid_extend_x, - grid_extend_y=self.grid_extend_y, - grid_extend_tau=self.grid_extend_tau, - name=self.name, - device=self.effect_length.device, - dtype=self.effect_length.dtype, - ) - new_space_charge_kick.length = self.length.repeat(shape) - return new_space_charge_kick - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for SpaceCharge properly, for now just returns the # element itself diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index a4542b2c..5ac9d6b3 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -4,7 +4,7 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants, speed_of_light -from torch import Size, nn +from torch import nn from cheetah.particles import Beam, ParticleBeam from cheetah.utils import UniqueNameGenerator, bmadx @@ -170,16 +170,24 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) ) - px = px + voltage * torch.sin(phase) + # TODO: Assigning px to px is really bad practice and should be separated into + # two separate variables + px = px + voltage.unsqueeze(-1) * torch.sin(phase) - beta = (1 + pz) * p0c / torch.sqrt(((1 + pz) * p0c) ** 2 + electron_mass_eV**2) + beta = ( + (1 + pz) + * p0c.unsqueeze(-1) + / torch.sqrt(((1 + pz) * p0c.unsqueeze(-1)) ** 2 + electron_mass_eV**2) + ) beta_old = beta - E_old = (1 + pz) * p0c / beta_old - E_new = E_old + voltage * torch.cos(phase) * k_rf * x * p0c + E_old = (1 + pz) * p0c.unsqueeze(-1) / beta_old + E_new = E_old + voltage.unsqueeze(-1) * torch.cos( + phase + ) * k_rf * x * p0c.unsqueeze(-1) pc = torch.sqrt(E_new**2 - electron_mass_eV**2) beta = pc / E_new - pz = (pc - p0c) / p0c + pz = (pc - p0c.unsqueeze(-1)) / p0c.unsqueeze(-1) z = z * beta / beta_old x, y, z = bmadx.track_a_drift( @@ -205,20 +213,6 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) return outgoing_beam - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - voltage=self.voltage.repeat(shape), - phase=self.phase.repeat(shape), - frequency=self.frequency.repeat(shape), - misalignment=self.misalignment.repeat((*shape, 1)), - tilt=self.tilt.repeat(shape), - tracking_method=self.tracking_method, - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for cavity properly, for now just returns the # element itself diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index dd6a6256..c4c72e2c 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -4,10 +4,9 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn - -from cheetah.utils import UniqueNameGenerator +from torch import nn +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -48,21 +47,15 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: gamma = energy / electron_mass_eV igamma2 = torch.where(gamma != 0, 1 / gamma**2, torch.zeros_like(gamma)) - tm = torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1)) + vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape) + + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length tm[..., 4, 5] = self.length * igamma2 return tm - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - is_active=self.is_active, - name=self.name, - device=self.length.device, - ) - @property def is_skippable(self) -> bool: return True diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 9ce717ec..bd78e367 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -5,10 +5,9 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn - -from cheetah.utils import UniqueNameGenerator +from torch import nn +from ..utils import UniqueNameGenerator, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -52,26 +51,17 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - gamma = energy / electron_mass_eV - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - beta = torch.sqrt(1 - igamma2) + _, igamma2, beta = compute_relativistic_factors(energy) + + vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape) - tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length tm[..., 3, 6] = self.angle tm[..., 4, 5] = -self.length / beta**2 * igamma2 - return tm - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - angle=self.angle, - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) + return tm @property def is_skippable(self) -> bool: @@ -79,7 +69,7 @@ def is_skippable(self) -> bool: @property def is_active(self) -> bool: - return any(self.angle != 0) + return torch.any(self.angle != 0) def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() diff --git a/cheetah/converters/__init__.py b/cheetah/converters/__init__.py index 154a9097..e64f8749 100644 --- a/cheetah/converters/__init__.py +++ b/cheetah/converters/__init__.py @@ -1,2 +1 @@ -# flake8: noqa -from cheetah.converters import astra, bmad, elegant, nxtables, ocelot +from . import astra, bmad, elegant, nxtables, ocelot # noqa: F401 diff --git a/cheetah/converters/bmad.py b/cheetah/converters/bmad.py index 7bb249a2..4b76804b 100644 --- a/cheetah/converters/bmad.py +++ b/cheetah/converters/bmad.py @@ -61,7 +61,7 @@ def convert_element( ) if "l" in bmad_parsed: return cheetah.Drift( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -74,7 +74,7 @@ def convert_element( ) if "l" in bmad_parsed: return cheetah.Drift( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -86,7 +86,7 @@ def convert_element( ["element_type", "alias", "type", "l", "descrip"], bmad_parsed ) return cheetah.Drift( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -96,7 +96,7 @@ def convert_element( ["element_type", "l", "type", "descrip"], bmad_parsed ) return cheetah.Drift( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -106,8 +106,8 @@ def convert_element( ["element_type", "type", "alias"], bmad_parsed ) return cheetah.HorizontalCorrector( - length=torch.tensor([bmad_parsed.get("l", 0.0)]), - angle=torch.tensor([bmad_parsed.get("kick", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), + angle=torch.tensor(bmad_parsed.get("kick", 0.0)), name=name, device=device, dtype=dtype, @@ -117,8 +117,8 @@ def convert_element( ["element_type", "type", "alias"], bmad_parsed ) return cheetah.VerticalCorrector( - length=torch.tensor([bmad_parsed.get("l", 0.0)]), - angle=torch.tensor([bmad_parsed.get("kick", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), + angle=torch.tensor(bmad_parsed.get("kick", 0.0)), name=name, device=device, dtype=dtype, @@ -144,15 +144,15 @@ def convert_element( bmad_parsed, ) return cheetah.Dipole( - length=torch.tensor([bmad_parsed["l"]]), - gap=torch.tensor([bmad_parsed.get("hgap", 0.0)]), - angle=torch.tensor([bmad_parsed.get("angle", 0.0)]), - e1=torch.tensor([bmad_parsed["e1"]]), - e2=torch.tensor([bmad_parsed.get("e2", 0.0)]), - tilt=torch.tensor([bmad_parsed.get("ref_tilt", 0.0)]), - fringe_integral=torch.tensor([bmad_parsed.get("fint", 0.0)]), + length=torch.tensor(bmad_parsed["l"]), + gap=torch.tensor(bmad_parsed.get("hgap", 0.0)), + angle=torch.tensor(bmad_parsed.get("angle", 0.0)), + e1=torch.tensor(bmad_parsed["e1"]), + e2=torch.tensor(bmad_parsed.get("e2", 0.0)), + tilt=torch.tensor(bmad_parsed.get("ref_tilt", 0.0)), + fringe_integral=torch.tensor(bmad_parsed.get("fint", 0.0)), fringe_integral_exit=( - torch.tensor([bmad_parsed["fintx"]]) + torch.tensor(bmad_parsed["fintx"]) if "fintx" in bmad_parsed else None ), @@ -167,9 +167,9 @@ def convert_element( bmad_parsed, ) return cheetah.Quadrupole( - length=torch.tensor([bmad_parsed["l"]]), - k1=torch.tensor([bmad_parsed["k1"]]), - tilt=torch.tensor([bmad_parsed.get("tilt", 0.0)]), + length=torch.tensor(bmad_parsed["l"]), + k1=torch.tensor(bmad_parsed["k1"]), + tilt=torch.tensor(bmad_parsed.get("tilt", 0.0)), name=name, device=device, dtype=dtype, @@ -179,8 +179,8 @@ def convert_element( ["element_type", "l", "ks", "alias"], bmad_parsed ) return cheetah.Solenoid( - length=torch.tensor([bmad_parsed["l"]]), - k=torch.tensor([bmad_parsed["ks"]]), + length=torch.tensor(bmad_parsed["l"]), + k=torch.tensor(bmad_parsed["ks"]), name=name, device=device, dtype=dtype, @@ -201,12 +201,12 @@ def convert_element( bmad_parsed, ) return cheetah.Cavity( - length=torch.tensor([bmad_parsed["l"]]), - voltage=torch.tensor([bmad_parsed.get("voltage", 0.0)]), + length=torch.tensor(bmad_parsed["l"]), + voltage=torch.tensor(bmad_parsed.get("voltage", 0.0)), phase=torch.tensor( - [-np.degrees(bmad_parsed.get("phi0", 0.0) * 2 * np.pi)] + -np.degrees(bmad_parsed.get("phi0", 0.0) * 2 * np.pi) ), - frequency=torch.tensor([bmad_parsed["rf_frequency"]]), + frequency=torch.tensor(bmad_parsed["rf_frequency"]), name=name, device=device, dtype=dtype, @@ -219,14 +219,14 @@ def convert_element( return cheetah.Segment( elements=[ cheetah.Drift( - length=torch.tensor([bmad_parsed.get("l", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), name=name + "_drift", device=device, dtype=dtype, ), cheetah.Aperture( - x_max=torch.tensor([bmad_parsed.get("x_limit", np.inf)]), - y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]), + x_max=torch.tensor(bmad_parsed.get("x_limit", np.inf)), + y_max=torch.tensor(bmad_parsed.get("y_limit", np.inf)), shape="rectangular", name=name + "_aperture", device=device, @@ -243,14 +243,14 @@ def convert_element( return cheetah.Segment( elements=[ cheetah.Drift( - length=torch.tensor([bmad_parsed.get("l", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), name=name + "_drift", device=device, dtype=dtype, ), cheetah.Aperture( - x_max=torch.tensor([bmad_parsed.get("x_limit", np.inf)]), - y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]), + x_max=torch.tensor(bmad_parsed.get("x_limit", np.inf)), + y_max=torch.tensor(bmad_parsed.get("y_limit", np.inf)), shape="elliptical", name=name + "_aperture", device=device, @@ -274,7 +274,7 @@ def convert_element( bmad_parsed, ) return cheetah.Undulator( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -283,7 +283,7 @@ def convert_element( # TODO: Does this need to be implemented in Cheetah in a more proper way? validate_understood_properties(["element_type", "tilt"], bmad_parsed) return cheetah.Drift( - length=torch.tensor([bmad_parsed.get("l", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), name=name, device=device, dtype=dtype, @@ -296,7 +296,7 @@ def convert_element( # TODO: Remove the length if by adding markers to Cheeath return cheetah.Drift( name=name, - length=torch.tensor([bmad_parsed.get("l", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), device=device, dtype=dtype, ) diff --git a/cheetah/converters/nxtables.py b/cheetah/converters/nxtables.py index d449a023..03af4e93 100644 --- a/cheetah/converters/nxtables.py +++ b/cheetah/converters/nxtables.py @@ -53,10 +53,10 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]: elif class_name == "MCXG": # TODO: Check length with Willi assert name[6] == "X" horizontal_coil = cheetah.HorizontalCorrector( - name=name[:6] + "H" + name[6 + 1 :], length=torch.tensor([5e-05]) + name=name[:6] + "H" + name[6 + 1 :], length=torch.tensor(5e-05) ) vertical_coil = cheetah.VerticalCorrector( - name=name[:6] + "V" + name[6 + 1 :], length=torch.tensor([5e-05]) + name=name[:6] + "V" + name[6 + 1 :], length=torch.tensor(5e-05) ) element = cheetah.Segment(elements=[horizontal_coil, vertical_coil], name=name) elif class_name == "BSCX": @@ -115,57 +115,57 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]: elif class_name == "SLHG": element = cheetah.Aperture( # TODO: Ask for actual size and shape name=name, - x_max=torch.tensor([float("inf")]), - y_max=torch.tensor([float("inf")]), + x_max=torch.tensor(float("inf")), + y_max=torch.tensor(float("inf")), shape="elliptical", ) elif class_name == "SLHB": element = cheetah.Aperture( # TODO: Ask for actual size and shape name=name, - x_max=torch.tensor([float("inf")]), - y_max=torch.tensor([float("inf")]), + x_max=torch.tensor(float("inf")), + y_max=torch.tensor(float("inf")), shape="rectangular", ) elif class_name == "SLHS": element = cheetah.Aperture( # TODO: Ask for actual size and shape name=name, - x_max=torch.tensor([float("inf")]), - y_max=torch.tensor([float("inf")]), + x_max=torch.tensor(float("inf")), + y_max=torch.tensor(float("inf")), shape="rectangular", ) elif class_name == "MCHM": - element = cheetah.HorizontalCorrector(name=name, length=torch.tensor([0.02])) + element = cheetah.HorizontalCorrector(name=name, length=torch.tensor(0.02)) elif class_name == "MCVM": - element = cheetah.VerticalCorrector(name=name, length=torch.tensor([0.02])) + element = cheetah.VerticalCorrector(name=name, length=torch.tensor(0.02)) elif class_name == "MBHL": - element = cheetah.Dipole(name=name, length=torch.tensor([0.322])) + element = cheetah.Dipole(name=name, length=torch.tensor(0.322)) elif class_name == "MBHB": - element = cheetah.Dipole(name=name, length=torch.tensor([0.22])) + element = cheetah.Dipole(name=name, length=torch.tensor(0.22)) elif class_name == "MBHO": element = cheetah.Dipole( name=name, - length=torch.tensor([0.43852543421396856]), - angle=torch.tensor([0.8203047484373349]), - e2=torch.tensor([-0.7504915783575616]), + length=torch.tensor(0.43852543421396856), + angle=torch.tensor(0.8203047484373349), + e2=torch.tensor(-0.7504915783575616), ) elif class_name == "MQZM": - element = cheetah.Quadrupole(name=name, length=torch.tensor([0.122])) + element = cheetah.Quadrupole(name=name, length=torch.tensor(0.122)) elif class_name == "RSBL": element = cheetah.Cavity( name=name, - length=torch.tensor([4.139]), - frequency=torch.tensor([2.998e9]), - voltage=torch.tensor([76e6]), + length=torch.tensor(4.139), + frequency=torch.tensor(2.998e9), + voltage=torch.tensor(76e6), ) elif class_name == "RXBD": element = cheetah.Cavity( # TODO: TD? and tilt? name=name, - length=torch.tensor([1.0]), - frequency=torch.tensor([11.9952e9]), - voltage=torch.tensor([0.0]), + length=torch.tensor(1.0), + frequency=torch.tensor(11.9952e9), + voltage=torch.tensor(0.0), ) elif class_name == "UNDA": # TODO: Figure out actual length - element = cheetah.Undulator(name=name, length=torch.tensor([0.25])) + element = cheetah.Undulator(name=name, length=torch.tensor(0.25)) elif class_name in [ "SOLG", "BCMG", diff --git a/cheetah/converters/ocelot.py b/cheetah/converters/ocelot.py index 83b26a53..27ea9221 100644 --- a/cheetah/converters/ocelot.py +++ b/cheetah/converters/ocelot.py @@ -31,91 +31,91 @@ def convert_element_to_cheetah( if isinstance(element, ocelot.Drift): return cheetah.Drift( - length=torch.tensor([element.l], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Quadrupole): return cheetah.Quadrupole( - length=torch.tensor([element.l], dtype=torch.float32), - k1=torch.tensor([element.k1], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + k1=torch.tensor(element.k1, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Solenoid): return cheetah.Solenoid( - length=torch.tensor([element.l], dtype=torch.float32), - k=torch.tensor([element.k], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + k=torch.tensor(element.k, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Hcor): return cheetah.HorizontalCorrector( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Vcor): return cheetah.VerticalCorrector( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Bend): return cheetah.Dipole( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), - e1=torch.tensor([element.e1], dtype=torch.float32), - e2=torch.tensor([element.e2], dtype=torch.float32), - tilt=torch.tensor([element.tilt], dtype=torch.float32), - fringe_integral=torch.tensor([element.fint], dtype=torch.float32), - fringe_integral_exit=torch.tensor([element.fintx], dtype=torch.float32), - gap=torch.tensor([element.gap], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), + e1=torch.tensor(element.e1, dtype=torch.float32), + e2=torch.tensor(element.e2, dtype=torch.float32), + tilt=torch.tensor(element.tilt, dtype=torch.float32), + fringe_integral=torch.tensor(element.fint, dtype=torch.float32), + fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), + gap=torch.tensor(element.gap, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.SBend): return cheetah.Dipole( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), - e1=torch.tensor([element.e1], dtype=torch.float32), - e2=torch.tensor([element.e2], dtype=torch.float32), - tilt=torch.tensor([element.tilt], dtype=torch.float32), - fringe_integral=torch.tensor([element.fint], dtype=torch.float32), - fringe_integral_exit=torch.tensor([element.fintx], dtype=torch.float32), - gap=torch.tensor([element.gap], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), + e1=torch.tensor(element.e1, dtype=torch.float32), + e2=torch.tensor(element.e2, dtype=torch.float32), + tilt=torch.tensor(element.tilt, dtype=torch.float32), + fringe_integral=torch.tensor(element.fint, dtype=torch.float32), + fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), + gap=torch.tensor(element.gap, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.RBend): return cheetah.RBend( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), - e1=torch.tensor([element.e1], dtype=torch.float32) - element.angle / 2, - e2=torch.tensor([element.e2], dtype=torch.float32) - element.angle / 2, - tilt=torch.tensor([element.tilt], dtype=torch.float32), - fringe_integral=torch.tensor([element.fint], dtype=torch.float32), - fringe_integral_exit=torch.tensor([element.fintx], dtype=torch.float32), - gap=torch.tensor([element.gap], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), + e1=torch.tensor(element.e1, dtype=torch.float32) - element.angle / 2, + e2=torch.tensor(element.e2, dtype=torch.float32) - element.angle / 2, + tilt=torch.tensor(element.tilt, dtype=torch.float32), + fringe_integral=torch.tensor(element.fint, dtype=torch.float32), + fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), + gap=torch.tensor(element.gap, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Cavity): return cheetah.Cavity( - length=torch.tensor([element.l], dtype=torch.float32), - voltage=torch.tensor([element.v], dtype=torch.float32) * 1e9, - frequency=torch.tensor([element.freq], dtype=torch.float32), - phase=torch.tensor([element.phi], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9, + frequency=torch.tensor(element.freq, dtype=torch.float32), + phase=torch.tensor(element.phi, dtype=torch.float32), name=element.id, device=device, dtype=dtype, @@ -123,10 +123,10 @@ def convert_element_to_cheetah( elif isinstance(element, ocelot.TDCavity): # TODO: Better replacement at some point? return cheetah.Cavity( - length=torch.tensor([element.l], dtype=torch.float32), - voltage=torch.tensor([element.v], dtype=torch.float32) * 1e9, - frequency=torch.tensor([element.freq], dtype=torch.float32), - phase=torch.tensor([element.phi], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9, + frequency=torch.tensor(element.freq, dtype=torch.float32), + phase=torch.tensor(element.phi, dtype=torch.float32), name=element.id, device=device, dtype=dtype, @@ -162,8 +162,8 @@ def convert_element_to_cheetah( elif isinstance(element, ocelot.Aperture): shape_translation = {"rect": "rectangular", "elip": "elliptical"} return cheetah.Aperture( - x_max=torch.tensor([element.xmax], dtype=torch.float32), - y_max=torch.tensor([element.ymax], dtype=torch.float32), + x_max=torch.tensor(element.xmax, dtype=torch.float32), + y_max=torch.tensor(element.ymax, dtype=torch.float32), shape=shape_translation[element.type], is_active=True, name=element.id, @@ -177,7 +177,7 @@ def convert_element_to_cheetah( " replacing with drift section." ) return cheetah.Drift( - length=torch.tensor([element.l], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), name=element.id, device=device, dtype=dtype, diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index a8755e3e..3b1601b7 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -4,9 +4,7 @@ from scipy.constants import physical_constants from torch import nn -electron_mass_eV = torch.tensor( - physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 -) +electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 class Beam(nn.Module): @@ -163,8 +161,8 @@ def transformed_to( :param energy: Reference energy of the beam in eV. :param total_charge: Total charge of the beam in C. """ - # Figure out batch size of the original beam and check that passed arguments - # have the same batch size + # Figure out vector dimensions of the original beam and check that passed + # arguments have the same vector dimensions. shape = self.mu_x.shape not_nones = [ argument @@ -359,10 +357,6 @@ def alpha_y(self) -> torch.Tensor: """Alpha function in y direction, dimensionless.""" return -self.sigma_ypy / self.emittance_y - def broadcast(self, shape: torch.Size) -> "Beam": - """Broadcast beam to new shape.""" - raise NotImplementedError - def __repr__(self) -> str: return ( f"{self.__class__.__name__}(mu_x={self.mu_x}, mu_px={self.mu_px}," diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index 336e43e9..ccebd331 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -63,67 +63,59 @@ def from_parameters( device=None, dtype=torch.float32, ) -> "ParameterBeam": - # Figure out if arguments were passed, figure out their shape - not_nones = [ - argument - for argument in [ - mu_x, - mu_px, - mu_y, - mu_py, - sigma_x, - sigma_px, - sigma_y, - sigma_py, - sigma_tau, - sigma_p, - cor_x, - cor_y, - cor_tau, - energy, - total_charge, - ] - if argument is not None - ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." - # Set default values without function call in function signature - mu_x = mu_x if mu_x is not None else torch.full(shape, 0.0) - mu_px = mu_px if mu_px is not None else torch.full(shape, 0.0) - mu_y = mu_y if mu_y is not None else torch.full(shape, 0.0) - mu_py = mu_py if mu_py is not None else torch.full(shape, 0.0) - sigma_x = sigma_x if sigma_x is not None else torch.full(shape, 175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.full(shape, 2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.full(shape, 175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.full(shape, 2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.full(shape, 1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.full(shape, 1e-6) - cor_x = cor_x if cor_x is not None else torch.full(shape, 0.0) - cor_y = cor_y if cor_y is not None else torch.full(shape, 0.0) - cor_tau = cor_tau if cor_tau is not None else torch.full(shape, 0.0) - energy = energy if energy is not None else torch.full(shape, 1e8) - total_charge = ( - total_charge if total_charge is not None else torch.full(shape, 0.0) - ) - + mu_x = mu_x if mu_x is not None else torch.tensor(0.0) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0) + sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) + sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) + sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) + sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) + sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) + sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) + cor_x = cor_x if cor_x is not None else torch.tensor(0.0) + cor_y = cor_y if cor_y is not None else torch.tensor(0.0) + cor_tau = cor_tau if cor_tau is not None else torch.tensor(0.0) + energy = energy if energy is not None else torch.tensor(1e8) + total_charge = total_charge if total_charge is not None else torch.tensor(0.0) + + mu_x, mu_px, mu_y, mu_py = torch.broadcast_tensors(mu_x, mu_px, mu_y, mu_py) mu = torch.stack( [ mu_x, mu_px, mu_y, mu_py, - torch.full(shape, 0.0), - torch.full(shape, 0.0), - torch.full(shape, 1.0), + torch.zeros_like(mu_x), + torch.zeros_like(mu_x), + torch.ones_like(mu_x), ], dim=-1, ) - cov = torch.zeros(*shape, 7, 7) + ( + sigma_x, + cor_x, + sigma_px, + sigma_y, + cor_y, + sigma_py, + sigma_tau, + cor_tau, + sigma_p, + ) = torch.broadcast_tensors( + sigma_x, + cor_x, + sigma_px, + sigma_y, + cor_y, + sigma_py, + sigma_tau, + cor_tau, + sigma_p, + ) + cov = torch.zeros(*sigma_x.shape, 7, 7) cov[..., 0, 0] = sigma_x**2 cov[..., 0, 1] = cor_x cov[..., 1, 0] = cor_x @@ -206,10 +198,10 @@ def from_twiss( total_charge if total_charge is not None else torch.full(shape, 0.0) ) - assert all( + assert torch.all( beta_x > 0 ), "Beta function in x direction must be larger than 0 everywhere." - assert all( + assert torch.all( beta_y > 0 ), "Beta function in y direction must be larger than 0 everywhere." @@ -271,10 +263,10 @@ def from_astra(cls, path: str, device=None, dtype=torch.float32) -> "ParameterBe total_charge = torch.tensor(np.sum(particle_charges), dtype=torch.float32) return cls( - mu=mu.unsqueeze(0), - cov=cov.unsqueeze(0), - energy=torch.tensor(energy, dtype=torch.float32).unsqueeze(0), - total_charge=total_charge.unsqueeze(0), + mu=mu, + cov=cov, + energy=torch.tensor(energy, dtype=torch.float32), + total_charge=total_charge, device=device, dtype=dtype, ) @@ -318,32 +310,6 @@ def transformed_to( device = device if device is not None else self.mu_x.device dtype = dtype if dtype is not None else self.mu_x.dtype - # Figure out batch size of the original beam and check that passed arguments - # have the same batch size - shape = self.mu_x.shape - not_nones = [ - argument - for argument in [ - mu_x, - mu_px, - mu_y, - mu_py, - sigma_x, - sigma_px, - sigma_y, - sigma_py, - sigma_tau, - sigma_p, - energy, - total_charge, - ] - if argument is not None - ] - if len(not_nones) > 0: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." - mu_x = mu_x if mu_x is not None else self.mu_x mu_px = mu_px if mu_px is not None else self.mu_px mu_y = mu_y if mu_y is not None else self.mu_y @@ -430,16 +396,6 @@ def sigma_xpx(self) -> torch.Tensor: def sigma_ypy(self) -> torch.Tensor: return self._cov[..., 2, 3] - def broadcast(self, shape: torch.Size) -> "ParameterBeam": - return self.__class__( - mu=self._mu.repeat((*shape, 1)), - cov=self._cov.repeat((*shape, 1, 1)), - energy=self.energy.repeat(shape), - total_charge=self.total_charge.repeat(shape), - device=self._mu.device, - dtype=self._mu.dtype, - ) - def __repr__(self) -> str: return ( f"{self.__class__.__name__}(mu_x={repr(self.mu_x)}," diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 28cf8bde..afe24238 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -2,14 +2,15 @@ import torch from scipy import constants +from scipy.constants import physical_constants from torch.distributions import MultivariateNormal from .beam import Beam speed_of_light = torch.tensor(constants.speed_of_light) # In m/s electron_mass = torch.tensor(constants.electron_mass) # In kg -electron_mass_eV = torch.tensor( - constants.physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 +electron_mass_eV = ( + physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 ) # In eV @@ -98,67 +99,60 @@ def from_parameters( :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. """ - # Figure out if arguments were passed, figure out their shape - not_nones = [ - argument - for argument in [ - mu_x, - mu_px, - mu_y, - mu_py, - sigma_x, - sigma_px, - sigma_y, - sigma_py, - sigma_tau, - sigma_p, - cor_x, - cor_y, - cor_tau, - energy, - total_charge, - ] - if argument is not None - ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." # Set default values without function call in function signature num_particles = ( num_particles if num_particles is not None else torch.tensor(100_000) ) - mu_x = mu_x if mu_x is not None else torch.full(shape, 0.0) - mu_px = mu_px if mu_px is not None else torch.full(shape, 0.0) - mu_y = mu_y if mu_y is not None else torch.full(shape, 0.0) - mu_py = mu_py if mu_py is not None else torch.full(shape, 0.0) - sigma_x = sigma_x if sigma_x is not None else torch.full(shape, 175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.full(shape, 2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.full(shape, 175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.full(shape, 2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.full(shape, 1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.full(shape, 1e-6) - cor_x = cor_x if cor_x is not None else torch.full(shape, 0.0) - cor_y = cor_y if cor_y is not None else torch.full(shape, 0.0) - cor_tau = cor_tau if cor_tau is not None else torch.full(shape, 0.0) - energy = energy if energy is not None else torch.full(shape, 1e8) - total_charge = ( - total_charge if total_charge is not None else torch.full(shape, 0.0) - ) + mu_x = mu_x if mu_x is not None else torch.tensor(0.0) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0) + sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) + sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) + sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) + sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) + sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) + sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) + cor_x = cor_x if cor_x is not None else torch.tensor(0.0) + cor_y = cor_y if cor_y is not None else torch.tensor(0.0) + cor_tau = cor_tau if cor_tau is not None else torch.tensor(0.0) + energy = energy if energy is not None else torch.tensor(1e8) + total_charge = total_charge if total_charge is not None else torch.tensor(0.0) particle_charges = ( - torch.ones((*shape, num_particles), device=device, dtype=dtype) + torch.ones((*total_charge.shape, num_particles)) * total_charge.unsqueeze(-1) / num_particles ) + mu_x, mu_px, mu_y, mu_py = torch.broadcast_tensors(mu_x, mu_px, mu_y, mu_py) mean = torch.stack( - [mu_x, mu_px, mu_y, mu_py, torch.full(shape, 0.0), torch.full(shape, 0.0)], + [mu_x, mu_px, mu_y, mu_py, torch.zeros_like(mu_x), torch.zeros_like(mu_x)], dim=-1, ) - cov = torch.zeros(*shape, 6, 6) + ( + sigma_x, + cor_x, + sigma_px, + sigma_y, + cor_y, + sigma_py, + sigma_tau, + cor_tau, + sigma_p, + ) = torch.broadcast_tensors( + sigma_x, + cor_x, + sigma_px, + sigma_y, + cor_y, + sigma_py, + sigma_tau, + cor_tau, + sigma_p, + ) + cov = torch.zeros(*sigma_x.shape, 6, 6) cov[..., 0, 0] = sigma_x**2 cov[..., 0, 1] = cor_x cov[..., 1, 0] = cor_x @@ -172,7 +166,7 @@ def from_parameters( cov[..., 5, 4] = cor_tau cov[..., 5, 5] = sigma_p**2 - particles = torch.ones((*shape, num_particles, 7)) + particles = torch.ones((*mean.shape[:-1], num_particles, 7)) distributions = [ MultivariateNormal(sample_mean, covariance_matrix=sample_cov) for sample_mean, sample_cov in zip(mean.view(-1, 6), cov.view(-1, 6, 6)) @@ -180,7 +174,7 @@ def from_parameters( particles[..., :6] = torch.stack( [distribution.sample((num_particles,)) for distribution in distributions], dim=0, - ).view(*shape, num_particles, 6) + ).view(*particles.shape[:-2], num_particles, 6) return cls( particles, @@ -300,7 +294,7 @@ def uniform_3d_ellipsoid( Note that: - The generated particles do not have correlation in the momentum directions, and by default a cold beam with no divergence is generated. - - For batched generation, parameters that are not `None` must have the same + - For vectorised generation, parameters that are not `None` must have the same shape. :param num_particles: Number of particles to generate. @@ -336,26 +330,41 @@ def uniform_3d_ellipsoid( ] if argument is not None ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) + shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([]) if len(not_nones) > 1: assert all( argument.shape == shape for argument in not_nones ), "Arguments must have the same shape." + # Expand to vectorised version for beam creation + vector_shape = shape if len(shape) > 0 else torch.Size([1]) + # Set default values without function call in function signature # NOTE that this does not need to be done for values that are passed to the # Gaussian beam generation. num_particles = ( num_particles if num_particles is not None else torch.tensor(1_000_000) ) - radius_x = radius_x if radius_x is not None else torch.full(shape, 1e-3) - radius_y = radius_y if radius_y is not None else torch.full(shape, 1e-3) - radius_tau = radius_tau if radius_tau is not None else torch.full(shape, 1e-3) + radius_x = ( + radius_x.expand(vector_shape) + if radius_x is not None + else torch.full(vector_shape, 1e-3) + ) + radius_y = ( + radius_y.expand(vector_shape) + if radius_y is not None + else torch.full(vector_shape, 1e-3) + ) + radius_tau = ( + radius_tau.expand(vector_shape) + if radius_tau is not None + else torch.full(vector_shape, 1e-3) + ) # Generate x, y and ss within the ellipsoid - flattened_x = torch.empty(*shape, num_particles).flatten(end_dim=-2) - flattened_y = torch.empty(*shape, num_particles).flatten(end_dim=-2) - flattened_tau = torch.empty(*shape, num_particles).flatten(end_dim=-2) + flattened_x = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) + flattened_y = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) + flattened_tau = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) for i, (r_x, r_y, r_tau) in enumerate( zip(radius_x.flatten(), radius_y.flatten(), radius_tau.flatten()) ): @@ -590,9 +599,9 @@ def from_astra(cls, path: str, device=None, dtype=torch.float32) -> "ParticleBea particles_7d[:, :6] = torch.from_numpy(particles) particle_charges = torch.from_numpy(particle_charges) return cls( - particles=particles_7d.unsqueeze(0), - energy=torch.tensor(energy).unsqueeze(0), - particle_charges=particle_charges.unsqueeze(0), + particles=particles_7d, + energy=torch.tensor(energy), + particle_charges=particle_charges, device=device, dtype=dtype, ) @@ -639,32 +648,6 @@ def transformed_to( device = device if device is not None else self.mu_x.device dtype = dtype if dtype is not None else self.mu_x.dtype - # Figure out batch size of the original beam and check that passed arguments - # have the same batch size - shape = self.mu_x.shape - not_nones = [ - argument - for argument in [ - mu_x, - mu_px, - mu_y, - mu_py, - sigma_x, - sigma_px, - sigma_y, - sigma_py, - sigma_tau, - sigma_p, - energy, - total_charge, - ] - if argument is not None - ] - if len(not_nones) > 0: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." - mu_x = mu_x if mu_x is not None else self.mu_x mu_y = mu_y if mu_y is not None else self.mu_y mu_px = mu_px if mu_px is not None else self.mu_px @@ -690,12 +673,18 @@ def transformed_to( / self.particle_charges.shape[-1] ) + mu_x, mu_px, mu_y, mu_py = torch.broadcast_tensors(mu_x, mu_px, mu_y, mu_py) new_mu = torch.stack( - [mu_x, mu_px, mu_y, mu_py, torch.full(shape, 0.0), torch.full(shape, 0.0)], - dim=1, + [mu_x, mu_px, mu_y, mu_py, torch.zeros_like(mu_x), torch.zeros_like(mu_x)], + dim=-1, + ) + sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p = ( + torch.broadcast_tensors( + sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p + ) ) new_sigma = torch.stack( - [sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p], dim=1 + [sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p], dim=-1 ) old_mu = torch.stack( @@ -704,10 +693,10 @@ def transformed_to( self.mu_px, self.mu_y, self.mu_py, - torch.full(shape, 0.0), - torch.full(shape, 0.0), + torch.zeros_like(self.mu_x), + torch.zeros_like(self.mu_x), ], - dim=1, + dim=-1, ) old_sigma = torch.stack( [ @@ -718,16 +707,19 @@ def transformed_to( self.sigma_tau, self.sigma_p, ], - dim=1, + dim=-1, ) - phase_space = self.particles[:, :, :6] - phase_space = (phase_space - old_mu.unsqueeze(1)) / old_sigma.unsqueeze( - 1 - ) * new_sigma.unsqueeze(1) + new_mu.unsqueeze(1) + phase_space = self.particles[..., :6] + phase_space = ( + (phase_space.transpose(-2, -1) - old_mu.unsqueeze(-1)) + / old_sigma.unsqueeze(-1) + * new_sigma.unsqueeze(-1) + + new_mu.unsqueeze(-1) + ).transpose(-2, -1) - particles = torch.ones_like(self.particles) - particles[:, :, :6] = phase_space + particles = torch.ones(*phase_space.shape[:-1], 7) + particles[..., :6] = phase_space return self.__class__( particles=particles, @@ -740,7 +732,7 @@ def transformed_to( @classmethod def from_xyz_pxpypz( cls, - xp_coords: torch.Tensor, + xp_coordinates: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, device=None, @@ -752,7 +744,7 @@ def from_xyz_pxpypz( is the moment vector $(x, p_x, y, p_y, z, p_z, 1)$. """ beam = cls( - particles=xp_coords.clone(), + particles=xp_coordinates.clone(), energy=energy, particle_charges=particle_charges, device=device, @@ -766,15 +758,17 @@ def from_xyz_pxpypz( * speed_of_light ) p = torch.sqrt( - xp_coords[..., 1] ** 2 + xp_coords[..., 3] ** 2 + xp_coords[..., 5] ** 2 + xp_coordinates[..., 1] ** 2 + + xp_coordinates[..., 3] ** 2 + + xp_coordinates[..., 5] ** 2 ) gamma = torch.sqrt(1 + (p / (electron_mass * speed_of_light)) ** 2) - beam.particles[..., 1] = xp_coords[..., 1] / p0.unsqueeze(-1) - beam.particles[..., 3] = xp_coords[..., 3] / p0.unsqueeze(-1) - beam.particles[..., 4] = -xp_coords[..., 4] / beam.relativistic_beta.unsqueeze( - -1 - ) + beam.particles[..., 1] = xp_coordinates[..., 1] / p0.unsqueeze(-1) + beam.particles[..., 3] = xp_coordinates[..., 3] / p0.unsqueeze(-1) + beam.particles[..., 4] = -xp_coordinates[ + ..., 4 + ] / beam.relativistic_beta.unsqueeze(-1) beam.particles[..., 5] = (gamma - beam.relativistic_gamma.unsqueeze(-1)) / ( (beam.relativistic_beta * beam.relativistic_gamma).unsqueeze(-1) ) @@ -944,15 +938,6 @@ def momenta(self) -> torch.Tensor: """Momenta of the individual particles.""" return torch.sqrt(self.energies**2 - electron_mass_eV**2) - def broadcast(self, shape: torch.Size) -> "ParticleBeam": - return self.__class__( - particles=self.particles.repeat((*shape, 1, 1)), - energy=self.energy.repeat(shape), - particle_charges=self.particle_charges.repeat((*shape, 1)), - device=self.particles.device, - dtype=self.particles.dtype, - ) - def __repr__(self) -> str: return ( f"{self.__class__.__name__}(n={repr(self.num_particles)}," diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 9101f248..1f6b6ba1 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -1,15 +1,14 @@ -"""Utility functions for creating transfer maps for the elements.""" +"""Utility functions for creating transfer maps for elements.""" from typing import Optional import torch -from scipy.constants import physical_constants -electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 +from .utils import compute_relativistic_factors def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: - """Rotate the transfer map in x-y plane + """Rotate the transfer map in x-y plane. :param angle: Rotation angle in rad, for example `angle = np.pi/2` for vertical = dipole. @@ -51,14 +50,12 @@ def base_rmatrix( device = length.device dtype = length.dtype - tilt = tilt if tilt is not None else torch.zeros_like(length) - energy = energy if energy is not None else torch.zeros_like(length) + tilt = tilt if tilt is not None else torch.tensor(0.0, device=device, dtype=dtype) + energy = ( + energy if energy is not None else torch.tensor(0.0, device=device, dtype=dtype) + ) - gamma = energy / electron_mass_eV - igamma2 = torch.ones_like(length) - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - - beta = torch.sqrt(1 - igamma2) + _, igamma2, beta = compute_relativistic_factors(energy) # Avoid division by zero k1 = k1.clone() @@ -70,16 +67,18 @@ def base_rmatrix( ky = torch.sqrt(torch.complex(ky2, torch.tensor(0.0, device=device, dtype=dtype))) cx = torch.cos(kx * length).real cy = torch.cos(ky * length).real - sy = torch.clone(length) - sy[ky != 0] = (torch.sin(ky[ky != 0] * length[ky != 0]) / ky[ky != 0]).real - + sy = (torch.sin(ky * length) / ky).real sx = (torch.sin(kx * length) / kx).real dx = hx / kx2 * (1.0 - cx) r56 = hx**2 * (length - sx) / kx2 / beta**2 r56 = r56 - length / beta**2 * igamma2 - R = torch.eye(7, dtype=dtype, device=device).repeat(*length.shape, 1, 1) + vector_shape = torch.broadcast_shapes( + length.shape, k1.shape, hx.shape, tilt.shape, energy.shape + ) + + R = torch.eye(7, dtype=dtype, device=device).repeat(*vector_shape, 1, 1) R[..., 0, 0] = cx R[..., 0, 1] = sx R[..., 0, 5] = dx / beta @@ -105,16 +104,17 @@ def base_rmatrix( def misalignment_matrix( misalignment: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """Shift the beam for tracking beam through misaligned elements""" + """Shift the beam for tracking beam through misaligned elements.""" device = misalignment.device dtype = misalignment.dtype - batch_shape = misalignment.shape[:-1] - R_exit = torch.eye(7, device=device, dtype=dtype).repeat(*batch_shape, 1, 1) + vector_shape = misalignment.shape[:-1] + + R_exit = torch.eye(7, device=device, dtype=dtype).repeat(*vector_shape, 1, 1) R_exit[..., 0, 6] = misalignment[..., 0] R_exit[..., 2, 6] = misalignment[..., 1] - R_entry = torch.eye(7, device=device, dtype=dtype).repeat(*batch_shape, 1, 1) + R_entry = torch.eye(7, device=device, dtype=dtype).repeat(*vector_shape, 1, 1) R_entry[..., 0, 6] = -misalignment[..., 0] R_entry[..., 2, 6] = -misalignment[..., 1] diff --git a/cheetah/utils/__init__.py b/cheetah/utils/__init__.py index ce4a5002..eb472409 100644 --- a/cheetah/utils/__init__.py +++ b/cheetah/utils/__init__.py @@ -1,4 +1,5 @@ from . import bmadx # noqa: F401 from .device import is_mps_available_and_functional # noqa: F401 from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401 +from .physics import compute_relativistic_factors # noqa: F401 from .unique_name_generator import UniqueNameGenerator # noqa: F401 diff --git a/cheetah/utils/bmadx.py b/cheetah/utils/bmadx.py index 7508f0e5..b5dda62f 100644 --- a/cheetah/utils/bmadx.py +++ b/cheetah/utils/bmadx.py @@ -31,7 +31,7 @@ def cheetah_to_bmad_z_pz( def bmad_to_cheetah_z_pz( z: torch.Tensor, pz: torch.Tensor, p0c: torch.Tensor, mc2: float -) -> torch.Tensor: +) -> tuple[torch.Tensor]: """ Transforms Bmad longitudinal coordinates to Cheetah coordinates and computes reference energy. @@ -284,14 +284,16 @@ def track_a_drift( Pxy2 = Px**2 + Py**2 # Particle's transverse mometum^2 over p0^2 Pl = torch.sqrt(1.0 - Pxy2) # Particle's longitudinal momentum over p0 - # z = z + L * ( beta/beta_ref - 1.0/Pl ) but numerically accurate: - dz = length * ( - sqrt_one((mc2**2 * (2 * pz_in + pz_in**2)) / ((p0c * P) ** 2 + mc2**2)) + # z = z + L * ( beta / beta_ref - 1.0 / Pl ) but numerically accurate: + dz = length.unsqueeze(-1) * ( + sqrt_one( + (mc2**2 * (2 * pz_in + pz_in**2)) / ((p0c.unsqueeze(-1) * P) ** 2 + mc2**2) + ) + sqrt_one(-Pxy2) / Pl ) - x_out = x_in + length * Px / Pl - y_out = y_in + length * Py / Pl + x_out = x_in + length.unsqueeze(-1) * Px / Pl + y_out = y_in + length.unsqueeze(-1) * Py / Pl z_out = z_in + dz return x_out, y_out, z_out @@ -299,7 +301,11 @@ def track_a_drift( def particle_rf_time(z, pz, p0c, mc2): """Returns rf time of Particle p.""" - beta = (1 + pz) * p0c / torch.sqrt(((1 + pz) * p0c) ** 2 + mc2**2) + beta = ( + (1 + pz) + * p0c.unsqueeze(-1) + / torch.sqrt(((1 + pz) * p0c.unsqueeze(-1)) ** 2 + mc2**2) + ) time = -z / (beta * speed_of_light) return time diff --git a/cheetah/utils/kde.py b/cheetah/utils/kde.py index 1aba41f2..36634fff 100644 --- a/cheetah/utils/kde.py +++ b/cheetah/utils/kde.py @@ -12,11 +12,11 @@ def _kde_marginal_pdf( epsilon: Union[torch.Tensor, float] = 1e-10, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Calculate the 1D marginal probability distribution function of the input tensor - based on the number of histogram bins. + Compute the 1D marginal probability distribution function of the input tensor based + on the number of histogram bins. - :param values: Input tensor with shape :math:`(B, N, 1)`. `B` is the batch shape. - :param bins: Positions of the bins where KDE is calculated. + :param values: Input tensor with shape :math:`(B, N)`. `B` is the vector shape. + :param bins: Positions of the bins where KDE is computed. Shape :math:`(N_{bins})`. :param sigma: Gaussian smoothing factor with shape `(1,)`. :param weights: Input data weights of shape :math:`(B, N)`. Default to None. @@ -46,6 +46,8 @@ def _kde_marginal_pdf( if not sigma.dim() == 0: raise ValueError(f"Input sigma must be a of the shape (1,). Got {sigma.shape}") + values = values.unsqueeze(-1) + if weights is None: weights = torch.ones_like(values) else: @@ -78,7 +80,7 @@ def _kde_joint_pdf_2d( epsilon: Union[torch.Tensor, float] = 1e-10, ) -> torch.Tensor: """ - Calculate the joint probability distribution function of the input tensors based on + Compute the joint probability distribution function of the input tensors based on the number of histogram bins. :param kernel_values1: shape :math:`(B, N, N_{bins})`. @@ -120,7 +122,7 @@ def kde_histogram_1d( """ Estimate the histogram using KDE of the input tensor. - The calculation uses kernel density estimation which requires a bandwidth + The computation uses kernel density estimation which requires a bandwidth (smoothing) parameter. :param x: Input tensor to compute the histogram with shape :math:`(B, D)`. @@ -139,7 +141,7 @@ def kde_histogram_1d( """ pdf, _ = _kde_marginal_pdf( - values=x.unsqueeze(-1), + values=x, bins=bins, sigma=bandwidth, weights=weights, @@ -161,7 +163,7 @@ def kde_histogram_2d( """ Estimate the 2D histogram of the input tensor. - The calculation uses kernel density estimation which requires a bandwidth + The computation uses kernel density estimation which requires a bandwidth (smoothing) parameter. This is a modified version of the `kornia.enhance.histogram` implementation. @@ -184,13 +186,13 @@ def kde_histogram_2d( """ _, kernel_values1 = _kde_marginal_pdf( - values=x1.unsqueeze(-1), + values=x1, bins=bins1, sigma=bandwidth, weights=weights, ) _, kernel_values2 = _kde_marginal_pdf( - values=x2.unsqueeze(-1), + values=x2, bins=bins2, sigma=bandwidth, weights=None, diff --git a/cheetah/utils/physics.py b/cheetah/utils/physics.py new file mode 100644 index 00000000..4c16cac7 --- /dev/null +++ b/cheetah/utils/physics.py @@ -0,0 +1,21 @@ +import torch +from scipy.constants import physical_constants + +electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 + + +def compute_relativistic_factors( + energy: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the relativistic factors gamma, inverse gamma squared and beta for + electrons. + + :param energy: Energy in eV. + :return: gamma, igamma2, beta. + """ + gamma = energy / electron_mass_eV + igamma2 = torch.where(gamma == 0.0, 0.0, 1 / gamma**2) + beta = torch.sqrt(1 - igamma2) + + return gamma, igamma2, beta diff --git a/tests/test_bpm.py b/tests/test_bpm.py index 4b53897a..6d58ead5 100644 --- a/tests/test_bpm.py +++ b/tests/test_bpm.py @@ -10,9 +10,9 @@ def test_no_tracking_error(is_bpm_active, beam_class): """Test that tracking a beam through an inactive BPM does not raise an error.""" segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.BPM(name="my_bpm"), - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), ], ) beam = beam_class.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") diff --git a/tests/test_cavity.py b/tests/test_cavity.py index bb2a695e..08652533 100644 --- a/tests/test_cavity.py +++ b/tests/test_cavity.py @@ -26,8 +26,8 @@ def test_assert_ei_greater_zero(): name="k26_2a", ) beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]) - ).broadcast((3,)) + num_particles=100_000, sigma_x=torch.tensor(1e-5) + ) _ = cavity.track(beam) @@ -39,7 +39,7 @@ def test_assert_ei_greater_zero(): def test_vectorized_cavity_zero_voltage(voltage): """ Tests that a vectorised cavity with zero voltage does not produce NaNs and that - zero voltage can be batched with non-zero voltage. + zero voltage can be vectorised with non-zero voltage. This was a bug introduced during the vectorisation of Cheetah, when the special case of zero was removed and the `_cavity_rmatrix` method was also used in the case @@ -55,20 +55,20 @@ def test_vectorized_cavity_zero_voltage(voltage): dtype=torch.float64, ) incoming = cheetah.ParameterBeam.from_parameters( - mu_x=torch.tensor([0.0]), - mu_px=torch.tensor([0.0]), - mu_y=torch.tensor([0.0]), - mu_py=torch.tensor([0.0]), - sigma_x=torch.tensor([4.8492e-06]), - sigma_px=torch.tensor([1.5603e-07]), - sigma_y=torch.tensor([4.1209e-07]), - sigma_py=torch.tensor([1.1035e-08]), - sigma_tau=torch.tensor([1.0000e-10]), - sigma_p=torch.tensor([1.0000e-06]), - energy=torch.tensor([8.0000e09]), - total_charge=torch.tensor([0.0]), + mu_x=torch.tensor(0.0), + mu_px=torch.tensor(0.0), + mu_y=torch.tensor(0.0), + mu_py=torch.tensor(0.0), + sigma_x=torch.tensor(4.8492e-06), + sigma_px=torch.tensor(1.5603e-07), + sigma_y=torch.tensor(4.1209e-07), + sigma_py=torch.tensor(1.1035e-08), + sigma_tau=torch.tensor(1.0000e-10), + sigma_p=torch.tensor(1.0000e-06), + energy=torch.tensor(8.0000e09), + total_charge=torch.tensor(0.0), dtype=torch.float64, - ).broadcast((2,)) + ) outgoing = cavity.track(incoming) diff --git a/tests/test_compare_beam_type.py b/tests/test_compare_beam_type.py index 24c747a8..6a6d3e28 100644 --- a/tests/test_compare_beam_type.py +++ b/tests/test_compare_beam_type.py @@ -12,25 +12,25 @@ def test_from_twiss(): Test that a beams created from Twiss parameters have the same properties. """ parameter_beam = cheetah.ParameterBeam.from_twiss( - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([2e-7]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(2e-7), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) particle_beam = cheetah.ParticleBeam.from_twiss( num_particles=torch.tensor( [10_000_000] ), # Large number of particles reduces noise - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([2e-7]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(2e-7), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) assert torch.isclose(parameter_beam.mu_x, particle_beam.mu_x, atol=1e-6) @@ -51,7 +51,7 @@ def test_drift(): """Test that the drift output for both beam types is roughly the same.""" # Set up lattice - cheetah_drift = cheetah.Drift(length=torch.tensor([1.0])) + cheetah_drift = cheetah.Drift(length=torch.tensor(1.0)) # Parameter beam incoming_parameter_beam = cheetah.ParameterBeam.from_astra( @@ -98,7 +98,7 @@ def test_quadrupole(): # Set up lattice cheetah_quadrupole = cheetah.Quadrupole( - length=torch.tensor([0.15]), k1=torch.tensor([4.2]) + length=torch.tensor(0.15), k1=torch.tensor(4.2) ) # Parameter beam @@ -149,10 +149,10 @@ def test_cavity_from_astra(): # Set up lattice cheetah_cavity = cheetah.Cavity( - length=torch.tensor([1.0377]), - voltage=torch.tensor([0.01815975e9]), - frequency=torch.tensor([1.3e9]), - phase=torch.tensor([0.0]), + length=torch.tensor(1.0377), + voltage=torch.tensor(0.01815975e9), + frequency=torch.tensor(1.3e9), + phase=torch.tensor(0.0), ) # Parameter beam @@ -221,33 +221,33 @@ def test_cavity_from_twiss(): # Set up lattice cheetah_cavity = cheetah.Cavity( - length=torch.tensor([1.0377]), - voltage=torch.tensor([0.01815975e9]), - frequency=torch.tensor([1.3e9]), - phase=torch.tensor([0.0]), + length=torch.tensor(1.0377), + voltage=torch.tensor(0.01815975e9), + frequency=torch.tensor(1.3e9), + phase=torch.tensor(0.0), ) # Parameter beam incoming_parameter_beam = cheetah.ParameterBeam.from_twiss( - beta_x=torch.tensor([5.91253677]), - alpha_x=torch.tensor([3.55631308]), - beta_y=torch.tensor([5.91253677]), - alpha_y=torch.tensor([3.55631308]), - emittance_x=torch.tensor([3.494768647122823e-09]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253677), + alpha_x=torch.tensor(3.55631308), + beta_y=torch.tensor(5.91253677), + alpha_y=torch.tensor(3.55631308), + emittance_x=torch.tensor(3.494768647122823e-09), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) outgoing_parameter_beam = cheetah_cavity.track(incoming_parameter_beam) # Particle beam incoming_particle_beam = cheetah.ParticleBeam.from_twiss( - beta_x=torch.tensor([5.91253677]), - alpha_x=torch.tensor([3.55631308]), - beta_y=torch.tensor([5.91253677]), - alpha_y=torch.tensor([3.55631308]), - emittance_x=torch.tensor([3.494768647122823e-09]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253677), + alpha_x=torch.tensor(3.55631308), + beta_y=torch.tensor(5.91253677), + alpha_y=torch.tensor(3.55631308), + emittance_x=torch.tensor(3.494768647122823e-09), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) outgoing_particle_beam = cheetah_cavity.track(incoming_particle_beam) diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index d4e0c0bc..1478454e 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -18,9 +18,7 @@ def test_dipole(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" ) - cheetah_dipole = cheetah.Dipole( - length=torch.tensor([0.1]), angle=torch.tensor([0.1]) - ) + cheetah_dipole = cheetah.Dipole(length=torch.tensor(0.1), angle=torch.tensor(0.1)) outgoing_beam = cheetah_dipole.track(incoming_beam) # Ocelot @@ -33,7 +31,7 @@ def test_dipole(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) @@ -48,9 +46,7 @@ def test_dipole_with_float64(): "tests/resources/ACHIP_EA1_2021.1351.001", dtype=torch.float64 ) cheetah_dipole = cheetah.Dipole( - length=torch.tensor([0.1]), - angle=torch.tensor([0.1]), - dtype=torch.float64, + length=torch.tensor(0.1), angle=torch.tensor(0.1), dtype=torch.float64 ) outgoing_beam = cheetah_dipole.track(incoming_beam) @@ -64,7 +60,7 @@ def test_dipole_with_float64(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) @@ -79,10 +75,10 @@ def test_dipole_with_fringe_field(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_dipole = cheetah.Dipole( - length=torch.tensor([0.1]), - angle=torch.tensor([0.1]), - fringe_integral=torch.tensor([0.1]), - gap=torch.tensor([0.2]), + length=torch.tensor(0.1), + angle=torch.tensor(0.1), + fringe_integral=torch.tensor(0.1), + gap=torch.tensor(0.2), ) outgoing_beam = cheetah_dipole.track(incoming_beam) @@ -96,7 +92,7 @@ def test_dipole_with_fringe_field(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) @@ -114,13 +110,13 @@ def test_dipole_with_fringe_field_and_tilt(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_dipole = cheetah.Dipole( - length=torch.tensor([1.0]), - angle=torch.tensor([bend_angle]), - fringe_integral=torch.tensor([0.1]), - gap=torch.tensor([0.2]), - tilt=torch.tensor([tilt_angle]), - e1=torch.tensor([bend_angle / 2]), - e2=torch.tensor([bend_angle / 2]), + length=torch.tensor(1.0), + angle=torch.tensor(bend_angle), + fringe_integral=torch.tensor(0.1), + gap=torch.tensor(0.2), + tilt=torch.tensor(tilt_angle), + e1=torch.tensor(bend_angle / 2), + e2=torch.tensor(bend_angle / 2), ) outgoing_beam = cheetah_dipole(incoming_beam) @@ -142,7 +138,7 @@ def test_dipole_with_fringe_field_and_tilt(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) @@ -159,13 +155,13 @@ def test_aperture(): cheetah_segment = cheetah.Segment( [ cheetah.Aperture( - x_max=torch.tensor([2e-4]), - y_max=torch.tensor([2e-4]), + x_max=torch.tensor(2e-4), + y_max=torch.tensor(2e-4), shape="rectangular", name="aperture", is_active=True, ), - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -195,13 +191,13 @@ def test_aperture_elliptical(): cheetah_segment = cheetah.Segment( [ cheetah.Aperture( - x_max=torch.tensor([2e-4]), - y_max=torch.tensor([2e-4]), + x_max=torch.tensor(2e-4), + y_max=torch.tensor(2e-4), shape="elliptical", name="aperture", is_active=True, ), - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -231,9 +227,7 @@ def test_solenoid(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" ) - cheetah_solenoid = cheetah.Solenoid( - length=torch.tensor([0.5]), k=torch.tensor([5.0]) - ) + cheetah_solenoid = cheetah.Solenoid(length=torch.tensor(0.5), k=torch.tensor(5.0)) outgoing_beam = cheetah_solenoid.track(incoming_beam) # Ocelot @@ -246,7 +240,7 @@ def test_solenoid(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[..., :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) @@ -384,7 +378,7 @@ def test_astra_import(): p_array = ocelot.astraBeam2particleArray("tests/resources/ACHIP_EA1_2021.1351.001") assert np.allclose( - beam.particles[0, :, :6].cpu().numpy(), p_array.rparticles.transpose() + beam.particles[:, :6].cpu().numpy(), p_array.rparticles.transpose() ) assert np.isclose(beam.energy.cpu().numpy(), (p_array.E * 1e9)) @@ -399,13 +393,13 @@ def test_quadrupole(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_quadrupole = cheetah.Quadrupole( - length=torch.tensor([0.23]), k1=torch.tensor([5.0]) + length=torch.tensor(0.23), k1=torch.tensor(5.0) ) cheetah_segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), cheetah_quadrupole, - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -423,7 +417,7 @@ def test_quadrupole(): # Split in order to allow for different tolerances for each particle dimension assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -441,13 +435,13 @@ def test_tilted_quadrupole(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_quadrupole = cheetah.Quadrupole( - length=torch.tensor([0.23]), k1=torch.tensor([5.0]), tilt=torch.tensor([0.79]) + length=torch.tensor(0.23), k1=torch.tensor(5.0), tilt=torch.tensor(0.79) ) cheetah_segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), cheetah_quadrupole, - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -464,7 +458,7 @@ def test_tilted_quadrupole(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -481,14 +475,12 @@ def test_sbend(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" ) - cheetah_dipole = cheetah.Dipole( - length=torch.tensor([0.1]), angle=torch.tensor([0.2]) - ) + cheetah_dipole = cheetah.Dipole(length=torch.tensor(0.1), angle=torch.tensor(0.2)) cheetah_segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), cheetah_dipole, - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -507,7 +499,7 @@ def test_sbend(): ) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -525,16 +517,16 @@ def test_rbend(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_dipole = cheetah.RBend( - length=torch.tensor([0.1]), - angle=torch.tensor([0.2]), - fringe_integral=torch.tensor([0.1]), - gap=torch.tensor([0.2]), + length=torch.tensor(0.1), + angle=torch.tensor(0.2), + fringe_integral=torch.tensor(0.1), + gap=torch.tensor(0.2), ) cheetah_segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), cheetah_dipole, - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -553,7 +545,7 @@ def test_rbend(): ) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -591,7 +583,7 @@ def test_convert_rbend(): outgoing_beam = cheetah_segment.track(incoming_beam) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -628,7 +620,7 @@ def test_asymmetric_bend(): outgoing_beam = cheetah_segment.track(incoming_beam) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -691,10 +683,10 @@ def test_cavity(): parray=p_array, dtype=torch.float64 ) cheetah_cavity = cheetah.Cavity( - length=torch.tensor([1.0377]), - voltage=torch.tensor([0.01815975e9]), - frequency=torch.tensor([1.3e9]), - phase=torch.tensor([0.0]), + length=torch.tensor(1.0377), + voltage=torch.tensor(0.01815975e9), + frequency=torch.tensor(1.3e9), + phase=torch.tensor(0.0), dtype=torch.float64, ) outgoing_beam = cheetah_cavity.track(incoming_beam) @@ -745,10 +737,10 @@ def test_cavity_non_zero_phase(): parray=p_array, dtype=torch.float64 ) cheetah_cavity = cheetah.Cavity( - length=torch.tensor([1.0377]), - voltage=torch.tensor([0.01815975e9]), - frequency=torch.tensor([1.3e9]), - phase=torch.tensor([30.0]), + length=torch.tensor(1.0377), + voltage=torch.tensor(0.01815975e9), + frequency=torch.tensor(1.3e9), + phase=torch.tensor(30.0), dtype=torch.float64, ) outgoing_beam = cheetah_cavity.track(incoming_beam) diff --git a/tests/test_device_dtype.py b/tests/test_device_dtype.py index 67e61205..fa3e0a45 100644 --- a/tests/test_device_dtype.py +++ b/tests/test_device_dtype.py @@ -25,7 +25,7 @@ def test_move_quadrupole_to_device(target_device: torch.device): """Test that a quadrupole magnet can be successfully moved to a different device.""" quad = cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([4.2]), name="my_quad" + length=torch.tensor(0.2), k1=torch.tensor(4.2), name="my_quad" ) # Test that by default the quadrupole is on the CPU @@ -49,7 +49,7 @@ def test_change_quadrupole_dtype(): Test that a quadrupole magnet can be successfully changed to a different dtype. """ quad = cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([4.2]), name="my_quad" + length=torch.tensor(0.2), k1=torch.tensor(4.2), name="my_quad" ) # Test that by default the quadrupole is of dtype float32 diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index 1d2eb302..c60e07dd 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -13,13 +13,13 @@ def test_simple_quadrupole(): """ segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.Quadrupole( - length=torch.tensor([0.2]), - k1=nn.Parameter(torch.tensor([3.142])), + length=torch.tensor(0.2), + k1=nn.Parameter(torch.tensor(3.142)), name="my_quad", ), - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), ] ) incoming_beam = cheetah.ParticleBeam.from_astra( diff --git a/tests/test_dipole.py b/tests/test_dipole.py index 356f607c..08d016e4 100644 --- a/tests/test_dipole.py +++ b/tests/test_dipole.py @@ -16,15 +16,15 @@ def test_dipole_off(): """ Test that a dipole with angle=0 behaves still like a drift. """ - dipole = Dipole(length=torch.tensor([1.0]), angle=torch.tensor([0.0])) - drift = Drift(length=torch.tensor([1.0])) + dipole = Dipole(length=torch.tensor(1.0), angle=torch.tensor(0.0)) + drift = Drift(length=torch.tensor(1.0)) incoming_beam = ParameterBeam.from_parameters( - sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) outbeam_dipole_off = dipole(incoming_beam) outbeam_drift = drift(incoming_beam) - dipole.angle = torch.tensor([1.0], device=dipole.angle.device) + dipole.angle = torch.tensor(1.0, device=dipole.angle.device) outbeam_dipole_on = dipole(incoming_beam) assert dipole.name is not None @@ -53,33 +53,63 @@ def test_dipole_focussing(): @pytest.mark.parametrize("DipoleType", [Dipole, RBend]) -def test_dipole_batched_execution(DipoleType): +def test_dipole_vectorized_execution(DipoleType): """ - Test that a dipole with batch dimensions behaves as expected. + Test that a dipole with vector dimensions behaves as expected. """ - batch_shape = torch.Size([6]) incoming = ParticleBeam.from_parameters( - num_particles=torch.tensor(1_000_000), - energy=torch.tensor([1e9]), - mu_x=torch.tensor([1e-5]), - ).broadcast(batch_shape) + num_particles=torch.tensor(100), + energy=torch.tensor(1e9), + mu_x=torch.tensor(1e-5), + ) + + # Test vectorisation to generate 3 beam lines segment = Segment( [ DipoleType( length=torch.tensor([0.5, 0.5, 0.5]), angle=torch.tensor([0.1, 0.2, 0.1]), - ).broadcast((2,)), - Drift(length=torch.tensor([0.5])).broadcast(batch_shape), + ), + Drift(length=torch.tensor(0.5)), ] ) outgoing = segment(incoming) + assert outgoing.particles.shape == torch.Size([3, 100, 7]) + assert outgoing.mu_x.shape == torch.Size([3]) + # Check that dipole with same bend angle produce same output assert torch.allclose(outgoing.particles[0], outgoing.particles[2]) # Check different angles do make a difference assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) + # Test vectorisation to generate 18 beamlines + segment = Segment( + [ + Dipole( + length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1), + angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3), + ), + Drift(length=torch.tensor([0.5, 1.0]).reshape(2, 1, 1)), + ] + ) + outgoing = segment(incoming) + assert outgoing.particles.shape == torch.Size([2, 3, 3, 100, 7]) + + # Test improper vectorisation -- this does not obey torch broadcasting rules + segment = Segment( + [ + Dipole( + length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1), + angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3), + ), + Drift(length=torch.tensor([0.5, 1.0]).reshape(2, 1)), + ] + ) + with pytest.raises(RuntimeError): + segment(incoming) + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_dipole_bmadx_tracking(dtype): diff --git a/tests/test_drift.py b/tests/test_drift.py index 8c7a32c4..f0115216 100644 --- a/tests/test_drift.py +++ b/tests/test_drift.py @@ -9,9 +9,9 @@ def test_diverging_parameter_beam(): Test that that a parameter beam with sigma_px > 0 and sigma_py > 0 increases in size in both dimensions when travelling through a drift section. """ - drift = cheetah.Drift(length=torch.tensor([1.0])) + drift = cheetah.Drift(length=torch.tensor(1.0)) incoming_beam = cheetah.ParameterBeam.from_parameters( - sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) outgoing_beam = drift.track(incoming_beam) @@ -25,11 +25,11 @@ def test_diverging_particle_beam(): Test that that a particle beam with sigma_px > 0 and sigma_py > 0 increases in size in both dimensions when travelling through a drift section. """ - drift = cheetah.Drift(length=torch.tensor([1.0])) + drift = cheetah.Drift(length=torch.tensor(1.0)) incoming_beam = cheetah.ParticleBeam.from_parameters( - num_particles=torch.tensor(1000), - sigma_px=torch.tensor([2e-7]), - sigma_py=torch.tensor([2e-7]), + num_particles=torch.tensor(1_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), ) outgoing_beam = drift.track(incoming_beam) @@ -52,7 +52,7 @@ def test_device_like_torch_module(): if not torch.cuda.is_available(): return - element = cheetah.Drift(length=torch.tensor([0.2]), device="cuda") + element = cheetah.Drift(length=torch.tensor(0.2), device="cuda") assert element.length.device.type == "cuda" diff --git a/tests/test_kde.py b/tests/test_kde.py index 12ad6a64..da08b473 100644 --- a/tests/test_kde.py +++ b/tests/test_kde.py @@ -1,12 +1,14 @@ +import pytest import torch +from torch import Size from cheetah.utils import kde_histogram_1d, kde_histogram_2d def test_weighted_samples_1d(): """ - Test that the 1d KDE histogram implementation correctly handles - heterogeneously weighted samples. + Test that the 1D KDE histogram implementation correctly handles heterogeneously + weighted samples. """ x_unweighted = torch.tensor([1.0, 1.0, 1.0, 2.0]) x_weighted = torch.tensor([1.0, 2.0]) @@ -29,8 +31,8 @@ def test_weighted_samples_1d(): def test_weighted_samples_2d(): """ - Test that the 2d KDE histogram implementation correctly handles - heterogeneously weighted samples. + Test that the 2D KDE histogram implementation correctly handles heterogeneously + weighted samples. """ x_unweighted = torch.tensor([[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [2.0, 1.0]]) x_weighted = torch.tensor([[1.0, 2.0], [2.0, 1.0]]) @@ -56,8 +58,50 @@ def test_weighted_samples_2d(): hist_neglect_weights = kde_histogram_2d( x_weighted[:, 0], x_weighted[:, 1], bins1, bins2, sigma ) - print(hist_unweighted[5]) - print(hist_weighted[5]) - print(hist_neglect_weights[5]) + assert torch.allclose(hist_unweighted, hist_weighted) assert not torch.allclose(hist_weighted, hist_neglect_weights) + + +def test_kde_1d_basic_usage(): + """ + Test that basic usage of the 1D KDE histogram implementation works, and that the + output has the correct shape. + """ + data = torch.randn((5, 100)) # 5 beamline states, 100 particles in 1D + bins = torch.linspace(0, 1, 10) # A single histogram + sigma = torch.tensor(0.1) # A single bandwidth + + pdf = kde_histogram_1d(data, bins, sigma) + + assert pdf.shape == Size([5, 10]) # 5 histograms at 10 points + + +def test_kde_1d_enforce_bins_shape(): + """ + Test that the 1D KDE histogram implementation correctly enforces the shape of the + bins tensor, i.e. throws an error if the bins tensor has the wrong shape. + """ + data = torch.randn((5, 100)) # 5 beamline states, 100 particles in 1D + bins = torch.linspace(0, 1, 10) # A single histogram + + with pytest.raises(ValueError): + kde_histogram_1d(data, bins, torch.rand(3) + 0.1) + + +def test_kde_2d_vectorized_basic_usage(): + """ + Test that the 2D KDE histogram implementation correctly handles vectorized inputs, + and that the output has the correct shape. + """ + # 2 diagnostic paths, 3 states per diagnostic path, 100 particles in 6D space + data = torch.randn((3, 2, 100, 6)) + # Two different bins (1 per path) + num_bins = 30 + bins_x = torch.linspace(-20, 20, num_bins) + # A single bandwidth + sigma = torch.tensor(0.1) + + pdf = kde_histogram_2d(data[..., 0], data[..., 1], bins_x, bins_x, sigma) + + assert pdf.shape == Size([3, 2, num_bins, num_bins]) diff --git a/tests/test_parameter_beam.py b/tests/test_parameter_beam.py index 475ada93..3f71f1f0 100644 --- a/tests/test_parameter_beam.py +++ b/tests/test_parameter_beam.py @@ -9,20 +9,20 @@ def test_create_from_parameters(): Test that a `ParameterBeam` created from parameters actually has those parameters. """ beam = ParameterBeam.from_parameters( - mu_x=torch.tensor([1e-5]), - mu_px=torch.tensor([1e-7]), - mu_y=torch.tensor([2e-5]), - mu_py=torch.tensor([2e-7]), - sigma_x=torch.tensor([1.75e-7]), - sigma_px=torch.tensor([2e-7]), - sigma_y=torch.tensor([1.75e-7]), - sigma_py=torch.tensor([2e-7]), - sigma_tau=torch.tensor([0.000001]), - sigma_p=torch.tensor([0.000001]), - cor_x=torch.tensor([0.0]), - cor_y=torch.tensor([0.0]), - cor_tau=torch.tensor([0.0]), - energy=torch.tensor([1e7]), + mu_x=torch.tensor(1e-5), + mu_px=torch.tensor(1e-7), + mu_y=torch.tensor(2e-5), + mu_py=torch.tensor(2e-7), + sigma_x=torch.tensor(1.75e-7), + sigma_px=torch.tensor(2e-7), + sigma_y=torch.tensor(1.75e-7), + sigma_py=torch.tensor(2e-7), + sigma_tau=torch.tensor(0.000001), + sigma_p=torch.tensor(0.000001), + cor_x=torch.tensor(0.0), + cor_y=torch.tensor(0.0), + cor_tau=torch.tensor(0.0), + energy=torch.tensor(1e7), ) assert np.isclose(beam.mu_x.cpu().numpy(), 1e-5) @@ -45,18 +45,18 @@ def test_transform_to(): """ original_beam = ParameterBeam.from_parameters() transformed_beam = original_beam.transformed_to( - mu_x=torch.tensor([1e-5]), - mu_px=torch.tensor([1e-7]), - mu_y=torch.tensor([2e-5]), - mu_py=torch.tensor([2e-7]), - sigma_x=torch.tensor([1.75e-7]), - sigma_px=torch.tensor([2e-7]), - sigma_y=torch.tensor([1.75e-7]), - sigma_py=torch.tensor([2e-7]), - sigma_tau=torch.tensor([0.000001]), - sigma_p=torch.tensor([0.000001]), - energy=torch.tensor([1e7]), - total_charge=torch.tensor([1e-9]), + mu_x=torch.tensor(1e-5), + mu_px=torch.tensor(1e-7), + mu_y=torch.tensor(2e-5), + mu_py=torch.tensor(2e-7), + sigma_x=torch.tensor(1.75e-7), + sigma_px=torch.tensor(2e-7), + sigma_y=torch.tensor(1.75e-7), + sigma_py=torch.tensor(2e-7), + sigma_tau=torch.tensor(0.000001), + sigma_p=torch.tensor(0.000001), + energy=torch.tensor(1e7), + total_charge=torch.tensor(1e-9), ) assert isinstance(transformed_beam, ParameterBeam) @@ -80,13 +80,13 @@ def test_from_twiss_to_twiss(): parameters. """ beam = ParameterBeam.from_twiss( - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([2e-7]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(2e-7), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) assert np.isclose(beam.beta_x.cpu().numpy(), 5.91253676811640894) @@ -103,13 +103,13 @@ def test_from_twiss_dtype(): Test that a `ParameterBeam` created from twiss parameters has the requested `dtype`. """ beam = ParameterBeam.from_twiss( - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([2e-7]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(2e-7), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), dtype=torch.float64, ) diff --git a/tests/test_particle_beam.py b/tests/test_particle_beam.py index f22b68c6..3d3f448f 100644 --- a/tests/test_particle_beam.py +++ b/tests/test_particle_beam.py @@ -9,22 +9,22 @@ def test_create_from_parameters(): Test that a `ParticleBeam` created from parameters actually has those parameters. """ beam = ParticleBeam.from_parameters( - num_particles=torch.tensor([1_000_000]), - mu_x=torch.tensor([1e-5]), - mu_px=torch.tensor([1e-7]), - mu_y=torch.tensor([2e-5]), - mu_py=torch.tensor([2e-7]), - sigma_x=torch.tensor([1.75e-7]), - sigma_px=torch.tensor([2e-7]), - sigma_y=torch.tensor([1.75e-7]), - sigma_py=torch.tensor([2e-7]), - sigma_tau=torch.tensor([0.000001]), - sigma_p=torch.tensor([0.000001]), - cor_x=torch.tensor([0.0]), - cor_y=torch.tensor([0.0]), - cor_tau=torch.tensor([0.0]), - energy=torch.tensor([1e7]), - total_charge=torch.tensor([1e-9]), + num_particles=torch.tensor(1_000_000), + mu_x=torch.tensor(1e-5), + mu_px=torch.tensor(1e-7), + mu_y=torch.tensor(2e-5), + mu_py=torch.tensor(2e-7), + sigma_x=torch.tensor(1.75e-7), + sigma_px=torch.tensor(2e-7), + sigma_y=torch.tensor(1.75e-7), + sigma_py=torch.tensor(2e-7), + sigma_tau=torch.tensor(0.000001), + sigma_p=torch.tensor(0.000001), + cor_x=torch.tensor(0.0), + cor_y=torch.tensor(0.0), + cor_tau=torch.tensor(0.0), + energy=torch.tensor(1e7), + total_charge=torch.tensor(1e-9), ) assert beam.num_particles == 1_000_000 @@ -49,18 +49,18 @@ def test_transform_to(): """ original_beam = ParticleBeam.from_parameters() transformed_beam = original_beam.transformed_to( - mu_x=torch.tensor([1e-5]), - mu_px=torch.tensor([1e-7]), - mu_y=torch.tensor([2e-5]), - mu_py=torch.tensor([2e-7]), - sigma_x=torch.tensor([1.75e-7]), - sigma_px=torch.tensor([2e-7]), - sigma_y=torch.tensor([1.75e-7]), - sigma_py=torch.tensor([2e-7]), - sigma_tau=torch.tensor([0.000001]), - sigma_p=torch.tensor([0.000001]), - energy=torch.tensor([1e7]), - total_charge=torch.tensor([1e-9]), + mu_x=torch.tensor(1e-5), + mu_px=torch.tensor(1e-7), + mu_y=torch.tensor(2e-5), + mu_py=torch.tensor(2e-7), + sigma_x=torch.tensor(1.75e-7), + sigma_px=torch.tensor(2e-7), + sigma_y=torch.tensor(1.75e-7), + sigma_py=torch.tensor(2e-7), + sigma_tau=torch.tensor(0.000001), + sigma_p=torch.tensor(0.000001), + energy=torch.tensor(1e7), + total_charge=torch.tensor(1e-9), ) assert isinstance(transformed_beam, ParticleBeam) @@ -86,14 +86,14 @@ def test_from_twiss_to_twiss(): parameters. """ beam = ParticleBeam.from_twiss( - num_particles=torch.tensor([10_000_000]), - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([1.0]), # TODO: set realistic value - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + num_particles=torch.tensor(10_000_000), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(1.0), # TODO: set realistic value + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) # rather loose rtol is needed here due to the random sampling of the beam assert np.isclose(beam.beta_x.cpu().numpy(), 5.91253676811640894, rtol=1e-2) @@ -105,7 +105,7 @@ def test_from_twiss_to_twiss(): assert np.isclose(beam.energy.cpu().numpy(), 6e6) -def test_generate_uniform_ellipsoid_batched(): +def test_generate_uniform_ellipsoid_vectorized(): """ Test that a `ParticleBeam` generated from a uniform 3D ellipsoid has the correct parameters, i.e. the all particles are within the ellipsoid, and that the other diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 0dc599a8..dbdedac3 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -8,37 +8,36 @@ def test_quadrupole_off(): """ Test that a quadrupole with k1=0 behaves still like a drift. """ - quadrupole = Quadrupole(length=torch.tensor([1.0]), k1=torch.tensor([0.0])) - drift = Drift(length=torch.tensor([1.0])) + quadrupole = Quadrupole(length=torch.tensor(1.0), k1=torch.tensor(0.0)) + drift = Drift(length=torch.tensor(1.0)) incoming_beam = ParameterBeam.from_parameters( - sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) outbeam_quad = quadrupole(incoming_beam) outbeam_drift = drift(incoming_beam) - quadrupole.k1 = torch.tensor([1.0], device=quadrupole.k1.device) + quadrupole.k1 = torch.tensor(1.0, device=quadrupole.k1.device) outbeam_quad_on = quadrupole(incoming_beam) assert torch.allclose(outbeam_quad.sigma_x, outbeam_drift.sigma_x) assert not torch.allclose(outbeam_quad_on.sigma_x, outbeam_drift.sigma_x) -def test_quadrupole_with_misalignments_batched(): +def test_quadrupole_with_misalignments_vectorized(): """ Test that a quadrupole with misalignments behaves as expected. """ - quad_with_misalignment = Quadrupole( - length=torch.tensor([1.0]), - k1=torch.tensor([1.0]), - misalignment=torch.tensor([[0.1, 0.1]]), + length=torch.tensor(1.0), + k1=torch.tensor(1.0), + misalignment=torch.tensor([0.1, 0.1]).unsqueeze(0), ) quad_without_misalignment = Quadrupole( - length=torch.tensor([1.0]), k1=torch.tensor([1.0]) + length=torch.tensor(1.0), k1=torch.tensor(1.0) ) incoming_beam = ParameterBeam.from_parameters( - sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) outbeam_quad_with_misalignment = quad_with_misalignment(incoming_beam) outbeam_quad_without_misalignment = quad_without_misalignment(incoming_beam) @@ -49,89 +48,125 @@ def test_quadrupole_with_misalignments_batched(): ) -def test_quadrupole_with_misalignments_multiple_batch_dimension(): +def test_quadrupole_with_misalignments_multiple_vector_dimensions(): """ - Test that a quadrupole with misalignments with multiple batch dimension. + Test that a quadrupole with misalignments that have multiple vector dimensions does + not raise an error and behaves as expected. """ - batch_shape = torch.Size([4, 3]) quad_with_misalignment = Quadrupole( - length=torch.tensor([1.0]), - k1=torch.tensor([1.0]), - misalignment=torch.tensor([[0.1, 0.1]]), - ).broadcast(batch_shape) - + length=torch.tensor(1.0), + k1=torch.tensor(1.0), + misalignment=torch.randn((4, 3, 2)) * 5e-4, + ) quad_without_misalignment = Quadrupole( - length=torch.tensor([1.0]), k1=torch.tensor([1.0]) - ).broadcast(batch_shape) - incoming_beam = ParameterBeam.from_parameters( - sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) - ).broadcast(batch_shape) - outbeam_quad_with_misalignment = quad_with_misalignment(incoming_beam) - outbeam_quad_without_misalignment = quad_without_misalignment(incoming_beam) + length=torch.tensor(1.0), k1=torch.tensor(1.0) + ) + + incoming = ParameterBeam.from_parameters( + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) + ) + + outgoing_with_misalignment = quad_with_misalignment(incoming) + outgoing_without_misalignment = quad_without_misalignment(incoming) # Check that the misalignment has an effect assert not torch.allclose( - outbeam_quad_with_misalignment.mu_x, - outbeam_quad_without_misalignment.mu_x, + outgoing_with_misalignment.mu_x, outgoing_without_misalignment.mu_x ) # Check that the output shape is correct - assert outbeam_quad_with_misalignment.mu_x.shape == batch_shape + assert outgoing_with_misalignment.mu_x.shape == (4, 3) -def test_tilted_quadrupole_batch(): +def test_tilted_quadrupole_vectorized(): """ Test that a quadrupole with a tilt behaves as expected in vectorised mode. """ - batch_shape = torch.Size([3]) incoming = ParticleBeam.from_parameters( - num_particles=torch.tensor(1000000), - energy=torch.tensor([1e9]), - mu_x=torch.tensor([1e-5]), - ).broadcast(batch_shape) + num_particles=torch.tensor(1_000_000), + energy=torch.tensor(1e9), + mu_x=torch.tensor(1e-5), + ) segment = Segment( [ Quadrupole( - length=torch.tensor([0.5, 0.5, 0.5]), - k1=torch.tensor([1.0, 1.0, 1.0]), + length=torch.tensor(0.5), + k1=torch.tensor(1.0), tilt=torch.tensor([torch.pi / 4, torch.pi / 2, torch.pi * 5 / 4]), ), - Drift(length=torch.tensor([0.5])).broadcast(batch_shape), + Drift(length=torch.tensor(0.5)), ] ) outgoing = segment(incoming) - # Check pi/4 and 5/4*pi rotations is the same for quadrupole + # Check that pi/4 and 5/4*pi rotations is the same for quadrupole assert torch.allclose(outgoing.particles[0], outgoing.particles[2]) - # Check pi/2 rotation is different + # Check that pi/2 rotation is different assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) -def test_tilted_quadrupole_multiple_batch_dimension(): +def test_tilted_quadrupole_multiple_vector_dimensions(): """ - Test that a quadrupole with a tilt behaves as expected in vectorised mode with - multiple vectorisation dimensions. + Test that a quadrupole with tilts that have multiple vectorisation dimensions does + not raise an error and behaves as expected. """ - batch_shape = torch.Size([3, 2]) - incoming = ParticleBeam.from_parameters( - num_particles=torch.tensor(10000), - energy=torch.tensor([1e9]), - mu_x=torch.tensor([1e-5]), - ).broadcast(batch_shape) segment = Segment( [ Quadrupole( - length=torch.tensor([0.5]), - k1=torch.tensor([1.0]), - tilt=torch.tensor([torch.pi / 4]), + length=torch.tensor(0.5), + k1=torch.tensor(1.0), + tilt=torch.tensor( + [ + [torch.pi / 4, torch.pi / 2, torch.pi * 5 / 4], + [torch.pi * 5 / 4, torch.pi / 2, torch.pi / 4], + ] + ), ), - Drift(length=torch.tensor([0.5])), + Drift(length=torch.tensor(0.5)), + ] + ) + + incoming = ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + energy=torch.tensor(1e9), + mu_x=torch.tensor(1e-5), + ) + + outgoing = segment(incoming) + + # Test that shape is correct + assert outgoing.particles.shape == (2, 3, 10_000, 7) + + # Check that same tilts give same results + assert torch.allclose(outgoing.particles[0, 0], outgoing.particles[1, 2]) + assert torch.allclose(outgoing.particles[0, 1], outgoing.particles[1, 1]) + assert torch.allclose(outgoing.particles[0, 2], outgoing.particles[1, 0]) + + +def test_quadrupole_length_multiple_vector_dimensions(): + """ + Test that a quadrupole with lengths that have multiple vectorisation dimensions does + not raise an error and behaves as expected. + """ + lengths = torch.tensor([[0.2, 0.3, 0.4], [0.5, 0.4, 0.7]]) + segment = Segment( + [ + Quadrupole(length=lengths, k1=torch.tensor(4.2)), + Drift(length=lengths * 2), ] - ).broadcast(batch_shape) + ) + + incoming = ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + energy=torch.tensor(1e9), + mu_x=torch.tensor(1e-5), + ) + outgoing = segment(incoming) - assert torch.allclose(outgoing.particles[0, 0], outgoing.particles[0, 1]) + assert outgoing.particles.shape == (2, 3, 10_000, 7) + assert torch.allclose(outgoing.particles[0, 2], outgoing.particles[1, 1]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @@ -144,10 +179,10 @@ def test_quadrupole_bmadx_tracking(dtype): dtype ) quadrupole = Quadrupole( - length=torch.tensor([1.0]), - k1=torch.tensor([10.0]), - misalignment=torch.tensor([[0.01, -0.02]], dtype=dtype), - tilt=torch.tensor([0.5]), + length=torch.tensor(1.0), + k1=torch.tensor(10.0), + misalignment=torch.tensor([0.01, -0.02], dtype=dtype), + tilt=torch.tensor(0.5), num_steps=10, tracking_method="bmadx", dtype=dtype, @@ -165,7 +200,7 @@ def test_quadrupole_bmadx_tracking(dtype): assert torch.allclose( outgoing.particles, outgoing_bmadx.to(dtype), - atol=1e-14 if dtype == torch.float64 else 0.00001, + atol=1e-14 if dtype == torch.float64 else 1e-5, rtol=1e-14 if dtype == torch.float64 else 1e-6, ) @@ -179,8 +214,8 @@ def test_tracking_method_vectorization(tracking_method): quadrupole = Quadrupole( length=torch.tensor([[0.2, 0.25], [0.3, 0.35], [0.4, 0.45]]), k1=torch.tensor([[4.2, 4.2], [4.3, 4.3], [4.4, 4.4]]), - misalignment=torch.zeros((3, 2, 2)), - tilt=torch.zeros((3, 2)), + misalignment=torch.zeros(2), + tilt=torch.tensor(0.0), tracking_method=tracking_method, ) incoming = ParticleBeam.from_parameters( @@ -199,5 +234,5 @@ def test_tracking_method_vectorization(tracking_method): assert outgoing.sigma_py.shape == (3, 2) assert outgoing.sigma_tau.shape == (3, 2) assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) diff --git a/tests/test_reading_nx_tables.py b/tests/test_reading_nx_tables.py index 5a119610..3689f65b 100644 --- a/tests/test_reading_nx_tables.py +++ b/tests/test_reading_nx_tables.py @@ -1,3 +1,5 @@ +import torch + import cheetah @@ -22,4 +24,4 @@ def test_length(): """ segment = cheetah.Segment.from_nx_tables("tests/resources/Stage4v3_9.txt") - assert segment.length == 44.2215 + assert torch.allclose(segment.length, torch.tensor(44.2215)) diff --git a/tests/test_screen.py b/tests/test_screen.py index 85383a8f..0a1a43c3 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -14,7 +14,7 @@ def test_reading_shows_beam_particle(screen_method): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( resolution=torch.tensor((100, 100)), pixel_size=torch.tensor((1e-5, 1e-5)), @@ -27,13 +27,13 @@ def test_reading_shows_beam_particle(screen_method): beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert np.allclose(segment.my_screen.reading, 0.0) _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert torch.all(segment.my_screen.reading >= 0.0) assert torch.any(segment.my_screen.reading > 0.0) @@ -44,7 +44,7 @@ def test_screen_kde_bandwidth(kde_bandwidth): segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( resolution=torch.tensor((100, 100)), pixel_size=torch.tensor((1e-5, 1e-5)), @@ -58,13 +58,13 @@ def test_screen_kde_bandwidth(kde_bandwidth): beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert np.allclose(segment.my_screen.reading, 0.0) _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert torch.all(segment.my_screen.reading >= 0.0) assert torch.any(segment.my_screen.reading > 0.0) @@ -76,7 +76,7 @@ def test_reading_shows_beam_parameter(screen_method): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( resolution=torch.tensor((100, 100)), pixel_size=torch.tensor((1e-5, 1e-5)), @@ -90,13 +90,13 @@ def test_reading_shows_beam_parameter(screen_method): beam = cheetah.ParameterBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert np.allclose(segment.my_screen.reading, 0.0) _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert torch.all(segment.my_screen.reading >= 0.0) assert torch.any(segment.my_screen.reading > 0.0) @@ -125,12 +125,12 @@ def test_reading_shows_beam_ares(screen_method): segment.AREABSCR1.is_active = True assert isinstance(segment.AREABSCR1.reading, torch.Tensor) - assert segment.AREABSCR1.reading.shape == (1, 2040, 2448) + assert segment.AREABSCR1.reading.shape == (2040, 2448) assert np.allclose(segment.AREABSCR1.reading, 0.0) _ = segment.track(beam) assert isinstance(segment.AREABSCR1.reading, torch.Tensor) - assert segment.AREABSCR1.reading.shape == (1, 2040, 2448) + assert segment.AREABSCR1.reading.shape == (2040, 2448) assert torch.all(segment.AREABSCR1.reading >= 0.0) assert torch.any(segment.AREABSCR1.reading > 0.0) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 2eab7a07..8d91cb32 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -19,8 +19,8 @@ def test_cold_uniform_beam_expansion(): torch.manual_seed(42) # Simulation parameters - R0 = torch.tensor([0.001]) - energy = torch.tensor([2.5e8]) + R0 = torch.tensor(0.001) + energy = torch.tensor(2.5e8) rest_energy = torch.tensor( constants.electron_mass * constants.speed_of_light**2 @@ -33,14 +33,14 @@ def test_cold_uniform_beam_expansion(): incoming = cheetah.ParticleBeam.uniform_3d_ellipsoid( num_particles=torch.tensor(10_000), - total_charge=torch.tensor([1e-9]), + total_charge=torch.tensor(1e-9), energy=energy, radius_x=R0, radius_y=R0, radius_tau=R0 / gamma, # Radius of the beam in s direction in the lab frame - sigma_px=torch.tensor([1e-15]), - sigma_py=torch.tensor([1e-15]), - sigma_p=torch.tensor([1e-15]), + sigma_px=torch.tensor(1e-15), + sigma_py=torch.tensor(1e-15), + sigma_p=torch.tensor(1e-15), ) # Compute section length @@ -74,9 +74,9 @@ def test_vectorized(): """ # Simulation parameters - section_length = torch.tensor([0.42]) - R0 = torch.tensor([0.001]) - energy = torch.tensor([2.5e8]) + section_length = torch.tensor(0.42) + R0 = torch.tensor(0.001) + energy = torch.tensor(2.5e8) rest_energy = torch.tensor( constants.electron_mass * constants.speed_of_light**2 @@ -87,15 +87,14 @@ def test_vectorized(): incoming = cheetah.ParticleBeam.uniform_3d_ellipsoid( num_particles=torch.tensor(10_000), total_charge=torch.tensor([[1e-9, 2e-9], [3e-9, 4e-9], [5e-9, 6e-9]]), - energy=energy.repeat(3, 2), - radius_x=R0.repeat(3, 2), - radius_y=R0.repeat(3, 2), - radius_tau=(R0 / gamma).repeat( - 3, 2 - ), # Radius of the beam in s direction in the lab frame - sigma_px=torch.tensor([1e-15]).repeat(3, 2), - sigma_py=torch.tensor([1e-15]).repeat(3, 2), - sigma_p=torch.tensor([1e-15]).repeat(3, 2), + energy=energy.expand([3, 2]), + radius_x=R0.expand([3, 2]), + radius_y=R0.expand([3, 2]), + radius_tau=R0.expand([3, 2]) / gamma, + # Radius of the beam in s direction in the lab frame + sigma_px=torch.tensor(1e-15).expand([3, 2]), + sigma_py=torch.tensor(1e-15).expand([3, 2]), + sigma_p=torch.tensor(1e-15).expand([3, 2]), ) segment = cheetah.Segment( @@ -108,7 +107,7 @@ def test_vectorized(): cheetah.SpaceChargeKick(section_length / 3), cheetah.Drift(section_length / 6), ] - ).broadcast(shape=(3, 2)) + ) outgoing = segment.track(incoming) @@ -125,8 +124,8 @@ def test_vectorized_cold_uniform_beam_expansion(): torch.manual_seed(42) # Simulation parameters - R0 = torch.tensor([0.001]) - energy = torch.tensor([2.5e8]) + R0 = torch.tensor(0.001) + energy = torch.tensor(2.5e8) rest_energy = torch.tensor( constants.electron_mass * constants.speed_of_light**2 @@ -139,15 +138,15 @@ def test_vectorized_cold_uniform_beam_expansion(): incoming = cheetah.ParticleBeam.uniform_3d_ellipsoid( num_particles=torch.tensor(10_000), - total_charge=torch.tensor([1e-9]), + total_charge=torch.tensor(1e-9), energy=energy, radius_x=R0, radius_y=R0, radius_tau=R0 / gamma, # Radius of the beam in s direction in the lab frame - sigma_px=torch.tensor([1e-15]), - sigma_py=torch.tensor([1e-15]), - sigma_p=torch.tensor([1e-15]), - ).broadcast(shape=(2, 3)) + sigma_px=torch.tensor(1e-15), + sigma_py=torch.tensor(1e-15), + sigma_p=torch.tensor(1e-15), + ) # Compute section length kappa = 1 + (torch.sqrt(torch.tensor(2)) / 4) * torch.log( @@ -184,13 +183,13 @@ def test_incoming_beam_not_modified(): incoming_beam = cheetah.ParticleBeam.from_parameters( num_particles=torch.tensor(10_000), - sigma_px=torch.tensor([2e-7]), - sigma_py=torch.tensor([2e-7]), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), ) # Initial beam properties incoming_beam_before = incoming_beam.particles - section_length = torch.tensor([1.0]) + section_length = torch.tensor(1.0) segment_space_charge = cheetah.Segment( elements=[ cheetah.Drift(section_length / 6), @@ -216,12 +215,12 @@ def test_gradient(): Tests that the gradient of the track method is computed withouth throwing an error. """ incoming_beam = cheetah.ParticleBeam.from_parameters( - num_particles=torch.tensor([10_000]), - sigma_px=torch.tensor([2e-7]), - sigma_py=torch.tensor([2e-7]), + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), ) - segment_length = nn.Parameter(torch.tensor([1.0])) + segment_length = nn.Parameter(torch.tensor(1.0)) segment = cheetah.Segment( elements=[ cheetah.Drift(segment_length / 6), @@ -246,7 +245,7 @@ def test_does_not_break_segment_length(): Test that the computation of a `Segment`'s length does not break when `SpaceChargeKick` is used. """ - section_length = torch.tensor([1.0]) + section_length = torch.tensor(1.0) segment = cheetah.Segment( elements=[ cheetah.Drift(section_length / 6), @@ -257,7 +256,21 @@ def test_does_not_break_segment_length(): cheetah.SpaceChargeKick(section_length / 3), cheetah.Drift(section_length / 6), ] - ).broadcast(shape=(3, 2)) + ) + + assert segment.length.shape == torch.Size([]) + assert torch.allclose(segment.length, torch.tensor(1.0)) + + +def test_space_charge_with_ares_astra_beam(): + """ + Tests running space charge through a 1m drift with an Astra beam from the ARES + linac. This test is added because running this code would throw an error: + `IndexError: index -38 is out of bounds for dimension 3 with size 32`. + """ + segment = cheetah.Segment( + [cheetah.Drift(length=1.0), cheetah.SpaceChargeKick(effect_length=1.0)] + ) + beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") - assert segment.length.shape == (3, 2) - assert torch.allclose(segment.length, torch.tensor([1.0]).repeat(3, 2)) + _ = segment.track(beam) diff --git a/tests/test_speed.py b/tests/test_speed.py index c78a9dee..917f9550 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -18,8 +18,8 @@ def test_tracking_speed(): particles = cheetah.ParticleBeam.from_parameters( num_particles=torch.tensor(int(1e5)), - sigma_x=torch.tensor([175e-6]), - sigma_y=torch.tensor([175e-6]), + sigma_x=torch.tensor(175e-6), + sigma_y=torch.tensor(175e-6), ) t1 = time.time() diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 25b1c0b4..a6d96edc 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -15,11 +15,11 @@ def test_merged_transfer_maps_tracking(): original_segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), + cheetah.Drift(length=torch.tensor(0.4)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]) + length=torch.tensor(0.1), angle=torch.tensor(1e-4) ), ] ) @@ -49,18 +49,18 @@ def test_merged_transfer_maps_tracking_vectorized(): """ incoming_beam = cheetah.ParameterBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((10,)) + ) original_segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), + cheetah.Drift(length=torch.linspace(0.3, 0.5, 10)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]) + length=torch.tensor(0.1), angle=torch.tensor(1e-4) ), ] - ).broadcast((10,)) + ) merged_segment = original_segment.transfer_maps_merged(incoming_beam=incoming_beam) original_beam = original_segment.track(incoming_beam) @@ -90,11 +90,11 @@ def test_merged_transfer_maps_num_elements(): original_segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), + cheetah.Drift(length=torch.tensor(0.4)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]) + length=torch.tensor(0.1), angle=torch.tensor(1e-4) ), ] ) @@ -110,12 +110,12 @@ def test_no_markers_left_after_removal(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), cheetah.Marker(), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.4)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]) + length=torch.tensor(0.1), angle=torch.tensor(1e-4) ), cheetah.Marker(), ] @@ -133,9 +133,9 @@ def test_inactive_magnet_is_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([0.0])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(0.0)), + cheetah.Drift(length=torch.tensor(0.4)), ] ) @@ -152,9 +152,9 @@ def test_active_elements_not_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), + cheetah.Drift(length=torch.tensor(0.4)), ] ) @@ -171,11 +171,11 @@ def test_inactive_magnet_drift_replacement_dtype(dtype: torch.dtype): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6]), dtype=dtype), + cheetah.Drift(length=torch.tensor(0.6), dtype=dtype), cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([0.0]), dtype=dtype + length=torch.tensor(0.2), k1=torch.tensor(0.0), dtype=dtype ), - cheetah.Drift(length=torch.tensor([0.4]), dtype=dtype), + cheetah.Drift(length=torch.tensor(0.4), dtype=dtype), ] ) @@ -194,15 +194,15 @@ def test_skippable_elements_reset(): ) original_segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), + cheetah.Drift(length=torch.tensor(0.6)), cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([4.2]), name="Q1" + length=torch.tensor(0.2), k1=torch.tensor(4.2), name="Q1" ), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.4)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]), name="HCOR_1" + length=torch.tensor(0.1), angle=torch.tensor(1e-4), name="HCOR_1" ), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.4)), ] ) diff --git a/tests/test_split.py b/tests/test_split.py index c5ee1aed..e2d20e96 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -14,7 +14,7 @@ def test_drift_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_drift.track(incoming_beam) outgoing_beam_split = split_drift.track(incoming_beam) @@ -38,7 +38,7 @@ def test_quadrupole_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_quadrupole.track(incoming_beam) outgoing_beam_split = split_quadrupole.track(incoming_beam) @@ -63,7 +63,7 @@ def test_cavity_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_cavity.track(incoming_beam) outgoing_beam_split = split_cavity.track(incoming_beam) @@ -87,7 +87,7 @@ def test_solenoid_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_solenoid.track(incoming_beam) outgoing_beam_split = split_solenoid.track(incoming_beam) @@ -102,6 +102,7 @@ def test_dipole_end(): Test that at the end of a split dipole the result is the same as at the end of the original dipole. """ + original_dipole = cheetah.Dipole( length=torch.tensor([0.2, 0.3]), angle=torch.tensor([4.2, 3.6]) ) @@ -109,7 +110,7 @@ def test_dipole_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_dipole.track(incoming_beam) outgoing_beam_split = split_dipole.track(incoming_beam) @@ -131,7 +132,7 @@ def test_undulator_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_undulator.track(incoming_beam) outgoing_beam_split = split_undulator.track(incoming_beam) @@ -156,7 +157,7 @@ def test_horizontal_corrector_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_horizontal_corrector.track(incoming_beam) outgoing_beam_split = split_horizontal_corrector.track(incoming_beam) @@ -181,7 +182,7 @@ def test_vertical_corrector_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_vertical_corrector.track(incoming_beam) outgoing_beam_split = split_vertical_corrector.track(incoming_beam) @@ -209,7 +210,7 @@ def test_split_preserves_dtype(ElementType): """ Test that the dtype of a drift section's splits is the same as the original drift. """ - original = ElementType(length=torch.tensor([2.0]), dtype=torch.float64) + original = ElementType(length=torch.tensor(2.0), dtype=torch.float64) splits = original.split(resolution=torch.tensor(0.1)) for split in splits: diff --git a/tests/test_tracking_lengthless_elements.py b/tests/test_tracking_lengthless_elements.py index f4bde4fb..d5a738b7 100644 --- a/tests/test_tracking_lengthless_elements.py +++ b/tests/test_tracking_lengthless_elements.py @@ -19,11 +19,11 @@ def test_tracking_lengthless_elements(): segment = cheetah.Segment( [ cheetah.Cavity( - length=torch.tensor([0.1]), voltage=torch.tensor([1e6]), name="C2" + length=torch.tensor(0.1), voltage=torch.tensor(1e6), name="C2" ), cheetah.Marker(name="start"), cheetah.Cavity( - length=torch.tensor([0.1]), voltage=torch.tensor([1e6]), name="C1" + length=torch.tensor(0.1), voltage=torch.tensor(1e6), name="C1" ), ] ) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 1b525507..87ebd03e 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,3 +1,4 @@ +import pytest import torch import cheetah @@ -22,7 +23,7 @@ def test_segment_length_shape(): def test_segment_length_shape_2d(): """ - Test that the shape of a segment's length matches the input for a batch with + Test that the shape of a segment's length matches the input for a vectorisation with multiple dimensions. """ segment = cheetah.Segment( @@ -39,21 +40,21 @@ def test_segment_length_shape_2d(): assert segment.length.shape == (3, 2) -def test_track_particle_single_element_shape(): +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_track_quadrupole_shape(BeamClass): """ - Test that the shape of a beam tracked through a single element matches the input. + Test that the shape of a beam tracked through a single quadrupole element matches + the input. """ quadrupole = cheetah.Quadrupole( length=torch.tensor([0.2, 0.25]), k1=torch.tensor([4.2, 4.2]) ) - incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5, 2e-5]) - ) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor([1e-5, 2e-5])) outgoing = quadrupole.track(incoming) - assert outgoing.particles.shape == incoming.particles.shape - assert outgoing.particles.shape == (2, 100_000, 7) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.mu_x.shape == (2,) assert outgoing.mu_px.shape == (2,) assert outgoing.mu_y.shape == (2,) @@ -64,30 +65,30 @@ def test_track_particle_single_element_shape(): assert outgoing.sigma_py.shape == (2,) assert outgoing.sigma_tau.shape == (2,) assert outgoing.sigma_p.shape == (2,) - assert outgoing.energy.shape == (2,) - assert outgoing.total_charge.shape == (2,) - assert outgoing.particle_charges.shape == (2, 100_000) - assert isinstance(outgoing.num_particles, int) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (100_000,) -def test_track_particle_single_element_shape_2d(): +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_track_quadrupole_shape_2d(BeamClass): """ - Test that the shape of a beam tracked through a single element matches the input for - an n-dimensional batch. + Test that the shape of a beam tracked through a single quadrupole element matches + the input for an n-dimensional batch. """ quadrupole = cheetah.Quadrupole( length=torch.tensor([[0.2, 0.25], [0.3, 0.35], [0.4, 0.45]]), k1=torch.tensor([[4.2, 4.2], [4.3, 4.3], [4.4, 4.4]]), ) - incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, - sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]), + incoming = BeamClass.from_parameters( + sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]) ) outgoing = quadrupole.track(incoming) - assert outgoing.particles.shape == incoming.particles.shape - assert outgoing.particles.shape == (3, 2, 100_000, 7) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.mu_x.shape == (3, 2) assert outgoing.mu_px.shape == (3, 2) assert outgoing.mu_y.shape == (3, 2) @@ -98,13 +99,14 @@ def test_track_particle_single_element_shape_2d(): assert outgoing.sigma_py.shape == (3, 2) assert outgoing.sigma_tau.shape == (3, 2) assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) - assert outgoing.particle_charges.shape == (3, 2, 100_000) - assert isinstance(outgoing.num_particles, int) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (100_000,) -def test_track_particle_segment_shape(): +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_track_segment_shape(BeamClass): """ Test that the shape of a beam tracked through a segment matches the input. """ @@ -117,14 +119,12 @@ def test_track_particle_segment_shape(): cheetah.Drift(length=torch.tensor([0.4, 0.3])), ] ) - incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5, 2e-5]) - ) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor([1e-5, 2e-5])) outgoing = segment.track(incoming) - assert outgoing.particles.shape == incoming.particles.shape - assert outgoing.particles.shape == (2, 100_000, 7) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.mu_x.shape == (2,) assert outgoing.mu_px.shape == (2,) assert outgoing.mu_y.shape == (2,) @@ -135,16 +135,17 @@ def test_track_particle_segment_shape(): assert outgoing.sigma_py.shape == (2,) assert outgoing.sigma_tau.shape == (2,) assert outgoing.sigma_p.shape == (2,) - assert outgoing.energy.shape == (2,) - assert outgoing.total_charge.shape == (2,) - assert outgoing.particle_charges.shape == (2, 100_000) - assert isinstance(outgoing.num_particles, int) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (100_000,) -def test_track_particle_segment_shape_2d(): +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_track_particle_segment_shape_2d(BeamClass): """ - Test that the shape of a beam tracked through a segment matches the input for the - case of a multi-dimensional batch. + Test that the shape of a particle beam tracked through a segment matches the input + for the case of a multi-dimensional batch. """ segment = cheetah.Segment( elements=[ @@ -156,15 +157,14 @@ def test_track_particle_segment_shape_2d(): cheetah.Drift(length=torch.tensor([[0.4, 0.3], [0.6, 0.5], [0.8, 0.7]])), ] ) - incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, - sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]), + incoming = BeamClass.from_parameters( + sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]) ) outgoing = segment.track(incoming) - assert outgoing.particles.shape == incoming.particles.shape - assert outgoing.particles.shape == (3, 2, 100_000, 7) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.mu_x.shape == (3, 2) assert outgoing.mu_px.shape == (3, 2) assert outgoing.mu_y.shape == (3, 2) @@ -175,23 +175,85 @@ def test_track_particle_segment_shape_2d(): assert outgoing.sigma_py.shape == (3, 2) assert outgoing.sigma_tau.shape == (3, 2) assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) - assert outgoing.particle_charges.shape == (3, 2, 100_000) - assert isinstance(outgoing.num_particles, int) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (100_000,) -def test_track_parameter_single_element_shape(): +def test_enormous_through_ares_ea(): """ - Test that the shape of a beam tracked through a single element matches the input. + Test ARES EA with a huge number of settings. This is a stress test and only run + for `ParameterBeam` because `ParticleBeam` would require a lot of memory. """ - quadrupole = cheetah.Quadrupole( - length=torch.tensor([0.2, 0.25]), k1=torch.tensor([4.2, 4.2]) + segment = cheetah.Segment.from_ocelot(ares.cell).subcell("AREASOLA1", "AREABSCR1") + incoming = cheetah.ParameterBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" ) - incoming = cheetah.ParameterBeam.from_parameters(sigma_x=torch.tensor([1e-5, 2e-5])) - outgoing = quadrupole.track(incoming) + segment.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 200_000).repeat(3, 1) + outgoing = segment.track(incoming) + + assert outgoing.mu_x.shape == (3, 200_000) + assert outgoing.mu_px.shape == (3, 200_000) + assert outgoing.mu_y.shape == (3, 200_000) + assert outgoing.mu_py.shape == (3, 200_000) + assert outgoing.sigma_x.shape == (3, 200_000) + assert outgoing.sigma_px.shape == (3, 200_000) + assert outgoing.sigma_y.shape == (3, 200_000) + assert outgoing.sigma_py.shape == (3, 200_000) + assert outgoing.sigma_tau.shape == (3, 200_000) + assert outgoing.sigma_p.shape == (3, 200_000) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) + + +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_cavity_with_zero_and_non_zero_voltage(BeamClass): + """ + Tests that if zero and non-zero voltages are passed to a cavity in a single batch, + there are no errors. This test does NOT check physical correctness. + """ + cavity = cheetah.Cavity( + length=torch.tensor(3.0441), + voltage=torch.tensor([0.0, 48_198_468.0, 0.0]), + phase=torch.tensor(48198468.0), + frequency=torch.tensor(2.8560e09), + name="my_test_cavity", + ) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) + + outgoing = cavity.track(incoming) + + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (3, 100_000, 7) + assert outgoing.mu_x.shape == (3,) + assert outgoing.mu_px.shape == (3,) + assert outgoing.mu_y.shape == (3,) + assert outgoing.mu_py.shape == (3,) + assert outgoing.sigma_x.shape == (3,) + assert outgoing.sigma_px.shape == (3,) + assert outgoing.sigma_y.shape == (3,) + assert outgoing.sigma_py.shape == (3,) + assert outgoing.sigma_tau.shape == (3,) + assert outgoing.sigma_p.shape == (3,) + assert outgoing.energy.shape == (3,) + assert outgoing.total_charge.shape == torch.Size([]) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (100_000,) + + +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_vectorized_undulator(BeamClass): + """Test that a vectorized `Undulator` is able to track a particle beam.""" + element = cheetah.Undulator(length=torch.tensor([0.4, 0.7])) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) + + outgoing = element.track(incoming) + + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.mu_x.shape == (2,) assert outgoing.mu_px.shape == (2,) assert outgoing.mu_y.shape == (2,) @@ -202,56 +264,24 @@ def test_track_parameter_single_element_shape(): assert outgoing.sigma_py.shape == (2,) assert outgoing.sigma_tau.shape == (2,) assert outgoing.sigma_p.shape == (2,) - assert outgoing.energy.shape == (2,) - assert outgoing.total_charge.shape == (2,) - - -def test_track_parameter_single_element_shape_2d(): - """ - Test that the shape of a beam tracked through a single element matches the input for - an n-dimensional batch. - """ - quadrupole = cheetah.Quadrupole( - length=torch.tensor([[0.2, 0.25], [0.3, 0.35], [0.4, 0.45]]), - k1=torch.tensor([[4.2, 4.2], [4.3, 4.3], [4.4, 4.4]]), - ) - incoming = cheetah.ParameterBeam.from_parameters( - sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]) - ) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (100_000,) - outgoing = quadrupole.track(incoming) - assert outgoing.mu_x.shape == (3, 2) - assert outgoing.mu_px.shape == (3, 2) - assert outgoing.mu_y.shape == (3, 2) - assert outgoing.mu_py.shape == (3, 2) - assert outgoing.sigma_x.shape == (3, 2) - assert outgoing.sigma_px.shape == (3, 2) - assert outgoing.sigma_y.shape == (3, 2) - assert outgoing.sigma_py.shape == (3, 2) - assert outgoing.sigma_tau.shape == (3, 2) - assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) - - -def test_track_parameter_segment_shape(): - """ - Test that the shape of a beam tracked through a segment matches the input. - """ - segment = cheetah.Segment( - elements=[ - cheetah.Drift(length=torch.tensor([0.6, 0.5])), - cheetah.Quadrupole( - length=torch.tensor([0.2, 0.25]), k1=torch.tensor([4.2, 4.2]) - ), - cheetah.Drift(length=torch.tensor([0.4, 0.3])), - ] +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_vectorized_solenoid(BeamClass): + """Test that a vectorized `Solenoid` is able to track a particle beam.""" + element = cheetah.Solenoid( + length=torch.tensor([0.4, 0.7]), k=torch.tensor([4.2, 3.1]) ) - incoming = cheetah.ParameterBeam.from_parameters(sigma_x=torch.tensor([1e-5, 2e-5])) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) - outgoing = segment.track(incoming) + outgoing = element.track(incoming) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.mu_x.shape == (2,) assert outgoing.mu_px.shape == (2,) assert outgoing.mu_y.shape == (2,) @@ -262,252 +292,116 @@ def test_track_parameter_segment_shape(): assert outgoing.sigma_py.shape == (2,) assert outgoing.sigma_tau.shape == (2,) assert outgoing.sigma_p.shape == (2,) - assert outgoing.energy.shape == (2,) - assert outgoing.total_charge.shape == (2,) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (100_000,) -def test_track_parameter_segment_shape_2d(): +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam]) +@pytest.mark.parametrize("method", ["kde"]) # Currently only KDE supports vectorisation +def test_vectorized_screen_2d(BeamClass, method): """ - Test that the shape of a beam tracked through a segment matches the input for the - case of a multi-dimensional batch. + Test that a vectorized `Screen` is able to track a particle beam and produce a + reading with 2 vector dimensions. """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]])), - cheetah.Quadrupole( - length=torch.tensor([[0.2, 0.25], [0.3, 0.35], [0.4, 0.45]]), - k1=torch.tensor([[4.2, 4.2], [4.3, 4.3], [4.4, 4.4]]), + cheetah.Drift(length=torch.tensor(1.0)), + cheetah.Screen( + resolution=torch.tensor((100, 100)), + pixel_size=torch.tensor((1e-5, 1e-5)), + misalignment=torch.tensor( + [ + [[1e-4, 2e-4], [3e-4, 4e-4], [5e-4, 6e-4]], + [[-1e-4, -2e-4], [-3e-4, -4e-4], [-5e-4, -6e-4]], + ] + ), + is_active=True, + method=method, + name="my_screen", ), - cheetah.Drift(length=torch.tensor([[0.4, 0.3], [0.6, 0.5], [0.8, 0.7]])), - ] - ) - incoming = cheetah.ParameterBeam.from_parameters( - sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]) + ], + name="my_segment", ) - - outgoing = segment.track(incoming) - - assert outgoing.mu_x.shape == (3, 2) - assert outgoing.mu_px.shape == (3, 2) - assert outgoing.mu_y.shape == (3, 2) - assert outgoing.mu_py.shape == (3, 2) - assert outgoing.sigma_x.shape == (3, 2) - assert outgoing.sigma_px.shape == (3, 2) - assert outgoing.sigma_y.shape == (3, 2) - assert outgoing.sigma_py.shape == (3, 2) - assert outgoing.sigma_tau.shape == (3, 2) - assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) - - -def test_enormous_through_ares(): - """Test ARES EA with a huge number of settings.""" - segment = cheetah.Segment.from_ocelot(ares.cell).subcell("AREASOLA1", "AREABSCR1") - incoming = cheetah.ParameterBeam.from_astra( - "tests/resources/ACHIP_EA1_2021.1351.001" - ) - - segment_broadcast = segment.broadcast((3, 100_000)) - incoming_broadcast = incoming.broadcast((3, 100_000)) - - segment_broadcast.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 100_000).repeat(3, 1) - - outgoing = segment_broadcast.track(incoming_broadcast) - - assert outgoing.mu_x.shape == (3, 100_000) - assert outgoing.mu_px.shape == (3, 100_000) - assert outgoing.mu_y.shape == (3, 100_000) - assert outgoing.mu_py.shape == (3, 100_000) - assert outgoing.sigma_x.shape == (3, 100_000) - assert outgoing.sigma_px.shape == (3, 100_000) - assert outgoing.sigma_y.shape == (3, 100_000) - assert outgoing.sigma_py.shape == (3, 100_000) - assert outgoing.sigma_tau.shape == (3, 100_000) - assert outgoing.sigma_p.shape == (3, 100_000) - assert outgoing.energy.shape == (3, 100_000) - assert outgoing.total_charge.shape == (3, 100_000) - - -def test_before_after_broadcast_tracking_equal_cavity(): + incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) + + _ = segment.track(incoming) + + # Check the reading + assert segment.my_screen.reading.shape == (2, 3, 100, 100) + + +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Cavity, + cheetah.Dipole, + cheetah.Drift, + cheetah.HorizontalCorrector, + cheetah.Quadrupole, + cheetah.RBend, + cheetah.Solenoid, + cheetah.TransverseDeflectingCavity, + cheetah.Undulator, + cheetah.VerticalCorrector, + ], +) +def test_drift_broadcasting_two_different_inputs(ElementClass): """ - Test that when tracking through a segment after broadcasting, the resulting beam is - the same as in the segment before broadcasting. A cavity is used as a reference. + Test that broadcasting rules are correctly applied to a elements with two different + input shapes for elements that have a `length` attribute. """ - cavity = cheetah.Cavity( - length=torch.tensor([3.0441]), - voltage=torch.tensor([48198468.0]), - phase=torch.tensor([-0.0]), - frequency=torch.tensor([2.8560e09]), - name="k26_2d", - ) - incoming = cheetah.ParameterBeam.from_astra( - "tests/resources/ACHIP_EA1_2021.1351.001" + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=100_000, energy=torch.tensor([154e6, 14e9]) ) - outgoing = cavity.track(incoming) + element = ElementClass(length=torch.tensor([[0.6], [0.5], [0.4]])) - broadcast_cavity = cavity.broadcast((3, 10)) - broadcast_incoming = incoming.broadcast((3, 10)) - broadcast_outgoing = broadcast_cavity.track(broadcast_incoming) + outgoing = element.track(incoming) - for i in range(3): - for j in range(10): - assert torch.all(broadcast_outgoing._mu[i, j] == outgoing._mu[0]) - assert torch.all(broadcast_outgoing._cov[i, j] == outgoing._cov[0]) + assert outgoing.particles.shape == (3, 2, 100_000, 7) + assert outgoing.particle_charges.shape == (100_000,) + assert outgoing.energy.shape == (2,) -def test_before_after_broadcast_tracking_equal_ares_ea(): +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Dipole, + cheetah.Drift, + cheetah.Quadrupole, + cheetah.TransverseDeflectingCavity, + ], +) +def test_drift_broadcasting_two_different_inputs_bmadx(ElementClass): """ - Test that when tracking through a segment after broadcasting, the resulting beam is - the same as in the segment before broadcasting. The ARES EA is used as a reference. + Test that broadcasting rules are correctly applied to a elements with two different + input shapes for elements that have a `"bmadx"` tracking method. """ - segment = cheetah.Segment.from_ocelot(ares.cell).subcell("AREASOLA1", "AREABSCR1") - incoming = cheetah.ParameterBeam.from_astra( - "tests/resources/ACHIP_EA1_2021.1351.001" - ) - segment.AREAMQZM1.k1 = torch.tensor([4.2]) - outgoing = segment.track(incoming) - - broadcast_segment = segment.broadcast((3, 10)) - broadcast_incoming = incoming.broadcast((3, 10)) - broadcast_outgoing = broadcast_segment.track(broadcast_incoming) - - for i in range(3): - for j in range(10): - assert torch.allclose(broadcast_outgoing._mu[i, j], outgoing._mu[0]) - assert torch.allclose(broadcast_outgoing._cov[i, j], outgoing._cov[0]) - - -def test_broadcast_customtransfermap(): - """Test that broadcasting a `CustomTransferMap` element gives the correct result.""" - tm = torch.tensor( - [ - [ - [1.0, 4.0e-02, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0e-05], - [0.0, 0.0, 1.0, 4.0e-02, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, -4.6422e-07, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - ] - ] - ) - - element = cheetah.CustomTransferMap(length=torch.tensor([0.4]), transfer_map=tm) - broadcast_element = element.broadcast((3, 10)) - - assert broadcast_element.length.shape == (3, 10) - assert broadcast_element._transfer_map.shape == (3, 10, 7, 7) - for i in range(3): - for j in range(10): - assert torch.all(broadcast_element._transfer_map[i, j] == tm[0]) - - -def test_broadcast_element_keeps_dtype(): - """Test that broadcasting an element keeps the same dtype.""" - element = cheetah.Drift(length=torch.tensor([0.4]), dtype=torch.float64) - broadcast_element = element.broadcast((3, 10)) - - assert broadcast_element.length.dtype == torch.float64 - - -def test_broadcast_beam_keeps_dtype(): - """Test that broadcasting a beam keeps the same dtype.""" - beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]), dtype=torch.float64 + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=100_000, energy=torch.tensor([154e6, 14e9]) ) - broadcast_beam = beam.broadcast((2,)) - drift = cheetah.Drift(length=torch.tensor([0.4, 0.4]), dtype=torch.float64) - - assert broadcast_beam.particles.dtype == torch.float64 - - # This should not raise an error - _ = drift(broadcast_beam) - - -def test_broadcast_drift(): - """Test that broadcasting a `Drift` element gives the correct result.""" - element = cheetah.Drift(length=torch.tensor([0.4])) - broadcast_element = element.broadcast((3, 10)) - - assert broadcast_element.length.shape == (3, 10) - for i in range(3): - for j in range(10): - assert broadcast_element.length[i, j] == 0.4 - - -def test_broadcast_quadrupole(): - """Test that broadcasting a `Quadrupole` element gives the correct result.""" - - # TODO Add misalignment to the test - # TODO Add tilt to the test - - element = cheetah.Quadrupole(length=torch.tensor([0.4]), k1=torch.tensor([4.2])) - broadcast_element = element.broadcast((3, 10)) - - assert broadcast_element.length.shape == (3, 10) - assert broadcast_element.k1.shape == (3, 10) - for i in range(3): - for j in range(10): - assert broadcast_element.length[i, j] == 0.4 - assert broadcast_element.k1[i, j] == 4.2 - - -def test_cavity_with_zero_and_non_zero_voltage(): - """ - Tests that if zero and non-zero voltages are passed to a cavity in a single batch, - there are no errors. This test does NOT check physical correctness. - """ - cavity = cheetah.Cavity( - length=torch.tensor([3.0441, 3.0441, 3.0441]), - voltage=torch.tensor([0.0, 48198468.0, 0.0]), - phase=torch.tensor([48198468.0, 48198468.0, 48198468.0]), - frequency=torch.tensor([2.8560e09, 2.8560e09, 2.8560e09]), - name="my_test_cavity", + element = ElementClass( + tracking_method="bmadx", length=torch.tensor([[0.6], [0.5], [0.4]]) ) - beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]) - ).broadcast((3,)) - _ = cavity.track(beam) + outgoing = element.track(incoming) - -def test_screen_length_shape(): - """ - Test that the shape of a screen's length matches the shape of its misalignment. - """ - screen = cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2], [0.3, 0.4]])) - assert screen.length.shape == screen.misalignment.shape[:-1] + assert outgoing.particles.shape == (3, 2, 100_000, 7) + assert outgoing.particle_charges.shape == (100_000,) + assert outgoing.energy.shape == (2,) -def test_screen_length_broadcast_shape(): +def test_vectorized_parameter_beam_creation(): """ - Test that the shape of a screen's length matches the shape of its misalignment - after broadcasting. + Tests that creating a parameter beam with a few vectorised parameters works as + expected. """ - screen = cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2]])) - broadcast_screen = screen.broadcast((3, 10)) - assert broadcast_screen.length.shape == broadcast_screen.misalignment.shape[:-1] - - -def test_vectorized_undulator(): - """Test that a vectorized `Undulator` is able to track a particle beam.""" - element = cheetah.Undulator(length=torch.tensor([0.4, 0.7])) - beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]) - ).broadcast((2,)) - - _ = element.track(beam) - - -def test_vectorized_solenoid(): - """Test that a vectorized `Solenoid` is able to track a particle beam.""" - element = cheetah.Solenoid( - length=torch.tensor([0.4, 0.7]), k=torch.tensor([4.2, 3.1]) + beam = cheetah.ParameterBeam.from_parameters( + mu_x=torch.tensor([2e-4, 3e-4]), sigma_x=torch.tensor([1e-5, 2e-5]) ) - beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]) - ).broadcast((2,)) - _ = element.track(beam) + assert beam.mu_x.shape == (2,) + assert torch.allclose(beam.mu_x, torch.tensor([2e-4, 3e-4])) + assert beam.sigma_x.shape == (2,) + assert torch.allclose(beam.sigma_x, torch.tensor([1e-5, 2e-5]))