From 404456a7a1b654d958a5cb4ca5f8a056e0ef41d7 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Tue, 2 Jul 2024 11:53:45 -0500 Subject: [PATCH 001/111] update track methods to enable automatic broadcasting for drifts/quads --- cheetah/track_methods.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index f7166b70..98a06ca7 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -72,8 +72,8 @@ def base_rmatrix( ky = torch.sqrt(torch.complex(ky2, torch.tensor(0.0, device=device, dtype=dtype))) cx = torch.cos(kx * length).real cy = torch.cos(ky * length).real - sy = torch.clone(length) - sy[ky != 0] = (torch.sin(ky[ky != 0] * length[ky != 0]) / ky[ky != 0]).real + sy = torch.ones_like(cx) * length + sy[ky != 0] = (torch.sin(ky[ky != 0] * length) / ky[ky != 0]).real sx = (torch.sin(kx * length) / kx).real dx = hx / kx2 * (1.0 - cx) @@ -81,7 +81,7 @@ def base_rmatrix( r56 = r56 - length / beta**2 * igamma2 - R = torch.eye(7, dtype=dtype, device=device).repeat(*length.shape, 1, 1) + R = torch.eye(7, dtype=dtype, device=device).repeat(*cx.shape, 1, 1) R[..., 0, 0] = cx R[..., 0, 1] = sx R[..., 0, 5] = dx / beta From 849421c0fcdd1346e3115e2d606845a488f80495 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Tue, 2 Jul 2024 13:44:41 -0500 Subject: [PATCH 002/111] update vectorize calcs and tests --- .gitignore | 2 ++ cheetah/track_methods.py | 4 +--- tests/test_quadrupole.py | 22 ++++++++++++---------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 414db649..4ad82c3d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ build distributions docs/_build + +.idea \ No newline at end of file diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 98a06ca7..e1501660 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -72,9 +72,7 @@ def base_rmatrix( ky = torch.sqrt(torch.complex(ky2, torch.tensor(0.0, device=device, dtype=dtype))) cx = torch.cos(kx * length).real cy = torch.cos(ky * length).real - sy = torch.ones_like(cx) * length - sy[ky != 0] = (torch.sin(ky[ky != 0] * length) / ky[ky != 0]).real - + sy = (torch.sin(ky * length) / ky).real sx = (torch.sin(kx * length) / kx).real dx = hx / kx2 * (1.0 - cx) r56 = hx**2 * (length - sx) / kx2 / beta**2 diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 9477f5b7..d9e7d3fc 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -52,19 +52,20 @@ def test_quadrupole_with_misalignments_multiple_batch_dimension(): """ Test that a quadrupole with misalignments with multiple batch dimension. """ - batch_shape = torch.Size([4, 3]) + + misalignments = torch.randn((4, 3, 2)) quad_with_misalignment = Quadrupole( length=torch.tensor([1.0]), k1=torch.tensor([1.0]), - misalignment=torch.tensor([[0.1, 0.1]]), - ).broadcast(batch_shape) + misalignment=misalignments, + ) quad_without_misalignment = Quadrupole( length=torch.tensor([1.0]), k1=torch.tensor([1.0]) - ).broadcast(batch_shape) + ) incoming_beam = ParameterBeam.from_parameters( sigma_xp=torch.tensor([2e-7]), sigma_yp=torch.tensor([2e-7]) - ).broadcast(batch_shape) + ) outbeam_quad_with_misalignment = quad_with_misalignment(incoming_beam) outbeam_quad_without_misalignment = quad_without_misalignment(incoming_beam) @@ -75,7 +76,7 @@ def test_quadrupole_with_misalignments_multiple_batch_dimension(): ) # Check that the output shape is correct - assert outbeam_quad_with_misalignment.mu_x.shape == batch_shape + assert outbeam_quad_with_misalignment.mu_x.shape == misalignments.shape[:-1] def test_tilted_quadrupole_batch(): @@ -105,12 +106,11 @@ def test_tilted_quadrupole_batch(): def test_tilted_quadrupole_multiple_batch_dimension(): - batch_shape = torch.Size([3, 2]) incoming = ParticleBeam.from_parameters( num_particles=torch.tensor(10000), energy=torch.tensor([1e9]), mu_x=torch.tensor([1e-5]), - ).broadcast(batch_shape) + ) segment = Segment( [ Quadrupole( @@ -120,7 +120,9 @@ def test_tilted_quadrupole_multiple_batch_dimension(): ), Drift(length=torch.tensor([0.5])), ] - ).broadcast(batch_shape) + ) outgoing = segment(incoming) - assert torch.allclose(outgoing.particles[0, 0], outgoing.particles[0, 1]) + assert torch.allclose( + outgoing.particles[0, 0], outgoing.particles[0, 1], rtol=1e-1, atol=1e-5 + ) From 078ebb47f78031997a12a410bffd58bb74512546 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 9 Jul 2024 12:35:07 +0200 Subject: [PATCH 003/111] Add test that breaks current automatic broadcasting idea --- tests/test_quadrupole.py | 64 ++++++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 13 deletions(-) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index d9e7d3fc..a36d1c5b 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -48,9 +48,10 @@ def test_quadrupole_with_misalignments_batched(): ) -def test_quadrupole_with_misalignments_multiple_batch_dimension(): +def test_quadrupole_with_misalignments_multiple_batch_dimensions(): """ - Test that a quadrupole with misalignments with multiple batch dimension. + Test that a quadrupole with misalignments that have multiple batch dimensions does + not raise an error and behaves as expected. """ misalignments = torch.randn((4, 3, 2)) @@ -80,9 +81,12 @@ def test_quadrupole_with_misalignments_multiple_batch_dimension(): def test_tilted_quadrupole_batch(): + """ + Test that a quadrupole with a multiple tilts behaves as expected. + """ batch_shape = torch.Size([3]) incoming = ParticleBeam.from_parameters( - num_particles=torch.tensor(1000000), + num_particles=torch.tensor(1_000_000), energy=torch.tensor([1e9]), mu_x=torch.tensor([1e-5]), ).broadcast(batch_shape) @@ -105,24 +109,58 @@ def test_tilted_quadrupole_batch(): assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) -def test_tilted_quadrupole_multiple_batch_dimension(): - incoming = ParticleBeam.from_parameters( - num_particles=torch.tensor(10000), - energy=torch.tensor([1e9]), - mu_x=torch.tensor([1e-5]), +# TODO Change batched to vectorised +def test_tilted_quadrupole_multiple_batch_dimensions(): + """ + Test that a quadrupole with tilts that have multiple vectorisation dimensions does + not raise an error and behaves as expected. + """ + tilts = torch.tensor( + [ + [torch.pi / 4, torch.pi / 2, torch.pi * 5 / 4], + [torch.pi * 5 / 4, torch.pi / 2, torch.pi / 4], + ] ) segment = Segment( [ - Quadrupole( - length=torch.tensor([0.5]), - k1=torch.tensor([1.0]), - tilt=torch.tensor([torch.pi / 4]), - ), + Quadrupole(length=torch.tensor([0.5]), k1=torch.tensor([1.0]), tilt=tilts), Drift(length=torch.tensor([0.5])), ] ) + + incoming = ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + energy=torch.tensor([1e9]), + mu_x=torch.tensor([1e-5]), + ) + outgoing = segment(incoming) assert torch.allclose( outgoing.particles[0, 0], outgoing.particles[0, 1], rtol=1e-1, atol=1e-5 ) + assert outgoing.particles.shape == (2, 3, 10_000, 7) + + +def test_quadrupole_length_multiple_batch_dimensions(): + """ + Test that a quadrupole with lengths that have multiple vectorisation dimensions does + not raise an error and behaves as expected. + """ + lengths = torch.tensor([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]]) + segment = Segment( + [ + Quadrupole(length=lengths, k1=torch.tensor([4.2])), + Drift(length=lengths * 2), + ] + ) + + incoming = ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + energy=torch.tensor([1e9]), + mu_x=torch.tensor([1e-5]), + ) + + outgoing = segment(incoming) + + assert outgoing.particles.shape == (2, 3, 10_000, 7) From 176611859507fc6be33b555c38d16aff652f6538 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 11:16:00 -0500 Subject: [PATCH 004/111] remove drift broadcast method, add utility function to calculate inv gamma2 --- cheetah/accelerator/drift.py | 13 +++------ cheetah/track_methods.py | 11 +++----- cheetah/utils/physics.py | 12 +++++++++ tests/test_dipole.py | 51 +++++++++++++++++++++++++++++------- 4 files changed, 59 insertions(+), 28 deletions(-) create mode 100644 cheetah/utils/physics.py diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 73bfbeda..42b6ab1c 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -3,10 +3,11 @@ import matplotlib.pyplot as plt import torch from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from cheetah.utils import UniqueNameGenerator +from ..utils.physics import calculate_inverse_gamma_squared from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -40,16 +41,11 @@ def __init__( self.register_buffer("length", torch.as_tensor(length, **factory_kwargs)) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - assert ( - energy.shape == self.length.shape - ), f"Beam shape {energy.shape} does not match element shape {self.length.shape}" device = self.length.device dtype = self.length.dtype - gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 + igamma2 = calculate_inverse_gamma_squared(energy) beta = torch.sqrt(1 - igamma2) tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) @@ -59,9 +55,6 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return tm - def broadcast(self, shape: Size) -> Element: - return self.__class__(length=self.length.repeat(shape), name=self.name) - @property def is_skippable(self) -> bool: return True diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index e1501660..a8d93783 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -3,11 +3,8 @@ from typing import Optional import torch -from scipy.constants import physical_constants -electron_mass_eV = torch.tensor( - physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 -) +from cheetah.utils.physics import calculate_inverse_gamma_squared def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: @@ -54,11 +51,9 @@ def base_rmatrix( dtype = length.dtype tilt = tilt if tilt is not None else torch.zeros_like(length) - energy = energy if energy is not None else torch.zeros_like(length) + energy = energy if energy is not None else torch.zeros(1) - gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) - igamma2 = torch.ones_like(length) - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 + igamma2 = calculate_inverse_gamma_squared(energy) beta = torch.sqrt(1 - igamma2) diff --git a/cheetah/utils/physics.py b/cheetah/utils/physics.py new file mode 100644 index 00000000..535f077d --- /dev/null +++ b/cheetah/utils/physics.py @@ -0,0 +1,12 @@ +import torch +from scipy.constants import physical_constants + +electron_mass_eV = torch.tensor( + physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 +) + + +def calculate_inverse_gamma_squared(energy): + gamma = energy / electron_mass_eV.to(energy) + igamma2 = torch.where(gamma == 0.0, 0.0, 1 / gamma**2) + return igamma2 diff --git a/tests/test_dipole.py b/tests/test_dipole.py index c6eeae46..53a80d4c 100644 --- a/tests/test_dipole.py +++ b/tests/test_dipole.py @@ -1,3 +1,4 @@ +import pytest import torch from cheetah import Dipole, Drift, ParameterBeam, ParticleBeam, Segment @@ -7,15 +8,15 @@ def test_dipole_off(): """ Test that a dipole with angle=0 behaves still like a drift. """ - dipole = Dipole(length=torch.tensor([1.0]), angle=torch.tensor([0.0])) - drift = Drift(length=torch.tensor([1.0])) + dipole = Dipole(length=torch.tensor(1.0), angle=torch.tensor(0.0)) + drift = Drift(length=torch.tensor(1.0)) incoming_beam = ParameterBeam.from_parameters( - sigma_xp=torch.tensor([2e-7]), sigma_yp=torch.tensor([2e-7]) + sigma_xp=torch.tensor(2e-7), sigma_yp=torch.tensor(2e-7) ) outbeam_dipole_off = dipole(incoming_beam) outbeam_drift = drift(incoming_beam) - dipole.angle = torch.tensor([1.0], device=dipole.angle.device) + dipole.angle = torch.tensor(1.0, device=dipole.angle.device) outbeam_dipole_on = dipole(incoming_beam) assert dipole.name is not None @@ -27,25 +28,55 @@ def test_dipole_batched_execution(): """ Test that a dipole with batch dimensions behaves as expected. """ - batch_shape = torch.Size([3]) incoming = ParticleBeam.from_parameters( - num_particles=torch.tensor(1000000), - energy=torch.tensor([1e9]), - mu_x=torch.tensor([1e-5]), - ).broadcast(batch_shape) + num_particles=torch.tensor(100), + energy=torch.tensor(1e9), + mu_x=torch.tensor(1e-5), + ) + + # test batching to generate 3 beam lines segment = Segment( [ Dipole( length=torch.tensor([0.5, 0.5, 0.5]), angle=torch.tensor([0.1, 0.2, 0.1]), ), - Drift(length=torch.tensor([0.5])).broadcast(batch_shape), + Drift(length=torch.tensor(0.5)), ] ) outgoing = segment(incoming) + assert outgoing.particles.shape == torch.Size([3, 100, 7]) + assert outgoing.mu_x.shape == torch.Size([3]) + # Check that dipole with same bend angle produce same output assert torch.allclose(outgoing.particles[0], outgoing.particles[2]) # Check different angles do make a difference assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) + + # test batching to generate 18 beamlines + segment = Segment( + [ + Dipole( + length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1), + angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3), + ), + Drift(length=torch.tensor([0.5, 1.0]).reshape(2, 1, 1)), + ] + ) + outgoing = segment(incoming) + assert outgoing.particles.shape == torch.Size([2, 3, 3, 100, 7]) + + # test improper batching -- this does not obey torch broadcasting rules + segment = Segment( + [ + Dipole( + length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1), + angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3), + ), + Drift(length=torch.tensor([0.5, 1.0]).reshape(2, 1)), + ] + ) + with pytest.raises(RuntimeError): + segment(incoming) From 44e589fd0de5349ac5a0509c751d933ebe5eed30 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 11:28:59 -0500 Subject: [PATCH 005/111] update relativistic factor calc util, cavity --- cheetah/accelerator/cavity.py | 37 +++++------------------------------ cheetah/accelerator/drift.py | 5 ++--- cheetah/track_methods.py | 6 ++---- cheetah/utils/physics.py | 12 ++++++++++-- tests/test_cavity.py | 2 +- 5 files changed, 20 insertions(+), 42 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 5310c721..051d4fb0 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -4,20 +4,15 @@ import torch from matplotlib.patches import Rectangle from scipy import constants -from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from cheetah.particles import Beam, ParameterBeam, ParticleBeam from cheetah.utils import UniqueNameGenerator - from .element import Element +from ..utils.physics import calculate_relativistic_factors, electron_mass_eV generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") -electron_mass_eV = torch.tensor( - physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 -) - class Cavity(Element): """ @@ -103,19 +98,7 @@ def track(self, incoming: Beam) -> Beam: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") def _track_beam(self, incoming: Beam) -> Beam: - device = self.length.device - dtype = self.length.dtype - - beta0 = torch.full_like(self.length, 1.0) - igamma2 = torch.full_like(self.length, 0.0) - g0 = torch.full_like(self.length, 1e10) - - mask = incoming.energy != 0 - g0[mask] = incoming.energy[mask] / electron_mass_eV.to( - device=device, dtype=dtype - ) - igamma2[mask] = 1 / g0[mask] ** 2 - beta0[mask] = torch.sqrt(1 - igamma2[mask]) + g0, igamma2, beta0 = calculate_relativistic_factors(incoming.energy) phi = torch.deg2rad(self.phase) @@ -136,8 +119,7 @@ def _track_beam(self, incoming: Beam) -> Beam: if torch.any(incoming.energy + delta_energy > 0): k = 2 * torch.pi * self.frequency / constants.speed_of_light outgoing_energy = incoming.energy + delta_energy - g1 = outgoing_energy / electron_mass_eV - beta1 = torch.sqrt(1 - 1 / g1**2) + g1, ig1, beta1 = calculate_relativistic_factors(outgoing_energy) if isinstance(incoming, ParameterBeam): outgoing_mu[..., 5] = incoming._mu[..., 5] * incoming.energy * beta0 / ( @@ -168,7 +150,7 @@ def _track_beam(self, incoming: Beam) -> Beam: - torch.cos(phi).unsqueeze(-1) ) - dgamma = self.voltage / electron_mass_eV + dgamma = self.voltage / electron_mass_eV.to(self.voltage) if torch.any(delta_energy > 0): T566 = ( self.length @@ -338,15 +320,6 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: return R - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - voltage=self.voltage.repeat(shape), - phase=self.phase.repeat(shape), - frequency=self.frequency.repeat(shape), - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for cavity properly, for now just returns the # element itself diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 42b6ab1c..f016d0dc 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -7,7 +7,7 @@ from cheetah.utils import UniqueNameGenerator -from ..utils.physics import calculate_inverse_gamma_squared +from ..utils.physics import calculate_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -45,8 +45,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - igamma2 = calculate_inverse_gamma_squared(energy) - beta = torch.sqrt(1 - igamma2) + _, igamma2, beta = calculate_relativistic_factors(energy) tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) tm[..., 0, 1] = self.length diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index a8d93783..6c786362 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -4,7 +4,7 @@ import torch -from cheetah.utils.physics import calculate_inverse_gamma_squared +from cheetah.utils.physics import calculate_relativistic_factors def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: @@ -53,9 +53,7 @@ def base_rmatrix( tilt = tilt if tilt is not None else torch.zeros_like(length) energy = energy if energy is not None else torch.zeros(1) - igamma2 = calculate_inverse_gamma_squared(energy) - - beta = torch.sqrt(1 - igamma2) + _, igamma2, beta = calculate_relativistic_factors(energy) # Avoid division by zero k1 = k1.clone() diff --git a/cheetah/utils/physics.py b/cheetah/utils/physics.py index 535f077d..0a0d1ba2 100644 --- a/cheetah/utils/physics.py +++ b/cheetah/utils/physics.py @@ -6,7 +6,15 @@ ) -def calculate_inverse_gamma_squared(energy): +def calculate_relativistic_factors(energy): + """ + calculates relativistic factors gamma, inverse gamma squared, beta + for electrons + + :param energy: Energy in eV + :return: gamma, igamma2, beta + """ gamma = energy / electron_mass_eV.to(energy) igamma2 = torch.where(gamma == 0.0, 0.0, 1 / gamma**2) - return igamma2 + beta = torch.sqrt(1 - igamma2) + return gamma, igamma2, beta diff --git a/tests/test_cavity.py b/tests/test_cavity.py index 0b6c47a1..2e6dd1e3 100644 --- a/tests/test_cavity.py +++ b/tests/test_cavity.py @@ -26,6 +26,6 @@ def test_assert_ei_greater_zero(): ) beam = cheetah.ParticleBeam.from_parameters( num_particles=100_000, sigma_x=torch.tensor([1e-5]) - ).broadcast((3,)) + ) _ = cavity.track(beam) From 7edad77e821f74adb979ee6be2b57bc8895a0c5e Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 11:29:11 -0500 Subject: [PATCH 006/111] Update cavity.py --- cheetah/accelerator/cavity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 051d4fb0..f47319cf 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -8,8 +8,9 @@ from cheetah.particles import Beam, ParameterBeam, ParticleBeam from cheetah.utils import UniqueNameGenerator -from .element import Element + from ..utils.physics import calculate_relativistic_factors, electron_mass_eV +from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") From 7150248d70c79033b520a0e8c9fdc07df6108327 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 11:32:01 -0500 Subject: [PATCH 007/111] update quadrupole and segment for batching --- cheetah/accelerator/quadrupole.py | 12 +----------- cheetah/accelerator/segment.py | 7 ++----- tests/test_quadrupole.py | 5 ++--- 3 files changed, 5 insertions(+), 19 deletions(-) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index c67c129e..f1bf73e5 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -4,11 +4,10 @@ import numpy as np import torch from matplotlib.patches import Rectangle -from torch import Size, nn +from torch import nn from cheetah.track_methods import base_rmatrix, misalignment_matrix from cheetah.utils import UniqueNameGenerator - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -81,15 +80,6 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry) return R - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - k1=self.k1.repeat(shape), - misalignment=self.misalignment.repeat((*shape, 1)), - tilt=self.tilt.repeat(shape), - name=self.name, - ) - @property def is_skippable(self) -> bool: return True diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index aad08357..2dfd88b5 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -324,11 +324,8 @@ def is_skippable(self) -> bool: @property def length(self) -> torch.Tensor: - lengths = torch.stack( - [element.length for element in self.elements], - dim=1, - ) - return torch.sum(lengths, dim=1) + lengths = [element.length for element in self.elements] + return torch.add(*lengths) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: if self.is_skippable: diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index a36d1c5b..50d80071 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -84,12 +84,11 @@ def test_tilted_quadrupole_batch(): """ Test that a quadrupole with a multiple tilts behaves as expected. """ - batch_shape = torch.Size([3]) incoming = ParticleBeam.from_parameters( num_particles=torch.tensor(1_000_000), energy=torch.tensor([1e9]), mu_x=torch.tensor([1e-5]), - ).broadcast(batch_shape) + ) segment = Segment( [ Quadrupole( @@ -97,7 +96,7 @@ def test_tilted_quadrupole_batch(): k1=torch.tensor([1.0, 1.0, 1.0]), tilt=torch.tensor([torch.pi / 4, torch.pi / 2, torch.pi * 5 / 4]), ), - Drift(length=torch.tensor([0.5])).broadcast(batch_shape), + Drift(length=torch.tensor([0.5])), ] ) outgoing = segment(incoming) From a0aef4f15f7a4da9b6f38a4fda958227ed0cb47d Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 11:32:08 -0500 Subject: [PATCH 008/111] Update quadrupole.py --- cheetah/accelerator/quadrupole.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index f1bf73e5..205abd76 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -8,6 +8,7 @@ from cheetah.track_methods import base_rmatrix, misalignment_matrix from cheetah.utils import UniqueNameGenerator + from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") From b37b3cf00bcdee7eb1e34b542694b7676f3a1def Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 11:46:50 -0500 Subject: [PATCH 009/111] fix test errors and bugs --- cheetah/accelerator/dipole.py | 6 +----- cheetah/accelerator/drift.py | 2 +- cheetah/accelerator/element.py | 4 ---- cheetah/accelerator/horizontal_corrector.py | 18 +++--------------- cheetah/accelerator/segment.py | 6 +++++- tests/test_reading_nx_tables.py | 4 +++- tests/test_space_charge_kick.py | 2 +- tests/test_vectorized.py | 11 +++-------- 8 files changed, 17 insertions(+), 36 deletions(-) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index a59eab10..e328bd25 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -108,11 +108,7 @@ def __init__( @property def hx(self) -> torch.Tensor: - value = torch.zeros_like(self.length) - value[self.length != 0] = ( - self.angle[self.length != 0] / self.length[self.length != 0] - ) - return value + return torch.where(self.length == 0.0, 0.0, self.angle / self.length) @property def is_skippable(self) -> bool: diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index f016d0dc..74c3f081 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -47,7 +47,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: _, igamma2, beta = calculate_relativistic_factors(energy) - tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + tm = torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length tm[..., 4, 5] = -self.length / beta**2 * igamma2 diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 44abab95..6db1c625 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -87,10 +87,6 @@ def forward(self, incoming: Beam) -> Beam: """Forward function required by `torch.nn.Module`. Simply calls `track`.""" return self.track(incoming) - def broadcast(self, shape: torch.Size) -> "Element": - """Broadcast the element to higher batch dimensions.""" - raise NotImplementedError - @property @abstractmethod def is_skippable(self) -> bool: diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index fc0ef535..8f5bad74 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -4,19 +4,15 @@ import numpy as np import torch from matplotlib.patches import Rectangle -from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from cheetah.utils import UniqueNameGenerator +from ..utils.physics import calculate_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") -electron_mass_eV = torch.tensor( - physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 -) - class HorizontalCorrector(Element): """ @@ -54,10 +50,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - beta = torch.sqrt(1 - igamma2) + _, igamma2, beta = calculate_relativistic_factors(energy) tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) tm[..., 0, 1] = self.length @@ -67,11 +60,6 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return tm - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), angle=self.angle, name=self.name - ) - @property def is_skippable(self) -> bool: return True diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index 2dfd88b5..3b0d8ad6 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -1,4 +1,5 @@ from copy import deepcopy +from functools import reduce from pathlib import Path from typing import Any, Optional, Union @@ -324,8 +325,11 @@ def is_skippable(self) -> bool: @property def length(self) -> torch.Tensor: + if len(self.elements) == 1: + return self.elements[0].length + lengths = [element.length for element in self.elements] - return torch.add(*lengths) + return reduce(torch.add, lengths) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: if self.is_skippable: diff --git a/tests/test_reading_nx_tables.py b/tests/test_reading_nx_tables.py index 5a119610..f3bdee7c 100644 --- a/tests/test_reading_nx_tables.py +++ b/tests/test_reading_nx_tables.py @@ -1,3 +1,5 @@ +import torch + import cheetah @@ -22,4 +24,4 @@ def test_length(): """ segment = cheetah.Segment.from_nx_tables("tests/resources/Stage4v3_9.txt") - assert segment.length == 44.2215 + assert segment.length == torch.tensor(44.2215) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index d487cd8d..d7c74b1d 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -108,7 +108,7 @@ def test_vectorized(): cheetah.SpaceChargeKick(section_length / 3), cheetah.Drift(section_length / 6), ] - ).broadcast(shape=(3, 2)) + ) outgoing = segment.track(incoming) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 9d8ead39..4bf601ee 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -421,15 +421,10 @@ def test_broadcast_quadrupole(): # TODO Add misalignment to the test # TODO Add tilt to the test - element = cheetah.Quadrupole(length=torch.tensor([0.4]), k1=torch.tensor([4.2])) - broadcast_element = element.broadcast((3, 10)) + element = cheetah.Quadrupole(length=torch.randn((3, 10)), k1=torch.tensor([4.2])) - assert broadcast_element.length.shape == (3, 10) - assert broadcast_element.k1.shape == (3, 10) - for i in range(3): - for j in range(10): - assert broadcast_element.length[i, j] == 0.4 - assert broadcast_element.k1[i, j] == 4.2 + assert element.length.shape == (3, 10) + assert element.k1.shape == (1,) def test_cavity_with_zero_and_non_zero_voltage(): From a17dd65530a03408f00f8cf0a55be5ec9e4f055a Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 12:14:41 -0500 Subject: [PATCH 010/111] implement batch calculation utility and fix vectorize tests --- cheetah/accelerator/drift.py | 4 +- cheetah/accelerator/horizontal_corrector.py | 4 +- cheetah/track_methods.py | 4 +- cheetah/utils/batching.py | 6 +++ tests/test_reading_nx_tables.py | 2 +- tests/test_space_charge_kick.py | 4 +- tests/test_speed_optimizations.py | 4 +- tests/test_vectorized.py | 54 ++++++++++----------- 8 files changed, 47 insertions(+), 35 deletions(-) create mode 100644 cheetah/utils/batching.py diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 74c3f081..60af6c00 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -47,7 +47,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: _, igamma2, beta = calculate_relativistic_factors(energy) - tm = torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1)) + tm = torch.eye(7, device=device, dtype=dtype).repeat(( + *(self.length*igamma2).shape, 1, 1) + ) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length tm[..., 4, 5] = -self.length / beta**2 * igamma2 diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index 8f5bad74..d02e94ad 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -7,6 +7,7 @@ from torch import nn from cheetah.utils import UniqueNameGenerator +from ..utils.batching import get_batch_shape from ..utils.physics import calculate_relativistic_factors from .element import Element @@ -52,7 +53,8 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: _, igamma2, beta = calculate_relativistic_factors(energy) - tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + batch_shape = get_batch_shape(self.length, self.angle, beta) + tm = torch.eye(7, device=device, dtype=dtype).repeat((*batch_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 1, 6] = self.angle tm[..., 2, 3] = self.length diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 6c786362..91d7dde8 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -4,6 +4,7 @@ import torch +from cheetah.utils.batching import get_batch_shape from cheetah.utils.physics import calculate_relativistic_factors @@ -72,7 +73,8 @@ def base_rmatrix( r56 = r56 - length / beta**2 * igamma2 - R = torch.eye(7, dtype=dtype, device=device).repeat(*cx.shape, 1, 1) + batch_shape = get_batch_shape(dx, sx, beta, cx) + R = torch.eye(7, dtype=dtype, device=device).repeat(*batch_shape, 1, 1) R[..., 0, 0] = cx R[..., 0, 1] = sx R[..., 0, 5] = dx / beta diff --git a/cheetah/utils/batching.py b/cheetah/utils/batching.py new file mode 100644 index 00000000..d0372d9d --- /dev/null +++ b/cheetah/utils/batching.py @@ -0,0 +1,6 @@ +import torch + + +def get_batch_shape(*args): + result = torch.broadcast_tensors(*args) + return result[0].shape diff --git a/tests/test_reading_nx_tables.py b/tests/test_reading_nx_tables.py index f3bdee7c..f1fb7792 100644 --- a/tests/test_reading_nx_tables.py +++ b/tests/test_reading_nx_tables.py @@ -24,4 +24,4 @@ def test_length(): """ segment = cheetah.Segment.from_nx_tables("tests/resources/Stage4v3_9.txt") - assert segment.length == torch.tensor(44.2215) + assert torch.allclose(segment.length, torch.tensor([44.2215])) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index d7c74b1d..d9fcd0a3 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -246,7 +246,7 @@ def test_does_not_break_segment_length(): Test that the computation of a `Segment`'s length does not break when `SpaceChargeKick` is used. """ - section_length = torch.tensor([1.0]) + section_length = torch.tensor([1.0]).repeat((3,2)) segment = cheetah.Segment( elements=[ cheetah.Drift(section_length / 6), @@ -257,7 +257,7 @@ def test_does_not_break_segment_length(): cheetah.SpaceChargeKick(section_length / 3), cheetah.Drift(section_length / 6), ] - ).broadcast(shape=(3, 2)) + ) assert segment.length.shape == (3, 2) assert torch.allclose(segment.length, torch.tensor([1.0]).repeat(3, 2)) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 6ed8d5a6..722d9eeb 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -49,7 +49,7 @@ def test_merged_transfer_maps_tracking_vectorized(): """ incoming_beam = cheetah.ParameterBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((10,)) + ).broadcast(torch.Size([10])) original_segment = cheetah.Segment( elements=[ @@ -60,7 +60,7 @@ def test_merged_transfer_maps_tracking_vectorized(): length=torch.tensor([0.1]), angle=torch.tensor([1e-4]) ), ] - ).broadcast((10,)) + ) merged_segment = original_segment.transfer_maps_merged(incoming_beam=incoming_beam) original_beam = original_segment.track(incoming_beam) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 4bf601ee..f81e47aa 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -308,25 +308,24 @@ def test_enormous_through_ares(): "tests/resources/ACHIP_EA1_2021.1351.001" ) - segment_broadcast = segment.broadcast((3, 100_000)) - incoming_broadcast = incoming.broadcast((3, 100_000)) + segment_broadcast = segment + incoming_broadcast = incoming - segment_broadcast.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 100_000).repeat(3, 1) + segment_broadcast.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 100).repeat(3, 1) outgoing = segment_broadcast.track(incoming_broadcast) - assert outgoing.mu_x.shape == (3, 100_000) - assert outgoing.mu_xp.shape == (3, 100_000) - assert outgoing.mu_y.shape == (3, 100_000) - assert outgoing.mu_yp.shape == (3, 100_000) - assert outgoing.sigma_x.shape == (3, 100_000) - assert outgoing.sigma_xp.shape == (3, 100_000) - assert outgoing.sigma_y.shape == (3, 100_000) - assert outgoing.sigma_yp.shape == (3, 100_000) - assert outgoing.sigma_s.shape == (3, 100_000) - assert outgoing.sigma_p.shape == (3, 100_000) - assert outgoing.energy.shape == (3, 100_000) - assert outgoing.total_charge.shape == (3, 100_000) + assert outgoing.mu_x.shape == (3, 100) + assert outgoing.mu_xp.shape == (3, 100) + assert outgoing.mu_y.shape == (3, 100) + assert outgoing.mu_yp.shape == (3, 100) + assert outgoing.sigma_x.shape == (3, 100) + assert outgoing.sigma_xp.shape == (3, 100) + assert outgoing.sigma_y.shape == (3, 100) + assert outgoing.sigma_yp.shape == (3, 100) + assert outgoing.sigma_s.shape == (3, 100) + assert outgoing.sigma_p.shape == (3, 100) + assert outgoing.total_charge.shape == torch.Size([1]) def test_before_after_broadcast_tracking_equal_cavity(): @@ -335,7 +334,7 @@ def test_before_after_broadcast_tracking_equal_cavity(): the same as in the segment before broadcasting. A cavity is used as a reference. """ cavity = cheetah.Cavity( - length=torch.tensor([3.0441]), + length=torch.tensor([3.0441]).repeat((3, 10)), voltage=torch.tensor([48198468.0]), phase=torch.tensor([-0.0]), frequency=torch.tensor([2.8560e09]), @@ -346,8 +345,8 @@ def test_before_after_broadcast_tracking_equal_cavity(): ) outgoing = cavity.track(incoming) - broadcast_cavity = cavity.broadcast((3, 10)) - broadcast_incoming = incoming.broadcast((3, 10)) + broadcast_cavity = cavity + broadcast_incoming = incoming broadcast_outgoing = broadcast_cavity.track(broadcast_incoming) for i in range(3): @@ -365,11 +364,11 @@ def test_before_after_broadcast_tracking_equal_ares_ea(): incoming = cheetah.ParameterBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" ) - segment.AREAMQZM1.k1 = torch.tensor([4.2]) + segment.AREAMQZM1.k1 = torch.tensor([4.2]).repeat((3,10)) outgoing = segment.track(incoming) - broadcast_segment = segment.broadcast((3, 10)) - broadcast_incoming = incoming.broadcast((3, 10)) + broadcast_segment = segment + broadcast_incoming = incoming broadcast_outgoing = broadcast_segment.track(broadcast_incoming) for i in range(3): @@ -394,8 +393,10 @@ def test_broadcast_customtransfermap(): ] ) - element = cheetah.CustomTransferMap(length=torch.tensor([0.4]), transfer_map=tm) - broadcast_element = element.broadcast((3, 10)) + element = cheetah.CustomTransferMap( + length=torch.tensor([0.4]), transfer_map=tm.repeat((3,10)) + ) + broadcast_element = element assert broadcast_element.length.shape == (3, 10) assert broadcast_element._transfer_map.shape == (3, 10, 7, 7) @@ -406,13 +407,12 @@ def test_broadcast_customtransfermap(): def test_broadcast_drift(): """Test that broadcasting a `Drift` element gives the correct result.""" - element = cheetah.Drift(length=torch.tensor([0.4])) - broadcast_element = element.broadcast((3, 10)) + element = cheetah.Drift(length=torch.tensor([0.4]).repeat((3,10))) - assert broadcast_element.length.shape == (3, 10) + assert element.length.shape == (3, 10) for i in range(3): for j in range(10): - assert broadcast_element.length[i, j] == 0.4 + assert element.length[i, j] == 0.4 def test_broadcast_quadrupole(): From 35ebe233eca43f589d51ead36c3eb2fb8df58691 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 12:16:47 -0500 Subject: [PATCH 011/111] ufmt formatting --- cheetah/accelerator/drift.py | 5 ++--- cheetah/accelerator/horizontal_corrector.py | 2 +- tests/test_space_charge_kick.py | 5 ++--- tests/test_vectorized.py | 9 ++++----- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 60af6c00..4e949ecf 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -41,14 +41,13 @@ def __init__( self.register_buffer("length", torch.as_tensor(length, **factory_kwargs)) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - device = self.length.device dtype = self.length.dtype _, igamma2, beta = calculate_relativistic_factors(energy) - tm = torch.eye(7, device=device, dtype=dtype).repeat(( - *(self.length*igamma2).shape, 1, 1) + tm = torch.eye(7, device=device, dtype=dtype).repeat( + (*(self.length * igamma2).shape, 1, 1) ) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index d02e94ad..48682dce 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -7,8 +7,8 @@ from torch import nn from cheetah.utils import UniqueNameGenerator -from ..utils.batching import get_batch_shape +from ..utils.batching import get_batch_shape from ..utils.physics import calculate_relativistic_factors from .element import Element diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index d9fcd0a3..7f2c6511 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -1,10 +1,9 @@ +import cheetah import torch from scipy import constants from scipy.constants import physical_constants from torch import nn -import cheetah - def test_cold_uniform_beam_expansion(): """ @@ -246,7 +245,7 @@ def test_does_not_break_segment_length(): Test that the computation of a `Segment`'s length does not break when `SpaceChargeKick` is used. """ - section_length = torch.tensor([1.0]).repeat((3,2)) + section_length = torch.tensor([1.0]).repeat((3, 2)) segment = cheetah.Segment( elements=[ cheetah.Drift(section_length / 6), diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index f81e47aa..6297c970 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,6 +1,5 @@ -import torch - import cheetah +import torch from .resources import ARESlatticeStage3v1_9 as ares @@ -364,7 +363,7 @@ def test_before_after_broadcast_tracking_equal_ares_ea(): incoming = cheetah.ParameterBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" ) - segment.AREAMQZM1.k1 = torch.tensor([4.2]).repeat((3,10)) + segment.AREAMQZM1.k1 = torch.tensor([4.2]).repeat((3, 10)) outgoing = segment.track(incoming) broadcast_segment = segment @@ -394,7 +393,7 @@ def test_broadcast_customtransfermap(): ) element = cheetah.CustomTransferMap( - length=torch.tensor([0.4]), transfer_map=tm.repeat((3,10)) + length=torch.tensor([0.4]), transfer_map=tm.repeat((3, 10)) ) broadcast_element = element @@ -407,7 +406,7 @@ def test_broadcast_customtransfermap(): def test_broadcast_drift(): """Test that broadcasting a `Drift` element gives the correct result.""" - element = cheetah.Drift(length=torch.tensor([0.4]).repeat((3,10))) + element = cheetah.Drift(length=torch.tensor([0.4]).repeat((3, 10))) assert element.length.shape == (3, 10) for i in range(3): From 42fcd83cd096374660353ff89ee4b358e03da3e8 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 12:22:21 -0500 Subject: [PATCH 012/111] remove broadcasting methods, fix vectorized test --- cheetah/accelerator/bpm.py | 7 ------- cheetah/accelerator/custom_transfer_map.py | 10 +--------- cheetah/accelerator/undulator.py | 10 +--------- tests/test_vectorized.py | 6 ++++-- 4 files changed, 6 insertions(+), 27 deletions(-) diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index 5aede5ef..a11fd635 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -4,11 +4,9 @@ import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import Size from cheetah.particles import Beam, ParameterBeam, ParticleBeam from cheetah.utils import UniqueNameGenerator - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -50,11 +48,6 @@ def track(self, incoming: Beam) -> Beam: return deepcopy(incoming) - def broadcast(self, shape: Size) -> Element: - new_bpm = self.__class__(is_active=self.is_active, name=self.name) - new_bpm.length = self.length.repeat(shape) - return new_bpm - def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index ef6297b7..c13aca6a 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -2,11 +2,10 @@ import matplotlib.pyplot as plt import torch -from torch import Size, nn +from torch import nn from cheetah.particles import Beam from cheetah.utils import UniqueNameGenerator - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -85,13 +84,6 @@ def from_merging_elements( def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return self._transfer_map - def broadcast(self, shape: Size) -> Element: - return self.__class__( - self._transfer_map.repeat((*shape, 1, 1)), - length=self.length.repeat(shape), - name=self.name, - ) - @property def is_skippable(self) -> bool: return True diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index 8a360e5c..9227d4b2 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -4,10 +4,9 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from cheetah.utils import UniqueNameGenerator - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -61,13 +60,6 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return tm - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - is_active=self.is_active, - name=self.name, - ) - @property def is_skippable(self) -> bool: return True diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 6297c970..3bdc213e 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,6 +1,7 @@ -import cheetah import torch +import cheetah + from .resources import ARESlatticeStage3v1_9 as ares @@ -393,7 +394,8 @@ def test_broadcast_customtransfermap(): ) element = cheetah.CustomTransferMap( - length=torch.tensor([0.4]), transfer_map=tm.repeat((3, 10)) + length=torch.tensor([0.4]).repeat((3, 10)), + transfer_map=tm.repeat((3, 10, 1, 1)), ) broadcast_element = element From 8ed54b7daf8f53b88ef14ec8ec67f872fe23011a Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 12:26:24 -0500 Subject: [PATCH 013/111] remove broadcast methods --- cheetah/accelerator/aperture.py | 14 +------------- cheetah/accelerator/dipole.py | 16 +--------------- cheetah/accelerator/marker.py | 7 ------- cheetah/accelerator/screen.py | 15 +-------------- cheetah/accelerator/segment.py | 9 +-------- tests/test_vectorized.py | 4 +--- 6 files changed, 5 insertions(+), 60 deletions(-) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index 37b9ee6d..77dd6d2f 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -3,11 +3,10 @@ import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import Size, nn +from torch import nn from cheetah.particles import Beam, ParticleBeam from cheetah.utils import UniqueNameGenerator - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -110,17 +109,6 @@ def track(self, incoming: Beam) -> Beam: else ParticleBeam.empty ) - def broadcast(self, shape: Size) -> Element: - new_aperture = self.__class__( - x_max=self.x_max.repeat(shape), - y_max=self.y_max.repeat(shape), - shape=self.shape, - is_active=self.is_active, - name=self.name, - ) - new_aperture.length = self.length.repeat(shape) - return new_aperture - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for aperture properly, for now just return self return [self] diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index e328bd25..6bfdec0a 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -4,11 +4,10 @@ import numpy as np import torch from matplotlib.patches import Rectangle -from torch import Size, nn +from torch import nn from cheetah.track_methods import base_rmatrix, rotation_matrix from cheetah.utils import UniqueNameGenerator - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -189,19 +188,6 @@ def _transfer_map_exit(self) -> torch.Tensor: return tm - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - angle=self.angle.repeat(shape), - e1=self.e1.repeat(shape), - e2=self.e2.repeat(shape), - tilt=self.tilt.repeat(shape), - fringe_integral=self.fringe_integral.repeat(shape), - fringe_integral_exit=self.fringe_integral_exit.repeat(shape), - gap=self.gap.repeat(shape), - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for dipole properly, for now just returns the # element itself diff --git a/cheetah/accelerator/marker.py b/cheetah/accelerator/marker.py index 605c81df..1375155c 100644 --- a/cheetah/accelerator/marker.py +++ b/cheetah/accelerator/marker.py @@ -2,11 +2,9 @@ import matplotlib.pyplot as plt import torch -from torch import Size from cheetah.particles import Beam from cheetah.utils import UniqueNameGenerator - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -32,11 +30,6 @@ def track(self, incoming: Beam) -> Beam: # Markers would be able to record the beam tracked through them. return incoming - def broadcast(self, shape: Size) -> Element: - new_marker = self.__class__(name=self.name) - new_marker.length = self.length.repeat(shape) - return new_marker - @property def is_skippable(self) -> bool: return True diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 196e4e7e..5fc80af1 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -4,12 +4,11 @@ import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import Size, nn +from torch import nn from torch.distributions import MultivariateNormal from cheetah.particles import Beam, ParameterBeam, ParticleBeam from cheetah.utils import UniqueNameGenerator, kde_histogram_2d - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -281,18 +280,6 @@ def set_read_beam(self, value: Beam) -> None: self._read_beam = [value] self.cached_reading = None - def broadcast(self, shape: Size) -> Element: - new_screen = self.__class__( - resolution=self.resolution, - pixel_size=self.pixel_size, - binning=self.binning, - misalignment=self.misalignment.repeat((*shape, 1)), - is_active=self.is_active, - name=self.name, - ) - new_screen.length = self.length.repeat(shape) - return new_screen - def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index 3b0d8ad6..c3f2137c 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -6,14 +6,13 @@ import matplotlib import matplotlib.pyplot as plt import torch -from torch import Size, nn +from torch import nn from cheetah.converters.bmad import convert_bmad_lattice from cheetah.converters.nxtables import read_nx_tables from cheetah.latticejson import load_cheetah_model, save_cheetah_model from cheetah.particles import Beam, ParticleBeam from cheetah.utils import UniqueNameGenerator - from .custom_transfer_map import CustomTransferMap from .drift import Drift from .element import Element @@ -360,12 +359,6 @@ def track(self, incoming: Beam) -> Beam: return incoming - def broadcast(self, shape: Size) -> Element: - return self.__class__( - elements=[element.broadcast(shape) for element in self.elements], - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: return [ split_element diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 3bdc213e..7da0c62f 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -460,6 +460,4 @@ def test_screen_length_broadcast_shape(): Test that the shape of a screen's length matches the shape of its misalignment after broadcasting. """ - screen = cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2]])) - broadcast_screen = screen.broadcast((3, 10)) - assert broadcast_screen.length.shape == broadcast_screen.misalignment.shape[:-1] + cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2]]).repeat(3, 10, 1)) From 521136911831d1e9cb51c515ab5e58cf71c396c7 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 12:28:37 -0500 Subject: [PATCH 014/111] remove broadcast methods --- cheetah/accelerator/solenoid.py | 11 +---------- cheetah/accelerator/space_charge_kick.py | 20 -------------------- cheetah/accelerator/vertical_corrector.py | 8 +------- 3 files changed, 2 insertions(+), 37 deletions(-) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 9c48bc60..36156187 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -4,11 +4,10 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from cheetah.track_methods import misalignment_matrix from cheetah.utils import UniqueNameGenerator - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -108,14 +107,6 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry) return R - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - k=self.k.repeat(shape), - misalignment=self.misalignment.repeat(shape), - name=self.name, - ) - @property def is_active(self) -> bool: return any(self.k != 0) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 1b14cf3d..b446810d 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -608,26 +608,6 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") - def broadcast(self, shape: torch.Size) -> "SpaceChargeKick": - """ - Broadcast the element to higher batch dimensions. - - :param shape: Shape to broadcast the element to. - :returns: Broadcasted element. - """ - new_space_charge_kick = self.__class__( - effect_length=self.effect_length, - num_grid_points_x=self.grid_shape[0], - num_grid_points_y=self.grid_shape[1], - num_grid_points_s=self.grid_shape[2], - grid_extend_x=self.grid_extend_x, - grid_extend_y=self.grid_extend_y, - grid_extend_s=self.grid_extend_s, - name=self.name, - ) - new_space_charge_kick.length = self.length.repeat(shape) - return new_space_charge_kick - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for SpaceCharge properly, for now just returns the # element itself diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 3cef94f3..dc354d1d 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -5,10 +5,9 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from cheetah.utils import UniqueNameGenerator - from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -66,11 +65,6 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: tm[..., 4, 5] = -self.length / beta**2 * igamma2 return tm - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), angle=self.angle, name=self.name - ) - @property def is_skippable(self) -> bool: return True From bfddf8d94ef0f1329e31556562a8804b1fdcd680 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 11 Jul 2024 12:41:14 -0500 Subject: [PATCH 015/111] remove utility function, add element property --- cheetah/accelerator/element.py | 11 +++++++++++ cheetah/accelerator/horizontal_corrector.py | 3 +-- cheetah/track_methods.py | 3 +-- cheetah/utils/batching.py | 6 ------ tests/test_quadrupole.py | 2 ++ 5 files changed, 15 insertions(+), 10 deletions(-) delete mode 100644 cheetah/utils/batching.py diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 6db1c625..5eb83ee8 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -87,6 +87,17 @@ def forward(self, incoming: Beam) -> Beam: """Forward function required by `torch.nn.Module`. Simply calls `track`.""" return self.track(incoming) + @property + def batch_shape(self) -> torch.Size: + tensors = [] + # Get all parameters + for param in self.parameters(): + tensors.append(param) + # Get all buffers + for buffer in self.buffers(): + tensors.append(buffer) + return torch.broadcast_tensors(*tensors)[0].shape + @property @abstractmethod def is_skippable(self) -> bool: diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index 48682dce..de01daea 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -8,7 +8,6 @@ from cheetah.utils import UniqueNameGenerator -from ..utils.batching import get_batch_shape from ..utils.physics import calculate_relativistic_factors from .element import Element @@ -53,7 +52,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: _, igamma2, beta = calculate_relativistic_factors(energy) - batch_shape = get_batch_shape(self.length, self.angle, beta) + batch_shape = torch.broadcast_tensors(self.length, self.angle, beta)[0].shape tm = torch.eye(7, device=device, dtype=dtype).repeat((*batch_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 1, 6] = self.angle diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 91d7dde8..4c6a0130 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -4,7 +4,6 @@ import torch -from cheetah.utils.batching import get_batch_shape from cheetah.utils.physics import calculate_relativistic_factors @@ -73,7 +72,7 @@ def base_rmatrix( r56 = r56 - length / beta**2 * igamma2 - batch_shape = get_batch_shape(dx, sx, beta, cx) + batch_shape = torch.broadcast_tensors(length, k1, hx, tilt, energy)[0].shape R = torch.eye(7, dtype=dtype, device=device).repeat(*batch_shape, 1, 1) R[..., 0, 0] = cx R[..., 0, 1] = sx diff --git a/cheetah/utils/batching.py b/cheetah/utils/batching.py deleted file mode 100644 index d0372d9d..00000000 --- a/cheetah/utils/batching.py +++ /dev/null @@ -1,6 +0,0 @@ -import torch - - -def get_batch_shape(*args): - result = torch.broadcast_tensors(*args) - return result[0].shape diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 50d80071..c63a6b51 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -33,6 +33,8 @@ def test_quadrupole_with_misalignments_batched(): misalignment=torch.tensor([[0.1, 0.1]]), ) + assert quad_with_misalignment.batch_shape == torch.Size([1, 2]) + quad_without_misalignment = Quadrupole( length=torch.tensor([1.0]), k1=torch.tensor([1.0]) ) From 230a0cf8a11efaac3f016f20cf47d2a5d6102502 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 15 Jul 2024 21:32:52 +0200 Subject: [PATCH 016/111] Fix vectorisation tests (not code that causes one to fail) --- tests/test_vectorized.py | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 7e7ee1d2..2997c45e 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -308,12 +308,15 @@ def test_enormous_through_ares(): "tests/resources/ACHIP_EA1_2021.1351.001" ) - segment_broadcast = segment - incoming_broadcast = incoming + segment.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 100_000).repeat(3, 1) - segment_broadcast.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 100).repeat(3, 1) + from icecream import ic - outgoing = segment_broadcast.track(incoming_broadcast) + ic(segment.AREAMQZM1.k1.shape, incoming.sigma_x.shape) + + outgoing = segment.track(incoming) + + ic(outgoing.sigma_x.shape, outgoing.energy.shape) assert outgoing.mu_x.shape == (3, 100_000) assert outgoing.mu_px.shape == (3, 100_000) @@ -335,10 +338,10 @@ def test_before_after_broadcast_tracking_equal_cavity(): the same as in the segment before broadcasting. A cavity is used as a reference. """ cavity = cheetah.Cavity( - length=torch.tensor([3.0441]).repeat((3, 10)), - voltage=torch.tensor([48198468.0]), - phase=torch.tensor([-0.0]), - frequency=torch.tensor([2.8560e09]), + length=torch.tensor(3.0441).repeat((3, 10)), + voltage=torch.tensor(48198468.0), + phase=torch.tensor(-0.0), + frequency=torch.tensor(2.8560e09), name="k26_2d", ) incoming = cheetah.ParameterBeam.from_astra( @@ -365,7 +368,7 @@ def test_before_after_broadcast_tracking_equal_ares_ea(): incoming = cheetah.ParameterBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" ) - segment.AREAMQZM1.k1 = torch.tensor([4.2]).repeat((3, 10)) + segment.AREAMQZM1.k1 = torch.tensor(4.2).repeat((3, 10)) outgoing = segment.track(incoming) broadcast_segment = segment @@ -395,7 +398,7 @@ def test_broadcast_customtransfermap(): ) element = cheetah.CustomTransferMap( - length=torch.tensor([0.4]).repeat((3, 10)), + length=torch.tensor(0.4).repeat((3, 10)), transfer_map=tm.repeat((3, 10, 1, 1)), ) broadcast_element = element @@ -407,18 +410,10 @@ def test_broadcast_customtransfermap(): assert torch.all(broadcast_element._transfer_map[i, j] == tm[0]) -def test_broadcast_element_keeps_dtype(): - """Test that broadcasting an element keeps the same dtype.""" - element = cheetah.Drift(length=torch.tensor([0.4]), dtype=torch.float64) - broadcast_element = element.broadcast((3, 10)) - - assert broadcast_element.length.dtype == torch.float64 - - def test_broadcast_beam_keeps_dtype(): """Test that broadcasting a beam keeps the same dtype.""" beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]), dtype=torch.float64 + num_particles=100_000, sigma_x=torch.tensor(1e-5), dtype=torch.float64 ) broadcast_beam = beam.broadcast((2,)) drift = cheetah.Drift(length=torch.tensor([0.4, 0.4]), dtype=torch.float64) @@ -431,7 +426,7 @@ def test_broadcast_beam_keeps_dtype(): def test_broadcast_drift(): """Test that broadcasting a `Drift` element gives the correct result.""" - element = cheetah.Drift(length=torch.tensor([0.4]).repeat((3, 10))) + element = cheetah.Drift(length=torch.tensor(0.4).repeat((3, 10))) assert element.length.shape == (3, 10) for i in range(3): @@ -464,7 +459,7 @@ def test_cavity_with_zero_and_non_zero_voltage(): name="my_test_cavity", ) beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]) + num_particles=100_000, sigma_x=torch.tensor(1e-5) ).broadcast((3,)) _ = cavity.track(beam) From dfcd73de0d1c582bf862e7ff284bfe8847dc6e39 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 15 Jul 2024 21:42:41 +0200 Subject: [PATCH 017/111] Remove vectorisation tests that no longer make sense --- tests/test_vectorized.py | 136 --------------------------------------- 1 file changed, 136 deletions(-) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 2997c45e..5b9d9175 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -310,14 +310,8 @@ def test_enormous_through_ares(): segment.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 100_000).repeat(3, 1) - from icecream import ic - - ic(segment.AREAMQZM1.k1.shape, incoming.sigma_x.shape) - outgoing = segment.track(incoming) - ic(outgoing.sigma_x.shape, outgoing.energy.shape) - assert outgoing.mu_x.shape == (3, 100_000) assert outgoing.mu_px.shape == (3, 100_000) assert outgoing.mu_y.shape == (3, 100_000) @@ -332,120 +326,6 @@ def test_enormous_through_ares(): assert outgoing.total_charge.shape == (3, 100_000) -def test_before_after_broadcast_tracking_equal_cavity(): - """ - Test that when tracking through a segment after broadcasting, the resulting beam is - the same as in the segment before broadcasting. A cavity is used as a reference. - """ - cavity = cheetah.Cavity( - length=torch.tensor(3.0441).repeat((3, 10)), - voltage=torch.tensor(48198468.0), - phase=torch.tensor(-0.0), - frequency=torch.tensor(2.8560e09), - name="k26_2d", - ) - incoming = cheetah.ParameterBeam.from_astra( - "tests/resources/ACHIP_EA1_2021.1351.001" - ) - outgoing = cavity.track(incoming) - - broadcast_cavity = cavity - broadcast_incoming = incoming - broadcast_outgoing = broadcast_cavity.track(broadcast_incoming) - - for i in range(3): - for j in range(10): - assert torch.all(broadcast_outgoing._mu[i, j] == outgoing._mu[0]) - assert torch.all(broadcast_outgoing._cov[i, j] == outgoing._cov[0]) - - -def test_before_after_broadcast_tracking_equal_ares_ea(): - """ - Test that when tracking through a segment after broadcasting, the resulting beam is - the same as in the segment before broadcasting. The ARES EA is used as a reference. - """ - segment = cheetah.Segment.from_ocelot(ares.cell).subcell("AREASOLA1", "AREABSCR1") - incoming = cheetah.ParameterBeam.from_astra( - "tests/resources/ACHIP_EA1_2021.1351.001" - ) - segment.AREAMQZM1.k1 = torch.tensor(4.2).repeat((3, 10)) - outgoing = segment.track(incoming) - - broadcast_segment = segment - broadcast_incoming = incoming - broadcast_outgoing = broadcast_segment.track(broadcast_incoming) - - for i in range(3): - for j in range(10): - assert torch.allclose(broadcast_outgoing._mu[i, j], outgoing._mu[0]) - assert torch.allclose(broadcast_outgoing._cov[i, j], outgoing._cov[0]) - - -def test_broadcast_customtransfermap(): - """Test that broadcasting a `CustomTransferMap` element gives the correct result.""" - tm = torch.tensor( - [ - [ - [1.0, 4.0e-02, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0e-05], - [0.0, 0.0, 1.0, 4.0e-02, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, -4.6422e-07, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - ] - ] - ) - - element = cheetah.CustomTransferMap( - length=torch.tensor(0.4).repeat((3, 10)), - transfer_map=tm.repeat((3, 10, 1, 1)), - ) - broadcast_element = element - - assert broadcast_element.length.shape == (3, 10) - assert broadcast_element._transfer_map.shape == (3, 10, 7, 7) - for i in range(3): - for j in range(10): - assert torch.all(broadcast_element._transfer_map[i, j] == tm[0]) - - -def test_broadcast_beam_keeps_dtype(): - """Test that broadcasting a beam keeps the same dtype.""" - beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor(1e-5), dtype=torch.float64 - ) - broadcast_beam = beam.broadcast((2,)) - drift = cheetah.Drift(length=torch.tensor([0.4, 0.4]), dtype=torch.float64) - - assert broadcast_beam.particles.dtype == torch.float64 - - # This should not raise an error - _ = drift(broadcast_beam) - - -def test_broadcast_drift(): - """Test that broadcasting a `Drift` element gives the correct result.""" - element = cheetah.Drift(length=torch.tensor(0.4).repeat((3, 10))) - - assert element.length.shape == (3, 10) - for i in range(3): - for j in range(10): - assert element.length[i, j] == 0.4 - - -def test_broadcast_quadrupole(): - """Test that broadcasting a `Quadrupole` element gives the correct result.""" - - # TODO Add misalignment to the test - # TODO Add tilt to the test - - element = cheetah.Quadrupole(length=torch.randn((3, 10)), k1=torch.tensor([4.2])) - - assert element.length.shape == (3, 10) - assert element.k1.shape == (1,) - - def test_cavity_with_zero_and_non_zero_voltage(): """ Tests that if zero and non-zero voltages are passed to a cavity in a single batch, @@ -463,19 +343,3 @@ def test_cavity_with_zero_and_non_zero_voltage(): ).broadcast((3,)) _ = cavity.track(beam) - - -def test_screen_length_shape(): - """ - Test that the shape of a screen's length matches the shape of its misalignment. - """ - screen = cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2], [0.3, 0.4]])) - assert screen.length.shape == screen.misalignment.shape[:-1] - - -def test_screen_length_broadcast_shape(): - """ - Test that the shape of a screen's length matches the shape of its misalignment - after broadcasting. - """ - cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2]]).repeat(3, 10, 1)) From c54d04ff1cd842405b472c0f4eb23d928a5c89e0 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 15 Jul 2024 21:45:30 +0200 Subject: [PATCH 018/111] Fix formating --- cheetah/accelerator/bpm.py | 1 + cheetah/accelerator/marker.py | 1 + cheetah/accelerator/segment.py | 1 + tests/test_space_charge_kick.py | 3 ++- 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index a11fd635..fc3126b7 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -7,6 +7,7 @@ from cheetah.particles import Beam, ParameterBeam, ParticleBeam from cheetah.utils import UniqueNameGenerator + from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/marker.py b/cheetah/accelerator/marker.py index 1375155c..c8f130af 100644 --- a/cheetah/accelerator/marker.py +++ b/cheetah/accelerator/marker.py @@ -5,6 +5,7 @@ from cheetah.particles import Beam from cheetah.utils import UniqueNameGenerator + from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index 246c0dc4..7dae4495 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -13,6 +13,7 @@ from cheetah.latticejson import load_cheetah_model, save_cheetah_model from cheetah.particles import Beam, ParticleBeam from cheetah.utils import UniqueNameGenerator + from .custom_transfer_map import CustomTransferMap from .drift import Drift from .element import Element diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 21480ac8..180071a4 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -1,9 +1,10 @@ -import cheetah import torch from scipy import constants from scipy.constants import physical_constants from torch import nn +import cheetah + def test_cold_uniform_beam_expansion(): """ From a9670100a44a8452daf18db8501fb920b2f56e8c Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 15 Jul 2024 21:53:10 +0200 Subject: [PATCH 019/111] Remove `broadcast` method from all beams --- cheetah/particles/beam.py | 4 ---- cheetah/particles/parameter_beam.py | 10 ---------- cheetah/particles/particle_beam.py | 9 --------- tests/test_cavity.py | 2 +- tests/test_space_charge_kick.py | 2 +- tests/test_speed_optimizations.py | 2 +- tests/test_vectorized.py | 2 +- 7 files changed, 4 insertions(+), 27 deletions(-) diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index a8755e3e..c170e2e1 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -359,10 +359,6 @@ def alpha_y(self) -> torch.Tensor: """Alpha function in y direction, dimensionless.""" return -self.sigma_ypy / self.emittance_y - def broadcast(self, shape: torch.Size) -> "Beam": - """Broadcast beam to new shape.""" - raise NotImplementedError - def __repr__(self) -> str: return ( f"{self.__class__.__name__}(mu_x={self.mu_x}, mu_px={self.mu_px}," diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index 336e43e9..807da959 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -430,16 +430,6 @@ def sigma_xpx(self) -> torch.Tensor: def sigma_ypy(self) -> torch.Tensor: return self._cov[..., 2, 3] - def broadcast(self, shape: torch.Size) -> "ParameterBeam": - return self.__class__( - mu=self._mu.repeat((*shape, 1)), - cov=self._cov.repeat((*shape, 1, 1)), - energy=self.energy.repeat(shape), - total_charge=self.total_charge.repeat(shape), - device=self._mu.device, - dtype=self._mu.dtype, - ) - def __repr__(self) -> str: return ( f"{self.__class__.__name__}(mu_x={repr(self.mu_x)}," diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 28cf8bde..cf91c25a 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -944,15 +944,6 @@ def momenta(self) -> torch.Tensor: """Momenta of the individual particles.""" return torch.sqrt(self.energies**2 - electron_mass_eV**2) - def broadcast(self, shape: torch.Size) -> "ParticleBeam": - return self.__class__( - particles=self.particles.repeat((*shape, 1, 1)), - energy=self.energy.repeat(shape), - particle_charges=self.particle_charges.repeat((*shape, 1)), - device=self.particles.device, - dtype=self.particles.dtype, - ) - def __repr__(self) -> str: return ( f"{self.__class__.__name__}(n={repr(self.num_particles)}," diff --git a/tests/test_cavity.py b/tests/test_cavity.py index ab02682b..a461d2c5 100644 --- a/tests/test_cavity.py +++ b/tests/test_cavity.py @@ -68,7 +68,7 @@ def test_vectorized_cavity_zero_voltage(voltage): energy=torch.tensor([8.0000e09]), total_charge=torch.tensor([0.0]), dtype=torch.float64, - ).broadcast((2,)) + ) outgoing = cavity.track(incoming) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 180071a4..0bb942d1 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -147,7 +147,7 @@ def test_vectorized_cold_uniform_beam_expansion(): sigma_px=torch.tensor([1e-15]), sigma_py=torch.tensor([1e-15]), sigma_p=torch.tensor([1e-15]), - ).broadcast(shape=(2, 3)) + ) # Compute section length kappa = 1 + (torch.sqrt(torch.tensor(2)) / 4) * torch.log( diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 5d36f8bc..9946d513 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -49,7 +49,7 @@ def test_merged_transfer_maps_tracking_vectorized(): """ incoming_beam = cheetah.ParameterBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast(torch.Size([10])) + ) original_segment = cheetah.Segment( elements=[ diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 5b9d9175..f32a6dbc 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -340,6 +340,6 @@ def test_cavity_with_zero_and_non_zero_voltage(): ) beam = cheetah.ParticleBeam.from_parameters( num_particles=100_000, sigma_x=torch.tensor(1e-5) - ).broadcast((3,)) + ) _ = cavity.track(beam) From a67310f080fa3ee53ea5c42b647f83e9f73fef8d Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 15 Jul 2024 21:56:02 +0200 Subject: [PATCH 020/111] Fix remaining `xp`s and `yp`s --- tests/test_dipole.py | 2 +- tests/test_quadrupole.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dipole.py b/tests/test_dipole.py index 53a80d4c..098b4410 100644 --- a/tests/test_dipole.py +++ b/tests/test_dipole.py @@ -11,7 +11,7 @@ def test_dipole_off(): dipole = Dipole(length=torch.tensor(1.0), angle=torch.tensor(0.0)) drift = Drift(length=torch.tensor(1.0)) incoming_beam = ParameterBeam.from_parameters( - sigma_xp=torch.tensor(2e-7), sigma_yp=torch.tensor(2e-7) + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) outbeam_dipole_off = dipole(incoming_beam) outbeam_drift = drift(incoming_beam) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 8a604829..8fea4c60 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -67,7 +67,7 @@ def test_quadrupole_with_misalignments_multiple_batch_dimensions(): length=torch.tensor([1.0]), k1=torch.tensor([1.0]) ) incoming_beam = ParameterBeam.from_parameters( - sigma_xp=torch.tensor([2e-7]), sigma_yp=torch.tensor([2e-7]) + sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) ) outbeam_quad_with_misalignment = quad_with_misalignment(incoming_beam) outbeam_quad_without_misalignment = quad_without_misalignment(incoming_beam) From d39d0087a190a038974f92e7eb7d6a4b16b3490c Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 15 Jul 2024 22:25:35 +0200 Subject: [PATCH 021/111] Remove unnecesary dimensions from elements (not yet beams) in tests --- tests/test_bpm.py | 4 +- tests/test_cavity.py | 26 +++--- tests/test_compare_ocelot.py | 94 ++++++++++------------ tests/test_device_dtype.py | 4 +- tests/test_drift.py | 14 ++-- tests/test_kde.py | 4 +- tests/test_quadrupole.py | 28 +++---- tests/test_reading_nx_tables.py | 2 +- tests/test_screen.py | 6 +- tests/test_space_charge_kick.py | 22 ++--- tests/test_speed_optimizations.py | 60 +++++++------- tests/test_split.py | 28 +++---- tests/test_tracking_lengthless_elements.py | 4 +- 13 files changed, 140 insertions(+), 156 deletions(-) diff --git a/tests/test_bpm.py b/tests/test_bpm.py index 4b53897a..6d58ead5 100644 --- a/tests/test_bpm.py +++ b/tests/test_bpm.py @@ -10,9 +10,9 @@ def test_no_tracking_error(is_bpm_active, beam_class): """Test that tracking a beam through an inactive BPM does not raise an error.""" segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.BPM(name="my_bpm"), - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), ], ) beam = beam_class.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") diff --git a/tests/test_cavity.py b/tests/test_cavity.py index a461d2c5..7129cbc0 100644 --- a/tests/test_cavity.py +++ b/tests/test_cavity.py @@ -26,7 +26,7 @@ def test_assert_ei_greater_zero(): name="k26_2a", ) beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]) + num_particles=100_000, sigma_x=torch.tensor(1e-5) ) _ = cavity.track(beam) @@ -55,18 +55,18 @@ def test_vectorized_cavity_zero_voltage(voltage): dtype=torch.float64, ) incoming = cheetah.ParameterBeam.from_parameters( - mu_x=torch.tensor([0.0]), - mu_px=torch.tensor([0.0]), - mu_y=torch.tensor([0.0]), - mu_py=torch.tensor([0.0]), - sigma_x=torch.tensor([4.8492e-06]), - sigma_px=torch.tensor([1.5603e-07]), - sigma_y=torch.tensor([4.1209e-07]), - sigma_py=torch.tensor([1.1035e-08]), - sigma_tau=torch.tensor([1.0000e-10]), - sigma_p=torch.tensor([1.0000e-06]), - energy=torch.tensor([8.0000e09]), - total_charge=torch.tensor([0.0]), + mu_x=torch.tensor(0.0), + mu_px=torch.tensor(0.0), + mu_y=torch.tensor(0.0), + mu_py=torch.tensor(0.0), + sigma_x=torch.tensor(4.8492e-06), + sigma_px=torch.tensor(1.5603e-07), + sigma_y=torch.tensor(4.1209e-07), + sigma_py=torch.tensor(1.1035e-08), + sigma_tau=torch.tensor(1.0000e-10), + sigma_p=torch.tensor(1.0000e-06), + energy=torch.tensor(8.0000e09), + total_charge=torch.tensor(0.0), dtype=torch.float64, ) diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index d4e0c0bc..d9586e6b 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -18,9 +18,7 @@ def test_dipole(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" ) - cheetah_dipole = cheetah.Dipole( - length=torch.tensor([0.1]), angle=torch.tensor([0.1]) - ) + cheetah_dipole = cheetah.Dipole(length=torch.tensor(0.1), angle=torch.tensor(0.1)) outgoing_beam = cheetah_dipole.track(incoming_beam) # Ocelot @@ -48,9 +46,7 @@ def test_dipole_with_float64(): "tests/resources/ACHIP_EA1_2021.1351.001", dtype=torch.float64 ) cheetah_dipole = cheetah.Dipole( - length=torch.tensor([0.1]), - angle=torch.tensor([0.1]), - dtype=torch.float64, + length=torch.tensor(0.1), angle=torch.tensor(0.1), dtype=torch.float64 ) outgoing_beam = cheetah_dipole.track(incoming_beam) @@ -79,10 +75,10 @@ def test_dipole_with_fringe_field(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_dipole = cheetah.Dipole( - length=torch.tensor([0.1]), - angle=torch.tensor([0.1]), - fringe_integral=torch.tensor([0.1]), - gap=torch.tensor([0.2]), + length=torch.tensor(0.1), + angle=torch.tensor(0.1), + fringe_integral=torch.tensor(0.1), + gap=torch.tensor(0.2), ) outgoing_beam = cheetah_dipole.track(incoming_beam) @@ -114,13 +110,13 @@ def test_dipole_with_fringe_field_and_tilt(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_dipole = cheetah.Dipole( - length=torch.tensor([1.0]), - angle=torch.tensor([bend_angle]), - fringe_integral=torch.tensor([0.1]), - gap=torch.tensor([0.2]), - tilt=torch.tensor([tilt_angle]), - e1=torch.tensor([bend_angle / 2]), - e2=torch.tensor([bend_angle / 2]), + length=torch.tensor(1.0), + angle=torch.tensor(bend_angle), + fringe_integral=torch.tensor(0.1), + gap=torch.tensor(0.2), + tilt=torch.tensor(tilt_angle), + e1=torch.tensor(bend_angle / 2), + e2=torch.tensor(bend_angle / 2), ) outgoing_beam = cheetah_dipole(incoming_beam) @@ -159,13 +155,13 @@ def test_aperture(): cheetah_segment = cheetah.Segment( [ cheetah.Aperture( - x_max=torch.tensor([2e-4]), - y_max=torch.tensor([2e-4]), + x_max=torch.tensor(2e-4), + y_max=torch.tensor(2e-4), shape="rectangular", name="aperture", is_active=True, ), - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -195,13 +191,13 @@ def test_aperture_elliptical(): cheetah_segment = cheetah.Segment( [ cheetah.Aperture( - x_max=torch.tensor([2e-4]), - y_max=torch.tensor([2e-4]), + x_max=torch.tensor(2e-4), + y_max=torch.tensor(2e-4), shape="elliptical", name="aperture", is_active=True, ), - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -231,9 +227,7 @@ def test_solenoid(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" ) - cheetah_solenoid = cheetah.Solenoid( - length=torch.tensor([0.5]), k=torch.tensor([5.0]) - ) + cheetah_solenoid = cheetah.Solenoid(length=torch.tensor(0.5), k=torch.tensor(5.0)) outgoing_beam = cheetah_solenoid.track(incoming_beam) # Ocelot @@ -399,13 +393,13 @@ def test_quadrupole(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_quadrupole = cheetah.Quadrupole( - length=torch.tensor([0.23]), k1=torch.tensor([5.0]) + length=torch.tensor(0.23), k1=torch.tensor(5.0) ) cheetah_segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), cheetah_quadrupole, - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -441,13 +435,13 @@ def test_tilted_quadrupole(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_quadrupole = cheetah.Quadrupole( - length=torch.tensor([0.23]), k1=torch.tensor([5.0]), tilt=torch.tensor([0.79]) + length=torch.tensor(0.23), k1=torch.tensor(5.0), tilt=torch.tensor(0.79) ) cheetah_segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), cheetah_quadrupole, - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -481,14 +475,12 @@ def test_sbend(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" ) - cheetah_dipole = cheetah.Dipole( - length=torch.tensor([0.1]), angle=torch.tensor([0.2]) - ) + cheetah_dipole = cheetah.Dipole(length=torch.tensor(0.1), angle=torch.tensor(0.2)) cheetah_segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), cheetah_dipole, - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -525,16 +517,16 @@ def test_rbend(): "tests/resources/ACHIP_EA1_2021.1351.001" ) cheetah_dipole = cheetah.RBend( - length=torch.tensor([0.1]), - angle=torch.tensor([0.2]), - fringe_integral=torch.tensor([0.1]), - gap=torch.tensor([0.2]), + length=torch.tensor(0.1), + angle=torch.tensor(0.2), + fringe_integral=torch.tensor(0.1), + gap=torch.tensor(0.2), ) cheetah_segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), cheetah_dipole, - cheetah.Drift(length=torch.tensor([0.1])), + cheetah.Drift(length=torch.tensor(0.1)), ] ) outgoing_beam = cheetah_segment.track(incoming_beam) @@ -691,10 +683,10 @@ def test_cavity(): parray=p_array, dtype=torch.float64 ) cheetah_cavity = cheetah.Cavity( - length=torch.tensor([1.0377]), - voltage=torch.tensor([0.01815975e9]), - frequency=torch.tensor([1.3e9]), - phase=torch.tensor([0.0]), + length=torch.tensor(1.0377), + voltage=torch.tensor(0.01815975e9), + frequency=torch.tensor(1.3e9), + phase=torch.tensor(0.0), dtype=torch.float64, ) outgoing_beam = cheetah_cavity.track(incoming_beam) @@ -745,10 +737,10 @@ def test_cavity_non_zero_phase(): parray=p_array, dtype=torch.float64 ) cheetah_cavity = cheetah.Cavity( - length=torch.tensor([1.0377]), - voltage=torch.tensor([0.01815975e9]), - frequency=torch.tensor([1.3e9]), - phase=torch.tensor([30.0]), + length=torch.tensor(1.0377), + voltage=torch.tensor(0.01815975e9), + frequency=torch.tensor(1.3e9), + phase=torch.tensor(30.0), dtype=torch.float64, ) outgoing_beam = cheetah_cavity.track(incoming_beam) diff --git a/tests/test_device_dtype.py b/tests/test_device_dtype.py index bd7e13a8..0d69968d 100644 --- a/tests/test_device_dtype.py +++ b/tests/test_device_dtype.py @@ -24,7 +24,7 @@ def test_move_quadrupole_to_device(target_device: torch.device): """Test that a quadrupole magnet can be successfully moved to a different device.""" quad = cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([4.2]), name="my_quad" + length=torch.tensor(0.2), k1=torch.tensor(4.2), name="my_quad" ) # Test that by default the quadrupole is on the CPU @@ -48,7 +48,7 @@ def test_change_quadrupole_dtype(): Test that a quadrupole magnet can be successfully changed to a different dtype. """ quad = cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([4.2]), name="my_quad" + length=torch.tensor(0.2), k1=torch.tensor(4.2), name="my_quad" ) # Test that by default the quadrupole is of dtype float32 diff --git a/tests/test_drift.py b/tests/test_drift.py index 2cd3643b..fd4ee31a 100644 --- a/tests/test_drift.py +++ b/tests/test_drift.py @@ -9,9 +9,9 @@ def test_diverging_parameter_beam(): Test that that a parameter beam with sigma_px > 0 and sigma_py > 0 increases in size in both dimensions when travelling through a drift section. """ - drift = cheetah.Drift(length=torch.tensor([1.0])) + drift = cheetah.Drift(length=torch.tensor(1.0)) incoming_beam = cheetah.ParameterBeam.from_parameters( - sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) outgoing_beam = drift.track(incoming_beam) @@ -25,11 +25,11 @@ def test_diverging_particle_beam(): Test that that a particle beam with sigma_px > 0 and sigma_py > 0 increases in size in both dimensions when travelling through a drift section. """ - drift = cheetah.Drift(length=torch.tensor([1.0])) + drift = cheetah.Drift(length=torch.tensor(1.0)) incoming_beam = cheetah.ParticleBeam.from_parameters( - num_particles=torch.tensor(1000), - sigma_px=torch.tensor([2e-7]), - sigma_py=torch.tensor([2e-7]), + num_particles=torch.tensor(1_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), ) outgoing_beam = drift.track(incoming_beam) @@ -52,7 +52,7 @@ def test_device_like_torch_module(): if not torch.cuda.is_available(): return - element = cheetah.Drift(length=torch.tensor([0.2]), device="cuda") + element = cheetah.Drift(length=torch.tensor(0.2), device="cuda") assert element.length.device.type == "cuda" diff --git a/tests/test_kde.py b/tests/test_kde.py index 12ad6a64..f3110663 100644 --- a/tests/test_kde.py +++ b/tests/test_kde.py @@ -56,8 +56,6 @@ def test_weighted_samples_2d(): hist_neglect_weights = kde_histogram_2d( x_weighted[:, 0], x_weighted[:, 1], bins1, bins2, sigma ) - print(hist_unweighted[5]) - print(hist_weighted[5]) - print(hist_neglect_weights[5]) + assert torch.allclose(hist_unweighted, hist_weighted) assert not torch.allclose(hist_weighted, hist_neglect_weights) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 8fea4c60..e3c0682a 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -58,16 +58,14 @@ def test_quadrupole_with_misalignments_multiple_batch_dimensions(): misalignments = torch.randn((4, 3, 2)) quad_with_misalignment = Quadrupole( - length=torch.tensor([1.0]), - k1=torch.tensor([1.0]), - misalignment=misalignments, + length=torch.tensor(1.0), k1=torch.tensor(1.0), misalignment=misalignments ) quad_without_misalignment = Quadrupole( - length=torch.tensor([1.0]), k1=torch.tensor([1.0]) + length=torch.tensor(1.0), k1=torch.tensor(1.0) ) incoming_beam = ParameterBeam.from_parameters( - sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) outbeam_quad_with_misalignment = quad_with_misalignment(incoming_beam) outbeam_quad_without_misalignment = quad_without_misalignment(incoming_beam) @@ -88,8 +86,8 @@ def test_tilted_quadrupole_batch(): """ incoming = ParticleBeam.from_parameters( num_particles=torch.tensor(1_000_000), - energy=torch.tensor([1e9]), - mu_x=torch.tensor([1e-5]), + energy=torch.tensor(1e9), + mu_x=torch.tensor(1e-5), ) segment = Segment( [ @@ -98,7 +96,7 @@ def test_tilted_quadrupole_batch(): k1=torch.tensor([1.0, 1.0, 1.0]), tilt=torch.tensor([torch.pi / 4, torch.pi / 2, torch.pi * 5 / 4]), ), - Drift(length=torch.tensor([0.5])), + Drift(length=torch.tensor(0.5)), ] ) outgoing = segment(incoming) @@ -124,15 +122,15 @@ def test_tilted_quadrupole_multiple_batch_dimensions(): ) segment = Segment( [ - Quadrupole(length=torch.tensor([0.5]), k1=torch.tensor([1.0]), tilt=tilts), - Drift(length=torch.tensor([0.5])), + Quadrupole(length=torch.tensor(0.5), k1=torch.tensor(1.0), tilt=tilts), + Drift(length=torch.tensor(0.5)), ] ) incoming = ParticleBeam.from_parameters( num_particles=torch.tensor(10_000), - energy=torch.tensor([1e9]), - mu_x=torch.tensor([1e-5]), + energy=torch.tensor(1e9), + mu_x=torch.tensor(1e-5), ) outgoing = segment(incoming) @@ -151,15 +149,15 @@ def test_quadrupole_length_multiple_batch_dimensions(): lengths = torch.tensor([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]]) segment = Segment( [ - Quadrupole(length=lengths, k1=torch.tensor([4.2])), + Quadrupole(length=lengths, k1=torch.tensor(4.2)), Drift(length=lengths * 2), ] ) incoming = ParticleBeam.from_parameters( num_particles=torch.tensor(10_000), - energy=torch.tensor([1e9]), - mu_x=torch.tensor([1e-5]), + energy=torch.tensor(1e9), + mu_x=torch.tensor(1e-5), ) outgoing = segment(incoming) diff --git a/tests/test_reading_nx_tables.py b/tests/test_reading_nx_tables.py index f1fb7792..3689f65b 100644 --- a/tests/test_reading_nx_tables.py +++ b/tests/test_reading_nx_tables.py @@ -24,4 +24,4 @@ def test_length(): """ segment = cheetah.Segment.from_nx_tables("tests/resources/Stage4v3_9.txt") - assert torch.allclose(segment.length, torch.tensor([44.2215])) + assert torch.allclose(segment.length, torch.tensor(44.2215)) diff --git a/tests/test_screen.py b/tests/test_screen.py index 85383a8f..c9cdfbc5 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -14,7 +14,7 @@ def test_reading_shows_beam_particle(screen_method): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( resolution=torch.tensor((100, 100)), pixel_size=torch.tensor((1e-5, 1e-5)), @@ -44,7 +44,7 @@ def test_screen_kde_bandwidth(kde_bandwidth): segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( resolution=torch.tensor((100, 100)), pixel_size=torch.tensor((1e-5, 1e-5)), @@ -76,7 +76,7 @@ def test_reading_shows_beam_parameter(screen_method): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( resolution=torch.tensor((100, 100)), pixel_size=torch.tensor((1e-5, 1e-5)), diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 0bb942d1..ed9aeafd 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -19,8 +19,8 @@ def test_cold_uniform_beam_expansion(): torch.manual_seed(42) # Simulation parameters - R0 = torch.tensor([0.001]) - energy = torch.tensor([2.5e8]) + R0 = torch.tensor(0.001) + energy = torch.tensor(2.5e8) rest_energy = torch.tensor( constants.electron_mass * constants.speed_of_light**2 @@ -33,14 +33,14 @@ def test_cold_uniform_beam_expansion(): incoming = cheetah.ParticleBeam.uniform_3d_ellipsoid( num_particles=torch.tensor(10_000), - total_charge=torch.tensor([1e-9]), + total_charge=torch.tensor(1e-9), energy=energy, radius_x=R0, radius_y=R0, radius_tau=R0 / gamma, # Radius of the beam in s direction in the lab frame - sigma_px=torch.tensor([1e-15]), - sigma_py=torch.tensor([1e-15]), - sigma_p=torch.tensor([1e-15]), + sigma_px=torch.tensor(1e-15), + sigma_py=torch.tensor(1e-15), + sigma_p=torch.tensor(1e-15), ) # Compute section length @@ -74,9 +74,9 @@ def test_vectorized(): """ # Simulation parameters - section_length = torch.tensor([0.42]) - R0 = torch.tensor([0.001]) - energy = torch.tensor([2.5e8]) + section_length = torch.tensor(0.42) + R0 = torch.tensor(0.001) + energy = torch.tensor(2.5e8) rest_energy = torch.tensor( constants.electron_mass * constants.speed_of_light**2 @@ -246,7 +246,7 @@ def test_does_not_break_segment_length(): Test that the computation of a `Segment`'s length does not break when `SpaceChargeKick` is used. """ - section_length = torch.tensor([1.0]).repeat((3, 2)) + section_length = torch.tensor(1.0).repeat((3, 2)) segment = cheetah.Segment( elements=[ cheetah.Drift(section_length / 6), @@ -260,4 +260,4 @@ def test_does_not_break_segment_length(): ) assert segment.length.shape == (3, 2) - assert torch.allclose(segment.length, torch.tensor([1.0]).repeat(3, 2)) + assert torch.allclose(segment.length, torch.tensor(1.0).repeat(3, 2)) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 9946d513..8d760f78 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -15,11 +15,11 @@ def test_merged_transfer_maps_tracking(): original_segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), + cheetah.Drift(length=torch.tensor(0.4)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]) + length=torch.tensor(0.1), angle=torch.tensor(1e-4) ), ] ) @@ -53,11 +53,11 @@ def test_merged_transfer_maps_tracking_vectorized(): original_segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), + cheetah.Drift(length=torch.tensor(0.4)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]) + length=torch.tensor(0.1), angle=torch.tensor(1e-4) ), ] ) @@ -90,11 +90,11 @@ def test_merged_transfer_maps_num_elements(): original_segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), + cheetah.Drift(length=torch.tensor(0.4)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]) + length=torch.tensor(0.1), angle=torch.tensor(1e-4) ), ] ) @@ -110,12 +110,12 @@ def test_no_markers_left_after_removal(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), cheetah.Marker(), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.4)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]) + length=torch.tensor(0.1), angle=torch.tensor(1e-4) ), cheetah.Marker(), ] @@ -133,9 +133,9 @@ def test_inactive_magnet_is_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([0.0])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(0.0)), + cheetah.Drift(length=torch.tensor(0.4)), ] ) @@ -152,9 +152,9 @@ def test_active_elements_not_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), + cheetah.Drift(length=torch.tensor(0.4)), ] ) @@ -171,11 +171,11 @@ def test_inactive_magnet_drift_replacement_dtype(dtype: torch.dtype): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6]), dtype=dtype), + cheetah.Drift(length=torch.tensor(0.6), dtype=dtype), cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([0.0]), dtype=dtype + length=torch.tensor(0.2), k1=torch.tensor(0.0), dtype=dtype ), - cheetah.Drift(length=torch.tensor([0.4]), dtype=dtype), + cheetah.Drift(length=torch.tensor(0.4), dtype=dtype), ] ) @@ -194,15 +194,15 @@ def test_skippable_elements_reset(): ) original_segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), + cheetah.Drift(length=torch.tensor(0.6)), cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([4.2]), name="Q1" + length=torch.tensor(0.2), k1=torch.tensor(4.2), name="Q1" ), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.4)), cheetah.HorizontalCorrector( - length=torch.tensor([0.1]), angle=torch.tensor([1e-4]), name="HCOR_1" + length=torch.tensor(0.1), angle=torch.tensor(1e-4), name="HCOR_1" ), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.4)), ] ) diff --git a/tests/test_split.py b/tests/test_split.py index f80cf398..85995004 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -10,7 +10,7 @@ def test_drift_end(): Test that at the end of a split drift the result is the same as at the end of the original drift. """ - original_drift = cheetah.Drift(length=torch.tensor([2.0])) + original_drift = cheetah.Drift(length=torch.tensor(2.0)) split_drift = cheetah.Segment(original_drift.split(resolution=torch.tensor(0.1))) incoming_beam = cheetah.ParticleBeam.from_astra( @@ -32,7 +32,7 @@ def test_quadrupole_end(): the original quadrupole. """ original_quadrupole = cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([4.2]) + length=torch.tensor(0.2), k1=torch.tensor(4.2) ) split_quadrupole = cheetah.Segment( original_quadrupole.split(resolution=torch.tensor(0.01)) @@ -56,10 +56,10 @@ def test_cavity_end(): the original cavity. """ original_cavity = cheetah.Cavity( - length=torch.tensor([1.0377]), - voltage=torch.tensor([0.01815975e9]), - frequency=torch.tensor([1.3e9]), - phase=torch.tensor([0.0]), + length=torch.tensor(1.0377), + voltage=torch.tensor(0.01815975e9), + frequency=torch.tensor(1.3e9), + phase=torch.tensor(0.0), ) split_cavity = cheetah.Segment(original_cavity.split(resolution=torch.tensor(0.1))) @@ -80,9 +80,7 @@ def test_solenoid_end(): Test that at the end of a split solenoid the result is the same as at the end of the original solenoid. """ - original_solenoid = cheetah.Solenoid( - length=torch.tensor([0.2]), k=torch.tensor([4.2]) - ) + original_solenoid = cheetah.Solenoid(length=torch.tensor(0.2), k=torch.tensor(4.2)) split_solenoid = cheetah.Segment( original_solenoid.split(resolution=torch.tensor(0.01)) ) @@ -104,9 +102,7 @@ def test_dipole_end(): Test that at the end of a split dipole the result is the same as at the end of the original dipole. """ - original_dipole = cheetah.Dipole( - length=torch.tensor([0.2]), angle=torch.tensor([4.2]) - ) + original_dipole = cheetah.Dipole(length=torch.tensor(0.2), angle=torch.tensor(4.2)) split_dipole = cheetah.Segment(original_dipole.split(resolution=torch.tensor(0.01))) incoming_beam = cheetah.ParticleBeam.from_astra( @@ -126,7 +122,7 @@ def test_undulator_end(): Test that at the end of a split undulator the result is the same as at the end of the original undulator. """ - original_undulator = cheetah.Undulator(length=torch.tensor([3.142])) + original_undulator = cheetah.Undulator(length=torch.tensor(3.142)) split_undulator = cheetah.Segment( original_undulator.split(resolution=torch.tensor(0.1)) ) @@ -150,7 +146,7 @@ def test_horizontal_corrector_end(): the end of the original horizontal corrector. """ original_horizontal_corrector = cheetah.HorizontalCorrector( - length=torch.tensor([0.2]), angle=torch.tensor([4.2]) + length=torch.tensor(0.2), angle=torch.tensor(4.2) ) split_horizontal_corrector = cheetah.Segment( original_horizontal_corrector.split(resolution=torch.tensor(0.01)) @@ -175,7 +171,7 @@ def test_vertical_corrector_end(): the end of the original vertical corrector. """ original_vertical_corrector = cheetah.VerticalCorrector( - length=torch.tensor([0.2]), angle=torch.tensor([4.2]) + length=torch.tensor(0.2), angle=torch.tensor(4.2) ) split_vertical_corrector = cheetah.Segment( original_vertical_corrector.split(resolution=torch.tensor(0.01)) @@ -211,7 +207,7 @@ def test_split_preserves_dtype(ElementType): """ Test that the dtype of a drift section's splits is the same as the original drift. """ - original = ElementType(length=torch.tensor([2.0]), dtype=torch.float64) + original = ElementType(length=torch.tensor(2.0), dtype=torch.float64) splits = original.split(resolution=torch.tensor(0.1)) for split in splits: diff --git a/tests/test_tracking_lengthless_elements.py b/tests/test_tracking_lengthless_elements.py index f4bde4fb..d5a738b7 100644 --- a/tests/test_tracking_lengthless_elements.py +++ b/tests/test_tracking_lengthless_elements.py @@ -19,11 +19,11 @@ def test_tracking_lengthless_elements(): segment = cheetah.Segment( [ cheetah.Cavity( - length=torch.tensor([0.1]), voltage=torch.tensor([1e6]), name="C2" + length=torch.tensor(0.1), voltage=torch.tensor(1e6), name="C2" ), cheetah.Marker(name="start"), cheetah.Cavity( - length=torch.tensor([0.1]), voltage=torch.tensor([1e6]), name="C1" + length=torch.tensor(0.1), voltage=torch.tensor(1e6), name="C1" ), ] ) From 247919c852f2f1e2bfed829ef2675c33cb71c888 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 15 Jul 2024 22:35:52 +0200 Subject: [PATCH 022/111] Cleanup imports --- cheetah/__init__.py | 6 +++--- cheetah/accelerator/aperture.py | 5 ++--- cheetah/accelerator/bpm.py | 5 ++--- cheetah/accelerator/cavity.py | 7 +++---- cheetah/accelerator/custom_transfer_map.py | 5 ++--- cheetah/accelerator/dipole.py | 5 ++--- cheetah/accelerator/drift.py | 3 +-- cheetah/accelerator/element.py | 4 ++-- cheetah/accelerator/horizontal_corrector.py | 3 +-- cheetah/accelerator/marker.py | 5 ++--- cheetah/accelerator/quadrupole.py | 5 ++--- cheetah/accelerator/rbend.py | 3 +-- cheetah/accelerator/screen.py | 5 ++--- cheetah/accelerator/segment.py | 11 +++++------ cheetah/accelerator/solenoid.py | 5 ++--- cheetah/accelerator/space_charge_kick.py | 3 +-- cheetah/accelerator/undulator.py | 3 +-- cheetah/accelerator/vertical_corrector.py | 3 +-- cheetah/converters/__init__.py | 3 +-- cheetah/track_methods.py | 4 ++-- 20 files changed, 38 insertions(+), 55 deletions(-) diff --git a/cheetah/__init__.py b/cheetah/__init__.py index 37a36995..91c13335 100644 --- a/cheetah/__init__.py +++ b/cheetah/__init__.py @@ -1,5 +1,5 @@ -import cheetah.converters # noqa: F401 -from cheetah.accelerator import ( # noqa: F401 +from . import converters # noqa: F401 +from .accelerator import ( # noqa: F401 BPM, Aperture, Cavity, @@ -17,4 +17,4 @@ Undulator, VerticalCorrector, ) -from cheetah.particles import ParameterBeam, ParticleBeam # noqa: F401 +from .particles import ParameterBeam, ParticleBeam # noqa: F401 diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index e2a78ad2..2d1d9ad0 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -5,9 +5,8 @@ from matplotlib.patches import Rectangle from torch import nn -from cheetah.particles import Beam, ParticleBeam -from cheetah.utils import UniqueNameGenerator - +from ..particles import Beam, ParticleBeam +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index fc3126b7..945c9d38 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -5,9 +5,8 @@ import torch from matplotlib.patches import Rectangle -from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.utils import UniqueNameGenerator - +from ..particles import Beam, ParameterBeam, ParticleBeam +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index e22ef229..f9b0ee8b 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -6,10 +6,9 @@ from scipy import constants from torch import nn -from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.track_methods import base_rmatrix -from cheetah.utils import UniqueNameGenerator - +from ..particles import Beam, ParameterBeam, ParticleBeam +from ..track_methods import base_rmatrix +from ..utils import UniqueNameGenerator from ..utils.physics import calculate_relativistic_factors, electron_mass_eV from .element import Element diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index 7318c7a2..d5bac2f5 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -4,9 +4,8 @@ import torch from torch import nn -from cheetah.particles import Beam -from cheetah.utils import UniqueNameGenerator - +from ..particles import Beam +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index f7059a91..7ece38de 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -6,9 +6,8 @@ from matplotlib.patches import Rectangle from torch import nn -from cheetah.track_methods import base_rmatrix, rotation_matrix -from cheetah.utils import UniqueNameGenerator - +from ..track_methods import base_rmatrix, rotation_matrix +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 4e949ecf..0a526c8d 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -5,8 +5,7 @@ from scipy.constants import physical_constants from torch import nn -from cheetah.utils import UniqueNameGenerator - +from ..utils import UniqueNameGenerator from ..utils.physics import calculate_relativistic_factors from .element import Element diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index afddbcc8..560fba23 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -5,8 +5,8 @@ import torch from torch import nn -from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.utils import UniqueNameGenerator +from ..particles import Beam, ParameterBeam, ParticleBeam +from ..utils import UniqueNameGenerator generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index de01daea..e64fcc55 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -6,8 +6,7 @@ from matplotlib.patches import Rectangle from torch import nn -from cheetah.utils import UniqueNameGenerator - +from ..utils import UniqueNameGenerator from ..utils.physics import calculate_relativistic_factors from .element import Element diff --git a/cheetah/accelerator/marker.py b/cheetah/accelerator/marker.py index c8f130af..643d4f46 100644 --- a/cheetah/accelerator/marker.py +++ b/cheetah/accelerator/marker.py @@ -3,9 +3,8 @@ import matplotlib.pyplot as plt import torch -from cheetah.particles import Beam -from cheetah.utils import UniqueNameGenerator - +from ..particles import Beam +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 205abd76..f412e1a7 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -6,9 +6,8 @@ from matplotlib.patches import Rectangle from torch import nn -from cheetah.track_methods import base_rmatrix, misalignment_matrix -from cheetah.utils import UniqueNameGenerator - +from ..track_methods import base_rmatrix, misalignment_matrix +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/rbend.py b/cheetah/accelerator/rbend.py index 39192ecd..b15ea26c 100644 --- a/cheetah/accelerator/rbend.py +++ b/cheetah/accelerator/rbend.py @@ -3,8 +3,7 @@ import torch from torch import nn -from cheetah.utils import UniqueNameGenerator - +from ..utils import UniqueNameGenerator from .dipole import Dipole generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index a25df9b0..35b96a44 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -7,9 +7,8 @@ from torch import nn from torch.distributions import MultivariateNormal -from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.utils import UniqueNameGenerator, kde_histogram_2d - +from ..particles import Beam, ParameterBeam, ParticleBeam +from ..utils import UniqueNameGenerator, kde_histogram_2d from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index 7dae4495..1274bfad 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -8,12 +8,11 @@ import torch from torch import nn -from cheetah.converters.bmad import convert_bmad_lattice -from cheetah.converters.nxtables import read_nx_tables -from cheetah.latticejson import load_cheetah_model, save_cheetah_model -from cheetah.particles import Beam, ParticleBeam -from cheetah.utils import UniqueNameGenerator - +from ..converters.bmad import convert_bmad_lattice +from ..converters.nxtables import read_nx_tables +from ..latticejson import load_cheetah_model, save_cheetah_model +from ..particles import Beam, ParticleBeam +from ..utils import UniqueNameGenerator from .custom_transfer_map import CustomTransferMap from .drift import Drift from .element import Element diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 38852474..02b5338f 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -6,9 +6,8 @@ from scipy.constants import physical_constants from torch import nn -from cheetah.track_methods import misalignment_matrix -from cheetah.utils import UniqueNameGenerator - +from ..track_methods import misalignment_matrix +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 273e98bd..8ca20684 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -5,8 +5,7 @@ from scipy import constants from torch import nn -from cheetah.particles import Beam, ParticleBeam - +from ..particles import Beam, ParticleBeam from .element import Element # Constants diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index e6c29da1..f29b5d68 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -6,8 +6,7 @@ from scipy.constants import physical_constants from torch import nn -from cheetah.utils import UniqueNameGenerator - +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 46ab9103..3b374afa 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -7,8 +7,7 @@ from scipy.constants import physical_constants from torch import nn -from cheetah.utils import UniqueNameGenerator - +from ..utils import UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/converters/__init__.py b/cheetah/converters/__init__.py index cda1cdc2..8eefee19 100644 --- a/cheetah/converters/__init__.py +++ b/cheetah/converters/__init__.py @@ -1,2 +1 @@ -# flake8: noqa -from cheetah.converters import astra, bmad, nxtables, ocelot +from . import astra, bmad, nxtables, ocelot # noqa: F401 diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 4c6a0130..08ed8d1e 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -1,10 +1,10 @@ -"""Utility functions for creating transfer maps for the elements.""" +"""Utility functions for creating transfer maps for elements.""" from typing import Optional import torch -from cheetah.utils.physics import calculate_relativistic_factors +from .utils.physics import calculate_relativistic_factors def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: From da1bd2bfa2f956d9f863565de3808a08c00b32ec Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Tue, 16 Jul 2024 14:27:51 -0500 Subject: [PATCH 023/111] fix some tests + solenoid fixes --- cheetah/accelerator/quadrupole.py | 2 +- cheetah/accelerator/solenoid.py | 49 ++++++++++++++++--------------- tests/test_vectorized.py | 4 +-- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index f412e1a7..b6934128 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -86,7 +86,7 @@ def is_skippable(self) -> bool: @property def is_active(self) -> bool: - return any(self.k1 != 0) + return bool(torch.any(self.k1 != 0)) def split(self, resolution: torch.Tensor) -> list[Element]: split_elements = [] diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 02b5338f..33c0c7ee 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -9,6 +9,7 @@ from ..track_methods import misalignment_matrix from ..utils import UniqueNameGenerator from .element import Element +from ..utils.physics import calculate_relativistic_factors generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -32,13 +33,13 @@ class Solenoid(Element): """ def __init__( - self, - length: Union[torch.Tensor, nn.Parameter] = None, - k: Optional[Union[torch.Tensor, nn.Parameter]] = None, - misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, - name: Optional[str] = None, - device=None, - dtype=torch.float32, + self, + length: Union[torch.Tensor, nn.Parameter] = None, + k: Optional[Union[torch.Tensor, nn.Parameter]] = None, + misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, + name: Optional[str] = None, + device=None, + dtype=torch.float32, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -65,37 +66,39 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) + gamma, _, _ = calculate_relativistic_factors(energy) c = torch.cos(self.length * self.k) s = torch.sin(self.length * self.k) - s_k = torch.empty_like(self.length) - s_k[self.k == 0] = self.length[self.k == 0] - s_k[self.k != 0] = s[self.k != 0] / self.k[self.k != 0] + s_k = torch.where(self.k == 0.0, self.length, s / self.k) - r56 = torch.zeros_like(self.length) + batch_shape = torch.broadcast_tensors( + *self.parameters(), *self.buffers(), energy + )[0].shape + + r56 = torch.zeros(batch_shape) if gamma != 0: gamma2 = gamma * gamma beta = torch.sqrt(1.0 - 1.0 / gamma2) r56 -= self.length / (beta * beta * gamma2) - R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) - R[..., 0, 0] = c**2 + R = torch.eye(7, device=device, dtype=dtype).repeat((*batch_shape, 1, 1)) + R[..., 0, 0] = c ** 2 R[..., 0, 1] = c * s_k R[..., 0, 2] = s * c R[..., 0, 3] = s * s_k R[..., 1, 0] = -self.k * s * c - R[..., 1, 1] = c**2 - R[..., 1, 2] = -self.k * s**2 + R[..., 1, 1] = c ** 2 + R[..., 1, 2] = -self.k * s ** 2 R[..., 1, 3] = s * c R[..., 2, 0] = -s * c R[..., 2, 1] = -s * s_k - R[..., 2, 2] = c**2 + R[..., 2, 2] = c ** 2 R[..., 2, 3] = c * s_k - R[..., 3, 0] = self.k * s**2 + R[..., 3, 0] = self.k * s ** 2 R[..., 3, 1] = -s * c R[..., 3, 2] = -self.k * s * c - R[..., 3, 3] = c**2 + R[..., 3, 3] = c ** 2 R[..., 4, 5] = r56 R = R.real @@ -133,8 +136,8 @@ def defining_features(self) -> list[str]: def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(length={repr(self.length)}, " - + f"k={repr(self.k)}, " - + f"misalignment={repr(self.misalignment)}, " - + f"name={repr(self.name)})" + f"{self.__class__.__name__}(length={repr(self.length)}, " + + f"k={repr(self.k)}, " + + f"misalignment={repr(self.misalignment)}, " + + f"name={repr(self.name)})" ) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index f32a6dbc..42757662 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -322,8 +322,8 @@ def test_enormous_through_ares(): assert outgoing.sigma_py.shape == (3, 100_000) assert outgoing.sigma_tau.shape == (3, 100_000) assert outgoing.sigma_p.shape == (3, 100_000) - assert outgoing.energy.shape == (3, 100_000) - assert outgoing.total_charge.shape == (3, 100_000) + assert outgoing.energy.shape == torch.Size([1]) + assert outgoing.total_charge.shape == torch.Size([1]) def test_cavity_with_zero_and_non_zero_voltage(): From 374b55b61d0e0e23f139f28c5936b719e60e1f10 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Tue, 16 Jul 2024 14:28:02 -0500 Subject: [PATCH 024/111] Update solenoid.py --- cheetah/accelerator/solenoid.py | 36 ++++++++++++++++----------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 33c0c7ee..61646f9f 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -8,8 +8,8 @@ from ..track_methods import misalignment_matrix from ..utils import UniqueNameGenerator -from .element import Element from ..utils.physics import calculate_relativistic_factors +from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -33,13 +33,13 @@ class Solenoid(Element): """ def __init__( - self, - length: Union[torch.Tensor, nn.Parameter] = None, - k: Optional[Union[torch.Tensor, nn.Parameter]] = None, - misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, - name: Optional[str] = None, - device=None, - dtype=torch.float32, + self, + length: Union[torch.Tensor, nn.Parameter] = None, + k: Optional[Union[torch.Tensor, nn.Parameter]] = None, + misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, + name: Optional[str] = None, + device=None, + dtype=torch.float32, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -83,22 +83,22 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: r56 -= self.length / (beta * beta * gamma2) R = torch.eye(7, device=device, dtype=dtype).repeat((*batch_shape, 1, 1)) - R[..., 0, 0] = c ** 2 + R[..., 0, 0] = c**2 R[..., 0, 1] = c * s_k R[..., 0, 2] = s * c R[..., 0, 3] = s * s_k R[..., 1, 0] = -self.k * s * c - R[..., 1, 1] = c ** 2 - R[..., 1, 2] = -self.k * s ** 2 + R[..., 1, 1] = c**2 + R[..., 1, 2] = -self.k * s**2 R[..., 1, 3] = s * c R[..., 2, 0] = -s * c R[..., 2, 1] = -s * s_k - R[..., 2, 2] = c ** 2 + R[..., 2, 2] = c**2 R[..., 2, 3] = c * s_k - R[..., 3, 0] = self.k * s ** 2 + R[..., 3, 0] = self.k * s**2 R[..., 3, 1] = -s * c R[..., 3, 2] = -self.k * s * c - R[..., 3, 3] = c ** 2 + R[..., 3, 3] = c**2 R[..., 4, 5] = r56 R = R.real @@ -136,8 +136,8 @@ def defining_features(self) -> list[str]: def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(length={repr(self.length)}, " - + f"k={repr(self.k)}, " - + f"misalignment={repr(self.misalignment)}, " - + f"name={repr(self.name)})" + f"{self.__class__.__name__}(length={repr(self.length)}, " + + f"k={repr(self.k)}, " + + f"misalignment={repr(self.misalignment)}, " + + f"name={repr(self.name)})" ) From 9945b50d51585a39a146d013d104e8e37c85266c Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Wed, 17 Jul 2024 10:23:08 +0200 Subject: [PATCH 025/111] Fix non-batched error in space charge and particle beam creation --- cheetah/accelerator/space_charge_kick.py | 37 +++++++++++++++++++----- cheetah/particles/particle_beam.py | 29 ++++++++++++++----- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 8ca20684..be507e9c 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -557,6 +557,15 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: elif isinstance(incoming, ParticleBeam): # This flattening is a hack to only think about one vector dimension in the # following code. It is reversed at the end of the function. + + # Make sure that the incoming beam has at least one batch dimension + incoming_batched = True + if len(incoming.particles.shape) == 2: + incoming_batched = False + incoming.particles = incoming.particles.unsqueeze(0) + incoming.energy = incoming.energy.unsqueeze(0) + incoming.particle_charges = incoming.particle_charges.unsqueeze(0) + flattened_incoming = ParticleBeam( particles=incoming.particles.flatten(end_dim=-3), energy=incoming.energy.flatten(end_dim=-1), @@ -595,14 +604,26 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ..., 2 ] * dt.unsqueeze(-1) - outgoing = ParticleBeam.from_xyz_pxpypz( - xp_coordinates.unflatten(dim=0, sizes=incoming.particles.shape[:-2]), - incoming.energy, - incoming.particle_charges, - incoming.particles.device, - incoming.particles.dtype, - ) - + if not incoming_batched: + # Reshape to the original shape + outgoing = ParticleBeam.from_xyz_pxpypz( + xp_coordinates.squeeze(0), + incoming.energy.squeeze(0), + incoming.particle_charges.squeeze(0), + incoming.particles.device, + incoming.particles.dtype, + ) + else: + # Reverse the flattening + outgoing = ParticleBeam.from_xyz_pxpypz( + xp_coordinates.unflatten( + dim=0, sizes=incoming.particles.shape[:-2] + ), + incoming.energy, + incoming.particle_charges, + incoming.particles.device, + incoming.particles.dtype, + ) return outgoing else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index cf91c25a..0369aea3 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -336,26 +336,41 @@ def uniform_3d_ellipsoid( ] if argument is not None ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) + shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([]) if len(not_nones) > 1: assert all( argument.shape == shape for argument in not_nones ), "Arguments must have the same shape." + # Expand to batched version for beam creation + batch_shape = shape if len(shape) > 0 else torch.Size([1]) + # Set default values without function call in function signature # NOTE that this does not need to be done for values that are passed to the # Gaussian beam generation. num_particles = ( num_particles if num_particles is not None else torch.tensor(1_000_000) ) - radius_x = radius_x if radius_x is not None else torch.full(shape, 1e-3) - radius_y = radius_y if radius_y is not None else torch.full(shape, 1e-3) - radius_tau = radius_tau if radius_tau is not None else torch.full(shape, 1e-3) + radius_x = ( + radius_x.expand(batch_shape) + if radius_x is not None + else torch.full(batch_shape, 1e-3) + ) + radius_y = ( + radius_y.expand(batch_shape) + if radius_y is not None + else torch.full(batch_shape, 1e-3) + ) + radius_tau = ( + radius_tau.expand(batch_shape) + if radius_tau is not None + else torch.full(batch_shape, 1e-3) + ) # Generate x, y and ss within the ellipsoid - flattened_x = torch.empty(*shape, num_particles).flatten(end_dim=-2) - flattened_y = torch.empty(*shape, num_particles).flatten(end_dim=-2) - flattened_tau = torch.empty(*shape, num_particles).flatten(end_dim=-2) + flattened_x = torch.empty(*batch_shape, num_particles).flatten(end_dim=-2) + flattened_y = torch.empty(*batch_shape, num_particles).flatten(end_dim=-2) + flattened_tau = torch.empty(*batch_shape, num_particles).flatten(end_dim=-2) for i, (r_x, r_y, r_tau) in enumerate( zip(radius_x.flatten(), radius_y.flatten(), radius_tau.flatten()) ): From 5d8be5fab969b98a2036d21a9dd669c9da473e99 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 17 Jul 2024 15:07:17 +0200 Subject: [PATCH 026/111] Procastinate by removing some more brackets than are no longer needed --- tests/test_compare_beam_type.py | 76 ++++++++++++++++----------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/tests/test_compare_beam_type.py b/tests/test_compare_beam_type.py index 24c747a8..6a6d3e28 100644 --- a/tests/test_compare_beam_type.py +++ b/tests/test_compare_beam_type.py @@ -12,25 +12,25 @@ def test_from_twiss(): Test that a beams created from Twiss parameters have the same properties. """ parameter_beam = cheetah.ParameterBeam.from_twiss( - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([2e-7]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(2e-7), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) particle_beam = cheetah.ParticleBeam.from_twiss( num_particles=torch.tensor( [10_000_000] ), # Large number of particles reduces noise - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([2e-7]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(2e-7), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) assert torch.isclose(parameter_beam.mu_x, particle_beam.mu_x, atol=1e-6) @@ -51,7 +51,7 @@ def test_drift(): """Test that the drift output for both beam types is roughly the same.""" # Set up lattice - cheetah_drift = cheetah.Drift(length=torch.tensor([1.0])) + cheetah_drift = cheetah.Drift(length=torch.tensor(1.0)) # Parameter beam incoming_parameter_beam = cheetah.ParameterBeam.from_astra( @@ -98,7 +98,7 @@ def test_quadrupole(): # Set up lattice cheetah_quadrupole = cheetah.Quadrupole( - length=torch.tensor([0.15]), k1=torch.tensor([4.2]) + length=torch.tensor(0.15), k1=torch.tensor(4.2) ) # Parameter beam @@ -149,10 +149,10 @@ def test_cavity_from_astra(): # Set up lattice cheetah_cavity = cheetah.Cavity( - length=torch.tensor([1.0377]), - voltage=torch.tensor([0.01815975e9]), - frequency=torch.tensor([1.3e9]), - phase=torch.tensor([0.0]), + length=torch.tensor(1.0377), + voltage=torch.tensor(0.01815975e9), + frequency=torch.tensor(1.3e9), + phase=torch.tensor(0.0), ) # Parameter beam @@ -221,33 +221,33 @@ def test_cavity_from_twiss(): # Set up lattice cheetah_cavity = cheetah.Cavity( - length=torch.tensor([1.0377]), - voltage=torch.tensor([0.01815975e9]), - frequency=torch.tensor([1.3e9]), - phase=torch.tensor([0.0]), + length=torch.tensor(1.0377), + voltage=torch.tensor(0.01815975e9), + frequency=torch.tensor(1.3e9), + phase=torch.tensor(0.0), ) # Parameter beam incoming_parameter_beam = cheetah.ParameterBeam.from_twiss( - beta_x=torch.tensor([5.91253677]), - alpha_x=torch.tensor([3.55631308]), - beta_y=torch.tensor([5.91253677]), - alpha_y=torch.tensor([3.55631308]), - emittance_x=torch.tensor([3.494768647122823e-09]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253677), + alpha_x=torch.tensor(3.55631308), + beta_y=torch.tensor(5.91253677), + alpha_y=torch.tensor(3.55631308), + emittance_x=torch.tensor(3.494768647122823e-09), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) outgoing_parameter_beam = cheetah_cavity.track(incoming_parameter_beam) # Particle beam incoming_particle_beam = cheetah.ParticleBeam.from_twiss( - beta_x=torch.tensor([5.91253677]), - alpha_x=torch.tensor([3.55631308]), - beta_y=torch.tensor([5.91253677]), - alpha_y=torch.tensor([3.55631308]), - emittance_x=torch.tensor([3.494768647122823e-09]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253677), + alpha_x=torch.tensor(3.55631308), + beta_y=torch.tensor(5.91253677), + alpha_y=torch.tensor(3.55631308), + emittance_x=torch.tensor(3.494768647122823e-09), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) outgoing_particle_beam = cheetah_cavity.track(incoming_particle_beam) From 9b1db538c1644e47a268dcb6010d047f116f3105 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 24 Jul 2024 15:03:51 +0200 Subject: [PATCH 027/111] Fix expected values for `energy` and `total_charge` shapes --- tests/test_vectorized.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 42757662..f32a6dbc 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -322,8 +322,8 @@ def test_enormous_through_ares(): assert outgoing.sigma_py.shape == (3, 100_000) assert outgoing.sigma_tau.shape == (3, 100_000) assert outgoing.sigma_p.shape == (3, 100_000) - assert outgoing.energy.shape == torch.Size([1]) - assert outgoing.total_charge.shape == torch.Size([1]) + assert outgoing.energy.shape == (3, 100_000) + assert outgoing.total_charge.shape == (3, 100_000) def test_cavity_with_zero_and_non_zero_voltage(): From 039cd16cad9db8246efb84b200cd8caec6dc0ba6 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 24 Jul 2024 15:19:53 +0200 Subject: [PATCH 028/111] Fix `total_charge` and `energy` broacast issue --- cheetah/accelerator/element.py | 10 ++-- cheetah/converters/ocelot.py | 88 +++++++++++++++++----------------- 2 files changed, 50 insertions(+), 48 deletions(-) diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 560fba23..b40acac9 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -69,8 +69,8 @@ def track(self, incoming: Beam) -> Beam: return ParameterBeam( mu, cov, - incoming.energy, - total_charge=incoming.total_charge, + incoming.energy.expand(mu.shape[:-1]), + total_charge=incoming.total_charge.expand(mu.shape[:-1]), device=mu.device, dtype=mu.dtype, ) @@ -79,8 +79,10 @@ def track(self, incoming: Beam) -> Beam: new_particles = torch.matmul(incoming.particles, tm.transpose(-2, -1)) return ParticleBeam( new_particles, - incoming.energy, - particle_charges=incoming.particle_charges, + incoming.energy.expand(new_particles.shape[:-2]), + particle_charges=incoming.particle_charges.expand( + new_particles.shape[:-1] + ), device=new_particles.device, dtype=new_particles.dtype, ) diff --git a/cheetah/converters/ocelot.py b/cheetah/converters/ocelot.py index 3b8f4236..6b5f9385 100644 --- a/cheetah/converters/ocelot.py +++ b/cheetah/converters/ocelot.py @@ -31,91 +31,91 @@ def ocelot2cheetah( if isinstance(element, ocelot.Drift): return cheetah.Drift( - length=torch.tensor([element.l], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Quadrupole): return cheetah.Quadrupole( - length=torch.tensor([element.l], dtype=torch.float32), - k1=torch.tensor([element.k1], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + k1=torch.tensor(element.k1, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Solenoid): return cheetah.Solenoid( - length=torch.tensor([element.l], dtype=torch.float32), - k=torch.tensor([element.k], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + k=torch.tensor(element.k, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Hcor): return cheetah.HorizontalCorrector( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Vcor): return cheetah.VerticalCorrector( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Bend): return cheetah.Dipole( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), - e1=torch.tensor([element.e1], dtype=torch.float32), - e2=torch.tensor([element.e2], dtype=torch.float32), - tilt=torch.tensor([element.tilt], dtype=torch.float32), - fringe_integral=torch.tensor([element.fint], dtype=torch.float32), - fringe_integral_exit=torch.tensor([element.fintx], dtype=torch.float32), - gap=torch.tensor([element.gap], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), + e1=torch.tensor(element.e1, dtype=torch.float32), + e2=torch.tensor(element.e2, dtype=torch.float32), + tilt=torch.tensor(element.tilt, dtype=torch.float32), + fringe_integral=torch.tensor(element.fint, dtype=torch.float32), + fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), + gap=torch.tensor(element.gap, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.SBend): return cheetah.Dipole( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), - e1=torch.tensor([element.e1], dtype=torch.float32), - e2=torch.tensor([element.e2], dtype=torch.float32), - tilt=torch.tensor([element.tilt], dtype=torch.float32), - fringe_integral=torch.tensor([element.fint], dtype=torch.float32), - fringe_integral_exit=torch.tensor([element.fintx], dtype=torch.float32), - gap=torch.tensor([element.gap], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), + e1=torch.tensor(element.e1, dtype=torch.float32), + e2=torch.tensor(element.e2, dtype=torch.float32), + tilt=torch.tensor(element.tilt, dtype=torch.float32), + fringe_integral=torch.tensor(element.fint, dtype=torch.float32), + fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), + gap=torch.tensor(element.gap, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.RBend): return cheetah.RBend( - length=torch.tensor([element.l], dtype=torch.float32), - angle=torch.tensor([element.angle], dtype=torch.float32), - e1=torch.tensor([element.e1], dtype=torch.float32) - element.angle / 2, - e2=torch.tensor([element.e2], dtype=torch.float32) - element.angle / 2, - tilt=torch.tensor([element.tilt], dtype=torch.float32), - fringe_integral=torch.tensor([element.fint], dtype=torch.float32), - fringe_integral_exit=torch.tensor([element.fintx], dtype=torch.float32), - gap=torch.tensor([element.gap], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + angle=torch.tensor(element.angle, dtype=torch.float32), + e1=torch.tensor(element.e1, dtype=torch.float32) - element.angle / 2, + e2=torch.tensor(element.e2, dtype=torch.float32) - element.angle / 2, + tilt=torch.tensor(element.tilt, dtype=torch.float32), + fringe_integral=torch.tensor(element.fint, dtype=torch.float32), + fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), + gap=torch.tensor(element.gap, dtype=torch.float32), name=element.id, device=device, dtype=dtype, ) elif isinstance(element, ocelot.Cavity): return cheetah.Cavity( - length=torch.tensor([element.l], dtype=torch.float32), - voltage=torch.tensor([element.v], dtype=torch.float32) * 1e9, - frequency=torch.tensor([element.freq], dtype=torch.float32), - phase=torch.tensor([element.phi], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9, + frequency=torch.tensor(element.freq, dtype=torch.float32), + phase=torch.tensor(element.phi, dtype=torch.float32), name=element.id, device=device, dtype=dtype, @@ -123,10 +123,10 @@ def ocelot2cheetah( elif isinstance(element, ocelot.TDCavity): # TODO: Better replacement at some point? return cheetah.Cavity( - length=torch.tensor([element.l], dtype=torch.float32), - voltage=torch.tensor([element.v], dtype=torch.float32) * 1e9, - frequency=torch.tensor([element.freq], dtype=torch.float32), - phase=torch.tensor([element.phi], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), + voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9, + frequency=torch.tensor(element.freq, dtype=torch.float32), + phase=torch.tensor(element.phi, dtype=torch.float32), name=element.id, device=device, dtype=dtype, @@ -162,8 +162,8 @@ def ocelot2cheetah( elif isinstance(element, ocelot.Aperture): shape_translation = {"rect": "rectangular", "elip": "elliptical"} return cheetah.Aperture( - x_max=torch.tensor([element.xmax], dtype=torch.float32), - y_max=torch.tensor([element.ymax], dtype=torch.float32), + x_max=torch.tensor(element.xmax, dtype=torch.float32), + y_max=torch.tensor(element.ymax, dtype=torch.float32), shape=shape_translation[element.type], is_active=True, name=element.id, @@ -177,7 +177,7 @@ def ocelot2cheetah( " replacing with drift section." ) return cheetah.Drift( - length=torch.tensor([element.l], dtype=torch.float32), + length=torch.tensor(element.l, dtype=torch.float32), name=element.id, device=device, dtype=dtype, From a97fcf23e76364a5eaa775b05579de913130a2d0 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 24 Jul 2024 15:21:03 +0200 Subject: [PATCH 029/111] Fix remaining failing tests --- cheetah/particles/parameter_beam.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index 807da959..3b8e55a1 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -206,10 +206,10 @@ def from_twiss( total_charge if total_charge is not None else torch.full(shape, 0.0) ) - assert all( + assert torch.all( beta_x > 0 ), "Beta function in x direction must be larger than 0 everywhere." - assert all( + assert torch.all( beta_y > 0 ), "Beta function in y direction must be larger than 0 everywhere." From 3aaae6f0dba12235b2646105bbf230c788105c47 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 25 Jul 2024 08:59:03 +0200 Subject: [PATCH 030/111] Fix `any`s to `torch.any` --- cheetah/accelerator/horizontal_corrector.py | 2 +- cheetah/accelerator/solenoid.py | 2 +- cheetah/accelerator/vertical_corrector.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index e64fcc55..5aaddab1 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -66,7 +66,7 @@ def is_skippable(self) -> bool: @property def is_active(self) -> bool: - return any(self.angle != 0) + return torch.any(self.angle != 0) def split(self, resolution: torch.Tensor) -> list[Element]: split_elements = [] diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 61646f9f..0a3fccbb 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -112,7 +112,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: @property def is_active(self) -> bool: - return any(self.k != 0) + return torch.any(self.k != 0) def is_skippable(self) -> bool: return True diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 3b374afa..da801df5 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -71,7 +71,7 @@ def is_skippable(self) -> bool: @property def is_active(self) -> bool: - return any(self.angle != 0) + return torch.any(self.angle != 0) def split(self, resolution: torch.Tensor) -> list[Element]: split_elements = [] From 604fa81db1e79144e15eff8e38baac9c29648276 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 25 Jul 2024 09:20:51 +0200 Subject: [PATCH 031/111] Remove no longer needed dimensions --- cheetah/converters/bmad.py | 64 +++++++++++++------------- cheetah/converters/nxtables.py | 46 +++++++++---------- tests/test_differentiable.py | 8 ++-- tests/test_parameter_beam.py | 80 ++++++++++++++++----------------- tests/test_particle_beam.py | 72 ++++++++++++++--------------- tests/test_quadrupole.py | 18 ++++---- tests/test_space_charge_kick.py | 44 +++++++++--------- tests/test_speed.py | 4 +- 8 files changed, 167 insertions(+), 169 deletions(-) diff --git a/cheetah/converters/bmad.py b/cheetah/converters/bmad.py index 7cfd41c7..54914d43 100644 --- a/cheetah/converters/bmad.py +++ b/cheetah/converters/bmad.py @@ -476,7 +476,7 @@ def convert_element( ) if "l" in bmad_parsed: return cheetah.Drift( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -489,7 +489,7 @@ def convert_element( ) if "l" in bmad_parsed: return cheetah.Drift( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -501,7 +501,7 @@ def convert_element( ["element_type", "alias", "type", "l", "descrip"], bmad_parsed ) return cheetah.Drift( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -511,7 +511,7 @@ def convert_element( ["element_type", "l", "type", "descrip"], bmad_parsed ) return cheetah.Drift( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -521,8 +521,8 @@ def convert_element( ["element_type", "type", "alias"], bmad_parsed ) return cheetah.HorizontalCorrector( - length=torch.tensor([bmad_parsed.get("l", 0.0)]), - angle=torch.tensor([bmad_parsed.get("kick", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), + angle=torch.tensor(bmad_parsed.get("kick", 0.0)), name=name, device=device, dtype=dtype, @@ -532,8 +532,8 @@ def convert_element( ["element_type", "type", "alias"], bmad_parsed ) return cheetah.VerticalCorrector( - length=torch.tensor([bmad_parsed.get("l", 0.0)]), - angle=torch.tensor([bmad_parsed.get("kick", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), + angle=torch.tensor(bmad_parsed.get("kick", 0.0)), name=name, device=device, dtype=dtype, @@ -559,15 +559,15 @@ def convert_element( bmad_parsed, ) return cheetah.Dipole( - length=torch.tensor([bmad_parsed["l"]]), - gap=torch.tensor([bmad_parsed.get("hgap", 0.0)]), - angle=torch.tensor([bmad_parsed.get("angle", 0.0)]), - e1=torch.tensor([bmad_parsed["e1"]]), - e2=torch.tensor([bmad_parsed.get("e2", 0.0)]), - tilt=torch.tensor([bmad_parsed.get("ref_tilt", 0.0)]), - fringe_integral=torch.tensor([bmad_parsed.get("fint", 0.0)]), + length=torch.tensor(bmad_parsed["l"]), + gap=torch.tensor(bmad_parsed.get("hgap", 0.0)), + angle=torch.tensor(bmad_parsed.get("angle", 0.0)), + e1=torch.tensor(bmad_parsed["e1"]), + e2=torch.tensor(bmad_parsed.get("e2", 0.0)), + tilt=torch.tensor(bmad_parsed.get("ref_tilt", 0.0)), + fringe_integral=torch.tensor(bmad_parsed.get("fint", 0.0)), fringe_integral_exit=( - torch.tensor([bmad_parsed["fintx"]]) + torch.tensor(bmad_parsed["fintx"]) if "fintx" in bmad_parsed else None ), @@ -582,9 +582,9 @@ def convert_element( bmad_parsed, ) return cheetah.Quadrupole( - length=torch.tensor([bmad_parsed["l"]]), - k1=torch.tensor([bmad_parsed["k1"]]), - tilt=torch.tensor([bmad_parsed.get("tilt", 0.0)]), + length=torch.tensor(bmad_parsed["l"]), + k1=torch.tensor(bmad_parsed["k1"]), + tilt=torch.tensor(bmad_parsed.get("tilt", 0.0)), name=name, device=device, dtype=dtype, @@ -594,8 +594,8 @@ def convert_element( ["element_type", "l", "ks", "alias"], bmad_parsed ) return cheetah.Solenoid( - length=torch.tensor([bmad_parsed["l"]]), - k=torch.tensor([bmad_parsed["ks"]]), + length=torch.tensor(bmad_parsed["l"]), + k=torch.tensor(bmad_parsed["ks"]), name=name, device=device, dtype=dtype, @@ -616,12 +616,12 @@ def convert_element( bmad_parsed, ) return cheetah.Cavity( - length=torch.tensor([bmad_parsed["l"]]), - voltage=torch.tensor([bmad_parsed.get("voltage", 0.0)]), + length=torch.tensor(bmad_parsed["l"]), + voltage=torch.tensor(bmad_parsed.get("voltage", 0.0)), phase=torch.tensor( - [-np.degrees(bmad_parsed.get("phi0", 0.0) * 2 * np.pi)] + -np.degrees(bmad_parsed.get("phi0", 0.0) * 2 * np.pi) ), - frequency=torch.tensor([bmad_parsed["rf_frequency"]]), + frequency=torch.tensor(bmad_parsed["rf_frequency"]), name=name, device=device, dtype=dtype, @@ -632,8 +632,8 @@ def convert_element( bmad_parsed, ) return cheetah.Aperture( - x_max=torch.tensor([bmad_parsed.get("x_limit", np.inf)]), - y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]), + x_max=torch.tensor(bmad_parsed.get("x_limit", np.inf)), + y_max=torch.tensor(bmad_parsed.get("y_limit", np.inf)), shape="rectangular", name=name, device=device, @@ -645,8 +645,8 @@ def convert_element( bmad_parsed, ) return cheetah.Aperture( - x_max=torch.tensor([bmad_parsed.get("x_limit", np.inf)]), - y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]), + x_max=torch.tensor(bmad_parsed.get("x_limit", np.inf)), + y_max=torch.tensor(bmad_parsed.get("y_limit", np.inf)), shape="elliptical", name=name, device=device, @@ -668,7 +668,7 @@ def convert_element( bmad_parsed, ) return cheetah.Undulator( - length=torch.tensor([bmad_parsed["l"]]), + length=torch.tensor(bmad_parsed["l"]), name=name, device=device, dtype=dtype, @@ -677,7 +677,7 @@ def convert_element( # TODO: Does this need to be implemented in Cheetah in a more proper way? validate_understood_properties(["element_type", "tilt"], bmad_parsed) return cheetah.Drift( - length=torch.tensor([bmad_parsed.get("l", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), name=name, device=device, dtype=dtype, @@ -690,7 +690,7 @@ def convert_element( # TODO: Remove the length if by adding markers to Cheeath return cheetah.Drift( name=name, - length=torch.tensor([bmad_parsed.get("l", 0.0)]), + length=torch.tensor(bmad_parsed.get("l", 0.0)), device=device, dtype=dtype, ) diff --git a/cheetah/converters/nxtables.py b/cheetah/converters/nxtables.py index 601d1552..bbda52e3 100644 --- a/cheetah/converters/nxtables.py +++ b/cheetah/converters/nxtables.py @@ -53,10 +53,10 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]: elif class_name == "MCXG": # TODO: Check length with Willi assert name[6] == "X" horizontal_coil = cheetah.HorizontalCorrector( - name=name[:6] + "H" + name[6 + 1 :], length=torch.tensor([5e-05]) + name=name[:6] + "H" + name[6 + 1 :], length=torch.tensor(5e-05) ) vertical_coil = cheetah.VerticalCorrector( - name=name[:6] + "V" + name[6 + 1 :], length=torch.tensor([5e-05]) + name=name[:6] + "V" + name[6 + 1 :], length=torch.tensor(5e-05) ) element = cheetah.Segment(elements=[horizontal_coil, vertical_coil], name=name) elif class_name == "BSCX": @@ -115,57 +115,57 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]: elif class_name == "SLHG": element = cheetah.Aperture( # TODO: Ask for actual size and shape name=name, - x_max=torch.tensor([float("inf")]), - y_max=torch.tensor([float("inf")]), + x_max=torch.tensor(float("inf")), + y_max=torch.tensor(float("inf")), shape="elliptical", ) elif class_name == "SLHB": element = cheetah.Aperture( # TODO: Ask for actual size and shape name=name, - x_max=torch.tensor([float("inf")]), - y_max=torch.tensor([float("inf")]), + x_max=torch.tensor(float("inf")), + y_max=torch.tensor(float("inf")), shape="rectangular", ) elif class_name == "SLHS": element = cheetah.Aperture( # TODO: Ask for actual size and shape name=name, - x_max=torch.tensor([float("inf")]), - y_max=torch.tensor([float("inf")]), + x_max=torch.tensor(float("inf")), + y_max=torch.tensor(float("inf")), shape="rectangular", ) elif class_name == "MCHM": - element = cheetah.HorizontalCorrector(name=name, length=torch.tensor([0.02])) + element = cheetah.HorizontalCorrector(name=name, length=torch.tensor(0.02)) elif class_name == "MCVM": - element = cheetah.VerticalCorrector(name=name, length=torch.tensor([0.02])) + element = cheetah.VerticalCorrector(name=name, length=torch.tensor(0.02)) elif class_name == "MBHL": - element = cheetah.Dipole(name=name, length=torch.tensor([0.322])) + element = cheetah.Dipole(name=name, length=torch.tensor(0.322)) elif class_name == "MBHB": - element = cheetah.Dipole(name=name, length=torch.tensor([0.22])) + element = cheetah.Dipole(name=name, length=torch.tensor(0.22)) elif class_name == "MBHO": element = cheetah.Dipole( name=name, - length=torch.tensor([0.43852543421396856]), - angle=torch.tensor([0.8203047484373349]), - e2=torch.tensor([-0.7504915783575616]), + length=torch.tensor(0.43852543421396856), + angle=torch.tensor(0.8203047484373349), + e2=torch.tensor(-0.7504915783575616), ) elif class_name == "MQZM": - element = cheetah.Quadrupole(name=name, length=torch.tensor([0.122])) + element = cheetah.Quadrupole(name=name, length=torch.tensor(0.122)) elif class_name == "RSBL": element = cheetah.Cavity( name=name, - length=torch.tensor([4.139]), - frequency=torch.tensor([2.998e9]), - voltage=torch.tensor([76e6]), + length=torch.tensor(4.139), + frequency=torch.tensor(2.998e9), + voltage=torch.tensor(76e6), ) elif class_name == "RXBD": element = cheetah.Cavity( # TODO: TD? and tilt? name=name, - length=torch.tensor([1.0]), - frequency=torch.tensor([11.9952e9]), - voltage=torch.tensor([0.0]), + length=torch.tensor(1.0), + frequency=torch.tensor(11.9952e9), + voltage=torch.tensor(0.0), ) elif class_name == "UNDA": # TODO: Figure out actual length - element = cheetah.Undulator(name=name, length=torch.tensor([0.25])) + element = cheetah.Undulator(name=name, length=torch.tensor(0.25)) elif class_name in [ "SOLG", "BCMG", diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index 1d2eb302..c60e07dd 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -13,13 +13,13 @@ def test_simple_quadrupole(): """ segment = cheetah.Segment( [ - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.Quadrupole( - length=torch.tensor([0.2]), - k1=nn.Parameter(torch.tensor([3.142])), + length=torch.tensor(0.2), + k1=nn.Parameter(torch.tensor(3.142)), name="my_quad", ), - cheetah.Drift(length=torch.tensor([1.0])), + cheetah.Drift(length=torch.tensor(1.0)), ] ) incoming_beam = cheetah.ParticleBeam.from_astra( diff --git a/tests/test_parameter_beam.py b/tests/test_parameter_beam.py index 475ada93..3f71f1f0 100644 --- a/tests/test_parameter_beam.py +++ b/tests/test_parameter_beam.py @@ -9,20 +9,20 @@ def test_create_from_parameters(): Test that a `ParameterBeam` created from parameters actually has those parameters. """ beam = ParameterBeam.from_parameters( - mu_x=torch.tensor([1e-5]), - mu_px=torch.tensor([1e-7]), - mu_y=torch.tensor([2e-5]), - mu_py=torch.tensor([2e-7]), - sigma_x=torch.tensor([1.75e-7]), - sigma_px=torch.tensor([2e-7]), - sigma_y=torch.tensor([1.75e-7]), - sigma_py=torch.tensor([2e-7]), - sigma_tau=torch.tensor([0.000001]), - sigma_p=torch.tensor([0.000001]), - cor_x=torch.tensor([0.0]), - cor_y=torch.tensor([0.0]), - cor_tau=torch.tensor([0.0]), - energy=torch.tensor([1e7]), + mu_x=torch.tensor(1e-5), + mu_px=torch.tensor(1e-7), + mu_y=torch.tensor(2e-5), + mu_py=torch.tensor(2e-7), + sigma_x=torch.tensor(1.75e-7), + sigma_px=torch.tensor(2e-7), + sigma_y=torch.tensor(1.75e-7), + sigma_py=torch.tensor(2e-7), + sigma_tau=torch.tensor(0.000001), + sigma_p=torch.tensor(0.000001), + cor_x=torch.tensor(0.0), + cor_y=torch.tensor(0.0), + cor_tau=torch.tensor(0.0), + energy=torch.tensor(1e7), ) assert np.isclose(beam.mu_x.cpu().numpy(), 1e-5) @@ -45,18 +45,18 @@ def test_transform_to(): """ original_beam = ParameterBeam.from_parameters() transformed_beam = original_beam.transformed_to( - mu_x=torch.tensor([1e-5]), - mu_px=torch.tensor([1e-7]), - mu_y=torch.tensor([2e-5]), - mu_py=torch.tensor([2e-7]), - sigma_x=torch.tensor([1.75e-7]), - sigma_px=torch.tensor([2e-7]), - sigma_y=torch.tensor([1.75e-7]), - sigma_py=torch.tensor([2e-7]), - sigma_tau=torch.tensor([0.000001]), - sigma_p=torch.tensor([0.000001]), - energy=torch.tensor([1e7]), - total_charge=torch.tensor([1e-9]), + mu_x=torch.tensor(1e-5), + mu_px=torch.tensor(1e-7), + mu_y=torch.tensor(2e-5), + mu_py=torch.tensor(2e-7), + sigma_x=torch.tensor(1.75e-7), + sigma_px=torch.tensor(2e-7), + sigma_y=torch.tensor(1.75e-7), + sigma_py=torch.tensor(2e-7), + sigma_tau=torch.tensor(0.000001), + sigma_p=torch.tensor(0.000001), + energy=torch.tensor(1e7), + total_charge=torch.tensor(1e-9), ) assert isinstance(transformed_beam, ParameterBeam) @@ -80,13 +80,13 @@ def test_from_twiss_to_twiss(): parameters. """ beam = ParameterBeam.from_twiss( - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([2e-7]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(2e-7), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) assert np.isclose(beam.beta_x.cpu().numpy(), 5.91253676811640894) @@ -103,13 +103,13 @@ def test_from_twiss_dtype(): Test that a `ParameterBeam` created from twiss parameters has the requested `dtype`. """ beam = ParameterBeam.from_twiss( - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([2e-7]), - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(2e-7), + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), dtype=torch.float64, ) diff --git a/tests/test_particle_beam.py b/tests/test_particle_beam.py index f22b68c6..2f1af1b1 100644 --- a/tests/test_particle_beam.py +++ b/tests/test_particle_beam.py @@ -9,22 +9,22 @@ def test_create_from_parameters(): Test that a `ParticleBeam` created from parameters actually has those parameters. """ beam = ParticleBeam.from_parameters( - num_particles=torch.tensor([1_000_000]), - mu_x=torch.tensor([1e-5]), - mu_px=torch.tensor([1e-7]), - mu_y=torch.tensor([2e-5]), - mu_py=torch.tensor([2e-7]), - sigma_x=torch.tensor([1.75e-7]), - sigma_px=torch.tensor([2e-7]), - sigma_y=torch.tensor([1.75e-7]), - sigma_py=torch.tensor([2e-7]), - sigma_tau=torch.tensor([0.000001]), - sigma_p=torch.tensor([0.000001]), - cor_x=torch.tensor([0.0]), - cor_y=torch.tensor([0.0]), - cor_tau=torch.tensor([0.0]), - energy=torch.tensor([1e7]), - total_charge=torch.tensor([1e-9]), + num_particles=torch.tensor(1_000_000), + mu_x=torch.tensor(1e-5), + mu_px=torch.tensor(1e-7), + mu_y=torch.tensor(2e-5), + mu_py=torch.tensor(2e-7), + sigma_x=torch.tensor(1.75e-7), + sigma_px=torch.tensor(2e-7), + sigma_y=torch.tensor(1.75e-7), + sigma_py=torch.tensor(2e-7), + sigma_tau=torch.tensor(0.000001), + sigma_p=torch.tensor(0.000001), + cor_x=torch.tensor(0.0), + cor_y=torch.tensor(0.0), + cor_tau=torch.tensor(0.0), + energy=torch.tensor(1e7), + total_charge=torch.tensor(1e-9), ) assert beam.num_particles == 1_000_000 @@ -49,18 +49,18 @@ def test_transform_to(): """ original_beam = ParticleBeam.from_parameters() transformed_beam = original_beam.transformed_to( - mu_x=torch.tensor([1e-5]), - mu_px=torch.tensor([1e-7]), - mu_y=torch.tensor([2e-5]), - mu_py=torch.tensor([2e-7]), - sigma_x=torch.tensor([1.75e-7]), - sigma_px=torch.tensor([2e-7]), - sigma_y=torch.tensor([1.75e-7]), - sigma_py=torch.tensor([2e-7]), - sigma_tau=torch.tensor([0.000001]), - sigma_p=torch.tensor([0.000001]), - energy=torch.tensor([1e7]), - total_charge=torch.tensor([1e-9]), + mu_x=torch.tensor(1e-5), + mu_px=torch.tensor(1e-7), + mu_y=torch.tensor(2e-5), + mu_py=torch.tensor(2e-7), + sigma_x=torch.tensor(1.75e-7), + sigma_px=torch.tensor(2e-7), + sigma_y=torch.tensor(1.75e-7), + sigma_py=torch.tensor(2e-7), + sigma_tau=torch.tensor(0.000001), + sigma_p=torch.tensor(0.000001), + energy=torch.tensor(1e7), + total_charge=torch.tensor(1e-9), ) assert isinstance(transformed_beam, ParticleBeam) @@ -86,14 +86,14 @@ def test_from_twiss_to_twiss(): parameters. """ beam = ParticleBeam.from_twiss( - num_particles=torch.tensor([10_000_000]), - beta_x=torch.tensor([5.91253676811640894]), - alpha_x=torch.tensor([3.55631307633660354]), - emittance_x=torch.tensor([3.494768647122823e-09]), - beta_y=torch.tensor([5.91253676811640982]), - alpha_y=torch.tensor([1.0]), # TODO: set realistic value - emittance_y=torch.tensor([3.497810737006068e-09]), - energy=torch.tensor([6e6]), + num_particles=torch.tensor(10_000_000), + beta_x=torch.tensor(5.91253676811640894), + alpha_x=torch.tensor(3.55631307633660354), + emittance_x=torch.tensor(3.494768647122823e-09), + beta_y=torch.tensor(5.91253676811640982), + alpha_y=torch.tensor(1.0), # TODO: set realistic value + emittance_y=torch.tensor(3.497810737006068e-09), + energy=torch.tensor(6e6), ) # rather loose rtol is needed here due to the random sampling of the beam assert np.isclose(beam.beta_x.cpu().numpy(), 5.91253676811640894, rtol=1e-2) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index e3c0682a..86020687 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -7,15 +7,15 @@ def test_quadrupole_off(): """ Test that a quadrupole with k1=0 behaves still like a drift. """ - quadrupole = Quadrupole(length=torch.tensor([1.0]), k1=torch.tensor([0.0])) - drift = Drift(length=torch.tensor([1.0])) + quadrupole = Quadrupole(length=torch.tensor(1.0), k1=torch.tensor(0.0)) + drift = Drift(length=torch.tensor(1.0)) incoming_beam = ParameterBeam.from_parameters( - sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) outbeam_quad = quadrupole(incoming_beam) outbeam_drift = drift(incoming_beam) - quadrupole.k1 = torch.tensor([1.0], device=quadrupole.k1.device) + quadrupole.k1 = torch.tensor(1.0, device=quadrupole.k1.device) outbeam_quad_on = quadrupole(incoming_beam) assert torch.allclose(outbeam_quad.sigma_x, outbeam_drift.sigma_x) @@ -28,18 +28,18 @@ def test_quadrupole_with_misalignments_batched(): """ quad_with_misalignment = Quadrupole( - length=torch.tensor([1.0]), - k1=torch.tensor([1.0]), - misalignment=torch.tensor([[0.1, 0.1]]), + length=torch.tensor(1.0), + k1=torch.tensor(1.0), + misalignment=torch.tensor([0.1, 0.1]), ) assert quad_with_misalignment.batch_shape == torch.Size([1, 2]) quad_without_misalignment = Quadrupole( - length=torch.tensor([1.0]), k1=torch.tensor([1.0]) + length=torch.tensor(1.0), k1=torch.tensor(1.0) ) incoming_beam = ParameterBeam.from_parameters( - sigma_px=torch.tensor([2e-7]), sigma_py=torch.tensor([2e-7]) + sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) outbeam_quad_with_misalignment = quad_with_misalignment(incoming_beam) outbeam_quad_without_misalignment = quad_without_misalignment(incoming_beam) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index ed9aeafd..6e52c773 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -87,15 +87,13 @@ def test_vectorized(): incoming = cheetah.ParticleBeam.uniform_3d_ellipsoid( num_particles=torch.tensor(10_000), total_charge=torch.tensor([[1e-9, 2e-9], [3e-9, 4e-9], [5e-9, 6e-9]]), - energy=energy.repeat(3, 2), - radius_x=R0.repeat(3, 2), - radius_y=R0.repeat(3, 2), - radius_tau=(R0 / gamma).repeat( - 3, 2 - ), # Radius of the beam in s direction in the lab frame - sigma_px=torch.tensor([1e-15]).repeat(3, 2), - sigma_py=torch.tensor([1e-15]).repeat(3, 2), - sigma_p=torch.tensor([1e-15]).repeat(3, 2), + energy=energy, + radius_x=R0, + radius_y=R0, + radius_tau=R0 / gamma, # Radius of the beam in s direction in the lab frame + sigma_px=torch.tensor(1e-15), + sigma_py=torch.tensor(1e-15), + sigma_p=torch.tensor(1e-15), ) segment = cheetah.Segment( @@ -125,8 +123,8 @@ def test_vectorized_cold_uniform_beam_expansion(): torch.manual_seed(42) # Simulation parameters - R0 = torch.tensor([0.001]) - energy = torch.tensor([2.5e8]) + R0 = torch.tensor(0.001) + energy = torch.tensor(2.5e8) rest_energy = torch.tensor( constants.electron_mass * constants.speed_of_light**2 @@ -139,14 +137,14 @@ def test_vectorized_cold_uniform_beam_expansion(): incoming = cheetah.ParticleBeam.uniform_3d_ellipsoid( num_particles=torch.tensor(10_000), - total_charge=torch.tensor([1e-9]), + total_charge=torch.tensor(1e-9), energy=energy, radius_x=R0, radius_y=R0, radius_tau=R0 / gamma, # Radius of the beam in s direction in the lab frame - sigma_px=torch.tensor([1e-15]), - sigma_py=torch.tensor([1e-15]), - sigma_p=torch.tensor([1e-15]), + sigma_px=torch.tensor(1e-15), + sigma_py=torch.tensor(1e-15), + sigma_p=torch.tensor(1e-15), ) # Compute section length @@ -184,13 +182,13 @@ def test_incoming_beam_not_modified(): incoming_beam = cheetah.ParticleBeam.from_parameters( num_particles=torch.tensor(10_000), - sigma_px=torch.tensor([2e-7]), - sigma_py=torch.tensor([2e-7]), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), ) # Initial beam properties incoming_beam_before = incoming_beam.particles - section_length = torch.tensor([1.0]) + section_length = torch.tensor(1.0) segment_space_charge = cheetah.Segment( elements=[ cheetah.Drift(section_length / 6), @@ -216,9 +214,9 @@ def test_gradient(): Tests that the gradient of the track method is computed withouth throwing an error. """ incoming_beam = cheetah.ParticleBeam.from_parameters( - num_particles=torch.tensor([10_000]), - sigma_px=torch.tensor([2e-7]), - sigma_py=torch.tensor([2e-7]), + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), ) segment_length = nn.Parameter(torch.tensor([1.0])) @@ -246,7 +244,7 @@ def test_does_not_break_segment_length(): Test that the computation of a `Segment`'s length does not break when `SpaceChargeKick` is used. """ - section_length = torch.tensor(1.0).repeat((3, 2)) + section_length = torch.tensor(1.0) segment = cheetah.Segment( elements=[ cheetah.Drift(section_length / 6), @@ -260,4 +258,4 @@ def test_does_not_break_segment_length(): ) assert segment.length.shape == (3, 2) - assert torch.allclose(segment.length, torch.tensor(1.0).repeat(3, 2)) + assert torch.allclose(segment.length, torch.tensor(1.0)) diff --git a/tests/test_speed.py b/tests/test_speed.py index c78a9dee..917f9550 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -18,8 +18,8 @@ def test_tracking_speed(): particles = cheetah.ParticleBeam.from_parameters( num_particles=torch.tensor(int(1e5)), - sigma_x=torch.tensor([175e-6]), - sigma_y=torch.tensor([175e-6]), + sigma_x=torch.tensor(175e-6), + sigma_y=torch.tensor(175e-6), ) t1 = time.time() From e1aa83c0a90974365f11186c90e75e9654bc5854 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Tue, 6 Aug 2024 14:37:33 -0500 Subject: [PATCH 032/111] test fixes/improvements --- tests/test_quadrupole.py | 6 +++--- tests/test_vectorized.py | 26 +++++++++++++------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index ba08065d..ae85b4b4 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -31,7 +31,7 @@ def test_quadrupole_with_misalignments_batched(): quad_with_misalignment = Quadrupole( length=torch.tensor(1.0), k1=torch.tensor(1.0), - misalignment=torch.tensor([0.1, 0.1]), + misalignment=torch.tensor([0.1, 0.1]).unsqueeze(0), ) assert quad_with_misalignment.batch_shape == torch.Size([1, 2]) @@ -147,7 +147,7 @@ def test_quadrupole_length_multiple_batch_dimensions(): Test that a quadrupole with lengths that have multiple vectorisation dimensions does not raise an error and behaves as expected. """ - lengths = torch.tensor([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]]) + lengths = torch.tensor([[0.2, 0.3, 0.4], [0.5, 0.4, 0.7]]) segment = Segment( [ Quadrupole(length=lengths, k1=torch.tensor(4.2)), @@ -164,7 +164,7 @@ def test_quadrupole_length_multiple_batch_dimensions(): outgoing = segment(incoming) assert outgoing.particles.shape == (2, 3, 10_000, 7) - assert torch.allclose(outgoing.particles[0, 0], outgoing.particles[0, 1]) + assert torch.allclose(outgoing.particles[0, -1], outgoing.particles[1, -2]) def test_quadrupole_bmadx_tracking(): diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index f32a6dbc..bd9d5805 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -308,22 +308,22 @@ def test_enormous_through_ares(): "tests/resources/ACHIP_EA1_2021.1351.001" ) - segment.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 100_000).repeat(3, 1) + segment.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 10).repeat(3, 1) outgoing = segment.track(incoming) - assert outgoing.mu_x.shape == (3, 100_000) - assert outgoing.mu_px.shape == (3, 100_000) - assert outgoing.mu_y.shape == (3, 100_000) - assert outgoing.mu_py.shape == (3, 100_000) - assert outgoing.sigma_x.shape == (3, 100_000) - assert outgoing.sigma_px.shape == (3, 100_000) - assert outgoing.sigma_y.shape == (3, 100_000) - assert outgoing.sigma_py.shape == (3, 100_000) - assert outgoing.sigma_tau.shape == (3, 100_000) - assert outgoing.sigma_p.shape == (3, 100_000) - assert outgoing.energy.shape == (3, 100_000) - assert outgoing.total_charge.shape == (3, 100_000) + assert outgoing.mu_x.shape == (3, 10) + assert outgoing.mu_px.shape == (3, 10) + assert outgoing.mu_y.shape == (3, 10) + assert outgoing.mu_py.shape == (3, 10) + assert outgoing.sigma_x.shape == (3, 10) + assert outgoing.sigma_px.shape == (3, 10) + assert outgoing.sigma_y.shape == (3, 10) + assert outgoing.sigma_py.shape == (3, 10) + assert outgoing.sigma_tau.shape == (3, 10) + assert outgoing.sigma_p.shape == (3, 10) + assert outgoing.energy.shape == (3, 10) + assert outgoing.total_charge.shape == (3, 10) def test_cavity_with_zero_and_non_zero_voltage(): From 12cfbb47a70f93514d1243c3e0a6fbab1d846e60 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 8 Aug 2024 11:38:36 -0500 Subject: [PATCH 033/111] fix test to require matching batch sizes for particle creation --- tests/test_space_charge_kick.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 6e52c773..82130c93 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -87,13 +87,14 @@ def test_vectorized(): incoming = cheetah.ParticleBeam.uniform_3d_ellipsoid( num_particles=torch.tensor(10_000), total_charge=torch.tensor([[1e-9, 2e-9], [3e-9, 4e-9], [5e-9, 6e-9]]), - energy=energy, - radius_x=R0, - radius_y=R0, - radius_tau=R0 / gamma, # Radius of the beam in s direction in the lab frame - sigma_px=torch.tensor(1e-15), - sigma_py=torch.tensor(1e-15), - sigma_p=torch.tensor(1e-15), + energy=energy.expand([3, 2]), + radius_x=R0.expand([3, 2]), + radius_y=R0.expand([3, 2]), + radius_tau=R0.expand([3, 2]) / gamma, + # Radius of the beam in s direction in the lab frame + sigma_px=torch.tensor(1e-15).expand([3, 2]), + sigma_py=torch.tensor(1e-15).expand([3, 2]), + sigma_p=torch.tensor(1e-15).expand([3, 2]), ) segment = cheetah.Segment( @@ -257,5 +258,5 @@ def test_does_not_break_segment_length(): ] ) - assert segment.length.shape == (3, 2) + assert segment.length.shape == torch.Size([1]) assert torch.allclose(segment.length, torch.tensor(1.0)) From df3d9254c4ac68db645a90a6683c01347830a283 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 8 Aug 2024 11:46:02 -0500 Subject: [PATCH 034/111] require that transform_to method takes scalar arguments --- cheetah/particles/particle_beam.py | 49 +++++++++++++++++++----------- tests/test_particle_beam.py | 6 ++++ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 0369aea3..3941c072 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -656,7 +656,7 @@ def transformed_to( # Figure out batch size of the original beam and check that passed arguments # have the same batch size - shape = self.mu_x.shape + shape = mu_x.shape not_nones = [ argument for argument in [ @@ -676,9 +676,12 @@ def transformed_to( if argument is not None ] if len(not_nones) > 0: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." + if not all( + argument.shape == torch.Size([]) for argument in not_nones + ): + raise NotImplementedError( + "Batching not implemented yet. Arguments must have shape []." + ) mu_x = mu_x if mu_x is not None else self.mu_x mu_y = mu_y if mu_y is not None else self.mu_y @@ -706,23 +709,29 @@ def transformed_to( ) new_mu = torch.stack( - [mu_x, mu_px, mu_y, mu_py, torch.full(shape, 0.0), torch.full(shape, 0.0)], - dim=1, + [ + mu_x.squeeze(), + mu_px.squeeze(), + mu_y.squeeze(), + mu_py.squeeze(), + torch.full(shape, 0.0), + torch.full(shape, 0.0)], + dim=0, ) new_sigma = torch.stack( - [sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p], dim=1 + [sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p], dim=0 ) old_mu = torch.stack( [ - self.mu_x, - self.mu_px, - self.mu_y, - self.mu_py, + self.mu_x.squeeze(), + self.mu_px.squeeze(), + self.mu_y.squeeze(), + self.mu_py.squeeze(), torch.full(shape, 0.0), torch.full(shape, 0.0), ], - dim=1, + dim=0, ) old_sigma = torch.stack( [ @@ -733,13 +742,19 @@ def transformed_to( self.sigma_tau, self.sigma_p, ], - dim=1, - ) + dim=0, + ).squeeze() phase_space = self.particles[:, :, :6] - phase_space = (phase_space - old_mu.unsqueeze(1)) / old_sigma.unsqueeze( - 1 - ) * new_sigma.unsqueeze(1) + new_mu.unsqueeze(1) + phase_space = (phase_space - old_mu.expand( + *phase_space.shape + )) / old_sigma.expand( + *phase_space.shape + ) * new_sigma.expand( + *phase_space.shape + ) + new_mu.expand( + *phase_space.shape + ) particles = torch.ones_like(self.particles) particles[:, :, :6] = phase_space diff --git a/tests/test_particle_beam.py b/tests/test_particle_beam.py index 2f1af1b1..f17dd053 100644 --- a/tests/test_particle_beam.py +++ b/tests/test_particle_beam.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch from cheetah import ParticleBeam @@ -79,6 +80,11 @@ def test_transform_to(): assert np.isclose(transformed_beam.energy.cpu().numpy(), 1e7) assert np.isclose(transformed_beam.total_charge.cpu().numpy(), 1e-9) + with pytest.raises(NotImplementedError): + original_beam.transformed_to( + mu_x=torch.tensor(1e-5).expand([3,2]), + ) + def test_from_twiss_to_twiss(): """ From d482dd6ee283d0bc4a1446f6a33d98f0e5289534 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 8 Aug 2024 11:46:43 -0500 Subject: [PATCH 035/111] Update parameter_beam.py --- cheetah/particles/parameter_beam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index 3b8e55a1..1288bda5 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -320,7 +320,7 @@ def transformed_to( # Figure out batch size of the original beam and check that passed arguments # have the same batch size - shape = self.mu_x.shape + shape = mu_x.shape not_nones = [ argument for argument in [ From de2679eaa0e99e4f387b021a948e9322559f1cca Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Thu, 8 Aug 2024 11:50:32 -0500 Subject: [PATCH 036/111] Update test_particle_beam.py --- tests/test_particle_beam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_particle_beam.py b/tests/test_particle_beam.py index f17dd053..e153ac71 100644 --- a/tests/test_particle_beam.py +++ b/tests/test_particle_beam.py @@ -82,7 +82,7 @@ def test_transform_to(): with pytest.raises(NotImplementedError): original_beam.transformed_to( - mu_x=torch.tensor(1e-5).expand([3,2]), + mu_x=torch.tensor(1e-5).expand([3, 2]), ) From 029187e8f4822dfdd95004650ca59458bb711575 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Tue, 13 Aug 2024 09:09:14 -0500 Subject: [PATCH 037/111] fix tests --- cheetah/accelerator/undulator.py | 3 ++- tests/test_split.py | 16 ++++++++-------- tests/test_vectorized.py | 8 ++++---- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index 8b6a0f92..c2f06e86 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -49,7 +49,8 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) igamma2 = torch.where(gamma != 0, 1 / gamma**2, torch.zeros_like(gamma)) - tm = torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1)) + batch_shape = torch.broadcast_tensors(self.length, energy)[0].shape + tm = torch.eye(7, device=device, dtype=dtype).repeat((*batch_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length tm[..., 4, 5] = self.length * igamma2 diff --git a/tests/test_split.py b/tests/test_split.py index fc5104d1..868c3f9c 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -15,7 +15,7 @@ def test_drift_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_drift.track(incoming_beam) outgoing_beam_split = split_drift.track(incoming_beam) @@ -39,7 +39,7 @@ def test_quadrupole_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_quadrupole.track(incoming_beam) outgoing_beam_split = split_quadrupole.track(incoming_beam) @@ -64,7 +64,7 @@ def test_cavity_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_cavity.track(incoming_beam) outgoing_beam_split = split_cavity.track(incoming_beam) @@ -88,7 +88,7 @@ def test_solenoid_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_solenoid.track(incoming_beam) outgoing_beam_split = split_solenoid.track(incoming_beam) @@ -111,7 +111,7 @@ def test_dipole_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_dipole.track(incoming_beam) outgoing_beam_split = split_dipole.track(incoming_beam) @@ -133,7 +133,7 @@ def test_undulator_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_undulator.track(incoming_beam) outgoing_beam_split = split_undulator.track(incoming_beam) @@ -158,7 +158,7 @@ def test_horizontal_corrector_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_horizontal_corrector.track(incoming_beam) outgoing_beam_split = split_horizontal_corrector.track(incoming_beam) @@ -183,7 +183,7 @@ def test_vertical_corrector_end(): incoming_beam = cheetah.ParticleBeam.from_astra( "tests/resources/ACHIP_EA1_2021.1351.001" - ).broadcast((2,)) + ) outgoing_beam_original = original_vertical_corrector.track(incoming_beam) outgoing_beam_split = split_vertical_corrector.track(incoming_beam) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 601b85f1..26f621ce 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -344,6 +344,7 @@ def test_cavity_with_zero_and_non_zero_voltage(): _ = cavity.track(beam) + def test_screen_length_shape(): """ Test that the shape of a screen's length matches the shape of its misalignment. @@ -358,8 +359,7 @@ def test_screen_length_broadcast_shape(): after broadcasting. """ screen = cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2]])) - broadcast_screen = screen.broadcast((3, 10)) - assert broadcast_screen.length.shape == broadcast_screen.misalignment.shape[:-1] + assert screen.length.shape == screen.misalignment.shape[:-1] def test_vectorized_undulator(): @@ -367,7 +367,7 @@ def test_vectorized_undulator(): element = cheetah.Undulator(length=torch.tensor([0.4, 0.7])) beam = cheetah.ParticleBeam.from_parameters( num_particles=100_000, sigma_x=torch.tensor([1e-5]) - ).broadcast((2,)) + ) _ = element.track(beam) @@ -379,6 +379,6 @@ def test_vectorized_solenoid(): ) beam = cheetah.ParticleBeam.from_parameters( num_particles=100_000, sigma_x=torch.tensor([1e-5]) - ).broadcast((2,)) + ) _ = element.track(beam) From 4761e93312f2776c7fd395131db9985f4f53836e Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Tue, 13 Aug 2024 12:55:16 -0500 Subject: [PATCH 038/111] apply formatting --- cheetah/__init__.py | 2 +- cheetah/accelerator/quadrupole.py | 4 ++-- cheetah/accelerator/screen.py | 3 +-- cheetah/particles/particle_beam.py | 19 +++++++++---------- tests/test_astra_import.py | 3 +-- tests/test_bmad_conversion.py | 3 +-- tests/test_bpm.py | 3 +-- tests/test_cavity.py | 3 +-- tests/test_compare_beam_type.py | 3 +-- tests/test_compare_ocelot.py | 4 ++-- tests/test_device_dtype.py | 3 +-- tests/test_differentiable.py | 3 +-- tests/test_drift.py | 3 +-- tests/test_ocelot_import.py | 3 +-- tests/test_reading_nx_tables.py | 3 +-- tests/test_screen.py | 3 +-- tests/test_space_charge_kick.py | 3 +-- tests/test_speed.py | 4 ++-- tests/test_speed_optimizations.py | 3 +-- tests/test_split.py | 3 +-- tests/test_tracking_lengthless_elements.py | 3 +-- tests/test_vectorized.py | 3 +-- 22 files changed, 33 insertions(+), 51 deletions(-) diff --git a/cheetah/__init__.py b/cheetah/__init__.py index 91c13335..bf5f5e3b 100644 --- a/cheetah/__init__.py +++ b/cheetah/__init__.py @@ -1,7 +1,7 @@ from . import converters # noqa: F401 from .accelerator import ( # noqa: F401 - BPM, Aperture, + BPM, Cavity, CustomTransferMap, Dipole, diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 841c746d..63247770 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -5,11 +5,11 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn, Size from ..particles import Beam, ParticleBeam from ..track_methods import base_rmatrix, misalignment_matrix -from ..utils import UniqueNameGenerator, bmadx +from ..utils import bmadx, UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 35b96a44..94bf2b67 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -8,7 +8,7 @@ from torch.distributions import MultivariateNormal from ..particles import Beam, ParameterBeam, ParticleBeam -from ..utils import UniqueNameGenerator, kde_histogram_2d +from ..utils import kde_histogram_2d, UniqueNameGenerator from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -229,7 +229,6 @@ def reading(self) -> torch.Tensor: ) image = torch.flip(image, dims=[1]) elif isinstance(read_beam, ParticleBeam): - if self.method == "histogram": image = torch.zeros( ( diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 3941c072..a7b75534 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -380,7 +380,9 @@ def uniform_3d_ellipsoid( y = (torch.rand(num_particles) - 0.5) * 2 * r_y tau = (torch.rand(num_particles) - 0.5) * 2 * r_tau - is_in_ellipsoid = x**2 / r_x**2 + y**2 / r_y**2 + tau**2 / r_tau**2 < 1 + is_in_ellipsoid = ( + x**2 / r_x**2 + y**2 / r_y**2 + tau**2 / r_tau**2 < 1 + ) num_to_add = min(num_particles - num_successful, is_in_ellipsoid.sum()) flattened_x[i, num_successful : num_successful + num_to_add] = x[ @@ -676,9 +678,7 @@ def transformed_to( if argument is not None ] if len(not_nones) > 0: - if not all( - argument.shape == torch.Size([]) for argument in not_nones - ): + if not all(argument.shape == torch.Size([]) for argument in not_nones): raise NotImplementedError( "Batching not implemented yet. Arguments must have shape []." ) @@ -715,7 +715,8 @@ def transformed_to( mu_y.squeeze(), mu_py.squeeze(), torch.full(shape, 0.0), - torch.full(shape, 0.0)], + torch.full(shape, 0.0), + ], dim=0, ) new_sigma = torch.stack( @@ -746,11 +747,9 @@ def transformed_to( ).squeeze() phase_space = self.particles[:, :, :6] - phase_space = (phase_space - old_mu.expand( - *phase_space.shape - )) / old_sigma.expand( - *phase_space.shape - ) * new_sigma.expand( + phase_space = ( + phase_space - old_mu.expand(*phase_space.shape) + ) / old_sigma.expand(*phase_space.shape) * new_sigma.expand( *phase_space.shape ) + new_mu.expand( *phase_space.shape diff --git a/tests/test_astra_import.py b/tests/test_astra_import.py index 4e361cce..7bda194e 100644 --- a/tests/test_astra_import.py +++ b/tests/test_astra_import.py @@ -1,8 +1,7 @@ +import cheetah import numpy as np import torch -import cheetah - def test_astra_to_parameter_beam(): """Test that Astra beams are correctly loaded into parameter beams.""" diff --git a/tests/test_bmad_conversion.py b/tests/test_bmad_conversion.py index 9dc78849..a0e76961 100644 --- a/tests/test_bmad_conversion.py +++ b/tests/test_bmad_conversion.py @@ -1,7 +1,6 @@ +import cheetah import pytest import torch - -import cheetah from cheetah.utils import is_mps_available_and_functional diff --git a/tests/test_bpm.py b/tests/test_bpm.py index 6d58ead5..d6c0ad1a 100644 --- a/tests/test_bpm.py +++ b/tests/test_bpm.py @@ -1,8 +1,7 @@ +import cheetah import pytest import torch -import cheetah - @pytest.mark.parametrize("is_bpm_active", [True, False]) @pytest.mark.parametrize("beam_class", [cheetah.ParticleBeam, cheetah.ParameterBeam]) diff --git a/tests/test_cavity.py b/tests/test_cavity.py index 7129cbc0..67f38439 100644 --- a/tests/test_cavity.py +++ b/tests/test_cavity.py @@ -1,8 +1,7 @@ +import cheetah import pytest import torch -import cheetah - def test_assert_ei_greater_zero(): """ diff --git a/tests/test_compare_beam_type.py b/tests/test_compare_beam_type.py index 6a6d3e28..6ddb22bc 100644 --- a/tests/test_compare_beam_type.py +++ b/tests/test_compare_beam_type.py @@ -2,9 +2,8 @@ Tests that ensure that both beam types produce (roughly) the same results. """ -import torch - import cheetah +import torch def test_from_twiss(): diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index d9586e6b..6aab1f63 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -1,11 +1,11 @@ from copy import deepcopy +import cheetah + import numpy as np import ocelot import torch -import cheetah - from .resources import ARESlatticeStage3v1_9 as ares diff --git a/tests/test_device_dtype.py b/tests/test_device_dtype.py index fa3e0a45..cdb3b522 100644 --- a/tests/test_device_dtype.py +++ b/tests/test_device_dtype.py @@ -1,7 +1,6 @@ +import cheetah import pytest import torch - -import cheetah from cheetah.utils import is_mps_available_and_functional diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index c60e07dd..47c1493e 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -1,8 +1,7 @@ +import cheetah import torch from torch import nn -import cheetah - from .resources import ARESlatticeStage3v1_9 as ares diff --git a/tests/test_drift.py b/tests/test_drift.py index fd4ee31a..00b40a7b 100644 --- a/tests/test_drift.py +++ b/tests/test_drift.py @@ -1,8 +1,7 @@ +import cheetah import pytest import torch -import cheetah - def test_diverging_parameter_beam(): """ diff --git a/tests/test_ocelot_import.py b/tests/test_ocelot_import.py index cb2ee660..27f97d92 100644 --- a/tests/test_ocelot_import.py +++ b/tests/test_ocelot_import.py @@ -1,9 +1,8 @@ +import cheetah import numpy as np import ocelot import pytest -import cheetah - from .resources import ARESlatticeStage3v1_9 as ares diff --git a/tests/test_reading_nx_tables.py b/tests/test_reading_nx_tables.py index 3689f65b..8802b58d 100644 --- a/tests/test_reading_nx_tables.py +++ b/tests/test_reading_nx_tables.py @@ -1,6 +1,5 @@ -import torch - import cheetah +import torch def test_no_error(): diff --git a/tests/test_screen.py b/tests/test_screen.py index c9cdfbc5..8cf5d745 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -1,9 +1,8 @@ +import cheetah import numpy as np import pytest import torch -import cheetah - from .resources import ARESlatticeStage3v1_9 as ocelot_lattice diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 82130c93..31aa3b28 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -1,10 +1,9 @@ +import cheetah import torch from scipy import constants from scipy.constants import physical_constants from torch import nn -import cheetah - def test_cold_uniform_beam_expansion(): """ diff --git a/tests/test_speed.py b/tests/test_speed.py index 917f9550..f8325822 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -1,9 +1,9 @@ import time -import torch - import cheetah +import torch + from .resources import ARESlatticeStage3v1_9 as ares diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 8d760f78..0e72ed1e 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -1,8 +1,7 @@ +import cheetah import pytest import torch -import cheetah - def test_merged_transfer_maps_tracking(): """ diff --git a/tests/test_split.py b/tests/test_split.py index 868c3f9c..16769f55 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -1,8 +1,7 @@ +import cheetah import pytest import torch -import cheetah - def test_drift_end(): """ diff --git a/tests/test_tracking_lengthless_elements.py b/tests/test_tracking_lengthless_elements.py index d5a738b7..07e6c4ac 100644 --- a/tests/test_tracking_lengthless_elements.py +++ b/tests/test_tracking_lengthless_elements.py @@ -1,6 +1,5 @@ -import torch - import cheetah +import torch beam_in = cheetah.ParticleBeam.from_parameters(num_particles=100) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 26f621ce..e066b9eb 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,6 +1,5 @@ -import torch - import cheetah +import torch from .resources import ARESlatticeStage3v1_9 as ares From 9be43e17a9c18be35e32ae3d8a0d0db07a01849f Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 2 Sep 2024 13:39:26 +0200 Subject: [PATCH 039/111] Clean up tests --- cheetah/accelerator/cavity.py | 17 +- tests/test_vectorized.py | 325 +++++++++++++--------------------- 2 files changed, 132 insertions(+), 210 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index f9b0ee8b..793c8155 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -224,8 +224,8 @@ def _track_beam(self, incoming: Beam) -> Beam: outgoing = ParameterBeam( outgoing_mu, outgoing_cov, - outgoing_energy, - total_charge=incoming.total_charge, + outgoing_energy.expand(outgoing_mu.shape[:-1]), + total_charge=incoming.total_charge.expand(outgoing_mu.shape[:-1]), device=outgoing_mu.device, dtype=outgoing_mu.dtype, ) @@ -234,7 +234,9 @@ def _track_beam(self, incoming: Beam) -> Beam: outgoing = ParticleBeam( outgoing_particles, outgoing_energy, - particle_charges=incoming.particle_charges, + particle_charges=incoming.particle_charges.expand( + outgoing_particles.shape[:-1] + ), device=outgoing_particles.device, dtype=outgoing_particles.dtype, ) @@ -287,7 +289,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: beta1 = torch.tensor(1.0) k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light) - r55_cor = 0.0 + r55_cor = torch.tensor(0.0) if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if? beta0 = torch.sqrt(1 - 1 / Ei**2) beta1 = torch.sqrt(1 - 1 / Ef**2) @@ -309,7 +311,12 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: r66 = Ei / Ef * beta0 / beta1 r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV) - R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + # Check that all matrix elements have the same shape + r11, r12, r21, r22, r55_cor, r56, r65, r66 = torch.broadcast_tensors( + r11, r12, r21, r22, r55_cor, r56, r65, r66 + ) + + R = torch.eye(7, device=device, dtype=dtype).repeat((*r11.shape, 1, 1)) R[..., 0, 0] = r11 R[..., 0, 1] = r12 R[..., 1, 0] = r21 diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index e066b9eb..cae46765 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,6 +1,8 @@ -import cheetah +import pytest import torch +import cheetah + from .resources import ARESlatticeStage3v1_9 as ares @@ -38,21 +40,21 @@ def test_segment_length_shape_2d(): assert segment.length.shape == (3, 2) -def test_track_particle_single_element_shape(): +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_track_quadrupole_shape(BeamClass): """ - Test that the shape of a beam tracked through a single element matches the input. + Test that the shape of a beam tracked through a single quadrupole element matches + the input. """ quadrupole = cheetah.Quadrupole( length=torch.tensor([0.2, 0.25]), k1=torch.tensor([4.2, 4.2]) ) - incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5, 2e-5]) - ) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor([1e-5, 2e-5])) outgoing = quadrupole.track(incoming) - assert outgoing.particles.shape == incoming.particles.shape - assert outgoing.particles.shape == (2, 100_000, 7) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.mu_x.shape == (2,) assert outgoing.mu_px.shape == (2,) assert outgoing.mu_y.shape == (2,) @@ -65,28 +67,28 @@ def test_track_particle_single_element_shape(): assert outgoing.sigma_p.shape == (2,) assert outgoing.energy.shape == (2,) assert outgoing.total_charge.shape == (2,) - assert outgoing.particle_charges.shape == (2, 100_000) - assert isinstance(outgoing.num_particles, int) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (2, 100_000) -def test_track_particle_single_element_shape_2d(): +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_track_quadrupole_shape_2d(BeamClass): """ - Test that the shape of a beam tracked through a single element matches the input for - an n-dimensional batch. + Test that the shape of a beam tracked through a single quadrupole element matches + the input for an n-dimensional batch. """ quadrupole = cheetah.Quadrupole( length=torch.tensor([[0.2, 0.25], [0.3, 0.35], [0.4, 0.45]]), k1=torch.tensor([[4.2, 4.2], [4.3, 4.3], [4.4, 4.4]]), ) - incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, - sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]), + incoming = BeamClass.from_parameters( + sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]) ) outgoing = quadrupole.track(incoming) - assert outgoing.particles.shape == incoming.particles.shape - assert outgoing.particles.shape == (3, 2, 100_000, 7) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.mu_x.shape == (3, 2) assert outgoing.mu_px.shape == (3, 2) assert outgoing.mu_y.shape == (3, 2) @@ -99,11 +101,12 @@ def test_track_particle_single_element_shape_2d(): assert outgoing.sigma_p.shape == (3, 2) assert outgoing.energy.shape == (3, 2) assert outgoing.total_charge.shape == (3, 2) - assert outgoing.particle_charges.shape == (3, 2, 100_000) - assert isinstance(outgoing.num_particles, int) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (3, 2, 100_000) -def test_track_particle_segment_shape(): +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_track_segment_shape(BeamClass): """ Test that the shape of a beam tracked through a segment matches the input. """ @@ -116,14 +119,12 @@ def test_track_particle_segment_shape(): cheetah.Drift(length=torch.tensor([0.4, 0.3])), ] ) - incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5, 2e-5]) - ) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor([1e-5, 2e-5])) outgoing = segment.track(incoming) - assert outgoing.particles.shape == incoming.particles.shape - assert outgoing.particles.shape == (2, 100_000, 7) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.mu_x.shape == (2,) assert outgoing.mu_px.shape == (2,) assert outgoing.mu_y.shape == (2,) @@ -136,14 +137,15 @@ def test_track_particle_segment_shape(): assert outgoing.sigma_p.shape == (2,) assert outgoing.energy.shape == (2,) assert outgoing.total_charge.shape == (2,) - assert outgoing.particle_charges.shape == (2, 100_000) - assert isinstance(outgoing.num_particles, int) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (2, 100_000) -def test_track_particle_segment_shape_2d(): +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_track_particle_segment_shape_2d(BeamClass): """ - Test that the shape of a beam tracked through a segment matches the input for the - case of a multi-dimensional batch. + Test that the shape of a particle beam tracked through a segment matches the input + for the case of a multi-dimensional batch. """ segment = cheetah.Segment( elements=[ @@ -155,15 +157,14 @@ def test_track_particle_segment_shape_2d(): cheetah.Drift(length=torch.tensor([[0.4, 0.3], [0.6, 0.5], [0.8, 0.7]])), ] ) - incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, - sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]), + incoming = BeamClass.from_parameters( + sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]) ) outgoing = segment.track(incoming) - assert outgoing.particles.shape == incoming.particles.shape - assert outgoing.particles.shape == (3, 2, 100_000, 7) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.mu_x.shape == (3, 2) assert outgoing.mu_px.shape == (3, 2) assert outgoing.mu_y.shape == (3, 2) @@ -176,21 +177,83 @@ def test_track_particle_segment_shape_2d(): assert outgoing.sigma_p.shape == (3, 2) assert outgoing.energy.shape == (3, 2) assert outgoing.total_charge.shape == (3, 2) - assert outgoing.particle_charges.shape == (3, 2, 100_000) - assert isinstance(outgoing.num_particles, int) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (3, 2, 100_000) -def test_track_parameter_single_element_shape(): +def test_enormous_through_ares(): """ - Test that the shape of a beam tracked through a single element matches the input. + Test ARES EA with a huge number of settings. This is a stress test and only run + for `ParameterBeam` because `ParticleBeam` would require a lot of memory. """ - quadrupole = cheetah.Quadrupole( - length=torch.tensor([0.2, 0.25]), k1=torch.tensor([4.2, 4.2]) + segment = cheetah.Segment.from_ocelot(ares.cell).subcell("AREASOLA1", "AREABSCR1") + incoming = cheetah.ParameterBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" ) - incoming = cheetah.ParameterBeam.from_parameters(sigma_x=torch.tensor([1e-5, 2e-5])) - outgoing = quadrupole.track(incoming) + segment.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 200_000).repeat(3, 1) + + outgoing = segment.track(incoming) + + assert outgoing.mu_x.shape == (3, 200_000) + assert outgoing.mu_px.shape == (3, 200_000) + assert outgoing.mu_y.shape == (3, 200_000) + assert outgoing.mu_py.shape == (3, 200_000) + assert outgoing.sigma_x.shape == (3, 200_000) + assert outgoing.sigma_px.shape == (3, 200_000) + assert outgoing.sigma_y.shape == (3, 200_000) + assert outgoing.sigma_py.shape == (3, 200_000) + assert outgoing.sigma_tau.shape == (3, 200_000) + assert outgoing.sigma_p.shape == (3, 200_000) + assert outgoing.energy.shape == (3, 200_000) + assert outgoing.total_charge.shape == (3, 200_000) + + +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_cavity_with_zero_and_non_zero_voltage(BeamClass): + """ + Tests that if zero and non-zero voltages are passed to a cavity in a single batch, + there are no errors. This test does NOT check physical correctness. + """ + cavity = cheetah.Cavity( + length=torch.tensor(3.0441), + voltage=torch.tensor([0.0, 48198468.0, 0.0]), + phase=torch.tensor(48198468.0), + frequency=torch.tensor(2.8560e09), + name="my_test_cavity", + ) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) + + outgoing = cavity.track(incoming) + + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (3, 100_000, 7) + assert outgoing.mu_x.shape == (3,) + assert outgoing.mu_px.shape == (3,) + assert outgoing.mu_y.shape == (3,) + assert outgoing.mu_py.shape == (3,) + assert outgoing.sigma_x.shape == (3,) + assert outgoing.sigma_px.shape == (3,) + assert outgoing.sigma_y.shape == (3,) + assert outgoing.sigma_py.shape == (3,) + assert outgoing.sigma_tau.shape == (3,) + assert outgoing.sigma_p.shape == (3,) + assert outgoing.energy.shape == (3,) + assert outgoing.total_charge.shape == (3,) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (3, 100_000) + + +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_vectorized_undulator(BeamClass): + """Test that a vectorized `Undulator` is able to track a particle beam.""" + element = cheetah.Undulator(length=torch.tensor([0.4, 0.7])) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) + outgoing = element.track(incoming) + + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.mu_x.shape == (2,) assert outgoing.mu_px.shape == (2,) assert outgoing.mu_y.shape == (2,) @@ -203,54 +266,22 @@ def test_track_parameter_single_element_shape(): assert outgoing.sigma_p.shape == (2,) assert outgoing.energy.shape == (2,) assert outgoing.total_charge.shape == (2,) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (2, 100_000) -def test_track_parameter_single_element_shape_2d(): - """ - Test that the shape of a beam tracked through a single element matches the input for - an n-dimensional batch. - """ - quadrupole = cheetah.Quadrupole( - length=torch.tensor([[0.2, 0.25], [0.3, 0.35], [0.4, 0.45]]), - k1=torch.tensor([[4.2, 4.2], [4.3, 4.3], [4.4, 4.4]]), - ) - incoming = cheetah.ParameterBeam.from_parameters( - sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]) - ) - - outgoing = quadrupole.track(incoming) - - assert outgoing.mu_x.shape == (3, 2) - assert outgoing.mu_px.shape == (3, 2) - assert outgoing.mu_y.shape == (3, 2) - assert outgoing.mu_py.shape == (3, 2) - assert outgoing.sigma_x.shape == (3, 2) - assert outgoing.sigma_px.shape == (3, 2) - assert outgoing.sigma_y.shape == (3, 2) - assert outgoing.sigma_py.shape == (3, 2) - assert outgoing.sigma_tau.shape == (3, 2) - assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) - - -def test_track_parameter_segment_shape(): - """ - Test that the shape of a beam tracked through a segment matches the input. - """ - segment = cheetah.Segment( - elements=[ - cheetah.Drift(length=torch.tensor([0.6, 0.5])), - cheetah.Quadrupole( - length=torch.tensor([0.2, 0.25]), k1=torch.tensor([4.2, 4.2]) - ), - cheetah.Drift(length=torch.tensor([0.4, 0.3])), - ] +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_vectorized_solenoid(BeamClass): + """Test that a vectorized `Solenoid` is able to track a particle beam.""" + element = cheetah.Solenoid( + length=torch.tensor([0.4, 0.7]), k=torch.tensor([4.2, 3.1]) ) - incoming = cheetah.ParameterBeam.from_parameters(sigma_x=torch.tensor([1e-5, 2e-5])) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) - outgoing = segment.track(incoming) + outgoing = element.track(incoming) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.mu_x.shape == (2,) assert outgoing.mu_px.shape == (2,) assert outgoing.mu_y.shape == (2,) @@ -263,121 +294,5 @@ def test_track_parameter_segment_shape(): assert outgoing.sigma_p.shape == (2,) assert outgoing.energy.shape == (2,) assert outgoing.total_charge.shape == (2,) - - -def test_track_parameter_segment_shape_2d(): - """ - Test that the shape of a beam tracked through a segment matches the input for the - case of a multi-dimensional batch. - """ - segment = cheetah.Segment( - elements=[ - cheetah.Drift(length=torch.tensor([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]])), - cheetah.Quadrupole( - length=torch.tensor([[0.2, 0.25], [0.3, 0.35], [0.4, 0.45]]), - k1=torch.tensor([[4.2, 4.2], [4.3, 4.3], [4.4, 4.4]]), - ), - cheetah.Drift(length=torch.tensor([[0.4, 0.3], [0.6, 0.5], [0.8, 0.7]])), - ] - ) - incoming = cheetah.ParameterBeam.from_parameters( - sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]) - ) - - outgoing = segment.track(incoming) - - assert outgoing.mu_x.shape == (3, 2) - assert outgoing.mu_px.shape == (3, 2) - assert outgoing.mu_y.shape == (3, 2) - assert outgoing.mu_py.shape == (3, 2) - assert outgoing.sigma_x.shape == (3, 2) - assert outgoing.sigma_px.shape == (3, 2) - assert outgoing.sigma_y.shape == (3, 2) - assert outgoing.sigma_py.shape == (3, 2) - assert outgoing.sigma_tau.shape == (3, 2) - assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) - - -def test_enormous_through_ares(): - """Test ARES EA with a huge number of settings.""" - segment = cheetah.Segment.from_ocelot(ares.cell).subcell("AREASOLA1", "AREABSCR1") - incoming = cheetah.ParameterBeam.from_astra( - "tests/resources/ACHIP_EA1_2021.1351.001" - ) - - segment.AREAMQZM1.k1 = torch.linspace(-30.0, 30.0, 10).repeat(3, 1) - - outgoing = segment.track(incoming) - - assert outgoing.mu_x.shape == (3, 10) - assert outgoing.mu_px.shape == (3, 10) - assert outgoing.mu_y.shape == (3, 10) - assert outgoing.mu_py.shape == (3, 10) - assert outgoing.sigma_x.shape == (3, 10) - assert outgoing.sigma_px.shape == (3, 10) - assert outgoing.sigma_y.shape == (3, 10) - assert outgoing.sigma_py.shape == (3, 10) - assert outgoing.sigma_tau.shape == (3, 10) - assert outgoing.sigma_p.shape == (3, 10) - assert outgoing.energy.shape == (3, 10) - assert outgoing.total_charge.shape == (3, 10) - - -def test_cavity_with_zero_and_non_zero_voltage(): - """ - Tests that if zero and non-zero voltages are passed to a cavity in a single batch, - there are no errors. This test does NOT check physical correctness. - """ - cavity = cheetah.Cavity( - length=torch.tensor([3.0441, 3.0441, 3.0441]), - voltage=torch.tensor([0.0, 48198468.0, 0.0]), - phase=torch.tensor([48198468.0, 48198468.0, 48198468.0]), - frequency=torch.tensor([2.8560e09, 2.8560e09, 2.8560e09]), - name="my_test_cavity", - ) - beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor(1e-5) - ) - - _ = cavity.track(beam) - - -def test_screen_length_shape(): - """ - Test that the shape of a screen's length matches the shape of its misalignment. - """ - screen = cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2], [0.3, 0.4]])) - assert screen.length.shape == screen.misalignment.shape[:-1] - - -def test_screen_length_broadcast_shape(): - """ - Test that the shape of a screen's length matches the shape of its misalignment - after broadcasting. - """ - screen = cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2]])) - assert screen.length.shape == screen.misalignment.shape[:-1] - - -def test_vectorized_undulator(): - """Test that a vectorized `Undulator` is able to track a particle beam.""" - element = cheetah.Undulator(length=torch.tensor([0.4, 0.7])) - beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]) - ) - - _ = element.track(beam) - - -def test_vectorized_solenoid(): - """Test that a vectorized `Solenoid` is able to track a particle beam.""" - element = cheetah.Solenoid( - length=torch.tensor([0.4, 0.7]), k=torch.tensor([4.2, 3.1]) - ) - beam = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5]) - ) - - _ = element.track(beam) + if BeamClass == cheetah.ParticleBeam: + assert outgoing.particle_charges.shape == (2, 100_000) From 4606ccf7efe23a3cdde9fd62062a027f13949095 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 2 Sep 2024 14:44:08 +0200 Subject: [PATCH 040/111] Minor fixes --- cheetah/particles/parameter_beam.py | 26 ----- cheetah/particles/particle_beam.py | 160 +++++++++++----------------- 2 files changed, 60 insertions(+), 126 deletions(-) diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index 1288bda5..c5da1824 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -318,32 +318,6 @@ def transformed_to( device = device if device is not None else self.mu_x.device dtype = dtype if dtype is not None else self.mu_x.dtype - # Figure out batch size of the original beam and check that passed arguments - # have the same batch size - shape = mu_x.shape - not_nones = [ - argument - for argument in [ - mu_x, - mu_px, - mu_y, - mu_py, - sigma_x, - sigma_px, - sigma_y, - sigma_py, - sigma_tau, - sigma_p, - energy, - total_charge, - ] - if argument is not None - ] - if len(not_nones) > 0: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." - mu_x = mu_x if mu_x is not None else self.mu_x mu_px = mu_px if mu_px is not None else self.mu_px mu_y = mu_y if mu_y is not None else self.mu_y diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index a7b75534..1a722460 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -1,6 +1,7 @@ from typing import Optional import torch +from icecream import ic from scipy import constants from torch.distributions import MultivariateNormal @@ -98,67 +99,55 @@ def from_parameters( :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. """ - # Figure out if arguments were passed, figure out their shape - not_nones = [ - argument - for argument in [ - mu_x, - mu_px, - mu_y, - mu_py, - sigma_x, - sigma_px, - sigma_y, - sigma_py, - sigma_tau, - sigma_p, - cor_x, - cor_y, - cor_tau, - energy, - total_charge, - ] - if argument is not None - ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." # Set default values without function call in function signature num_particles = ( num_particles if num_particles is not None else torch.tensor(100_000) ) - mu_x = mu_x if mu_x is not None else torch.full(shape, 0.0) - mu_px = mu_px if mu_px is not None else torch.full(shape, 0.0) - mu_y = mu_y if mu_y is not None else torch.full(shape, 0.0) - mu_py = mu_py if mu_py is not None else torch.full(shape, 0.0) - sigma_x = sigma_x if sigma_x is not None else torch.full(shape, 175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.full(shape, 2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.full(shape, 175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.full(shape, 2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.full(shape, 1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.full(shape, 1e-6) - cor_x = cor_x if cor_x is not None else torch.full(shape, 0.0) - cor_y = cor_y if cor_y is not None else torch.full(shape, 0.0) - cor_tau = cor_tau if cor_tau is not None else torch.full(shape, 0.0) - energy = energy if energy is not None else torch.full(shape, 1e8) - total_charge = ( - total_charge if total_charge is not None else torch.full(shape, 0.0) - ) - particle_charges = ( - torch.ones((*shape, num_particles), device=device, dtype=dtype) - * total_charge.unsqueeze(-1) - / num_particles - ) + mu_x = mu_x if mu_x is not None else torch.tensor(0.0) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0) + sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) + sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) + sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) + sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) + sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) + sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) + cor_x = cor_x if cor_x is not None else torch.tensor(0.0) + cor_y = cor_y if cor_y is not None else torch.tensor(0.0) + cor_tau = cor_tau if cor_tau is not None else torch.tensor(0.0) + energy = energy if energy is not None else torch.tensor(1e8) + total_charge = total_charge if total_charge is not None else torch.tensor(0.0) + particle_charges = torch.ones(num_particles) * total_charge / num_particles mean = torch.stack( - [mu_x, mu_px, mu_y, mu_py, torch.full(shape, 0.0), torch.full(shape, 0.0)], + [mu_x, mu_px, mu_y, mu_py, torch.zeros_like(mu_x), torch.zeros_like(mu_x)], dim=-1, ) - cov = torch.zeros(*shape, 6, 6) + ( + sigma_x, + cor_x, + sigma_px, + sigma_y, + cor_y, + sigma_py, + sigma_tau, + cor_tau, + sigma_p, + ) = torch.broadcast_tensors( + sigma_x, + cor_x, + sigma_px, + sigma_y, + cor_y, + sigma_py, + sigma_tau, + cor_tau, + sigma_p, + ) + cov = torch.zeros(*sigma_x.shape, 6, 6) cov[..., 0, 0] = sigma_x**2 cov[..., 0, 1] = cor_x cov[..., 1, 0] = cor_x @@ -172,7 +161,7 @@ def from_parameters( cov[..., 5, 4] = cor_tau cov[..., 5, 5] = sigma_p**2 - particles = torch.ones((*shape, num_particles, 7)) + particles = torch.ones((*mean.shape[:-1], num_particles, 7)) distributions = [ MultivariateNormal(sample_mean, covariance_matrix=sample_cov) for sample_mean, sample_cov in zip(mean.view(-1, 6), cov.view(-1, 6, 6)) @@ -180,7 +169,7 @@ def from_parameters( particles[..., :6] = torch.stack( [distribution.sample((num_particles,)) for distribution in distributions], dim=0, - ).view(*shape, num_particles, 6) + ).view(*particles.shape[:-2], num_particles, 6) return cls( particles, @@ -380,9 +369,7 @@ def uniform_3d_ellipsoid( y = (torch.rand(num_particles) - 0.5) * 2 * r_y tau = (torch.rand(num_particles) - 0.5) * 2 * r_tau - is_in_ellipsoid = ( - x**2 / r_x**2 + y**2 / r_y**2 + tau**2 / r_tau**2 < 1 - ) + is_in_ellipsoid = x**2 / r_x**2 + y**2 / r_y**2 + tau**2 / r_tau**2 < 1 num_to_add = min(num_particles - num_successful, is_in_ellipsoid.sum()) flattened_x[i, num_successful : num_successful + num_to_add] = x[ @@ -656,33 +643,6 @@ def transformed_to( device = device if device is not None else self.mu_x.device dtype = dtype if dtype is not None else self.mu_x.dtype - # Figure out batch size of the original beam and check that passed arguments - # have the same batch size - shape = mu_x.shape - not_nones = [ - argument - for argument in [ - mu_x, - mu_px, - mu_y, - mu_py, - sigma_x, - sigma_px, - sigma_y, - sigma_py, - sigma_tau, - sigma_p, - energy, - total_charge, - ] - if argument is not None - ] - if len(not_nones) > 0: - if not all(argument.shape == torch.Size([]) for argument in not_nones): - raise NotImplementedError( - "Batching not implemented yet. Arguments must have shape []." - ) - mu_x = mu_x if mu_x is not None else self.mu_x mu_y = mu_y if mu_y is not None else self.mu_y mu_px = mu_px if mu_px is not None else self.mu_px @@ -710,29 +670,29 @@ def transformed_to( new_mu = torch.stack( [ - mu_x.squeeze(), - mu_px.squeeze(), - mu_y.squeeze(), - mu_py.squeeze(), - torch.full(shape, 0.0), - torch.full(shape, 0.0), + mu_x, + mu_px, + mu_y, + mu_py, + torch.full_like(mu_x, 0.0), + torch.full_like(mu_x, 0.0), ], - dim=0, + dim=-1, ) new_sigma = torch.stack( - [sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p], dim=0 + [sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p], dim=-1 ) old_mu = torch.stack( [ - self.mu_x.squeeze(), - self.mu_px.squeeze(), - self.mu_y.squeeze(), - self.mu_py.squeeze(), - torch.full(shape, 0.0), - torch.full(shape, 0.0), + self.mu_x, + self.mu_px, + self.mu_y, + self.mu_py, + torch.full_like(self.mu_x, 0.0), + torch.full_like(self.mu_x, 0.0), ], - dim=0, + dim=-1, ) old_sigma = torch.stack( [ @@ -743,8 +703,8 @@ def transformed_to( self.sigma_tau, self.sigma_p, ], - dim=0, - ).squeeze() + dim=-1, + ) phase_space = self.particles[:, :, :6] phase_space = ( From 511fe7f6e2dd701c9643c156b7424fce01169e1a Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 2 Sep 2024 14:45:06 +0200 Subject: [PATCH 041/111] Fix flake8 warning --- cheetah/particles/particle_beam.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 1a722460..86722fec 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -1,7 +1,6 @@ from typing import Optional import torch -from icecream import ic from scipy import constants from torch.distributions import MultivariateNormal From 0703c134e97f91a2edca0be37993a2d21bfc5a4c Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 2 Sep 2024 14:53:23 +0200 Subject: [PATCH 042/111] Fix messed up formatting --- cheetah/__init__.py | 2 +- cheetah/accelerator/quadrupole.py | 4 ++-- cheetah/accelerator/screen.py | 2 +- tests/test_astra_import.py | 3 ++- tests/test_bmad_conversion.py | 3 ++- tests/test_bpm.py | 3 ++- tests/test_cavity.py | 3 ++- tests/test_compare_beam_type.py | 3 ++- tests/test_compare_ocelot.py | 4 ++-- tests/test_device_dtype.py | 3 ++- tests/test_differentiable.py | 3 ++- tests/test_drift.py | 3 ++- tests/test_ocelot_import.py | 3 ++- tests/test_reading_nx_tables.py | 3 ++- tests/test_screen.py | 3 ++- tests/test_space_charge_kick.py | 3 ++- tests/test_speed.py | 4 ++-- tests/test_speed_optimizations.py | 3 ++- tests/test_tracking_lengthless_elements.py | 3 ++- 19 files changed, 36 insertions(+), 22 deletions(-) diff --git a/cheetah/__init__.py b/cheetah/__init__.py index bf5f5e3b..91c13335 100644 --- a/cheetah/__init__.py +++ b/cheetah/__init__.py @@ -1,7 +1,7 @@ from . import converters # noqa: F401 from .accelerator import ( # noqa: F401 - Aperture, BPM, + Aperture, Cavity, CustomTransferMap, Dipole, diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 63247770..841c746d 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -5,11 +5,11 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import nn, Size +from torch import Size, nn from ..particles import Beam, ParticleBeam from ..track_methods import base_rmatrix, misalignment_matrix -from ..utils import bmadx, UniqueNameGenerator +from ..utils import UniqueNameGenerator, bmadx from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 94bf2b67..b8804770 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -8,7 +8,7 @@ from torch.distributions import MultivariateNormal from ..particles import Beam, ParameterBeam, ParticleBeam -from ..utils import kde_histogram_2d, UniqueNameGenerator +from ..utils import UniqueNameGenerator, kde_histogram_2d from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/tests/test_astra_import.py b/tests/test_astra_import.py index 7bda194e..4e361cce 100644 --- a/tests/test_astra_import.py +++ b/tests/test_astra_import.py @@ -1,7 +1,8 @@ -import cheetah import numpy as np import torch +import cheetah + def test_astra_to_parameter_beam(): """Test that Astra beams are correctly loaded into parameter beams.""" diff --git a/tests/test_bmad_conversion.py b/tests/test_bmad_conversion.py index a0e76961..9dc78849 100644 --- a/tests/test_bmad_conversion.py +++ b/tests/test_bmad_conversion.py @@ -1,6 +1,7 @@ -import cheetah import pytest import torch + +import cheetah from cheetah.utils import is_mps_available_and_functional diff --git a/tests/test_bpm.py b/tests/test_bpm.py index d6c0ad1a..6d58ead5 100644 --- a/tests/test_bpm.py +++ b/tests/test_bpm.py @@ -1,7 +1,8 @@ -import cheetah import pytest import torch +import cheetah + @pytest.mark.parametrize("is_bpm_active", [True, False]) @pytest.mark.parametrize("beam_class", [cheetah.ParticleBeam, cheetah.ParameterBeam]) diff --git a/tests/test_cavity.py b/tests/test_cavity.py index 67f38439..7129cbc0 100644 --- a/tests/test_cavity.py +++ b/tests/test_cavity.py @@ -1,7 +1,8 @@ -import cheetah import pytest import torch +import cheetah + def test_assert_ei_greater_zero(): """ diff --git a/tests/test_compare_beam_type.py b/tests/test_compare_beam_type.py index 6ddb22bc..6a6d3e28 100644 --- a/tests/test_compare_beam_type.py +++ b/tests/test_compare_beam_type.py @@ -2,9 +2,10 @@ Tests that ensure that both beam types produce (roughly) the same results. """ -import cheetah import torch +import cheetah + def test_from_twiss(): """ diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 6aab1f63..d9586e6b 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -1,11 +1,11 @@ from copy import deepcopy -import cheetah - import numpy as np import ocelot import torch +import cheetah + from .resources import ARESlatticeStage3v1_9 as ares diff --git a/tests/test_device_dtype.py b/tests/test_device_dtype.py index cdb3b522..fa3e0a45 100644 --- a/tests/test_device_dtype.py +++ b/tests/test_device_dtype.py @@ -1,6 +1,7 @@ -import cheetah import pytest import torch + +import cheetah from cheetah.utils import is_mps_available_and_functional diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index 47c1493e..c60e07dd 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -1,7 +1,8 @@ -import cheetah import torch from torch import nn +import cheetah + from .resources import ARESlatticeStage3v1_9 as ares diff --git a/tests/test_drift.py b/tests/test_drift.py index 00b40a7b..fd4ee31a 100644 --- a/tests/test_drift.py +++ b/tests/test_drift.py @@ -1,7 +1,8 @@ -import cheetah import pytest import torch +import cheetah + def test_diverging_parameter_beam(): """ diff --git a/tests/test_ocelot_import.py b/tests/test_ocelot_import.py index 27f97d92..cb2ee660 100644 --- a/tests/test_ocelot_import.py +++ b/tests/test_ocelot_import.py @@ -1,8 +1,9 @@ -import cheetah import numpy as np import ocelot import pytest +import cheetah + from .resources import ARESlatticeStage3v1_9 as ares diff --git a/tests/test_reading_nx_tables.py b/tests/test_reading_nx_tables.py index 8802b58d..3689f65b 100644 --- a/tests/test_reading_nx_tables.py +++ b/tests/test_reading_nx_tables.py @@ -1,6 +1,7 @@ -import cheetah import torch +import cheetah + def test_no_error(): """ diff --git a/tests/test_screen.py b/tests/test_screen.py index 8cf5d745..c9cdfbc5 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -1,8 +1,9 @@ -import cheetah import numpy as np import pytest import torch +import cheetah + from .resources import ARESlatticeStage3v1_9 as ocelot_lattice diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 31aa3b28..82130c93 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -1,9 +1,10 @@ -import cheetah import torch from scipy import constants from scipy.constants import physical_constants from torch import nn +import cheetah + def test_cold_uniform_beam_expansion(): """ diff --git a/tests/test_speed.py b/tests/test_speed.py index f8325822..917f9550 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -1,9 +1,9 @@ import time -import cheetah - import torch +import cheetah + from .resources import ARESlatticeStage3v1_9 as ares diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 0e72ed1e..8d760f78 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -1,7 +1,8 @@ -import cheetah import pytest import torch +import cheetah + def test_merged_transfer_maps_tracking(): """ diff --git a/tests/test_tracking_lengthless_elements.py b/tests/test_tracking_lengthless_elements.py index 07e6c4ac..d5a738b7 100644 --- a/tests/test_tracking_lengthless_elements.py +++ b/tests/test_tracking_lengthless_elements.py @@ -1,6 +1,7 @@ -import cheetah import torch +import cheetah + beam_in = cheetah.ParticleBeam.from_parameters(num_particles=100) From 862087e5cb58421b952c4c12955af57730c42edf Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 2 Sep 2024 14:54:31 +0200 Subject: [PATCH 043/111] Another reimaing formatting fix --- tests/test_split.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_split.py b/tests/test_split.py index 16769f55..868c3f9c 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -1,7 +1,8 @@ -import cheetah import pytest import torch +import cheetah + def test_drift_end(): """ From e2845037227bdc90dd89ec8abc277903ced3f0c6 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 2 Sep 2024 16:11:28 +0200 Subject: [PATCH 044/111] Fix all but screen and space charge tests --- cheetah/accelerator/cavity.py | 8 +-- cheetah/accelerator/element.py | 10 ++- cheetah/accelerator/quadrupole.py | 2 + cheetah/accelerator/screen.py | 2 +- cheetah/particles/parameter_beam.py | 97 +++++++++++++---------------- cheetah/particles/particle_beam.py | 40 ++++++------ tests/test_particle_beam.py | 7 +-- tests/test_quadrupole.py | 4 +- tests/test_screen.py | 16 ++--- tests/test_vectorized.py | 46 +++++++------- 10 files changed, 107 insertions(+), 125 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 793c8155..f303a972 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -224,8 +224,8 @@ def _track_beam(self, incoming: Beam) -> Beam: outgoing = ParameterBeam( outgoing_mu, outgoing_cov, - outgoing_energy.expand(outgoing_mu.shape[:-1]), - total_charge=incoming.total_charge.expand(outgoing_mu.shape[:-1]), + outgoing_energy, + total_charge=incoming.total_charge, device=outgoing_mu.device, dtype=outgoing_mu.dtype, ) @@ -234,9 +234,7 @@ def _track_beam(self, incoming: Beam) -> Beam: outgoing = ParticleBeam( outgoing_particles, outgoing_energy, - particle_charges=incoming.particle_charges.expand( - outgoing_particles.shape[:-1] - ), + particle_charges=incoming.particle_charges, device=outgoing_particles.device, dtype=outgoing_particles.dtype, ) diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index b40acac9..560fba23 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -69,8 +69,8 @@ def track(self, incoming: Beam) -> Beam: return ParameterBeam( mu, cov, - incoming.energy.expand(mu.shape[:-1]), - total_charge=incoming.total_charge.expand(mu.shape[:-1]), + incoming.energy, + total_charge=incoming.total_charge, device=mu.device, dtype=mu.dtype, ) @@ -79,10 +79,8 @@ def track(self, incoming: Beam) -> Beam: new_particles = torch.matmul(incoming.particles, tm.transpose(-2, -1)) return ParticleBeam( new_particles, - incoming.energy.expand(new_particles.shape[:-2]), - particle_charges=incoming.particle_charges.expand( - new_particles.shape[:-1] - ), + incoming.energy, + particle_charges=incoming.particle_charges, device=new_particles.device, dtype=new_particles.dtype, ) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 841c746d..343e482b 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -178,6 +178,8 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) # End of Bmad-X tracking + x, px, y, py, z, pz = torch.broadcast_tensors(x, px, y, py, z, pz) + bmad_coords = torch.empty((*x.shape, 6), device=x.device, dtype=x.dtype) bmad_coords[..., 0] = x bmad_coords[..., 1] = px bmad_coords[..., 2] = y diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index b8804770..33139290 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -81,7 +81,7 @@ def __init__( ( torch.as_tensor(misalignment, **factory_kwargs) if misalignment is not None - else torch.tensor([(0.0, 0.0)], **factory_kwargs) + else torch.tensor((0.0, 0.0), **factory_kwargs) ), ) self.register_buffer( diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index c5da1824..f5189b15 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -63,52 +63,22 @@ def from_parameters( device=None, dtype=torch.float32, ) -> "ParameterBeam": - # Figure out if arguments were passed, figure out their shape - not_nones = [ - argument - for argument in [ - mu_x, - mu_px, - mu_y, - mu_py, - sigma_x, - sigma_px, - sigma_y, - sigma_py, - sigma_tau, - sigma_p, - cor_x, - cor_y, - cor_tau, - energy, - total_charge, - ] - if argument is not None - ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." - # Set default values without function call in function signature - mu_x = mu_x if mu_x is not None else torch.full(shape, 0.0) - mu_px = mu_px if mu_px is not None else torch.full(shape, 0.0) - mu_y = mu_y if mu_y is not None else torch.full(shape, 0.0) - mu_py = mu_py if mu_py is not None else torch.full(shape, 0.0) - sigma_x = sigma_x if sigma_x is not None else torch.full(shape, 175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.full(shape, 2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.full(shape, 175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.full(shape, 2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.full(shape, 1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.full(shape, 1e-6) - cor_x = cor_x if cor_x is not None else torch.full(shape, 0.0) - cor_y = cor_y if cor_y is not None else torch.full(shape, 0.0) - cor_tau = cor_tau if cor_tau is not None else torch.full(shape, 0.0) - energy = energy if energy is not None else torch.full(shape, 1e8) - total_charge = ( - total_charge if total_charge is not None else torch.full(shape, 0.0) - ) + mu_x = mu_x if mu_x is not None else torch.tensor(0.0) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0) + sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) + sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) + sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) + sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) + sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) + sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) + cor_x = cor_x if cor_x is not None else torch.tensor(0.0) + cor_y = cor_y if cor_y is not None else torch.tensor(0.0) + cor_tau = cor_tau if cor_tau is not None else torch.tensor(0.0) + energy = energy if energy is not None else torch.tensor(1e8) + total_charge = total_charge if total_charge is not None else torch.tensor(0.0) mu = torch.stack( [ @@ -116,14 +86,35 @@ def from_parameters( mu_px, mu_y, mu_py, - torch.full(shape, 0.0), - torch.full(shape, 0.0), - torch.full(shape, 1.0), + torch.tensor(0.0), + torch.tensor(0.0), + torch.tensor(1.0), ], dim=-1, ) - cov = torch.zeros(*shape, 7, 7) + ( + sigma_x, + cor_x, + sigma_px, + sigma_y, + cor_y, + sigma_py, + sigma_tau, + cor_tau, + sigma_p, + ) = torch.broadcast_tensors( + sigma_x, + cor_x, + sigma_px, + sigma_y, + cor_y, + sigma_py, + sigma_tau, + cor_tau, + sigma_p, + ) + cov = torch.zeros(*sigma_x.shape, 7, 7) cov[..., 0, 0] = sigma_x**2 cov[..., 0, 1] = cor_x cov[..., 1, 0] = cor_x @@ -271,10 +262,10 @@ def from_astra(cls, path: str, device=None, dtype=torch.float32) -> "ParameterBe total_charge = torch.tensor(np.sum(particle_charges), dtype=torch.float32) return cls( - mu=mu.unsqueeze(0), - cov=cov.unsqueeze(0), - energy=torch.tensor(energy, dtype=torch.float32).unsqueeze(0), - total_charge=total_charge.unsqueeze(0), + mu=mu, + cov=cov, + energy=torch.tensor(energy, dtype=torch.float32), + total_charge=total_charge, device=device, dtype=dtype, ) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 86722fec..a6c0d790 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -118,8 +118,13 @@ def from_parameters( cor_tau = cor_tau if cor_tau is not None else torch.tensor(0.0) energy = energy if energy is not None else torch.tensor(1e8) total_charge = total_charge if total_charge is not None else torch.tensor(0.0) - particle_charges = torch.ones(num_particles) * total_charge / num_particles + particle_charges = ( + torch.ones((*total_charge.shape, num_particles)) + * total_charge.unsqueeze(-1) + / num_particles + ) + mu_x, mu_px, mu_y, mu_py = torch.broadcast_tensors(mu_x, mu_px, mu_y, mu_py) mean = torch.stack( [mu_x, mu_px, mu_y, mu_py, torch.zeros_like(mu_x), torch.zeros_like(mu_x)], dim=-1, @@ -667,15 +672,9 @@ def transformed_to( / self.particle_charges.shape[-1] ) + mu_x, mu_px, mu_y, mu_py = torch.broadcast_tensors(mu_x, mu_px, mu_y, mu_py) new_mu = torch.stack( - [ - mu_x, - mu_px, - mu_y, - mu_py, - torch.full_like(mu_x, 0.0), - torch.full_like(mu_x, 0.0), - ], + [mu_x, mu_px, mu_y, mu_py, torch.zeros_like(mu_x), torch.zeros_like(mu_x)], dim=-1, ) new_sigma = torch.stack( @@ -688,8 +687,8 @@ def transformed_to( self.mu_px, self.mu_y, self.mu_py, - torch.full_like(self.mu_x, 0.0), - torch.full_like(self.mu_x, 0.0), + torch.zeros_like(self.mu_x), + torch.zeros_like(self.mu_x), ], dim=-1, ) @@ -705,17 +704,16 @@ def transformed_to( dim=-1, ) - phase_space = self.particles[:, :, :6] + phase_space = self.particles[..., :6] phase_space = ( - phase_space - old_mu.expand(*phase_space.shape) - ) / old_sigma.expand(*phase_space.shape) * new_sigma.expand( - *phase_space.shape - ) + new_mu.expand( - *phase_space.shape - ) - - particles = torch.ones_like(self.particles) - particles[:, :, :6] = phase_space + (phase_space.transpose(-2, -1) - old_mu.unsqueeze(-1)) + / old_sigma.unsqueeze(-1) + * new_sigma.unsqueeze(-1) + + new_mu.unsqueeze(-1) + ).transpose(-2, -1) + + particles = torch.ones(*phase_space.shape[:-1], 7) + particles[..., :6] = phase_space return self.__class__( particles=particles, diff --git a/tests/test_particle_beam.py b/tests/test_particle_beam.py index e153ac71..ad257419 100644 --- a/tests/test_particle_beam.py +++ b/tests/test_particle_beam.py @@ -80,11 +80,6 @@ def test_transform_to(): assert np.isclose(transformed_beam.energy.cpu().numpy(), 1e7) assert np.isclose(transformed_beam.total_charge.cpu().numpy(), 1e-9) - with pytest.raises(NotImplementedError): - original_beam.transformed_to( - mu_x=torch.tensor(1e-5).expand([3, 2]), - ) - def test_from_twiss_to_twiss(): """ @@ -111,7 +106,7 @@ def test_from_twiss_to_twiss(): assert np.isclose(beam.energy.cpu().numpy(), 6e6) -def test_generate_uniform_ellipsoid_batched(): +def test_generate_uniform_ellipsoid_vectorized(): """ Test that a `ParticleBeam` generated from a uniform 3D ellipsoid has the correct parameters, i.e. the all particles are within the ellipsoid, and that the other diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index ae85b4b4..b08fe2c2 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -226,5 +226,5 @@ def test_tracking_method_vectorization(tracking_method): assert outgoing.sigma_py.shape == (3, 2) assert outgoing.sigma_tau.shape == (3, 2) assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) diff --git a/tests/test_screen.py b/tests/test_screen.py index c9cdfbc5..0a1a43c3 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -27,13 +27,13 @@ def test_reading_shows_beam_particle(screen_method): beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert np.allclose(segment.my_screen.reading, 0.0) _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert torch.all(segment.my_screen.reading >= 0.0) assert torch.any(segment.my_screen.reading > 0.0) @@ -58,13 +58,13 @@ def test_screen_kde_bandwidth(kde_bandwidth): beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert np.allclose(segment.my_screen.reading, 0.0) _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert torch.all(segment.my_screen.reading >= 0.0) assert torch.any(segment.my_screen.reading > 0.0) @@ -90,13 +90,13 @@ def test_reading_shows_beam_parameter(screen_method): beam = cheetah.ParameterBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert np.allclose(segment.my_screen.reading, 0.0) _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (1, 100, 100) + assert segment.my_screen.reading.shape == (100, 100) assert torch.all(segment.my_screen.reading >= 0.0) assert torch.any(segment.my_screen.reading > 0.0) @@ -125,12 +125,12 @@ def test_reading_shows_beam_ares(screen_method): segment.AREABSCR1.is_active = True assert isinstance(segment.AREABSCR1.reading, torch.Tensor) - assert segment.AREABSCR1.reading.shape == (1, 2040, 2448) + assert segment.AREABSCR1.reading.shape == (2040, 2448) assert np.allclose(segment.AREABSCR1.reading, 0.0) _ = segment.track(beam) assert isinstance(segment.AREABSCR1.reading, torch.Tensor) - assert segment.AREABSCR1.reading.shape == (1, 2040, 2448) + assert segment.AREABSCR1.reading.shape == (2040, 2448) assert torch.all(segment.AREABSCR1.reading >= 0.0) assert torch.any(segment.AREABSCR1.reading > 0.0) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index cae46765..1ff1fc54 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -65,10 +65,10 @@ def test_track_quadrupole_shape(BeamClass): assert outgoing.sigma_py.shape == (2,) assert outgoing.sigma_tau.shape == (2,) assert outgoing.sigma_p.shape == (2,) - assert outgoing.energy.shape == (2,) - assert outgoing.total_charge.shape == (2,) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) if BeamClass == cheetah.ParticleBeam: - assert outgoing.particle_charges.shape == (2, 100_000) + assert outgoing.particle_charges.shape == (100_000,) @pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) @@ -99,10 +99,10 @@ def test_track_quadrupole_shape_2d(BeamClass): assert outgoing.sigma_py.shape == (3, 2) assert outgoing.sigma_tau.shape == (3, 2) assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) if BeamClass == cheetah.ParticleBeam: - assert outgoing.particle_charges.shape == (3, 2, 100_000) + assert outgoing.particle_charges.shape == (100_000,) @pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) @@ -135,10 +135,10 @@ def test_track_segment_shape(BeamClass): assert outgoing.sigma_py.shape == (2,) assert outgoing.sigma_tau.shape == (2,) assert outgoing.sigma_p.shape == (2,) - assert outgoing.energy.shape == (2,) - assert outgoing.total_charge.shape == (2,) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) if BeamClass == cheetah.ParticleBeam: - assert outgoing.particle_charges.shape == (2, 100_000) + assert outgoing.particle_charges.shape == (100_000,) @pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) @@ -175,13 +175,13 @@ def test_track_particle_segment_shape_2d(BeamClass): assert outgoing.sigma_py.shape == (3, 2) assert outgoing.sigma_tau.shape == (3, 2) assert outgoing.sigma_p.shape == (3, 2) - assert outgoing.energy.shape == (3, 2) - assert outgoing.total_charge.shape == (3, 2) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) if BeamClass == cheetah.ParticleBeam: - assert outgoing.particle_charges.shape == (3, 2, 100_000) + assert outgoing.particle_charges.shape == (100_000,) -def test_enormous_through_ares(): +def test_enormous_through_ares_ea(): """ Test ARES EA with a huge number of settings. This is a stress test and only run for `ParameterBeam` because `ParticleBeam` would require a lot of memory. @@ -205,8 +205,8 @@ def test_enormous_through_ares(): assert outgoing.sigma_py.shape == (3, 200_000) assert outgoing.sigma_tau.shape == (3, 200_000) assert outgoing.sigma_p.shape == (3, 200_000) - assert outgoing.energy.shape == (3, 200_000) - assert outgoing.total_charge.shape == (3, 200_000) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) @pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) @@ -239,9 +239,9 @@ def test_cavity_with_zero_and_non_zero_voltage(BeamClass): assert outgoing.sigma_tau.shape == (3,) assert outgoing.sigma_p.shape == (3,) assert outgoing.energy.shape == (3,) - assert outgoing.total_charge.shape == (3,) + assert outgoing.total_charge.shape == torch.Size([]) if BeamClass == cheetah.ParticleBeam: - assert outgoing.particle_charges.shape == (3, 100_000) + assert outgoing.particle_charges.shape == (100_000,) @pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) @@ -264,10 +264,10 @@ def test_vectorized_undulator(BeamClass): assert outgoing.sigma_py.shape == (2,) assert outgoing.sigma_tau.shape == (2,) assert outgoing.sigma_p.shape == (2,) - assert outgoing.energy.shape == (2,) - assert outgoing.total_charge.shape == (2,) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) if BeamClass == cheetah.ParticleBeam: - assert outgoing.particle_charges.shape == (2, 100_000) + assert outgoing.particle_charges.shape == (100_000,) @pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) @@ -292,7 +292,7 @@ def test_vectorized_solenoid(BeamClass): assert outgoing.sigma_py.shape == (2,) assert outgoing.sigma_tau.shape == (2,) assert outgoing.sigma_p.shape == (2,) - assert outgoing.energy.shape == (2,) - assert outgoing.total_charge.shape == (2,) + assert outgoing.energy.shape == torch.Size([]) + assert outgoing.total_charge.shape == torch.Size([]) if BeamClass == cheetah.ParticleBeam: - assert outgoing.particle_charges.shape == (2, 100_000) + assert outgoing.particle_charges.shape == (100_000,) From 8edf7e4f9aeabde6d96255cfb0981e6dfba58a04 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 2 Sep 2024 16:14:33 +0200 Subject: [PATCH 045/111] Fix format --- tests/test_particle_beam.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_particle_beam.py b/tests/test_particle_beam.py index ad257419..3d3f448f 100644 --- a/tests/test_particle_beam.py +++ b/tests/test_particle_beam.py @@ -1,5 +1,4 @@ import numpy as np -import pytest import torch from cheetah import ParticleBeam From 35060cd274452b4c4e2c84a25d25a4fcb410a6c3 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 2 Sep 2024 16:27:33 +0200 Subject: [PATCH 046/111] Add more tests for Screen --- tests/test_vectorized.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 1ff1fc54..7240fe07 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -296,3 +296,30 @@ def test_vectorized_solenoid(BeamClass): assert outgoing.total_charge.shape == torch.Size([]) if BeamClass == cheetah.ParticleBeam: assert outgoing.particle_charges.shape == (100_000,) + + +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +@pytest.mark.parametrize("method", ["histogram", "kde"]) +def test_vectorized_screen_2d(BeamClass, method): + """ + Test that a vectorized `Screen` is able to track a particle beam and produce a + reading with 2D vector dimensions. + """ + element = cheetah.Screen( + resolution=torch.tensor([100, 100]), + pixel_size=torch.tensor([1e-5, 1e-5]), + misalignment=torch.tensor([[1e-4, 2e-4, 3e-4], [4e-4, 5e-4, 6e-4]]), + is_active=True, + method=method, + name="my_screen", + ) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) + + _ = element.track(incoming) + + # Check some properties of the read beam + assert element._read_beam.mu_x.shape == (2, 3) + assert element._read_beam.sigma_x.shape == (2, 3) + + # Check the reading + assert element.reading.shape == (2, 3, 100, 100) From 914d9b1353f16fc7b0e5ee70db6caba0eb29c5ab Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 2 Sep 2024 16:50:48 +0200 Subject: [PATCH 047/111] A few initial fixes to the vectorisation of Screens --- cheetah/accelerator/screen.py | 10 ++++++++-- tests/test_vectorized.py | 9 +++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 33139290..f41310bf 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -162,6 +162,9 @@ def track(self, incoming: Beam) -> Beam: copy_of_incoming = deepcopy(incoming) if isinstance(incoming, ParameterBeam): + copy_of_incoming._mu = torch.broadcast_to( + copy_of_incoming._mu, (*self.misalignment.shape[:-1], 7) + ).clone() copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0] copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1] elif isinstance(incoming, ParticleBeam): @@ -203,6 +206,9 @@ def reading(self) -> torch.Tensor: ], dim=-1, ) + transverse_mu, transverse_cov = torch.broadcast_tensors( + transverse_mu, transverse_cov + ) dist = [ MultivariateNormal( loc=transverse_mu_sample, covariance_matrix=transverse_cov_sample @@ -269,13 +275,13 @@ def get_read_beam(self) -> Beam: # Using these get and set methods instead of Python's property decorator to # prevent `nn.Module` from intercepting the read beam, which is itself an # `nn.Module`, and registering it as a submodule of the screen. - return self._read_beam[0] if self._read_beam is not None else None + return self._read_beam if self._read_beam is not None else None def set_read_beam(self, value: Beam) -> None: # Using these get and set methods instead of Python's property decorator to # prevent `nn.Module` from intercepting the read beam, which is itself an # `nn.Module`, and registering it as a submodule of the screen. - self._read_beam = [value] + self._read_beam = value self.cached_reading = None def split(self, resolution: torch.Tensor) -> list[Element]: diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 7240fe07..f118dce2 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -308,7 +308,12 @@ def test_vectorized_screen_2d(BeamClass, method): element = cheetah.Screen( resolution=torch.tensor([100, 100]), pixel_size=torch.tensor([1e-5, 1e-5]), - misalignment=torch.tensor([[1e-4, 2e-4, 3e-4], [4e-4, 5e-4, 6e-4]]), + misalignment=torch.tensor( + [ + [[1e-4, 2e-4], [3e-4, 4e-4], [5e-4, 6e-4]], + [[-1e-4, -2e-4], [-3e-4, -4e-4], [-5e-4, -6e-4]], + ] + ), is_active=True, method=method, name="my_screen", @@ -319,7 +324,7 @@ def test_vectorized_screen_2d(BeamClass, method): # Check some properties of the read beam assert element._read_beam.mu_x.shape == (2, 3) - assert element._read_beam.sigma_x.shape == (2, 3) + assert element._read_beam.sigma_x.shape == torch.Size([]) # Check the reading assert element.reading.shape == (2, 3, 100, 100) From 24cc2780b675caf2b51a1807105cb2a190509683 Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Fri, 6 Sep 2024 09:49:59 -0500 Subject: [PATCH 048/111] fixes to functionality and tests for screen batching - removes misalignment functionality in screen object (should be re-introduced in a future PR) - allows beam to pass through screen for downstream tracking (useful functionality for beamline design, preserves consistency with other beam dynamics codes) - raises NotImplemented Error for batching calculations using normal histogramming (couldn't figure out an efficient batching method that works for arbitrary batch dims) - refactors from_astra particle imports to remove unnecessary leading batch dimension --- cheetah/accelerator/dipole.py | 2 +- cheetah/accelerator/drift.py | 2 +- cheetah/accelerator/quadrupole.py | 4 + cheetah/accelerator/screen.py | 119 +++++++++++------------------ cheetah/particles/particle_beam.py | 6 +- cheetah/utils/bmadx.py | 5 +- cheetah/utils/kde.py | 10 ++- tests/test_compare_ocelot.py | 22 +++--- tests/test_kde.py | 56 ++++++++++++++ tests/test_quadrupole.py | 4 +- tests/test_screen.py | 44 +++++++---- tests/test_space_charge_kick.py | 2 +- tests/test_vectorized.py | 36 ++++----- 13 files changed, 178 insertions(+), 134 deletions(-) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index c9b5a91a..fe58c71a 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -7,7 +7,7 @@ from scipy.constants import physical_constants from torch import nn -from .. import Beam, ParticleBeam +from ..particles import Beam, ParticleBeam from ..track_methods import base_rmatrix, rotation_matrix from ..utils import UniqueNameGenerator, bmadx from .element import Element diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 53bc0a90..bee747a5 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -5,7 +5,7 @@ from scipy.constants import physical_constants from torch import Size, nn -from .. import Beam, ParticleBeam +from ..particles import Beam, ParticleBeam from ..utils import UniqueNameGenerator, bmadx from ..utils.physics import calculate_relativistic_factors from .element import Element diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index bd990e3d..36fce7fc 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -174,8 +174,12 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: x, px, y, py = bmadx.offset_particle_unset( x_offset, y_offset, self.tilt, x, px, y, py ) + + # p_z is unaffected by tracking, need to match batch dimensions + pz = pz * torch.ones_like(x) # End of Bmad-X tracking + # Convert back to Cheetah coordinates tau, delta, ref_energy = bmadx.bmad_to_cheetah_z_pz( z, pz, p0c, electron_mass_eV diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index f41310bf..93a436d3 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -37,17 +37,17 @@ class Screen(Element): """ def __init__( - self, - resolution: Optional[Union[torch.Tensor, nn.Parameter]] = None, - pixel_size: Optional[Union[torch.Tensor, nn.Parameter]] = None, - binning: Optional[Union[torch.Tensor, nn.Parameter]] = None, - misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, - kde_bandwidth: Optional[Union[torch.Tensor, nn.Parameter]] = None, - is_active: bool = False, - method: Literal["histogram", "kde"] = "histogram", - name: Optional[str] = None, - device=None, - dtype=torch.float32, + self, + resolution: Optional[Union[torch.Tensor, nn.Parameter]] = None, + pixel_size: Optional[Union[torch.Tensor, nn.Parameter]] = None, + binning: Optional[Union[torch.Tensor, nn.Parameter]] = None, + misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, + kde_bandwidth: Optional[Union[torch.Tensor, nn.Parameter]] = None, + is_active: bool = False, + method: Literal["histogram", "kde"] = "histogram", + name: Optional[str] = None, + device=None, + dtype=torch.float32, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -76,14 +76,6 @@ def __init__( else torch.tensor(1, **factory_kwargs) ), ) - self.register_buffer( - "misalignment", - ( - torch.as_tensor(misalignment, **factory_kwargs) - if misalignment is not None - else torch.tensor((0.0, 0.0), **factory_kwargs) - ), - ) self.register_buffer( "length", torch.zeros(self.misalignment.shape[:-1], **factory_kwargs), @@ -161,37 +153,25 @@ def track(self, incoming: Beam) -> Beam: if self.is_active: copy_of_incoming = deepcopy(incoming) - if isinstance(incoming, ParameterBeam): - copy_of_incoming._mu = torch.broadcast_to( - copy_of_incoming._mu, (*self.misalignment.shape[:-1], 7) - ).clone() - copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0] - copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1] - elif isinstance(incoming, ParticleBeam): - copy_of_incoming.particles[..., :, 0] -= self.misalignment[..., 0] - copy_of_incoming.particles[..., :, 1] -= self.misalignment[..., 1] - self.set_read_beam(copy_of_incoming) - return Beam.empty - else: - return incoming + return incoming @property def reading(self) -> torch.Tensor: + image = None if self.cached_reading is not None: return self.cached_reading read_beam = self.get_read_beam() if read_beam is Beam.empty or read_beam is None: - image = torch.zeros( - ( - *self.misalignment.shape[:-1], - int(self.effective_resolution[1]), - int(self.effective_resolution[0]), - ) - ) + return torch.tensor([]) elif isinstance(read_beam, ParameterBeam): + if torch.numel(read_beam._mu[..., 0]) > 1: + raise NotImplementedError( + "cannot perform batch screen predictions with ParameterBeam" + ) + transverse_mu = torch.stack( [read_beam._mu[..., 0], read_beam._mu[..., 2]], dim=-1 ) @@ -206,17 +186,10 @@ def reading(self) -> torch.Tensor: ], dim=-1, ) - transverse_mu, transverse_cov = torch.broadcast_tensors( - transverse_mu, transverse_cov + dist = MultivariateNormal( + loc=transverse_mu, + covariance_matrix=transverse_cov ) - dist = [ - MultivariateNormal( - loc=transverse_mu_sample, covariance_matrix=transverse_cov_sample - ) - for transverse_mu_sample, transverse_cov_sample in zip( - transverse_mu.cpu(), transverse_cov.cpu() - ) - ] left = self.extent[0] right = self.extent[1] @@ -230,29 +203,27 @@ def reading(self) -> torch.Tensor: indexing="ij", ) pos = torch.dstack((x, y)) - image = torch.stack( - [dist_sample.log_prob(pos).exp() for dist_sample in dist] - ) + image = dist.log_prob(pos).exp() image = torch.flip(image, dims=[1]) + elif isinstance(read_beam, ParticleBeam): if self.method == "histogram": - image = torch.zeros( - ( - *self.misalignment.shape[:-1], - int(self.effective_resolution[1]), - int(self.effective_resolution[0]), + if len(read_beam.x.shape) > 1 or len(read_beam.y.shape) > 1: + raise NotImplementedError( + "Currently cannot handle x/y particle " + "batching using `histogram`. Use `kde` instead." ) + + image, _ = torch.histogramdd( + torch.stack(( + read_beam.x, + read_beam.y + )).T, + bins=self.pixel_bin_edges ) - for i, (x_sample, y_sample) in enumerate(zip(read_beam.x, read_beam.y)): - image_sample, _ = torch.histogramdd( - torch.stack((x_sample, y_sample)).T.cpu(), - bins=self.pixel_bin_edges, - ) - image_sample = torch.flipud(image_sample.T) - image_sample = image_sample.cpu() + image = torch.flipud(image.T) - image[i] = image_sample elif self.method == "kde": image = kde_histogram_2d( x1=read_beam.x, @@ -263,7 +234,7 @@ def reading(self) -> torch.Tensor: ) # Change the x, y positions image = torch.transpose(image, -2, -1) - # Flip up an down, now row 0 corresponds to the top + # Flip up and down, now row 0 corresponds to the top image = torch.flip(image, dims=[-2]) else: raise TypeError(f"Read beam is of invalid type {type(read_beam)}") @@ -308,12 +279,12 @@ def defining_features(self) -> list[str]: def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(resolution={repr(self.resolution)}, " - + f"pixel_size={repr(self.pixel_size)}, " - + f"binning={repr(self.binning)}, " - + f"misalignment={repr(self.misalignment)}, " - + f"method={repr(self.method)}, " - + f"kde_bandwidth={repr(self.kde_bandwidth)}, " - + f"is_active={repr(self.is_active)}, " - + f"name={repr(self.name)})" + f"{self.__class__.__name__}(resolution={repr(self.resolution)}, " + + f"pixel_size={repr(self.pixel_size)}, " + + f"binning={repr(self.binning)}, " + + f"misalignment={repr(self.misalignment)}, " + + f"method={repr(self.method)}, " + + f"kde_bandwidth={repr(self.kde_bandwidth)}, " + + f"is_active={repr(self.is_active)}, " + + f"name={repr(self.name)})" ) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index a6c0d790..10b1cd14 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -598,9 +598,9 @@ def from_astra(cls, path: str, device=None, dtype=torch.float32) -> "ParticleBea particles_7d[:, :6] = torch.from_numpy(particles) particle_charges = torch.from_numpy(particle_charges) return cls( - particles=particles_7d.unsqueeze(0), - energy=torch.tensor(energy).unsqueeze(0), - particle_charges=particle_charges.unsqueeze(0), + particles=particles_7d, + energy=torch.tensor(energy), + particle_charges=particle_charges, device=device, dtype=dtype, ) diff --git a/cheetah/utils/bmadx.py b/cheetah/utils/bmadx.py index 7508f0e5..65d2e4b8 100644 --- a/cheetah/utils/bmadx.py +++ b/cheetah/utils/bmadx.py @@ -1,5 +1,6 @@ import torch from scipy.constants import speed_of_light +from torch import Tensor double_precision_epsilon = torch.finfo(torch.float64).eps @@ -30,8 +31,8 @@ def cheetah_to_bmad_z_pz( def bmad_to_cheetah_z_pz( - z: torch.Tensor, pz: torch.Tensor, p0c: torch.Tensor, mc2: float -) -> torch.Tensor: + z: Tensor, pz: Tensor, p0c: Tensor, mc2: float +) -> (Tensor, Tensor, Tensor): """ Transforms Bmad longitudinal coordinates to Cheetah coordinates and computes reference energy. diff --git a/cheetah/utils/kde.py b/cheetah/utils/kde.py index 1aba41f2..834beeb5 100644 --- a/cheetah/utils/kde.py +++ b/cheetah/utils/kde.py @@ -15,7 +15,7 @@ def _kde_marginal_pdf( Calculate the 1D marginal probability distribution function of the input tensor based on the number of histogram bins. - :param values: Input tensor with shape :math:`(B, N, 1)`. `B` is the batch shape. + :param values: Input tensor with shape :math:`(B, N)`. `B` is the batch shape. :param bins: Positions of the bins where KDE is calculated. Shape :math:`(N_{bins})`. :param sigma: Gaussian smoothing factor with shape `(1,)`. @@ -46,6 +46,8 @@ def _kde_marginal_pdf( if not sigma.dim() == 0: raise ValueError(f"Input sigma must be a of the shape (1,). Got {sigma.shape}") + values = values.unsqueeze(-1) + if weights is None: weights = torch.ones_like(values) else: @@ -139,7 +141,7 @@ def kde_histogram_1d( """ pdf, _ = _kde_marginal_pdf( - values=x.unsqueeze(-1), + values=x, bins=bins, sigma=bandwidth, weights=weights, @@ -184,13 +186,13 @@ def kde_histogram_2d( """ _, kernel_values1 = _kde_marginal_pdf( - values=x1.unsqueeze(-1), + values=x1, bins=bins1, sigma=bandwidth, weights=weights, ) _, kernel_values2 = _kde_marginal_pdf( - values=x2.unsqueeze(-1), + values=x2, bins=bins2, sigma=bandwidth, weights=None, diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index d9586e6b..6665cb8d 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -31,7 +31,7 @@ def test_dipole(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) @@ -60,7 +60,7 @@ def test_dipole_with_float64(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) @@ -92,7 +92,7 @@ def test_dipole_with_fringe_field(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) @@ -138,7 +138,7 @@ def test_dipole_with_fringe_field_and_tilt(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) @@ -378,7 +378,7 @@ def test_astra_import(): p_array = ocelot.astraBeam2particleArray("tests/resources/ACHIP_EA1_2021.1351.001") assert np.allclose( - beam.particles[0, :, :6].cpu().numpy(), p_array.rparticles.transpose() + beam.particles[:, :6].cpu().numpy(), p_array.rparticles.transpose() ) assert np.isclose(beam.energy.cpu().numpy(), (p_array.E * 1e9)) @@ -417,7 +417,7 @@ def test_quadrupole(): # Split in order to allow for different tolerances for each particle dimension assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -458,7 +458,7 @@ def test_tilted_quadrupole(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -499,7 +499,7 @@ def test_sbend(): ) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -545,7 +545,7 @@ def test_rbend(): ) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -583,7 +583,7 @@ def test_convert_rbend(): outgoing_beam = cheetah_segment.track(incoming_beam) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( @@ -620,7 +620,7 @@ def test_asymmetric_bend(): outgoing_beam = cheetah_segment.track(incoming_beam) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[:, :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) assert np.allclose( diff --git a/tests/test_kde.py b/tests/test_kde.py index f3110663..caa1cf34 100644 --- a/tests/test_kde.py +++ b/tests/test_kde.py @@ -1,6 +1,9 @@ +import pytest import torch +from torch import Size from cheetah.utils import kde_histogram_1d, kde_histogram_2d +from cheetah.utils.kde import _kde_marginal_pdf def test_weighted_samples_1d(): @@ -27,6 +30,59 @@ def test_weighted_samples_1d(): assert not torch.allclose(hist_weighted, hist_neglect_weights) +def test_kde_1d(): + # test basic usage + data = torch.randn(100) # 5 beamline states, 100 particles in 1D + bins = torch.linspace(0, 1, 10) # a single histogram + sigma = torch.tensor(0.1) # a single bandwidth + + pdf = kde_histogram_1d(data, bins, sigma) + + assert pdf.shape == Size([10]) # 5 histograms at 10 points + + # test bad bins + with pytest.raises(ValueError): + _kde_marginal_pdf(data, bins, torch.rand(3) + 0.1) + + +def test_kde_1d_batched(): + # test basic usage + data = torch.randn((5, 100)) # 5 beamline states, 100 particles in 1D + bins = torch.linspace(0, 1, 10) # a single histogram + sigma = torch.tensor(0.1) # a single bandwidth + + pdf = kde_histogram_1d(data, bins, sigma) + + assert pdf.shape == Size([5, 10]) # 5 histograms at 10 points + + # test bad bins + with pytest.raises(ValueError): + _kde_marginal_pdf(data, bins, torch.rand(3) + 0.1) + + +def test_kde_2d_batched(): + data = torch.randn((3, 2, 100, 6)) + # 2 diagnostic paths, + # 3 states per diagnostic paths, + # 100 particles in 6D space + + # two different bins (1 per path) + n = 30 + bins_x = torch.linspace(-20, 20, n) + + sigma = torch.tensor(0.1) # a single bandwidth + + pdf = kde_histogram_2d( + data[..., 0], + data[..., 1], + bins_x, + bins_x, + sigma + ) + + assert pdf.shape == Size([3, 2, n, n]) + + def test_weighted_samples_2d(): """ Test that the 2d KDE histogram implementation correctly handles diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 8cb0801b..cc0d66ab 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -198,8 +198,8 @@ def test_quadrupole_bmadx_tracking(dtype): assert torch.allclose( outgoing.particles, outgoing_bmadx.to(dtype), - atol=1e-14 if dtype == torch.float64 else 0.00001, - rtol=1e-14 if dtype == torch.float64 else 1e-6, + atol=1e-7 if dtype == torch.float64 else 0.00001, + rtol=1e-7 if dtype == torch.float64 else 1e-6, ) diff --git a/tests/test_screen.py b/tests/test_screen.py index 0a1a43c3..cf6153e3 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +from torch import Size import cheetah @@ -26,9 +27,8 @@ def test_reading_shows_beam_particle(screen_method): ) beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") - assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (100, 100) - assert np.allclose(segment.my_screen.reading, 0.0) + # before tracking the reading should be an empty tensor + assert torch.numel(segment.my_screen.reading) == 0 _ = segment.track(beam) @@ -57,10 +57,6 @@ def test_screen_kde_bandwidth(kde_bandwidth): ) beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") - assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (100, 100) - assert np.allclose(segment.my_screen.reading, 0.0) - _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) @@ -90,9 +86,6 @@ def test_reading_shows_beam_parameter(screen_method): beam = cheetah.ParameterBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") assert isinstance(segment.my_screen.reading, torch.Tensor) - assert segment.my_screen.reading.shape == (100, 100) - assert np.allclose(segment.my_screen.reading, 0.0) - _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) @@ -101,7 +94,32 @@ def test_reading_shows_beam_parameter(screen_method): assert torch.any(segment.my_screen.reading > 0.0) -@pytest.mark.parametrize("screen_method", ["histogram", "kde"]) +def test_reading_shows_beam_parameter_batched(): + """ + Test that a screen has a reading that shows some sign of the beam having hit it. + """ + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor((1.0,0.5))), + cheetah.Screen( + resolution=torch.tensor((100, 100)), + pixel_size=torch.tensor((1e-5, 1e-5)), + is_active=True, + name="my_screen", + ), + ], + name="my_segment", + ) + beam = cheetah.ParameterBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") + + assert isinstance(segment.my_screen.reading, torch.Tensor) + _ = segment.track(beam) + + with pytest.raises(NotImplementedError): + segment.my_screen.reading + + +@pytest.mark.parametrize("screen_method", ["kde"]) def test_reading_shows_beam_ares(screen_method): """ Test that a screen has a reading that shows some sign of the beam having hit it. @@ -124,10 +142,6 @@ def test_reading_shows_beam_ares(screen_method): segment.AREABSCR1.binning = torch.tensor(1, device=segment.AREABSCR1.binning.device) segment.AREABSCR1.is_active = True - assert isinstance(segment.AREABSCR1.reading, torch.Tensor) - assert segment.AREABSCR1.reading.shape == (2040, 2448) - assert np.allclose(segment.AREABSCR1.reading, 0.0) - _ = segment.track(beam) assert isinstance(segment.AREABSCR1.reading, torch.Tensor) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 82130c93..c2c66681 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -220,7 +220,7 @@ def test_gradient(): sigma_py=torch.tensor(2e-7), ) - segment_length = nn.Parameter(torch.tensor([1.0])) + segment_length = nn.Parameter(torch.tensor(1.0)) segment = cheetah.Segment( elements=[ cheetah.Drift(segment_length / 6), diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index f118dce2..9bded8e4 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -298,33 +298,29 @@ def test_vectorized_solenoid(BeamClass): assert outgoing.particle_charges.shape == (100_000,) -@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) -@pytest.mark.parametrize("method", ["histogram", "kde"]) +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam]) +@pytest.mark.parametrize("method", ["kde"]) def test_vectorized_screen_2d(BeamClass, method): """ Test that a vectorized `Screen` is able to track a particle beam and produce a reading with 2D vector dimensions. """ - element = cheetah.Screen( - resolution=torch.tensor([100, 100]), - pixel_size=torch.tensor([1e-5, 1e-5]), - misalignment=torch.tensor( - [ - [[1e-4, 2e-4], [3e-4, 4e-4], [5e-4, 6e-4]], - [[-1e-4, -2e-4], [-3e-4, -4e-4], [-5e-4, -6e-4]], - ] - ), - is_active=True, - method=method, - name="my_screen", + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor((1.0, 0.5))), + cheetah.Screen( + resolution=torch.tensor((100, 100)), + pixel_size=torch.tensor((1e-5, 1e-5)), + is_active=True, + method=method, + name="my_screen", + ), + ], + name="my_segment", ) incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) - _ = element.track(incoming) - - # Check some properties of the read beam - assert element._read_beam.mu_x.shape == (2, 3) - assert element._read_beam.sigma_x.shape == torch.Size([]) + _ = segment.track(incoming) # Check the reading - assert element.reading.shape == (2, 3, 100, 100) + assert segment.my_screen.reading.shape == (2, 100, 100) From b3209719ee0f3dd3c5bb2d30939ece7ccbd471aa Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Mon, 9 Sep 2024 10:50:44 -0500 Subject: [PATCH 049/111] fix misalignment broadcasting issue for particle beam --- cheetah/accelerator/screen.py | 28 +++++++++++++++++++++++++--- tests/test_vectorized.py | 10 ++++++++-- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 93a436d3..fb04bd4c 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -68,6 +68,14 @@ def __init__( else torch.tensor((1e-3, 1e-3), **factory_kwargs) ), ) + self.register_buffer( + "misalignment", + ( + torch.as_tensor(misalignment, **factory_kwargs) + if misalignment is not None + else torch.tensor((0.0, 0.0), **factory_kwargs) + ), + ) self.register_buffer( "binning", ( @@ -151,9 +159,23 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: def track(self, incoming: Beam) -> Beam: if self.is_active: - copy_of_incoming = deepcopy(incoming) - - self.set_read_beam(copy_of_incoming) + cin = deepcopy(incoming) + + if isinstance(incoming, ParameterBeam): + cin._mu = torch.broadcast_to( + cin._mu, (*self.misalignment.shape[:-1], 7) + ).clone() + cin._mu[..., 0] -= self.misalignment[..., 0] + cin._mu[..., 2] -= self.misalignment[..., 1] + elif isinstance(incoming, ParticleBeam): + cin.particles = cin.particles.broadcast_to( + self.misalignment[..., 0].shape + cin.particles.shape + ).clone() + + cin.particles[..., 0] -= self.misalignment[..., 0].unsqueeze(-1) + cin.particles[..., 1] -= self.misalignment[..., 1].unsqueeze(-1) + + self.set_read_beam(cin) return incoming diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 9bded8e4..31ed41d4 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -307,10 +307,16 @@ def test_vectorized_screen_2d(BeamClass, method): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor((1.0, 0.5))), + cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( resolution=torch.tensor((100, 100)), pixel_size=torch.tensor((1e-5, 1e-5)), + misalignment=torch.tensor( + [ + [[1e-4, 2e-4], [3e-4, 4e-4], [5e-4, 6e-4]], + [[-1e-4, -2e-4], [-3e-4, -4e-4], [-5e-4, -6e-4]], + ] + ), is_active=True, method=method, name="my_screen", @@ -323,4 +329,4 @@ def test_vectorized_screen_2d(BeamClass, method): _ = segment.track(incoming) # Check the reading - assert segment.my_screen.reading.shape == (2, 100, 100) + assert segment.my_screen.reading.shape == (2,3, 100, 100) From 78d81bff716ebac45a255428818c5152a55331de Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Mon, 9 Sep 2024 11:25:30 -0500 Subject: [PATCH 050/111] fix misalignment broadcasting issue with parameter beam expands the shape of mu tensor to make sure it is at least 2d (adds a batch dimension to the tensor if 1d) --- cheetah/accelerator/screen.py | 12 ++++++----- cheetah/particles/parameter_beam.py | 4 +++- tests/test_screen.py | 32 +++++++++++++++++++++++++++-- tests/test_vectorized.py | 2 +- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index fb04bd4c..07adfaa5 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -162,14 +162,16 @@ def track(self, incoming: Beam) -> Beam: cin = deepcopy(incoming) if isinstance(incoming, ParameterBeam): - cin._mu = torch.broadcast_to( - cin._mu, (*self.misalignment.shape[:-1], 7) + cin._mu = cin._mu.broadcast_to( + self.misalignment.shape[:-1] + cin._mu.shape ).clone() - cin._mu[..., 0] -= self.misalignment[..., 0] - cin._mu[..., 2] -= self.misalignment[..., 1] + + cin._mu[..., 0] -= self.misalignment[..., 0].unsqueeze(-1) + cin._mu[..., 2] -= self.misalignment[..., 1].unsqueeze(-1) + elif isinstance(incoming, ParticleBeam): cin.particles = cin.particles.broadcast_to( - self.misalignment[..., 0].shape + cin.particles.shape + self.misalignment.shape[:-1] + cin.particles.shape ).clone() cin.particles[..., 0] -= self.misalignment[..., 0].unsqueeze(-1) diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index f5189b15..4be65029 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -30,7 +30,9 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.register_buffer("_mu", torch.as_tensor(mu, **factory_kwargs)) + self.register_buffer( + "_mu", torch.atleast_2d(torch.as_tensor(mu, **factory_kwargs)) + ) self.register_buffer("_cov", torch.as_tensor(cov, **factory_kwargs)) total_charge = ( total_charge diff --git a/tests/test_screen.py b/tests/test_screen.py index cf6153e3..b1fd0b89 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -65,6 +65,34 @@ def test_screen_kde_bandwidth(kde_bandwidth): assert torch.any(segment.my_screen.reading > 0.0) +@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam]) +@pytest.mark.parametrize("method", ["kde"]) +def test__screen_2d(BeamClass, method): + """ + Test that a vectorized `Screen` is able to track a particle beam and produce a + reading with 2D vector dimensions. + """ + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(1.0)), + cheetah.Screen( + resolution=torch.tensor((100, 100)), + pixel_size=torch.tensor((1e-5, 1e-5)), + is_active=True, + method=method, + name="my_screen", + ), + ], + name="my_segment", + ) + incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) + + _ = segment.track(incoming) + + # Check the reading + assert segment.my_screen.reading.shape == (100, 100) + + @pytest.mark.parametrize("screen_method", ["histogram", "kde"]) def test_reading_shows_beam_parameter(screen_method): """ @@ -100,7 +128,7 @@ def test_reading_shows_beam_parameter_batched(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor((1.0,0.5))), + cheetah.Drift(length=torch.tensor((1.0, 0.5))), cheetah.Screen( resolution=torch.tensor((100, 100)), pixel_size=torch.tensor((1e-5, 1e-5)), @@ -145,6 +173,6 @@ def test_reading_shows_beam_ares(screen_method): _ = segment.track(beam) assert isinstance(segment.AREABSCR1.reading, torch.Tensor) - assert segment.AREABSCR1.reading.shape == (2040, 2448) + assert segment.AREABSCR1.reading.shape == (1, 2040, 2448) assert torch.all(segment.AREABSCR1.reading >= 0.0) assert torch.any(segment.AREABSCR1.reading > 0.0) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 31ed41d4..bd16939f 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -329,4 +329,4 @@ def test_vectorized_screen_2d(BeamClass, method): _ = segment.track(incoming) # Check the reading - assert segment.my_screen.reading.shape == (2,3, 100, 100) + assert segment.my_screen.reading.shape == (2, 3, 100, 100) From 06435a94a9c3550613641771dc74df755913c05b Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Mon, 9 Sep 2024 11:39:13 -0500 Subject: [PATCH 051/111] updated error message --- cheetah/accelerator/screen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 07adfaa5..b373b80e 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -234,8 +234,8 @@ def reading(self) -> torch.Tensor: if self.method == "histogram": if len(read_beam.x.shape) > 1 or len(read_beam.y.shape) > 1: raise NotImplementedError( - "Currently cannot handle x/y particle " - "batching using `histogram`. Use `kde` instead." + "Torch histogram does not support " + "batching. Use `kde` option instead." ) image, _ = torch.histogramdd( From 31528322c128faa7584832b66eb1c983fc20a2cd Mon Sep 17 00:00:00 2001 From: Ryan Roussel Date: Mon, 9 Sep 2024 11:47:47 -0500 Subject: [PATCH 052/111] remove reading from test due to the additional batch dimension form ocelot --- tests/test_speed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_speed.py b/tests/test_speed.py index 917f9550..81b1703e 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -25,7 +25,7 @@ def test_tracking_speed(): t1 = time.time() _ = segment.track(particles) - _ = segment.AREABSCR1.reading + # _ = segment.AREABSCR1.reading t2 = time.time() From ad4ad1c55dad45e7acbcdb34767fbfe8203f2a63 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 20 Sep 2024 10:16:50 +0200 Subject: [PATCH 053/111] Some cleanup --- cheetah/accelerator/screen.py | 77 ++++++++++++++--------------- cheetah/particles/parameter_beam.py | 4 +- tests/test_screen.py | 32 +----------- tests/test_speed.py | 2 +- tests/test_vectorized.py | 2 +- 5 files changed, 42 insertions(+), 75 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index b373b80e..853db586 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -37,17 +37,17 @@ class Screen(Element): """ def __init__( - self, - resolution: Optional[Union[torch.Tensor, nn.Parameter]] = None, - pixel_size: Optional[Union[torch.Tensor, nn.Parameter]] = None, - binning: Optional[Union[torch.Tensor, nn.Parameter]] = None, - misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, - kde_bandwidth: Optional[Union[torch.Tensor, nn.Parameter]] = None, - is_active: bool = False, - method: Literal["histogram", "kde"] = "histogram", - name: Optional[str] = None, - device=None, - dtype=torch.float32, + self, + resolution: Optional[Union[torch.Tensor, nn.Parameter]] = None, + pixel_size: Optional[Union[torch.Tensor, nn.Parameter]] = None, + binning: Optional[Union[torch.Tensor, nn.Parameter]] = None, + misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, + kde_bandwidth: Optional[Union[torch.Tensor, nn.Parameter]] = None, + is_active: bool = False, + method: Literal["histogram", "kde"] = "histogram", + name: Optional[str] = None, + device=None, + dtype=torch.float32, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -159,25 +159,29 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: def track(self, incoming: Beam) -> Beam: if self.is_active: - cin = deepcopy(incoming) + copy_of_incoming = deepcopy(incoming) if isinstance(incoming, ParameterBeam): - cin._mu = cin._mu.broadcast_to( - self.misalignment.shape[:-1] + cin._mu.shape + copy_of_incoming._mu = copy_of_incoming._mu.broadcast_to( + self.misalignment.shape[:-1] + copy_of_incoming._mu.shape ).clone() - cin._mu[..., 0] -= self.misalignment[..., 0].unsqueeze(-1) - cin._mu[..., 2] -= self.misalignment[..., 1].unsqueeze(-1) + copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0].unsqueeze(-1) + copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1].unsqueeze(-1) elif isinstance(incoming, ParticleBeam): - cin.particles = cin.particles.broadcast_to( - self.misalignment.shape[:-1] + cin.particles.shape + copy_of_incoming.particles = copy_of_incoming.particles.broadcast_to( + self.misalignment.shape[:-1] + copy_of_incoming.particles.shape ).clone() - cin.particles[..., 0] -= self.misalignment[..., 0].unsqueeze(-1) - cin.particles[..., 1] -= self.misalignment[..., 1].unsqueeze(-1) + copy_of_incoming.particles[..., 0] -= self.misalignment[ + ..., 0 + ].unsqueeze(-1) + copy_of_incoming.particles[..., 1] -= self.misalignment[ + ..., 1 + ].unsqueeze(-1) - self.set_read_beam(cin) + self.set_read_beam(copy_of_incoming) return incoming @@ -193,7 +197,7 @@ def reading(self) -> torch.Tensor: elif isinstance(read_beam, ParameterBeam): if torch.numel(read_beam._mu[..., 0]) > 1: raise NotImplementedError( - "cannot perform batch screen predictions with ParameterBeam" + "Cannot perform batch screen predictions with ParameterBeam" ) transverse_mu = torch.stack( @@ -211,8 +215,7 @@ def reading(self) -> torch.Tensor: dim=-1, ) dist = MultivariateNormal( - loc=transverse_mu, - covariance_matrix=transverse_cov + loc=transverse_mu, covariance_matrix=transverse_cov ) left = self.extent[0] @@ -234,16 +237,12 @@ def reading(self) -> torch.Tensor: if self.method == "histogram": if len(read_beam.x.shape) > 1 or len(read_beam.y.shape) > 1: raise NotImplementedError( - "Torch histogram does not support " - "batching. Use `kde` option instead." + 'The `"histogram"` method of `Screen` does not support ' + 'vectorization. Use `"kde"` instead.' ) image, _ = torch.histogramdd( - torch.stack(( - read_beam.x, - read_beam.y - )).T, - bins=self.pixel_bin_edges + torch.stack((read_beam.x, read_beam.y)).T, bins=self.pixel_bin_edges ) image = torch.flipud(image.T) @@ -303,12 +302,12 @@ def defining_features(self) -> list[str]: def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(resolution={repr(self.resolution)}, " - + f"pixel_size={repr(self.pixel_size)}, " - + f"binning={repr(self.binning)}, " - + f"misalignment={repr(self.misalignment)}, " - + f"method={repr(self.method)}, " - + f"kde_bandwidth={repr(self.kde_bandwidth)}, " - + f"is_active={repr(self.is_active)}, " - + f"name={repr(self.name)})" + f"{self.__class__.__name__}(resolution={repr(self.resolution)}, " + + f"pixel_size={repr(self.pixel_size)}, " + + f"binning={repr(self.binning)}, " + + f"misalignment={repr(self.misalignment)}, " + + f"method={repr(self.method)}, " + + f"kde_bandwidth={repr(self.kde_bandwidth)}, " + + f"is_active={repr(self.is_active)}, " + + f"name={repr(self.name)})" ) diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index 4be65029..f5189b15 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -30,9 +30,7 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.register_buffer( - "_mu", torch.atleast_2d(torch.as_tensor(mu, **factory_kwargs)) - ) + self.register_buffer("_mu", torch.as_tensor(mu, **factory_kwargs)) self.register_buffer("_cov", torch.as_tensor(cov, **factory_kwargs)) total_charge = ( total_charge diff --git a/tests/test_screen.py b/tests/test_screen.py index b1fd0b89..e8bcf7b3 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -1,7 +1,5 @@ -import numpy as np import pytest import torch -from torch import Size import cheetah @@ -65,34 +63,6 @@ def test_screen_kde_bandwidth(kde_bandwidth): assert torch.any(segment.my_screen.reading > 0.0) -@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam]) -@pytest.mark.parametrize("method", ["kde"]) -def test__screen_2d(BeamClass, method): - """ - Test that a vectorized `Screen` is able to track a particle beam and produce a - reading with 2D vector dimensions. - """ - segment = cheetah.Segment( - elements=[ - cheetah.Drift(length=torch.tensor(1.0)), - cheetah.Screen( - resolution=torch.tensor((100, 100)), - pixel_size=torch.tensor((1e-5, 1e-5)), - is_active=True, - method=method, - name="my_screen", - ), - ], - name="my_segment", - ) - incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5)) - - _ = segment.track(incoming) - - # Check the reading - assert segment.my_screen.reading.shape == (100, 100) - - @pytest.mark.parametrize("screen_method", ["histogram", "kde"]) def test_reading_shows_beam_parameter(screen_method): """ @@ -173,6 +143,6 @@ def test_reading_shows_beam_ares(screen_method): _ = segment.track(beam) assert isinstance(segment.AREABSCR1.reading, torch.Tensor) - assert segment.AREABSCR1.reading.shape == (1, 2040, 2448) + assert segment.AREABSCR1.reading.shape == (2040, 2448) assert torch.all(segment.AREABSCR1.reading >= 0.0) assert torch.any(segment.AREABSCR1.reading > 0.0) diff --git a/tests/test_speed.py b/tests/test_speed.py index 81b1703e..917f9550 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -25,7 +25,7 @@ def test_tracking_speed(): t1 = time.time() _ = segment.track(particles) - # _ = segment.AREABSCR1.reading + _ = segment.AREABSCR1.reading t2 = time.time() diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index bd16939f..44ebb8eb 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -303,7 +303,7 @@ def test_vectorized_solenoid(BeamClass): def test_vectorized_screen_2d(BeamClass, method): """ Test that a vectorized `Screen` is able to track a particle beam and produce a - reading with 2D vector dimensions. + reading with 2 vector dimensions. """ segment = cheetah.Segment( elements=[ From a09c2bd2ba54feda46e4243b975d0e5f2a315e2b Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 20 Sep 2024 10:22:47 +0200 Subject: [PATCH 054/111] Fix format --- tests/test_kde.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/test_kde.py b/tests/test_kde.py index caa1cf34..8da4177a 100644 --- a/tests/test_kde.py +++ b/tests/test_kde.py @@ -72,13 +72,7 @@ def test_kde_2d_batched(): sigma = torch.tensor(0.1) # a single bandwidth - pdf = kde_histogram_2d( - data[..., 0], - data[..., 1], - bins_x, - bins_x, - sigma - ) + pdf = kde_histogram_2d(data[..., 0], data[..., 1], bins_x, bins_x, sigma) assert pdf.shape == Size([3, 2, n, n]) From 2b350a339f71b2f1fe5ad68716dcf0ebb774eb59 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 20 Sep 2024 10:24:19 +0200 Subject: [PATCH 055/111] Another format fix --- cheetah/accelerator/quadrupole.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 36fce7fc..816f05b5 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -179,7 +179,6 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: pz = pz * torch.ones_like(x) # End of Bmad-X tracking - # Convert back to Cheetah coordinates tau, delta, ref_energy = bmadx.bmad_to_cheetah_z_pz( z, pz, p0c, electron_mass_eV From d080311cab111872e5bec1e93ca3698745b2c2f9 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Sat, 21 Sep 2024 08:03:14 +0200 Subject: [PATCH 056/111] Fix bug that needlessly added batch dimension during tracking --- cheetah/accelerator/segment.py | 4 +--- tests/test_screen.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index 175245a0..38e3d5b8 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -356,9 +356,7 @@ def length(self) -> torch.Tensor: def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: if self.is_skippable: - tm = torch.eye(7, device=energy.device, dtype=energy.dtype).repeat( - (*self.length.shape, 1, 1) - ) + tm = torch.eye(7, device=energy.device, dtype=energy.dtype) for element in self.elements: tm = torch.matmul(element.transfer_map(energy), tm) return tm diff --git a/tests/test_screen.py b/tests/test_screen.py index e8bcf7b3..98007b42 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -117,7 +117,7 @@ def test_reading_shows_beam_parameter_batched(): segment.my_screen.reading -@pytest.mark.parametrize("screen_method", ["kde"]) +@pytest.mark.parametrize("screen_method", ["histogram", "kde"]) def test_reading_shows_beam_ares(screen_method): """ Test that a screen has a reading that shows some sign of the beam having hit it. From dc92a36023831a43d582e1137fa02a1523498021 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Sat, 21 Sep 2024 13:36:58 +0200 Subject: [PATCH 057/111] Proper fix for remaining failing tests --- cheetah/accelerator/screen.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 853db586..24c9c02c 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -162,17 +162,19 @@ def track(self, incoming: Beam) -> Beam: copy_of_incoming = deepcopy(incoming) if isinstance(incoming, ParameterBeam): - copy_of_incoming._mu = copy_of_incoming._mu.broadcast_to( - self.misalignment.shape[:-1] + copy_of_incoming._mu.shape - ).clone() - - copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0].unsqueeze(-1) - copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1].unsqueeze(-1) + copy_of_incoming._mu, _ = torch.broadcast_tensors( + copy_of_incoming._mu, self.misalignment[..., 0] + ) + copy_of_incoming._mu = copy_of_incoming._mu.clone() + copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0] + copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1] elif isinstance(incoming, ParticleBeam): - copy_of_incoming.particles = copy_of_incoming.particles.broadcast_to( - self.misalignment.shape[:-1] + copy_of_incoming.particles.shape - ).clone() + copy_of_incoming.particles, _ = torch.broadcast_tensors( + copy_of_incoming.particles, + self.misalignment[..., 0].unsqueeze(-1).unsqueeze(-1), + ) + copy_of_incoming.particles = copy_of_incoming.particles.clone() copy_of_incoming.particles[..., 0] -= self.misalignment[ ..., 0 From cd71a056855caa29b43da3cb8d182431897bb434 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Sat, 21 Sep 2024 13:49:47 +0200 Subject: [PATCH 058/111] Clean up conditions for catching unsupported vectorisation with Screen histogram --- cheetah/accelerator/screen.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 24c9c02c..ad4f9e04 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -34,6 +34,9 @@ class Screen(Element): "histogram" or "kde", defaults to "histogram". KDE will be slower but allows backward differentiation. :param name: Unique identifier of the element. + + NOTE: `method='histogram'` currently does not support vectorisation. Please use + `method=`kde` instead. """ def __init__( @@ -237,10 +240,15 @@ def reading(self) -> torch.Tensor: elif isinstance(read_beam, ParticleBeam): if self.method == "histogram": - if len(read_beam.x.shape) > 1 or len(read_beam.y.shape) > 1: + # Catch vectorisation, which is currently not supported by "histogram" + if ( + len(read_beam.particles.shape) > 2 + or len(read_beam.particle_charges.shape) > 1 + or len(read_beam.energy.shape) > 0 + ): raise NotImplementedError( - 'The `"histogram"` method of `Screen` does not support ' - 'vectorization. Use `"kde"` instead.' + "The `'histogram'` method of `Screen` does not support " + "vectorization. Use `'kde'` instead." ) image, _ = torch.histogramdd( From 7a4c96501f924461596ae1c5839fa8ad944ad4a6 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Sat, 21 Sep 2024 13:52:42 +0200 Subject: [PATCH 059/111] Add please report bugs note to changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96765227..40a2c80d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## v0.7.0 [🚧 Work in Progress] +This is a major release with significant upgrades under the hood of Cheetah. Despite extensive testing, you might still encounter a few bugs. Please report them by opening an issue, so we can fix them as soon as possible and improve the experience for everyone. + ### 🚨 Breaking Changes - Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe) From ba36076d9d5acb8753d88b0844af87a867c5257b Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 14:49:18 +0200 Subject: [PATCH 060/111] Some function name and docstring cleanup --- cheetah/accelerator/cavity.py | 10 +++++++--- cheetah/accelerator/drift.py | 4 ++-- cheetah/accelerator/horizontal_corrector.py | 4 ++-- cheetah/accelerator/solenoid.py | 4 ++-- cheetah/track_methods.py | 4 ++-- cheetah/utils/physics.py | 2 +- 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index f303a972..078344ba 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -9,7 +9,7 @@ from ..particles import Beam, ParameterBeam, ParticleBeam from ..track_methods import base_rmatrix from ..utils import UniqueNameGenerator -from ..utils.physics import calculate_relativistic_factors, electron_mass_eV +from ..utils.physics import compute_relativistic_factors, electron_mass_eV from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -103,7 +103,11 @@ def track(self, incoming: Beam) -> Beam: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") def _track_beam(self, incoming: Beam) -> Beam: - g0, igamma2, beta0 = calculate_relativistic_factors(incoming.energy) + """ + Track particles through the cavity. The input can be a `ParameterBeam` or a + `ParticleBeam`. + """ + g0, igamma2, beta0 = compute_relativistic_factors(incoming.energy) phi = torch.deg2rad(self.phase) @@ -124,7 +128,7 @@ def _track_beam(self, incoming: Beam) -> Beam: if torch.any(incoming.energy + delta_energy > 0): k = 2 * torch.pi * self.frequency / constants.speed_of_light outgoing_energy = incoming.energy + delta_energy - g1, ig1, beta1 = calculate_relativistic_factors(outgoing_energy) + g1, ig1, beta1 = compute_relativistic_factors(outgoing_energy) if isinstance(incoming, ParameterBeam): outgoing_mu[..., 5] = incoming._mu[..., 5] * incoming.energy * beta0 / ( diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index bee747a5..b25fc99d 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -7,7 +7,7 @@ from ..particles import Beam, ParticleBeam from ..utils import UniqueNameGenerator, bmadx -from ..utils.physics import calculate_relativistic_factors +from ..utils.physics import compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -45,7 +45,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - _, igamma2, beta = calculate_relativistic_factors(energy) + _, igamma2, beta = compute_relativistic_factors(energy) tm = torch.eye(7, device=device, dtype=dtype).repeat( (*(self.length * igamma2).shape, 1, 1) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index cdd77683..7e11d780 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -7,7 +7,7 @@ from torch import nn from ..utils import UniqueNameGenerator -from ..utils.physics import calculate_relativistic_factors +from ..utils.physics import compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -49,7 +49,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - _, igamma2, beta = calculate_relativistic_factors(energy) + _, igamma2, beta = compute_relativistic_factors(energy) batch_shape = torch.broadcast_tensors(self.length, self.angle, beta)[0].shape tm = torch.eye(7, device=device, dtype=dtype).repeat((*batch_shape, 1, 1)) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index c5ab80d6..7be99b3c 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -8,7 +8,7 @@ from ..track_methods import misalignment_matrix from ..utils import UniqueNameGenerator -from ..utils.physics import calculate_relativistic_factors +from ..utils.physics import compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -64,7 +64,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - gamma, _, _ = calculate_relativistic_factors(energy) + gamma, _, _ = compute_relativistic_factors(energy) c = torch.cos(self.length * self.k) s = torch.sin(self.length * self.k) diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 08ed8d1e..c369ba71 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -4,7 +4,7 @@ import torch -from .utils.physics import calculate_relativistic_factors +from .utils.physics import compute_relativistic_factors def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: @@ -53,7 +53,7 @@ def base_rmatrix( tilt = tilt if tilt is not None else torch.zeros_like(length) energy = energy if energy is not None else torch.zeros(1) - _, igamma2, beta = calculate_relativistic_factors(energy) + _, igamma2, beta = compute_relativistic_factors(energy) # Avoid division by zero k1 = k1.clone() diff --git a/cheetah/utils/physics.py b/cheetah/utils/physics.py index 0a0d1ba2..ac6453a9 100644 --- a/cheetah/utils/physics.py +++ b/cheetah/utils/physics.py @@ -6,7 +6,7 @@ ) -def calculate_relativistic_factors(energy): +def compute_relativistic_factors(energy): """ calculates relativistic factors gamma, inverse gamma squared, beta for electrons From 1ddbbb6382a0efaaa0d5b1ea3d0b0cfaff822fdb Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 15:01:42 +0200 Subject: [PATCH 061/111] Clean up needless conversion of constants to tensors --- cheetah/accelerator/cavity.py | 28 ++++++++++++++-------------- cheetah/particles/beam.py | 4 +--- cheetah/particles/particle_beam.py | 5 +++-- cheetah/utils/physics.py | 7 +++---- 4 files changed, 21 insertions(+), 23 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 078344ba..da8fc880 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -107,7 +107,7 @@ def _track_beam(self, incoming: Beam) -> Beam: Track particles through the cavity. The input can be a `ParameterBeam` or a `ParticleBeam`. """ - g0, igamma2, beta0 = compute_relativistic_factors(incoming.energy) + gamma0, igamma2, beta0 = compute_relativistic_factors(incoming.energy) phi = torch.deg2rad(self.phase) @@ -128,7 +128,7 @@ def _track_beam(self, incoming: Beam) -> Beam: if torch.any(incoming.energy + delta_energy > 0): k = 2 * torch.pi * self.frequency / constants.speed_of_light outgoing_energy = incoming.energy + delta_energy - g1, ig1, beta1 = compute_relativistic_factors(outgoing_energy) + gamma1, _, beta1 = compute_relativistic_factors(outgoing_energy) if isinstance(incoming, ParameterBeam): outgoing_mu[..., 5] = incoming._mu[..., 5] * incoming.energy * beta0 / ( @@ -159,22 +159,22 @@ def _track_beam(self, incoming: Beam) -> Beam: - torch.cos(phi).unsqueeze(-1) ) - dgamma = self.voltage / electron_mass_eV.to(self.voltage) + dgamma = self.voltage / electron_mass_eV if torch.any(delta_energy > 0): T566 = ( self.length - * (beta0**3 * g0**3 - beta1**3 * g1**3) - / (2 * beta0 * beta1**3 * g0 * (g0 - g1) * g1**3) + * (beta0**3 * gamma0**3 - beta1**3 * gamma1**3) + / (2 * beta0 * beta1**3 * gamma0 * (gamma0 - gamma1) * gamma1**3) ) T556 = ( beta0 * k * self.length * dgamma - * g0 - * (beta1**3 * g1**3 + beta0 * (g0 - g1**3)) + * gamma0 + * (beta1**3 * gamma1**3 + beta0 * (gamma0 - gamma1**3)) * torch.sin(phi) - / (beta1**3 * g1**3 * (g0 - g1) ** 2) + / (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 2) ) T555 = ( beta0**2 @@ -185,15 +185,15 @@ def _track_beam(self, incoming: Beam) -> Beam: * ( dgamma * ( - 2 * g0 * g1**3 * (beta0 * beta1**3 - 1) - + g0**2 - + 3 * g1**2 + 2 * gamma0 * gamma1**3 * (beta0 * beta1**3 - 1) + + gamma0**2 + + 3 * gamma1**2 - 2 ) - / (beta1**3 * g1**3 * (g0 - g1) ** 3) + / (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 3) * torch.sin(phi) ** 2 - - (g1 * g0 * (beta1 * beta0 - 1) + 1) - / (beta1 * g1 * (g0 - g1) ** 2) + - (gamma1 * gamma0 * (beta1 * beta0 - 1) + 1) + / (beta1 * gamma1 * (gamma0 - gamma1) ** 2) * torch.cos(phi) ) ) diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index c170e2e1..8b3d3be0 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -4,9 +4,7 @@ from scipy.constants import physical_constants from torch import nn -electron_mass_eV = torch.tensor( - physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 -) +electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 class Beam(nn.Module): diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 10b1cd14..b8dcbe6f 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -2,14 +2,15 @@ import torch from scipy import constants +from scipy.constants import physical_constants from torch.distributions import MultivariateNormal from .beam import Beam speed_of_light = torch.tensor(constants.speed_of_light) # In m/s electron_mass = torch.tensor(constants.electron_mass) # In kg -electron_mass_eV = torch.tensor( - constants.physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 +electron_mass_eV = ( + physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 ) # In eV diff --git a/cheetah/utils/physics.py b/cheetah/utils/physics.py index ac6453a9..57b6ce71 100644 --- a/cheetah/utils/physics.py +++ b/cheetah/utils/physics.py @@ -1,9 +1,7 @@ import torch from scipy.constants import physical_constants -electron_mass_eV = torch.tensor( - physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 -) +electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 def compute_relativistic_factors(energy): @@ -14,7 +12,8 @@ def compute_relativistic_factors(energy): :param energy: Energy in eV :return: gamma, igamma2, beta """ - gamma = energy / electron_mass_eV.to(energy) + gamma = energy / electron_mass_eV igamma2 = torch.where(gamma == 0.0, 0.0, 1 / gamma**2) beta = torch.sqrt(1 - igamma2) + return gamma, igamma2, beta From 7daa7b17b9b29cf877c1ea6c56187b68eae25e32 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 15:04:34 +0200 Subject: [PATCH 062/111] Clean up relativistic factors function signature --- cheetah/utils/physics.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cheetah/utils/physics.py b/cheetah/utils/physics.py index 57b6ce71..0fd71aec 100644 --- a/cheetah/utils/physics.py +++ b/cheetah/utils/physics.py @@ -4,10 +4,12 @@ electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 -def compute_relativistic_factors(energy): +def compute_relativistic_factors( + energy: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - calculates relativistic factors gamma, inverse gamma squared, beta - for electrons + Computes the relativistic factors gamma, inverse gamma squared and beta for + electrons. :param energy: Energy in eV :return: gamma, igamma2, beta From f252c8fde9953d5b43326fb856e36c4d7136c413 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 15:09:47 +0200 Subject: [PATCH 063/111] Comment cleanup --- cheetah/accelerator/quadrupole.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 816f05b5..f5d3345f 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -175,7 +175,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: x_offset, y_offset, self.tilt, x, px, y, py ) - # p_z is unaffected by tracking, need to match batch dimensions + # pz is unaffected by tracking, therefore needs to match batch dimensions pz = pz * torch.ones_like(x) # End of Bmad-X tracking From a87227ab8fb3a12332b9c6278df89c7fcc73fc6d Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 15:11:53 +0200 Subject: [PATCH 064/111] Remove gradient removal --- cheetah/accelerator/quadrupole.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index f5d3345f..0232d5f2 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -211,7 +211,7 @@ def is_skippable(self) -> bool: @property def is_active(self) -> bool: - return bool(torch.any(self.k1 != 0)) + return torch.any(self.k1 != 0) def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() From e91b1797452dbbe9c8d7172de34aa9340f52b7f3 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 15:14:07 +0200 Subject: [PATCH 065/111] Return `misalginment` and `binning` order --- cheetah/accelerator/screen.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index ad4f9e04..5d7fd042 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -71,14 +71,6 @@ def __init__( else torch.tensor((1e-3, 1e-3), **factory_kwargs) ), ) - self.register_buffer( - "misalignment", - ( - torch.as_tensor(misalignment, **factory_kwargs) - if misalignment is not None - else torch.tensor((0.0, 0.0), **factory_kwargs) - ), - ) self.register_buffer( "binning", ( @@ -87,6 +79,14 @@ def __init__( else torch.tensor(1, **factory_kwargs) ), ) + self.register_buffer( + "misalignment", + ( + torch.as_tensor(misalignment, **factory_kwargs) + if misalignment is not None + else torch.tensor((0.0, 0.0), **factory_kwargs) + ), + ) self.register_buffer( "length", torch.zeros(self.misalignment.shape[:-1], **factory_kwargs), From f7b9c012ea38296ad8444b04f3678ccf5dbcc289 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 16:41:08 +0200 Subject: [PATCH 066/111] Update CHANGELOG.md Co-authored-by: Chenran Xu --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 40a2c80d..d0c7c456 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des ### 🚨 Breaking Changes -- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe) +- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe, @roussel-ryan) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu) ### 🚀 Features From d87844008c58809a719e0a40c10fbd822768a49a Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 16:42:58 +0200 Subject: [PATCH 067/111] Fix vector vs. batch terminology --- cheetah/accelerator/horizontal_corrector.py | 4 ++-- cheetah/accelerator/quadrupole.py | 2 +- cheetah/accelerator/solenoid.py | 4 ++-- cheetah/accelerator/space_charge_kick.py | 12 +++++------ cheetah/accelerator/undulator.py | 4 ++-- cheetah/particles/beam.py | 4 ++-- cheetah/particles/particle_beam.py | 24 ++++++++++----------- cheetah/track_methods.py | 10 ++++----- cheetah/utils/kde.py | 2 +- tests/test_cavity.py | 2 +- tests/test_dipole.py | 10 ++++----- tests/test_kde.py | 4 ++-- tests/test_quadrupole.py | 13 ++++++----- tests/test_screen.py | 2 +- tests/test_vectorized.py | 2 +- 15 files changed, 49 insertions(+), 50 deletions(-) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index 7e11d780..7980a32d 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -51,8 +51,8 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: _, igamma2, beta = compute_relativistic_factors(energy) - batch_shape = torch.broadcast_tensors(self.length, self.angle, beta)[0].shape - tm = torch.eye(7, device=device, dtype=dtype).repeat((*batch_shape, 1, 1)) + vector_shape = torch.broadcast_tensors(self.length, self.angle, beta)[0].shape + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 1, 6] = self.angle tm[..., 2, 3] = self.length diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 0232d5f2..0ce069a1 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -175,7 +175,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: x_offset, y_offset, self.tilt, x, px, y, py ) - # pz is unaffected by tracking, therefore needs to match batch dimensions + # pz is unaffected by tracking, therefore needs to match vector dimensions pz = pz * torch.ones_like(x) # End of Bmad-X tracking diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 7be99b3c..f9f7b047 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -70,7 +70,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: s_k = torch.where(self.k == 0.0, self.length, s / self.k) - batch_shape = torch.broadcast_tensors( + vector_shape = torch.broadcast_tensors( *self.parameters(), *self.buffers(), energy )[0].shape @@ -78,7 +78,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: gamma != 0, self.length / (1 - gamma**2), torch.zeros_like(self.length) ) - R = torch.eye(7, device=device, dtype=dtype).repeat((*batch_shape, 1, 1)) + R = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) R[..., 0, 0] = c**2 R[..., 0, 1] = c * s_k R[..., 0, 2] = s * c diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index b951d4e5..3ad54243 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -440,8 +440,8 @@ def _compute_forces( ) -> torch.Tensor: """ Interpolates the space charge force from the grid onto the macroparticles. - Reciprocal function of _deposit_charge_on_grid. - Beam needs to have a flattened batch shape. + Reciprocal function of _deposit_charge_on_grid. `beam` needs to have a flattened + vector shape. """ grad_x, grad_y, grad_z = self._E_plus_vB_field( beam, xp_coordinates, cell_size, grid_dimensions @@ -553,10 +553,10 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: # This flattening is a hack to only think about one vector dimension in the # following code. It is reversed at the end of the function. - # Make sure that the incoming beam has at least one batch dimension - incoming_batched = True + # Make sure that the incoming beam has at least one vector dimension + is_incoming_vectorized = True if len(incoming.particles.shape) == 2: - incoming_batched = False + is_incoming_vectorized = False incoming.particles = incoming.particles.unsqueeze(0) incoming.energy = incoming.energy.unsqueeze(0) incoming.particle_charges = incoming.particle_charges.unsqueeze(0) @@ -599,7 +599,7 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ..., 2 ] * dt.unsqueeze(-1) - if not incoming_batched: + if not is_incoming_vectorized: # Reshape to the original shape outgoing = ParticleBeam.from_xyz_pxpypz( xp_coordinates.squeeze(0), diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index 3bf2a0f4..85360e0c 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -47,8 +47,8 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: gamma = energy / electron_mass_eV igamma2 = torch.where(gamma != 0, 1 / gamma**2, torch.zeros_like(gamma)) - batch_shape = torch.broadcast_tensors(self.length, energy)[0].shape - tm = torch.eye(7, device=device, dtype=dtype).repeat((*batch_shape, 1, 1)) + vector_shape = torch.broadcast_tensors(self.length, energy)[0].shape + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length tm[..., 4, 5] = self.length * igamma2 diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index 8b3d3be0..3b1601b7 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -161,8 +161,8 @@ def transformed_to( :param energy: Reference energy of the beam in eV. :param total_charge: Total charge of the beam in C. """ - # Figure out batch size of the original beam and check that passed arguments - # have the same batch size + # Figure out vector dimensions of the original beam and check that passed + # arguments have the same vector dimensions. shape = self.mu_x.shape not_nones = [ argument diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index b8dcbe6f..20d20aa7 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -294,7 +294,7 @@ def uniform_3d_ellipsoid( Note that: - The generated particles do not have correlation in the momentum directions, and by default a cold beam with no divergence is generated. - - For batched generation, parameters that are not `None` must have the same + - For vectorised generation, parameters that are not `None` must have the same shape. :param num_particles: Number of particles to generate. @@ -336,8 +336,8 @@ def uniform_3d_ellipsoid( argument.shape == shape for argument in not_nones ), "Arguments must have the same shape." - # Expand to batched version for beam creation - batch_shape = shape if len(shape) > 0 else torch.Size([1]) + # Expand to vectorised version for beam creation + vector_shape = shape if len(shape) > 0 else torch.Size([1]) # Set default values without function call in function signature # NOTE that this does not need to be done for values that are passed to the @@ -346,25 +346,25 @@ def uniform_3d_ellipsoid( num_particles if num_particles is not None else torch.tensor(1_000_000) ) radius_x = ( - radius_x.expand(batch_shape) + radius_x.expand(vector_shape) if radius_x is not None - else torch.full(batch_shape, 1e-3) + else torch.full(vector_shape, 1e-3) ) radius_y = ( - radius_y.expand(batch_shape) + radius_y.expand(vector_shape) if radius_y is not None - else torch.full(batch_shape, 1e-3) + else torch.full(vector_shape, 1e-3) ) radius_tau = ( - radius_tau.expand(batch_shape) + radius_tau.expand(vector_shape) if radius_tau is not None - else torch.full(batch_shape, 1e-3) + else torch.full(vector_shape, 1e-3) ) # Generate x, y and ss within the ellipsoid - flattened_x = torch.empty(*batch_shape, num_particles).flatten(end_dim=-2) - flattened_y = torch.empty(*batch_shape, num_particles).flatten(end_dim=-2) - flattened_tau = torch.empty(*batch_shape, num_particles).flatten(end_dim=-2) + flattened_x = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) + flattened_y = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) + flattened_tau = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) for i, (r_x, r_y, r_tau) in enumerate( zip(radius_x.flatten(), radius_y.flatten(), radius_tau.flatten()) ): diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index c369ba71..10c1f4b1 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -72,8 +72,8 @@ def base_rmatrix( r56 = r56 - length / beta**2 * igamma2 - batch_shape = torch.broadcast_tensors(length, k1, hx, tilt, energy)[0].shape - R = torch.eye(7, dtype=dtype, device=device).repeat(*batch_shape, 1, 1) + vector_shape = torch.broadcast_tensors(length, k1, hx, tilt, energy)[0].shape + R = torch.eye(7, dtype=dtype, device=device).repeat(*vector_shape, 1, 1) R[..., 0, 0] = cx R[..., 0, 1] = sx R[..., 0, 5] = dx / beta @@ -102,13 +102,13 @@ def misalignment_matrix( """Shift the beam for tracking beam through misaligned elements""" device = misalignment.device dtype = misalignment.dtype - batch_shape = misalignment.shape[:-1] + vector_shape = misalignment.shape[:-1] - R_exit = torch.eye(7, device=device, dtype=dtype).repeat(*batch_shape, 1, 1) + R_exit = torch.eye(7, device=device, dtype=dtype).repeat(*vector_shape, 1, 1) R_exit[..., 0, 6] = misalignment[..., 0] R_exit[..., 2, 6] = misalignment[..., 1] - R_entry = torch.eye(7, device=device, dtype=dtype).repeat(*batch_shape, 1, 1) + R_entry = torch.eye(7, device=device, dtype=dtype).repeat(*vector_shape, 1, 1) R_entry[..., 0, 6] = -misalignment[..., 0] R_entry[..., 2, 6] = -misalignment[..., 1] diff --git a/cheetah/utils/kde.py b/cheetah/utils/kde.py index 834beeb5..3f873ef6 100644 --- a/cheetah/utils/kde.py +++ b/cheetah/utils/kde.py @@ -15,7 +15,7 @@ def _kde_marginal_pdf( Calculate the 1D marginal probability distribution function of the input tensor based on the number of histogram bins. - :param values: Input tensor with shape :math:`(B, N)`. `B` is the batch shape. + :param values: Input tensor with shape :math:`(B, N)`. `B` is the vector shape. :param bins: Positions of the bins where KDE is calculated. Shape :math:`(N_{bins})`. :param sigma: Gaussian smoothing factor with shape `(1,)`. diff --git a/tests/test_cavity.py b/tests/test_cavity.py index 7129cbc0..08652533 100644 --- a/tests/test_cavity.py +++ b/tests/test_cavity.py @@ -39,7 +39,7 @@ def test_assert_ei_greater_zero(): def test_vectorized_cavity_zero_voltage(voltage): """ Tests that a vectorised cavity with zero voltage does not produce NaNs and that - zero voltage can be batched with non-zero voltage. + zero voltage can be vectorised with non-zero voltage. This was a bug introduced during the vectorisation of Cheetah, when the special case of zero was removed and the `_cavity_rmatrix` method was also used in the case diff --git a/tests/test_dipole.py b/tests/test_dipole.py index ef23628e..08d016e4 100644 --- a/tests/test_dipole.py +++ b/tests/test_dipole.py @@ -53,9 +53,9 @@ def test_dipole_focussing(): @pytest.mark.parametrize("DipoleType", [Dipole, RBend]) -def test_dipole_batched_execution(DipoleType): +def test_dipole_vectorized_execution(DipoleType): """ - Test that a dipole with batch dimensions behaves as expected. + Test that a dipole with vector dimensions behaves as expected. """ incoming = ParticleBeam.from_parameters( num_particles=torch.tensor(100), @@ -63,7 +63,7 @@ def test_dipole_batched_execution(DipoleType): mu_x=torch.tensor(1e-5), ) - # Test batching to generate 3 beam lines + # Test vectorisation to generate 3 beam lines segment = Segment( [ DipoleType( @@ -84,7 +84,7 @@ def test_dipole_batched_execution(DipoleType): # Check different angles do make a difference assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) - # Test batching to generate 18 beamlines + # Test vectorisation to generate 18 beamlines segment = Segment( [ Dipole( @@ -97,7 +97,7 @@ def test_dipole_batched_execution(DipoleType): outgoing = segment(incoming) assert outgoing.particles.shape == torch.Size([2, 3, 3, 100, 7]) - # Test improper batching -- this does not obey torch broadcasting rules + # Test improper vectorisation -- this does not obey torch broadcasting rules segment = Segment( [ Dipole( diff --git a/tests/test_kde.py b/tests/test_kde.py index 8da4177a..5ba8a61d 100644 --- a/tests/test_kde.py +++ b/tests/test_kde.py @@ -45,7 +45,7 @@ def test_kde_1d(): _kde_marginal_pdf(data, bins, torch.rand(3) + 0.1) -def test_kde_1d_batched(): +def test_kde_1d_vectorized(): # test basic usage data = torch.randn((5, 100)) # 5 beamline states, 100 particles in 1D bins = torch.linspace(0, 1, 10) # a single histogram @@ -60,7 +60,7 @@ def test_kde_1d_batched(): _kde_marginal_pdf(data, bins, torch.rand(3) + 0.1) -def test_kde_2d_batched(): +def test_kde_2d_vectorized(): data = torch.randn((3, 2, 100, 6)) # 2 diagnostic paths, # 3 states per diagnostic paths, diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index cc0d66ab..a900dabf 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -23,7 +23,7 @@ def test_quadrupole_off(): assert not torch.allclose(outbeam_quad_on.sigma_x, outbeam_drift.sigma_x) -def test_quadrupole_with_misalignments_batched(): +def test_quadrupole_with_misalignments_vectorized(): """ Test that a quadrupole with misalignments behaves as expected. """ @@ -51,9 +51,9 @@ def test_quadrupole_with_misalignments_batched(): ) -def test_quadrupole_with_misalignments_multiple_batch_dimensions(): +def test_quadrupole_with_misalignments_multiple_vector_dimensions(): """ - Test that a quadrupole with misalignments that have multiple batch dimensions does + Test that a quadrupole with misalignments that have multiple vector dimensions does not raise an error and behaves as expected. """ @@ -81,7 +81,7 @@ def test_quadrupole_with_misalignments_multiple_batch_dimensions(): assert outbeam_quad_with_misalignment.mu_x.shape == misalignments.shape[:-1] -def test_tilted_quadrupole_batch(): +def test_tilted_quadrupole_vectorized(): """ Test that a quadrupole with a tilt behaves as expected in vectorised mode. """ @@ -109,8 +109,7 @@ def test_tilted_quadrupole_batch(): assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) -# TODO Change batched to vectorised -def test_tilted_quadrupole_multiple_batch_dimensions(): +def test_tilted_quadrupole_multiple_vector_dimensions(): """ Test that a quadrupole with tilts that have multiple vectorisation dimensions does not raise an error and behaves as expected. @@ -142,7 +141,7 @@ def test_tilted_quadrupole_multiple_batch_dimensions(): assert outgoing.particles.shape == (2, 3, 10_000, 7) -def test_quadrupole_length_multiple_batch_dimensions(): +def test_quadrupole_length_multiple_vector_dimensions(): """ Test that a quadrupole with lengths that have multiple vectorisation dimensions does not raise an error and behaves as expected. diff --git a/tests/test_screen.py b/tests/test_screen.py index 98007b42..b5ebe065 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -92,7 +92,7 @@ def test_reading_shows_beam_parameter(screen_method): assert torch.any(segment.my_screen.reading > 0.0) -def test_reading_shows_beam_parameter_batched(): +def test_reading_shows_beam_parameter_vectorized(): """ Test that a screen has a reading that shows some sign of the beam having hit it. """ diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 44ebb8eb..95ca9454 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -23,7 +23,7 @@ def test_segment_length_shape(): def test_segment_length_shape_2d(): """ - Test that the shape of a segment's length matches the input for a batch with + Test that the shape of a segment's length matches the input for a vectorisation with multiple dimensions. """ segment = cheetah.Segment( From 0d15a7de5b88996602073299e91b043096fe8ec5 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 16:45:35 +0200 Subject: [PATCH 068/111] Remove `batch_shape` property again --- cheetah/accelerator/element.py | 11 ----------- tests/test_quadrupole.py | 2 -- 2 files changed, 13 deletions(-) diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 560fba23..0aba489f 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -91,17 +91,6 @@ def forward(self, incoming: Beam) -> Beam: """Forward function required by `torch.nn.Module`. Simply calls `track`.""" return self.track(incoming) - @property - def batch_shape(self) -> torch.Size: - tensors = [] - # Get all parameters - for param in self.parameters(): - tensors.append(param) - # Get all buffers - for buffer in self.buffers(): - tensors.append(buffer) - return torch.broadcast_tensors(*tensors)[0].shape - @property @abstractmethod def is_skippable(self) -> bool: diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index a900dabf..2b1b3e57 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -34,8 +34,6 @@ def test_quadrupole_with_misalignments_vectorized(): misalignment=torch.tensor([0.1, 0.1]).unsqueeze(0), ) - assert quad_with_misalignment.batch_shape == torch.Size([1, 2]) - quad_without_misalignment = Quadrupole( length=torch.tensor(1.0), k1=torch.tensor(1.0) ) From e8bd8c2882baac9e4b4661173dddfc3c642c47e7 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 16:54:08 +0200 Subject: [PATCH 069/111] Clean up unnecessary `from torch import Tensor` --- cheetah/utils/bmadx.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cheetah/utils/bmadx.py b/cheetah/utils/bmadx.py index 65d2e4b8..75d850be 100644 --- a/cheetah/utils/bmadx.py +++ b/cheetah/utils/bmadx.py @@ -1,6 +1,5 @@ import torch from scipy.constants import speed_of_light -from torch import Tensor double_precision_epsilon = torch.finfo(torch.float64).eps @@ -31,8 +30,8 @@ def cheetah_to_bmad_z_pz( def bmad_to_cheetah_z_pz( - z: Tensor, pz: Tensor, p0c: Tensor, mc2: float -) -> (Tensor, Tensor, Tensor): + z: torch.Tensor, pz: torch.Tensor, p0c: torch.Tensor, mc2: float +) -> tuple[torch.Tensor]: """ Transforms Bmad longitudinal coordinates to Cheetah coordinates and computes reference energy. From 82df3af8d63f6dcdc001b6b3bbc6a9ce70bbc266 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 16:56:43 +0200 Subject: [PATCH 070/111] Cleanup docstring --- cheetah/utils/physics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cheetah/utils/physics.py b/cheetah/utils/physics.py index 0fd71aec..4c16cac7 100644 --- a/cheetah/utils/physics.py +++ b/cheetah/utils/physics.py @@ -11,8 +11,8 @@ def compute_relativistic_factors( Computes the relativistic factors gamma, inverse gamma squared and beta for electrons. - :param energy: Energy in eV - :return: gamma, igamma2, beta + :param energy: Energy in eV. + :return: gamma, igamma2, beta. """ gamma = energy / electron_mass_eV igamma2 = torch.where(gamma == 0.0, 0.0, 1 / gamma**2) From ca23f5c72113427e82509cccfd27d10d4e2552a5 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 17:41:19 +0200 Subject: [PATCH 071/111] Add test that breaks with `repeat` used for example in `Drift.transfer_map` --- tests/test_vectorized.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 95ca9454..c7068afc 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -330,3 +330,31 @@ def test_vectorized_screen_2d(BeamClass, method): # Check the reading assert segment.my_screen.reading.shape == (2, 3, 100, 100) + + +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Drift, + cheetah.Quadrupole, + cheetah.Cavity, + cheetah.Undulator, + cheetah.Solenoid, + cheetah.HorizontalCorrector, + cheetah.VerticalCorrector, + cheetah.TransverseDeflectingCavity, + ], +) +def test_drift_broadcasting_two_different_inputs(ElementClass): + """ + Test that broadcasting rules are correctly applied to a elements with two different + input shapes. + """ + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=100_000, sigma_x=torch.tensor([1e-5, 2e-5]) + ) + element = ElementClass(length=torch.tensor([[0.6], [0.5], [0.4]])) + + outgoing = element.track(incoming) + + assert outgoing.particles.shape == (3, 2, 100_000, 7) From 5d7c3631428c4cd671fcd5ae3dc7528cacac2fb7 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 17:55:04 +0200 Subject: [PATCH 072/111] Fix test to ask for correct specification --- tests/test_vectorized.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index c7068afc..6a4661ae 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -335,23 +335,26 @@ def test_vectorized_screen_2d(BeamClass, method): @pytest.mark.parametrize( "ElementClass", [ + cheetah.Cavity, + cheetah.CustomTransferMap, + cheetah.Dipole, cheetah.Drift, + cheetah.HorizontalCorrector, cheetah.Quadrupole, - cheetah.Cavity, - cheetah.Undulator, + cheetah.RBend, cheetah.Solenoid, - cheetah.HorizontalCorrector, - cheetah.VerticalCorrector, cheetah.TransverseDeflectingCavity, + cheetah.Undulator, + cheetah.VerticalCorrector, ], ) def test_drift_broadcasting_two_different_inputs(ElementClass): """ Test that broadcasting rules are correctly applied to a elements with two different - input shapes. + input shapes for elements that have a `length` attribute. """ incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, sigma_x=torch.tensor([1e-5, 2e-5]) + num_particles=100_000, energy=torch.tensor([154e6, 14e9]) ) element = ElementClass(length=torch.tensor([[0.6], [0.5], [0.4]])) From 3036b0aa0e3bdd57fe1f3cbf292cfc840407496a Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 17:55:19 +0200 Subject: [PATCH 073/111] Clean up `Drift.transfer_map` broadcasting --- cheetah/accelerator/drift.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index b25fc99d..e88f03ff 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -47,9 +47,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: _, igamma2, beta = compute_relativistic_factors(energy) - tm = torch.eye(7, device=device, dtype=dtype).repeat( - (*(self.length * igamma2).shape, 1, 1) - ) + vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape) + + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length tm[..., 4, 5] = -self.length / beta**2 * igamma2 From c15980e2190cb5d09e3e0c61553d6a21bc8f428b Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 18:08:15 +0200 Subject: [PATCH 074/111] Fix one of the failing tests; remove old broadcast methods --- cheetah/accelerator/cavity.py | 6 ++++-- cheetah/accelerator/drift.py | 14 ++------------ cheetah/accelerator/horizontal_corrector.py | 6 +++--- cheetah/accelerator/quadrupole.py | 12 ------------ cheetah/accelerator/solenoid.py | 3 +-- .../accelerator/transverse_deflecting_cavity.py | 14 -------------- cheetah/accelerator/vertical_corrector.py | 12 ++++++------ cheetah/track_methods.py | 2 +- cheetah/utils/__init__.py | 1 + tests/test_vectorized.py | 1 - 10 files changed, 18 insertions(+), 53 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index da8fc880..b692860e 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -4,16 +4,18 @@ import torch from matplotlib.patches import Rectangle from scipy import constants +from scipy.constants import physical_constants from torch import nn from ..particles import Beam, ParameterBeam, ParticleBeam from ..track_methods import base_rmatrix -from ..utils import UniqueNameGenerator -from ..utils.physics import compute_relativistic_factors, electron_mass_eV +from ..utils import UniqueNameGenerator, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") +electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 + class Cavity(Element): """ diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index e88f03ff..270f923b 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -3,11 +3,10 @@ import matplotlib.pyplot as plt import torch from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from ..particles import Beam, ParticleBeam -from ..utils import UniqueNameGenerator, bmadx -from ..utils.physics import compute_relativistic_factors +from ..utils import UniqueNameGenerator, bmadx, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -116,15 +115,6 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) return outgoing_beam - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - tracking_method=self.tracking_method, - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - @property def is_skippable(self) -> bool: return self.tracking_method == "cheetah" diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index 7980a32d..e17837d6 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -6,8 +6,7 @@ from matplotlib.patches import Rectangle from torch import nn -from ..utils import UniqueNameGenerator -from ..utils.physics import compute_relativistic_factors +from ..utils import UniqueNameGenerator, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -51,7 +50,8 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: _, igamma2, beta = compute_relativistic_factors(energy) - vector_shape = torch.broadcast_tensors(self.length, self.angle, beta)[0].shape + vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape) + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 1, 6] = self.angle diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 0ce069a1..3386279e 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -193,18 +193,6 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) return outgoing_beam - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - k1=self.k1.repeat(shape), - misalignment=self.misalignment.repeat((*shape, 1)), - tilt=self.tilt.repeat(shape), - tracking_method=self.tracking_method, - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - @property def is_skippable(self) -> bool: return self.tracking_method == "cheetah" diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index f9f7b047..c0ebbc10 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -7,8 +7,7 @@ from torch import nn from ..track_methods import misalignment_matrix -from ..utils import UniqueNameGenerator -from ..utils.physics import compute_relativistic_factors +from ..utils import UniqueNameGenerator, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index a4542b2c..c2b2e30a 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -205,20 +205,6 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) return outgoing_beam - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), - voltage=self.voltage.repeat(shape), - phase=self.phase.repeat(shape), - frequency=self.frequency.repeat(shape), - misalignment=self.misalignment.repeat((*shape, 1)), - tilt=self.tilt.repeat(shape), - tracking_method=self.tracking_method, - name=self.name, - device=self.length.device, - dtype=self.length.dtype, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for cavity properly, for now just returns the # element itself diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 154ec852..bd78e367 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -7,7 +7,7 @@ from scipy.constants import physical_constants from torch import nn -from ..utils import UniqueNameGenerator +from ..utils import UniqueNameGenerator, compute_relativistic_factors from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -51,16 +51,16 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - gamma = energy / electron_mass_eV - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - beta = torch.sqrt(1 - igamma2) + _, igamma2, beta = compute_relativistic_factors(energy) - tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape) + + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length tm[..., 3, 6] = self.angle tm[..., 4, 5] = -self.length / beta**2 * igamma2 + return tm @property diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 10c1f4b1..7b3ef696 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -4,7 +4,7 @@ import torch -from .utils.physics import compute_relativistic_factors +from .utils import compute_relativistic_factors def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: diff --git a/cheetah/utils/__init__.py b/cheetah/utils/__init__.py index ce4a5002..eb472409 100644 --- a/cheetah/utils/__init__.py +++ b/cheetah/utils/__init__.py @@ -1,4 +1,5 @@ from . import bmadx # noqa: F401 from .device import is_mps_available_and_functional # noqa: F401 from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401 +from .physics import compute_relativistic_factors # noqa: F401 from .unique_name_generator import UniqueNameGenerator # noqa: F401 diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 6a4661ae..79bd7027 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -336,7 +336,6 @@ def test_vectorized_screen_2d(BeamClass, method): "ElementClass", [ cheetah.Cavity, - cheetah.CustomTransferMap, cheetah.Dipole, cheetah.Drift, cheetah.HorizontalCorrector, From 3aef2f4c62de2968c499fc716ea33cfa4f4dcc07 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 18:15:01 +0200 Subject: [PATCH 075/111] Clean up some more of the automatic broadcasting --- cheetah/accelerator/cavity.py | 2 +- cheetah/accelerator/solenoid.py | 6 +++--- cheetah/track_methods.py | 6 +++++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index b692860e..eaae0e87 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -315,7 +315,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: r66 = Ei / Ef * beta0 / beta1 r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV) - # Check that all matrix elements have the same shape + # Make sure that all matrix elements have the same shape r11, r12, r21, r22, r55_cor, r56, r65, r66 = torch.broadcast_tensors( r11, r12, r21, r22, r55_cor, r56, r65, r66 ) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index c0ebbc10..2d89e208 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -69,9 +69,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: s_k = torch.where(self.k == 0.0, self.length, s / self.k) - vector_shape = torch.broadcast_tensors( - *self.parameters(), *self.buffers(), energy - )[0].shape + vector_shape = torch.broadcast_shapes( + self.length.shape, self.k.shape, energy.shape + ) r56 = torch.where( gamma != 0, self.length / (1 - gamma**2), torch.zeros_like(self.length) diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 7b3ef696..ba2a3ed1 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -72,7 +72,10 @@ def base_rmatrix( r56 = r56 - length / beta**2 * igamma2 - vector_shape = torch.broadcast_tensors(length, k1, hx, tilt, energy)[0].shape + vector_shape = torch.broadcast_shapes( + length.shape, k1.shape, hx.shape, tilt.shape, energy.shape + ) + R = torch.eye(7, dtype=dtype, device=device).repeat(*vector_shape, 1, 1) R[..., 0, 0] = cx R[..., 0, 1] = sx @@ -102,6 +105,7 @@ def misalignment_matrix( """Shift the beam for tracking beam through misaligned elements""" device = misalignment.device dtype = misalignment.dtype + vector_shape = misalignment.shape[:-1] R_exit = torch.eye(7, device=device, dtype=dtype).repeat(*vector_shape, 1, 1) From 2d72fcfa43178fcbba2bbf86c260a4794b6d5b16 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 23 Sep 2024 18:20:53 +0200 Subject: [PATCH 076/111] Fix broken Ocelot comparison test --- tests/test_compare_ocelot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 6665cb8d..1478454e 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -240,7 +240,7 @@ def test_solenoid(): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) assert np.allclose( - outgoing_beam.particles[0, :, :6].cpu().numpy(), + outgoing_beam.particles[..., :6].cpu().numpy(), outgoing_p_array.rparticles.transpose(), ) From ded28e35cb00894c89023e6db344140dcf9ef4ee Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 12:17:46 +0200 Subject: [PATCH 077/111] Complete expected vectorisation test results --- tests/test_vectorized.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 79bd7027..fed2516f 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -360,3 +360,5 @@ def test_drift_broadcasting_two_different_inputs(ElementClass): outgoing = element.track(incoming) assert outgoing.particles.shape == (3, 2, 100_000, 7) + assert outgoing.particle_charges.shape == (100_000,) + assert outgoing.energy.shape == (3, 2) From 24e7d2ae26922319c1b53a9ed743a04a9de6c423 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 14:55:09 +0200 Subject: [PATCH 078/111] Correct expected test results for vectorisation with different input shapes --- tests/test_vectorized.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index fed2516f..b6ea91ad 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -361,4 +361,4 @@ def test_drift_broadcasting_two_different_inputs(ElementClass): assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.particle_charges.shape == (100_000,) - assert outgoing.energy.shape == (3, 2) + assert outgoing.energy.shape == (2,) From 2569502aead506f16d2569754b305fbb09b16342 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 14:59:56 +0200 Subject: [PATCH 079/111] Add special test case for `Cavity` that affects outgoing beam energy --- tests/test_vectorized.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index b6ea91ad..941dc404 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -351,6 +351,9 @@ def test_drift_broadcasting_two_different_inputs(ElementClass): """ Test that broadcasting rules are correctly applied to a elements with two different input shapes for elements that have a `length` attribute. + + NOTE: Cavity is effectively off and therefore does not affect the outgoing beam + energy and the shape of the latter. """ incoming = cheetah.ParticleBeam.from_parameters( num_particles=100_000, energy=torch.tensor([154e6, 14e9]) @@ -362,3 +365,29 @@ def test_drift_broadcasting_two_different_inputs(ElementClass): assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.particle_charges.shape == (100_000,) assert outgoing.energy.shape == (2,) + + +def test_drift_broadcasting_two_different_inputs_cavity_with_energy(): + """ + Test that broadcasting rules are correctly applied to a `Cavity` element with two + different input shapes, when it has an effect on the outgoing beam energy. + + NOTE: This is basically the same test as the `Cavity` case of + `test_drift_broadcasting_two_different_inputs` but with an energy change. + """ + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=100_000, energy=torch.tensor([154e6, 14e9]) + ) + element = cheetah.Cavity( + length=torch.tensor([[0.6], [0.5], [0.4]]), + voltage=torch.tensor(48198468.0), + phase=torch.tensor(48198468.0), + frequency=torch.tensor(2.8560e09), + name="my_test_cavity", + ) + + outgoing = element.track(incoming) + + assert outgoing.particles.shape == (3, 2, 100_000, 7) + assert outgoing.particle_charges.shape == (100_000,) + assert outgoing.energy.shape == (3, 2) From 1a6df807bf7aea106c146f874376f7b04587234f Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 15:21:22 +0200 Subject: [PATCH 080/111] Remove test for cavity affecting energy because it turns out it makes no difference to the shape --- cheetah/accelerator/cavity.py | 6 +++--- cheetah/track_methods.py | 4 ++-- tests/test_vectorized.py | 29 ----------------------------- 3 files changed, 5 insertions(+), 34 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index eaae0e87..c1b53b17 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -228,9 +228,9 @@ def _track_beam(self, incoming: Beam) -> Beam: if isinstance(incoming, ParameterBeam): outgoing = ParameterBeam( - outgoing_mu, - outgoing_cov, - outgoing_energy, + mu=outgoing_mu, + cov=outgoing_cov, + energy=outgoing_energy, total_charge=incoming.total_charge, device=outgoing_mu.device, dtype=outgoing_mu.dtype, diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index ba2a3ed1..26e17804 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -8,7 +8,7 @@ def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: - """Rotate the transfer map in x-y plane + """Rotate the transfer map in x-y plane. :param angle: Rotation angle in rad, for example `angle = np.pi/2` for vertical = dipole. @@ -102,7 +102,7 @@ def base_rmatrix( def misalignment_matrix( misalignment: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """Shift the beam for tracking beam through misaligned elements""" + """Shift the beam for tracking beam through misaligned elements.""" device = misalignment.device dtype = misalignment.dtype diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 941dc404..b6ea91ad 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -351,9 +351,6 @@ def test_drift_broadcasting_two_different_inputs(ElementClass): """ Test that broadcasting rules are correctly applied to a elements with two different input shapes for elements that have a `length` attribute. - - NOTE: Cavity is effectively off and therefore does not affect the outgoing beam - energy and the shape of the latter. """ incoming = cheetah.ParticleBeam.from_parameters( num_particles=100_000, energy=torch.tensor([154e6, 14e9]) @@ -365,29 +362,3 @@ def test_drift_broadcasting_two_different_inputs(ElementClass): assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.particle_charges.shape == (100_000,) assert outgoing.energy.shape == (2,) - - -def test_drift_broadcasting_two_different_inputs_cavity_with_energy(): - """ - Test that broadcasting rules are correctly applied to a `Cavity` element with two - different input shapes, when it has an effect on the outgoing beam energy. - - NOTE: This is basically the same test as the `Cavity` case of - `test_drift_broadcasting_two_different_inputs` but with an energy change. - """ - incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, energy=torch.tensor([154e6, 14e9]) - ) - element = cheetah.Cavity( - length=torch.tensor([[0.6], [0.5], [0.4]]), - voltage=torch.tensor(48198468.0), - phase=torch.tensor(48198468.0), - frequency=torch.tensor(2.8560e09), - name="my_test_cavity", - ) - - outgoing = element.track(incoming) - - assert outgoing.particles.shape == (3, 2, 100_000, 7) - assert outgoing.particle_charges.shape == (100_000,) - assert outgoing.energy.shape == (3, 2) From 852aef005ea07bf3b931fb567296ee94dcce9794 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 15:47:22 +0200 Subject: [PATCH 081/111] Fix broadcasting issue in TDC code --- .../transverse_deflecting_cavity.py | 20 +++++++++++++------ cheetah/utils/bmadx.py | 18 +++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index c2b2e30a..5ac9d6b3 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -4,7 +4,7 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants, speed_of_light -from torch import Size, nn +from torch import nn from cheetah.particles import Beam, ParticleBeam from cheetah.utils import UniqueNameGenerator, bmadx @@ -170,16 +170,24 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) ) - px = px + voltage * torch.sin(phase) + # TODO: Assigning px to px is really bad practice and should be separated into + # two separate variables + px = px + voltage.unsqueeze(-1) * torch.sin(phase) - beta = (1 + pz) * p0c / torch.sqrt(((1 + pz) * p0c) ** 2 + electron_mass_eV**2) + beta = ( + (1 + pz) + * p0c.unsqueeze(-1) + / torch.sqrt(((1 + pz) * p0c.unsqueeze(-1)) ** 2 + electron_mass_eV**2) + ) beta_old = beta - E_old = (1 + pz) * p0c / beta_old - E_new = E_old + voltage * torch.cos(phase) * k_rf * x * p0c + E_old = (1 + pz) * p0c.unsqueeze(-1) / beta_old + E_new = E_old + voltage.unsqueeze(-1) * torch.cos( + phase + ) * k_rf * x * p0c.unsqueeze(-1) pc = torch.sqrt(E_new**2 - electron_mass_eV**2) beta = pc / E_new - pz = (pc - p0c) / p0c + pz = (pc - p0c.unsqueeze(-1)) / p0c.unsqueeze(-1) z = z * beta / beta_old x, y, z = bmadx.track_a_drift( diff --git a/cheetah/utils/bmadx.py b/cheetah/utils/bmadx.py index 75d850be..b5dda62f 100644 --- a/cheetah/utils/bmadx.py +++ b/cheetah/utils/bmadx.py @@ -284,14 +284,16 @@ def track_a_drift( Pxy2 = Px**2 + Py**2 # Particle's transverse mometum^2 over p0^2 Pl = torch.sqrt(1.0 - Pxy2) # Particle's longitudinal momentum over p0 - # z = z + L * ( beta/beta_ref - 1.0/Pl ) but numerically accurate: - dz = length * ( - sqrt_one((mc2**2 * (2 * pz_in + pz_in**2)) / ((p0c * P) ** 2 + mc2**2)) + # z = z + L * ( beta / beta_ref - 1.0 / Pl ) but numerically accurate: + dz = length.unsqueeze(-1) * ( + sqrt_one( + (mc2**2 * (2 * pz_in + pz_in**2)) / ((p0c.unsqueeze(-1) * P) ** 2 + mc2**2) + ) + sqrt_one(-Pxy2) / Pl ) - x_out = x_in + length * Px / Pl - y_out = y_in + length * Py / Pl + x_out = x_in + length.unsqueeze(-1) * Px / Pl + y_out = y_in + length.unsqueeze(-1) * Py / Pl z_out = z_in + dz return x_out, y_out, z_out @@ -299,7 +301,11 @@ def track_a_drift( def particle_rf_time(z, pz, p0c, mc2): """Returns rf time of Particle p.""" - beta = (1 + pz) * p0c / torch.sqrt(((1 + pz) * p0c) ** 2 + mc2**2) + beta = ( + (1 + pz) + * p0c.unsqueeze(-1) + / torch.sqrt(((1 + pz) * p0c.unsqueeze(-1)) ** 2 + mc2**2) + ) time = -z / (beta * speed_of_light) return time From 040942a62b8a026caebccd34e4f5c56589c2c55a Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 15:52:22 +0200 Subject: [PATCH 082/111] Add test for different dimension inputs to run over all bmadx tracking methods --- tests/test_vectorized.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index b6ea91ad..12cceadb 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -362,3 +362,31 @@ def test_drift_broadcasting_two_different_inputs(ElementClass): assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.particle_charges.shape == (100_000,) assert outgoing.energy.shape == (2,) + + +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Dipole, + cheetah.Drift, + cheetah.Quadrupole, + cheetah.TransverseDeflectingCavity, + ], +) +def test_drift_broadcasting_two_different_inputs_bmadx(ElementClass): + """ + Test that broadcasting rules are correctly applied to a elements with two different + input shapes for elements that have a `"bmadx"` tracking method. + """ + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=100_000, energy=torch.tensor([154e6, 14e9]) + ) + element = ElementClass( + tracking_method="bmadx", length=torch.tensor([[0.6], [0.5], [0.4]]) + ) + + outgoing = element.track(incoming) + + assert outgoing.particles.shape == (3, 2, 100_000, 7) + assert outgoing.particle_charges.shape == (100_000,) + assert outgoing.energy.shape == (2,) From 0d3eb38e0f0c569b3d013c8bd883f3ab8279efef Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 15:53:08 +0200 Subject: [PATCH 083/111] Fix flake8 warning --- cheetah/accelerator/quadrupole.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 3386279e..4123121c 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -5,7 +5,7 @@ import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from ..particles import Beam, ParticleBeam from ..track_methods import base_rmatrix, misalignment_matrix From 15ea7540598c2890b93e85045e0274912187409f Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 17:35:08 +0200 Subject: [PATCH 084/111] Fix broadcasting issues in elements with `"bmax"` tracking methods --- cheetah/accelerator/dipole.py | 72 +++++++++++++++++++++++++---------- cheetah/accelerator/drift.py | 9 ++++- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index fe58c71a..d6aab27c 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -184,6 +184,11 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: `ParticleBeam`. :return: Beam exiting the element. """ + # TODO: The renaming of the compinents of `incoming` to just the component name + # makes things hard to read. The resuse and overwriting of those component names + # throughout the function makes it even hard, is bad practice and should really + # be fixed! + # Compute Bmad coordinates and p0c x = incoming.x px = incoming.px @@ -219,9 +224,14 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: z, pz, p0c, electron_mass_eV ) + # Broadcast to align their shapes so that they can be stacked + x, px, y, py, tau, delta = torch.broadcast_tensors(x, px, y, py, tau, delta) + outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + (x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, @@ -255,48 +265,70 @@ def _bmadx_body( px_norm = torch.sqrt((1 + pz) ** 2 - py**2) # For simplicity phi1 = torch.arcsin(px / px_norm) g = self.angle / self.length - gp = g / px_norm + gp = g.unsqueeze(-1) / px_norm alpha = ( 2 - * (1 + g * x) - * torch.sin(self.angle + phi1) - * self.length - * bmadx.sinc(self.angle) - - gp * ((1 + g * x) * self.length * bmadx.sinc(self.angle)) ** 2 + * (1 + g.unsqueeze(-1) * x) + * torch.sin(self.angle.unsqueeze(-1) + phi1) + * self.length.unsqueeze(-1) + * bmadx.sinc(self.angle).unsqueeze(-1) + - gp + * ( + (1 + g.unsqueeze(-1) * x) + * self.length.unsqueeze(-1) + * bmadx.sinc(self.angle).unsqueeze(-1) + ) + ** 2 ) - x2_t1 = x * torch.cos(self.angle) + self.length**2 * g * bmadx.cosc(self.angle) + x2_t1 = x * torch.cos(self.angle.unsqueeze(-1)) + self.length.unsqueeze( + -1 + ) ** 2 * g.unsqueeze(-1) * bmadx.cosc(self.angle.unsqueeze(-1)) - x2_t2 = torch.sqrt((torch.cos(self.angle + phi1) ** 2) + gp * alpha) - x2_t3 = torch.cos(self.angle + phi1) + x2_t2 = torch.sqrt( + (torch.cos(self.angle.unsqueeze(-1) + phi1) ** 2) + gp * alpha + ) + x2_t3 = torch.cos(self.angle.unsqueeze(-1) + phi1) c1 = x2_t1 + alpha / (x2_t2 + x2_t3) c2 = x2_t1 + (x2_t2 - x2_t3) / gp - temp = torch.abs(self.angle + phi1) + temp = torch.abs(self.angle.unsqueeze(-1) + phi1) x2 = c1 * (temp < torch.pi / 2) + c2 * (temp >= torch.pi / 2) Lcu = ( - x2 - self.length**2 * g * bmadx.cosc(self.angle) - x * torch.cos(self.angle) + x2 + - self.length.unsqueeze(-1) ** 2 + * g.unsqueeze(-1) + * bmadx.cosc(self.angle.unsqueeze(-1)) + - x * torch.cos(self.angle.unsqueeze(-1)) ) - Lcv = -self.length * bmadx.sinc(self.angle) - x * torch.sin(self.angle) + Lcv = -self.length.unsqueeze(-1) * bmadx.sinc( + self.angle.unsqueeze(-1) + ) - x * torch.sin(self.angle.unsqueeze(-1)) - theta_p = 2 * (self.angle + phi1 - torch.pi / 2 - torch.arctan2(Lcv, Lcu)) + theta_p = 2 * ( + self.angle.unsqueeze(-1) + phi1 - torch.pi / 2 - torch.arctan2(Lcv, Lcu) + ) Lc = torch.sqrt(Lcu**2 + Lcv**2) Lp = Lc / bmadx.sinc(theta_p / 2) - P = p0c * (1 + pz) # In eV + P = p0c.unsqueeze(-1) * (1 + pz) # In eV E = torch.sqrt(P**2 + mc2**2) # In eV E0 = torch.sqrt(p0c**2 + mc2**2) # In eV beta = P / E beta0 = p0c / E0 x_f = x2 - px_f = px_norm * torch.sin(self.angle + phi1 - theta_p) + px_f = px_norm * torch.sin(self.angle.unsqueeze(-1) + phi1 - theta_p) y_f = y + py * Lp / px_norm - z_f = z + (beta * self.length / beta0) - ((1 + pz) * Lp / px_norm) + z_f = ( + z + + (beta * self.length.unsqueeze(-1) / beta0.unsqueeze(-1)) + - ((1 + pz) * Lp / px_norm) + ) return x_f, px_f, y_f, py, z_f, pz @@ -331,8 +363,8 @@ def _bmadx_fringe_linear( hy = -g * torch.tan( e - 2 * f_int * h_gap * g * (1 + torch.sin(e) ** 2) / torch.cos(e) ) - px_f = px + x * hx - py_f = py + y * hy + px_f = px + x * hx.unsqueeze(-1) + py_f = py + y * hy.unsqueeze(-1) return px_f, py_f diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 270f923b..4438c376 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -106,9 +106,14 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: z, pz, p0c, electron_mass_eV ) + # Broadcast to align their shapes so that they can be stacked + x, px, y, py, tau, delta = torch.broadcast_tensors(x, px, y, py, tau, delta) + outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + [x, px, y, py, tau, delta, torch.ones_like(x)], dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, From c5df11a954e5751b008e3011858c34ac1ea19ada Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 17:52:13 +0200 Subject: [PATCH 085/111] A little cleanup --- tests/test_split.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_split.py b/tests/test_split.py index 868c3f9c..e2d20e96 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -9,7 +9,6 @@ def test_drift_end(): Test that at the end of a split drift the result is the same as at the end of the original drift. """ - original_drift = cheetah.Drift(length=torch.tensor([2.0, 2.5])) split_drift = cheetah.Segment(original_drift.split(resolution=torch.tensor(0.1))) From 6f71b77fac968504280f393b6e35cb35c8bf595b Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 09:22:31 +0200 Subject: [PATCH 086/111] Fix issues with previously existing `Screen` tests --- tests/test_screen.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/test_screen.py b/tests/test_screen.py index b5ebe065..7bc57808 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -25,8 +26,9 @@ def test_reading_shows_beam_particle(screen_method): ) beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") - # before tracking the reading should be an empty tensor - assert torch.numel(segment.my_screen.reading) == 0 + assert isinstance(segment.my_screen.reading, torch.Tensor) + assert segment.my_screen.reading.shape == (100, 100) + assert np.allclose(segment.my_screen.reading, 0.0) _ = segment.track(beam) @@ -55,6 +57,10 @@ def test_screen_kde_bandwidth(kde_bandwidth): ) beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") + assert isinstance(segment.my_screen.reading, torch.Tensor) + assert segment.my_screen.reading.shape == (100, 100) + assert np.allclose(segment.my_screen.reading, 0.0) + _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) @@ -84,6 +90,9 @@ def test_reading_shows_beam_parameter(screen_method): beam = cheetah.ParameterBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") assert isinstance(segment.my_screen.reading, torch.Tensor) + assert segment.my_screen.reading.shape == (100, 100) + assert np.allclose(segment.my_screen.reading, 0.0) + _ = segment.track(beam) assert isinstance(segment.my_screen.reading, torch.Tensor) @@ -140,6 +149,10 @@ def test_reading_shows_beam_ares(screen_method): segment.AREABSCR1.binning = torch.tensor(1, device=segment.AREABSCR1.binning.device) segment.AREABSCR1.is_active = True + assert isinstance(segment.AREABSCR1.reading, torch.Tensor) + assert segment.AREABSCR1.reading.shape == (2040, 2448) + assert np.allclose(segment.AREABSCR1.reading, 0.0) + _ = segment.track(beam) assert isinstance(segment.AREABSCR1.reading, torch.Tensor) From db4c8919a6860f02e98ca55d83be0b49de31c9c7 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 09:24:56 +0200 Subject: [PATCH 087/111] Remove vectorised screen test --- tests/test_screen.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/tests/test_screen.py b/tests/test_screen.py index 7bc57808..0a1a43c3 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -101,31 +101,6 @@ def test_reading_shows_beam_parameter(screen_method): assert torch.any(segment.my_screen.reading > 0.0) -def test_reading_shows_beam_parameter_vectorized(): - """ - Test that a screen has a reading that shows some sign of the beam having hit it. - """ - segment = cheetah.Segment( - elements=[ - cheetah.Drift(length=torch.tensor((1.0, 0.5))), - cheetah.Screen( - resolution=torch.tensor((100, 100)), - pixel_size=torch.tensor((1e-5, 1e-5)), - is_active=True, - name="my_screen", - ), - ], - name="my_segment", - ) - beam = cheetah.ParameterBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") - - assert isinstance(segment.my_screen.reading, torch.Tensor) - _ = segment.track(beam) - - with pytest.raises(NotImplementedError): - segment.my_screen.reading - - @pytest.mark.parametrize("screen_method", ["histogram", "kde"]) def test_reading_shows_beam_ares(screen_method): """ From 091ab6a1443e1210aee7a76501a870deb3ce5387 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 09:50:19 +0200 Subject: [PATCH 088/111] Some cleanup in `Screen` code --- cheetah/accelerator/screen.py | 45 +++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 5d7fd042..8b4019ff 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -25,14 +25,15 @@ class Screen(Element): :param binning: Binning used by the camera. :param misalignment: Misalignment of the screen in meters given as a Tensor `(x, y)`. + :param method: Method used to generate the screen's reading. Can be either + "histogram" or "kde", defaults to "histogram". KDE will be slower but allows + backward differentiation. :param kde_bandwidth: Bandwidth used for the kernel density estimation in meters. Controls the smoothness of the distribution. + :param is_blocking: If `True` the screen is blocking and will stop the beam. :param is_active: If `True` the screen is active and will record the beam's distribution. If `False` the screen is inactive and will not record the beam's distribution. - :param method: Method used to generate the screen's reading. Can be either - "histogram" or "kde", defaults to "histogram". KDE will be slower but allows - backward differentiation. :param name: Unique identifier of the element. NOTE: `method='histogram'` currently does not support vectorisation. Please use @@ -45,9 +46,10 @@ def __init__( pixel_size: Optional[Union[torch.Tensor, nn.Parameter]] = None, binning: Optional[Union[torch.Tensor, nn.Parameter]] = None, misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, + method: Literal["histogram", "kde"] = "histogram", kde_bandwidth: Optional[Union[torch.Tensor, nn.Parameter]] = None, + is_blocking: bool = False, is_active: bool = False, - method: Literal["histogram", "kde"] = "histogram", name: Optional[str] = None, device=None, dtype=torch.float32, @@ -104,6 +106,7 @@ def __init__( else torch.clone(self.pixel_size[0]) ), ) + self.is_blocking = is_blocking self.is_active = is_active self.set_read_beam(None) @@ -188,7 +191,7 @@ def track(self, incoming: Beam) -> Beam: self.set_read_beam(copy_of_incoming) - return incoming + return Beam.empty if self.is_blocking else incoming @property def reading(self) -> torch.Tensor: @@ -198,13 +201,10 @@ def reading(self) -> torch.Tensor: read_beam = self.get_read_beam() if read_beam is Beam.empty or read_beam is None: - return torch.tensor([]) + image = torch.zeros( + (int(self.effective_resolution[1]), int(self.effective_resolution[0])) + ) elif isinstance(read_beam, ParameterBeam): - if torch.numel(read_beam._mu[..., 0]) > 1: - raise NotImplementedError( - "Cannot perform batch screen predictions with ParameterBeam" - ) - transverse_mu = torch.stack( [read_beam._mu[..., 0], read_beam._mu[..., 2]], dim=-1 ) @@ -219,9 +219,14 @@ def reading(self) -> torch.Tensor: ], dim=-1, ) - dist = MultivariateNormal( - loc=transverse_mu, covariance_matrix=transverse_cov - ) + dist = [ + MultivariateNormal( + loc=transverse_mu_sample, covariance_matrix=transverse_cov_sample + ) + for transverse_mu_sample, transverse_cov_sample in zip( + transverse_mu.cpu(), transverse_cov.cpu() + ) + ] left = self.extent[0] right = self.extent[1] @@ -235,9 +240,10 @@ def reading(self) -> torch.Tensor: indexing="ij", ) pos = torch.dstack((x, y)) - image = dist.log_prob(pos).exp() + image = torch.stack( + [dist_sample.log_prob(pos).exp() for dist_sample in dist] + ) image = torch.flip(image, dims=[1]) - elif isinstance(read_beam, ParticleBeam): if self.method == "histogram": # Catch vectorisation, which is currently not supported by "histogram" @@ -248,15 +254,14 @@ def reading(self) -> torch.Tensor: ): raise NotImplementedError( "The `'histogram'` method of `Screen` does not support " - "vectorization. Use `'kde'` instead." + "vectorization. Use `'kde'` instead. If this is a feature you " + "would like to see, please open an issue on GitHub." ) image, _ = torch.histogramdd( torch.stack((read_beam.x, read_beam.y)).T, bins=self.pixel_bin_edges ) - image = torch.flipud(image.T) - elif self.method == "kde": image = kde_histogram_2d( x1=read_beam.x, @@ -279,7 +284,7 @@ def get_read_beam(self) -> Beam: # Using these get and set methods instead of Python's property decorator to # prevent `nn.Module` from intercepting the read beam, which is itself an # `nn.Module`, and registering it as a submodule of the screen. - return self._read_beam if self._read_beam is not None else None + return self._read_beam def set_read_beam(self, value: Beam) -> None: # Using these get and set methods instead of Python's property decorator to From b07a619b87d96fc3e784ca2e1531b30ba9332910 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 10:03:57 +0200 Subject: [PATCH 089/111] Fix tests by reinstating some changes to `Screen` with `ParameterBeam` --- cheetah/accelerator/screen.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 8b4019ff..81a946eb 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -37,7 +37,8 @@ class Screen(Element): :param name: Unique identifier of the element. NOTE: `method='histogram'` currently does not support vectorisation. Please use - `method=`kde` instead. + `method=`kde` instead. Similarly, `ParameterBeam` can also not be vectorised. + Please use `ParticleBeam` instead. """ def __init__( @@ -205,6 +206,13 @@ def reading(self) -> torch.Tensor: (int(self.effective_resolution[1]), int(self.effective_resolution[0])) ) elif isinstance(read_beam, ParameterBeam): + if torch.numel(read_beam._mu[..., 0]) > 1: + raise NotImplementedError( + "`Screen` does not support vectorization of `ParameterBeam`. " + "Please use `ParticleBeam` instead. If this is a feature you would " + "like to see, please open an issue on GitHub." + ) + transverse_mu = torch.stack( [read_beam._mu[..., 0], read_beam._mu[..., 2]], dim=-1 ) @@ -219,14 +227,9 @@ def reading(self) -> torch.Tensor: ], dim=-1, ) - dist = [ - MultivariateNormal( - loc=transverse_mu_sample, covariance_matrix=transverse_cov_sample - ) - for transverse_mu_sample, transverse_cov_sample in zip( - transverse_mu.cpu(), transverse_cov.cpu() - ) - ] + dist = MultivariateNormal( + loc=transverse_mu, covariance_matrix=transverse_cov + ) left = self.extent[0] right = self.extent[1] @@ -240,9 +243,7 @@ def reading(self) -> torch.Tensor: indexing="ij", ) pos = torch.dstack((x, y)) - image = torch.stack( - [dist_sample.log_prob(pos).exp() for dist_sample in dist] - ) + image = dist.log_prob(pos).exp() image = torch.flip(image, dims=[1]) elif isinstance(read_beam, ParticleBeam): if self.method == "histogram": From 4af663070b5cc6daab8031b73f59b1cba965e4ef Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 10:04:20 +0200 Subject: [PATCH 090/111] Add changelog entry for `Screen.is_blocking` --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0c7c456..9dd601e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe, @roussel-ryan) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu) +- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) ### 🚀 Features From 7f875cf903dfde3ef2588d9713de3818918967cb Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 10:25:23 +0200 Subject: [PATCH 091/111] Benchmark timing of broadcasting and sum against reduce and add --- benchmark_sum_reduce.ipynb | 122 +++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 benchmark_sum_reduce.ipynb diff --git a/benchmark_sum_reduce.ipynb b/benchmark_sum_reduce.ipynb new file mode 100644 index 00000000..18cf353c --- /dev/null +++ b/benchmark_sum_reduce.ipynb @@ -0,0 +1,122 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import reduce\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[torch.Size([]), torch.Size([3]), torch.Size([2, 1])]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xs = [\n", + " torch.tensor(42.0),\n", + " torch.tensor([1.0, 2.0, 3.0]),\n", + " torch.tensor([[4.0], [5.0]]),\n", + "]\n", + "[x.shape for x in xs]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 3])]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "broadcast_xs = torch.broadcast_tensors(*xs)\n", + "[bx.shape for bx in broadcast_xs]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9.63 μs ± 16.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "\n", + "broadcast_xs = torch.broadcast_tensors(*xs)\n", + "stacked_xs = torch.stack(broadcast_xs)\n", + "torch.sum(stacked_xs, dim=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.91 μs ± 12 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "\n", + "reduce(torch.add, xs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cheetah-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From cff1416f71192f18def0bad07170c103bb6ecc45 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 10:27:51 +0200 Subject: [PATCH 092/111] Remove not needed special case for segment of length one --- benchmark_sum_reduce.ipynb | 20 ++++++++++++++++++++ cheetah/accelerator/segment.py | 3 --- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/benchmark_sum_reduce.ipynb b/benchmark_sum_reduce.ipynb index 18cf353c..3c52a4fe 100644 --- a/benchmark_sum_reduce.ipynb +++ b/benchmark_sum_reduce.ipynb @@ -96,6 +96,26 @@ "\n", "reduce(torch.add, xs)" ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(42.)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reduce(torch.add, xs[:1])" + ] } ], "metadata": { diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index 38e3d5b8..c2aaf65d 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -348,9 +348,6 @@ def is_skippable(self) -> bool: @property def length(self) -> torch.Tensor: - if len(self.elements) == 1: - return self.elements[0].length - lengths = [element.length for element in self.elements] return reduce(torch.add, lengths) From eefc022f17a900ccbaebbe25c3c7081b3aaf1043 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 10:49:22 +0200 Subject: [PATCH 093/111] Fix vectorisation issue in length computation with zero-length elements --- cheetah/accelerator/element.py | 2 +- test_space_charge_kick_length_shape.ipynb | 178 ++++++++++++++++++++++ tests/test_space_charge_kick.py | 2 +- 3 files changed, 180 insertions(+), 2 deletions(-) create mode 100644 test_space_charge_kick_length_shape.ipynb diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 0aba489f..6cbf433a 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -22,7 +22,7 @@ def __init__(self, name: Optional[str] = None) -> None: super().__init__() self.name = name if name is not None else generate_unique_name() - self.register_buffer("length", torch.zeros((1,))) + self.register_buffer("length", torch.tensor(0.0)) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: r""" diff --git a/test_space_charge_kick_length_shape.ipynb b/test_space_charge_kick_length_shape.ipynb new file mode 100644 index 00000000..cac60b50 --- /dev/null +++ b/test_space_charge_kick_length_shape.ipynb @@ -0,0 +1,178 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import cheetah" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Segment(elements=ModuleList(\n", + " (0): Drift(length=tensor(1.), tracking_method='cheetah', name='unnamed_element_0')\n", + " (1): SpaceChargeKick(effect_length=tensor(1.), num_grid_points_x=32, num_grid_points_y=32, num_grid_points_tau=32, grid_extend_x=tensor(3.), grid_extend_y=tensor(3.), grid_extend_tau=tensor(3.), name='unnamed_element_1')\n", + "), name='unnamed_element_2')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "segment = cheetah.Segment(\n", + " [\n", + " cheetah.Drift(length=1.0),\n", + " cheetah.SpaceChargeKick(effect_length=1.0),\n", + " ]\n", + ")\n", + "segment" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ic| lengths: [tensor(1.), tensor(0.)]\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor(1.)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "segment.length" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ParticleBeam(n=100000, mu_x=tensor(8.2413e-07), mu_px=tensor(5.9885e-08), mu_y=tensor(-1.7276e-06), mu_py=tensor(-1.1746e-07), sigma_x=tensor(0.0002), sigma_px=tensor(3.6794e-06), sigma_y=tensor(0.0002), sigma_py=tensor(3.6941e-06), sigma_tau=tensor(8.0116e-06), sigma_p=tensor(0.0023), energy=tensor(1.0732e+08)) total_charge=tensor(5.0000e-13))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "beam = cheetah.ParticleBeam.from_astra(\"tests/resources/ACHIP_EA1_2021.1351.001\")\n", + "beam" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "IndexError", + "evalue": "index -38 is out of bounds for dimension 3 with size 32", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43msegment\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbeam\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/DESY/cheetah/cheetah/accelerator/segment.py:380\u001b[0m, in \u001b[0;36mSegment.track\u001b[0;34m(self, incoming)\u001b[0m\n\u001b[1;32m 377\u001b[0m todos[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39melements\u001b[38;5;241m.\u001b[39mappend(element)\n\u001b[1;32m 379\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m todo \u001b[38;5;129;01min\u001b[39;00m todos:\n\u001b[0;32m--> 380\u001b[0m incoming \u001b[38;5;241m=\u001b[39m \u001b[43mtodo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mincoming\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 382\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m incoming\n", + "File \u001b[0;32m~/Documents/DESY/cheetah/cheetah/accelerator/space_charge_kick.py:589\u001b[0m, in \u001b[0;36mSpaceChargeKick.track\u001b[0;34m(self, incoming)\u001b[0m\n\u001b[1;32m 587\u001b[0m \u001b[38;5;66;03m# Change coordinates to apply the space charge effect\u001b[39;00m\n\u001b[1;32m 588\u001b[0m xp_coordinates \u001b[38;5;241m=\u001b[39m flattened_incoming\u001b[38;5;241m.\u001b[39mto_xyz_pxpypz()\n\u001b[0;32m--> 589\u001b[0m forces \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_compute_forces\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 590\u001b[0m \u001b[43m \u001b[49m\u001b[43mflattened_incoming\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mxp_coordinates\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcell_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid_dimensions\u001b[49m\n\u001b[1;32m 591\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 592\u001b[0m xp_coordinates[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m xp_coordinates[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m+\u001b[39m forces[\n\u001b[1;32m 593\u001b[0m \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 594\u001b[0m ] \u001b[38;5;241m*\u001b[39m dt\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 595\u001b[0m xp_coordinates[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m3\u001b[39m] \u001b[38;5;241m=\u001b[39m xp_coordinates[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m3\u001b[39m] \u001b[38;5;241m+\u001b[39m forces[\n\u001b[1;32m 596\u001b[0m \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 597\u001b[0m ] \u001b[38;5;241m*\u001b[39m dt\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n", + "File \u001b[0;32m~/Documents/DESY/cheetah/cheetah/accelerator/space_charge_kick.py:513\u001b[0m, in \u001b[0;36mSpaceChargeKick._compute_forces\u001b[0;34m(self, beam, xp_coordinates, cell_size, grid_dimensions)\u001b[0m\n\u001b[1;32m 505\u001b[0m \u001b[38;5;66;03m# Keep dimensions, and set F to zero if non-valid\u001b[39;00m\n\u001b[1;32m 506\u001b[0m force_indices \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 507\u001b[0m idx_vector,\n\u001b[1;32m 508\u001b[0m torch\u001b[38;5;241m.\u001b[39mclamp(idx_x, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39mgrid_shape[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 509\u001b[0m torch\u001b[38;5;241m.\u001b[39mclamp(idx_y, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39mgrid_shape[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 510\u001b[0m torch\u001b[38;5;241m.\u001b[39mclamp(idx_tau, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39mgrid_shape[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 511\u001b[0m )\n\u001b[0;32m--> 513\u001b[0m Fx_values \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mwhere(valid_mask, \u001b[43mgrad_x\u001b[49m\u001b[43m[\u001b[49m\u001b[43mforce_indices\u001b[49m\u001b[43m]\u001b[49m, \u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 514\u001b[0m Fy_values \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mwhere(valid_mask, grad_y[force_indices], \u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 515\u001b[0m Fz_values \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mwhere(\n\u001b[1;32m 516\u001b[0m valid_mask, grad_z[force_indices], \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 517\u001b[0m ) \u001b[38;5;66;03m# (..., 8 * num_particles)\u001b[39;00m\n", + "\u001b[0;31mIndexError\u001b[0m: index -38 is out of bounds for dimension 3 with size 32" + ] + } + ], + "source": [ + "segment.track(beam)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Segment(elements=ModuleList(\n", + " (0): Drift(length=tensor(1.), tracking_method='cheetah', name='unnamed_element_4')\n", + "), name='unnamed_element_5')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "other_segment = cheetah.Segment([cheetah.Drift(length=1.0)])\n", + "other_segment" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "other_segment.length" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cheetah-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index c2c66681..01418d92 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -258,5 +258,5 @@ def test_does_not_break_segment_length(): ] ) - assert segment.length.shape == torch.Size([1]) + assert segment.length.shape == torch.Size([]) assert torch.allclose(segment.length, torch.tensor(1.0)) From 90ffed076b63d432ab1b8de85ff2a3e92a575496 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 10:56:14 +0200 Subject: [PATCH 094/111] Add tests for new error with `SpaceChargeKick` I ran into that I think is the result of vectorisation --- tests/test_space_charge_kick.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 01418d92..278aa6ca 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -260,3 +260,20 @@ def test_does_not_break_segment_length(): assert segment.length.shape == torch.Size([]) assert torch.allclose(segment.length, torch.tensor(1.0)) + + +def test_space_charge_with_ares_astra_beam(): + """ + Tests running space charge through a 1m drift with an Astra beam from the ARES + linac. This test is added because running this code would throw an error: + `IndexError: index -38 is out of bounds for dimension 3 with size 32`. + """ + segment = cheetah.Segment( + [ + cheetah.Drift(length=1.0), + cheetah.SpaceChargeKick(effect_length=1.0), + ] + ) + beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") + + _ = segment.track(beam) From d78750c68c837f20731a3fbaf9522562dbee79eb Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 14:54:19 +0200 Subject: [PATCH 095/111] Fix test that failed on space charge code --- cheetah/accelerator/space_charge_kick.py | 56 +++++++++++++---------- cheetah/particles/particle_beam.py | 18 ++++---- test_space_charge_kick_length_shape.ipynb | 51 +++++++++------------ tests/test_space_charge_kick.py | 5 +- 4 files changed, 66 insertions(+), 64 deletions(-) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 3ad54243..51c97465 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -505,9 +505,9 @@ def _compute_forces( # Keep dimensions, and set F to zero if non-valid force_indices = ( idx_vector, - torch.clamp(idx_x, max=grid_shape[0] - 1), - torch.clamp(idx_y, max=grid_shape[1] - 1), - torch.clamp(idx_tau, max=grid_shape[2] - 1), + torch.clamp(idx_x, min=0, max=grid_shape[0] - 1), + torch.clamp(idx_y, min=0, max=grid_shape[1] - 1), + torch.clamp(idx_tau, min=0, max=grid_shape[2] - 1), ) Fx_values = torch.where(valid_mask, grad_x[force_indices], 0) @@ -554,19 +554,29 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: # following code. It is reversed at the end of the function. # Make sure that the incoming beam has at least one vector dimension - is_incoming_vectorized = True if len(incoming.particles.shape) == 2: is_incoming_vectorized = False - incoming.particles = incoming.particles.unsqueeze(0) - incoming.energy = incoming.energy.unsqueeze(0) - incoming.particle_charges = incoming.particle_charges.unsqueeze(0) + + vectorized_incoming = ParticleBeam( + particles=incoming.particles.unsqueeze(0), + energy=incoming.energy.unsqueeze(0), + particle_charges=incoming.particle_charges.unsqueeze(0), + device=incoming.particles.device, + dtype=incoming.particles.dtype, + ) + else: + is_incoming_vectorized = True + + vectorized_incoming = incoming flattened_incoming = ParticleBeam( - particles=incoming.particles.flatten(end_dim=-3), - energy=incoming.energy.flatten(end_dim=-1), - particle_charges=incoming.particle_charges.flatten(end_dim=-2), - device=incoming.particles.device, - dtype=incoming.particles.dtype, + particles=vectorized_incoming.particles.flatten(end_dim=-3), + energy=vectorized_incoming.energy.flatten(end_dim=-1), + particle_charges=vectorized_incoming.particle_charges.flatten( + end_dim=-2 + ), + device=vectorized_incoming.particles.device, + dtype=vectorized_incoming.particles.dtype, ) flattened_length_effect = self.effect_length.flatten(end_dim=-1) @@ -600,24 +610,24 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ] * dt.unsqueeze(-1) if not is_incoming_vectorized: - # Reshape to the original shape + # Reshape to the original non-vectorised shape outgoing = ParticleBeam.from_xyz_pxpypz( xp_coordinates.squeeze(0), - incoming.energy.squeeze(0), - incoming.particle_charges.squeeze(0), - incoming.particles.device, - incoming.particles.dtype, + vectorized_incoming.energy.squeeze(0), + vectorized_incoming.particle_charges.squeeze(0), + vectorized_incoming.particles.device, + vectorized_incoming.particles.dtype, ) else: - # Reverse the flattening + # Reverse the flattening of the vector dimensions outgoing = ParticleBeam.from_xyz_pxpypz( xp_coordinates.unflatten( - dim=0, sizes=incoming.particles.shape[:-2] + dim=0, sizes=vectorized_incoming.particles.shape[:-2] ), - incoming.energy, - incoming.particle_charges, - incoming.particles.device, - incoming.particles.dtype, + vectorized_incoming.energy, + vectorized_incoming.particle_charges, + vectorized_incoming.particles.device, + vectorized_incoming.particles.dtype, ) return outgoing else: diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 20d20aa7..5f3e0e38 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -727,7 +727,7 @@ def transformed_to( @classmethod def from_xyz_pxpypz( cls, - xp_coords: torch.Tensor, + xp_coordinates: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, device=None, @@ -739,7 +739,7 @@ def from_xyz_pxpypz( is the moment vector $(x, p_x, y, p_y, z, p_z, 1)$. """ beam = cls( - particles=xp_coords.clone(), + particles=xp_coordinates.clone(), energy=energy, particle_charges=particle_charges, device=device, @@ -753,15 +753,17 @@ def from_xyz_pxpypz( * speed_of_light ) p = torch.sqrt( - xp_coords[..., 1] ** 2 + xp_coords[..., 3] ** 2 + xp_coords[..., 5] ** 2 + xp_coordinates[..., 1] ** 2 + + xp_coordinates[..., 3] ** 2 + + xp_coordinates[..., 5] ** 2 ) gamma = torch.sqrt(1 + (p / (electron_mass * speed_of_light)) ** 2) - beam.particles[..., 1] = xp_coords[..., 1] / p0.unsqueeze(-1) - beam.particles[..., 3] = xp_coords[..., 3] / p0.unsqueeze(-1) - beam.particles[..., 4] = -xp_coords[..., 4] / beam.relativistic_beta.unsqueeze( - -1 - ) + beam.particles[..., 1] = xp_coordinates[..., 1] / p0.unsqueeze(-1) + beam.particles[..., 3] = xp_coordinates[..., 3] / p0.unsqueeze(-1) + beam.particles[..., 4] = -xp_coordinates[ + ..., 4 + ] / beam.relativistic_beta.unsqueeze(-1) beam.particles[..., 5] = (gamma - beam.relativistic_gamma.unsqueeze(-1)) / ( (beam.relativistic_beta * beam.relativistic_gamma).unsqueeze(-1) ) diff --git a/test_space_charge_kick_length_shape.ipynb b/test_space_charge_kick_length_shape.ipynb index cac60b50..e78a63d8 100644 --- a/test_space_charge_kick_length_shape.ipynb +++ b/test_space_charge_kick_length_shape.ipynb @@ -30,10 +30,7 @@ ], "source": [ "segment = cheetah.Segment(\n", - " [\n", - " cheetah.Drift(length=1.0),\n", - " cheetah.SpaceChargeKick(effect_length=1.0),\n", - " ]\n", + " [cheetah.Drift(length=1.0), cheetah.SpaceChargeKick(effect_length=1.0)]\n", ")\n", "segment" ] @@ -43,13 +40,6 @@ "execution_count": 3, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ic| lengths: [tensor(1.), tensor(0.)]\n" - ] - }, { "data": { "text/plain": [ @@ -92,18 +82,14 @@ "metadata": {}, "outputs": [ { - "ename": "IndexError", - "evalue": "index -38 is out of bounds for dimension 3 with size 32", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43msegment\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbeam\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/DESY/cheetah/cheetah/accelerator/segment.py:380\u001b[0m, in \u001b[0;36mSegment.track\u001b[0;34m(self, incoming)\u001b[0m\n\u001b[1;32m 377\u001b[0m todos[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39melements\u001b[38;5;241m.\u001b[39mappend(element)\n\u001b[1;32m 379\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m todo \u001b[38;5;129;01min\u001b[39;00m todos:\n\u001b[0;32m--> 380\u001b[0m incoming \u001b[38;5;241m=\u001b[39m \u001b[43mtodo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mincoming\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 382\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m incoming\n", - "File \u001b[0;32m~/Documents/DESY/cheetah/cheetah/accelerator/space_charge_kick.py:589\u001b[0m, in \u001b[0;36mSpaceChargeKick.track\u001b[0;34m(self, incoming)\u001b[0m\n\u001b[1;32m 587\u001b[0m \u001b[38;5;66;03m# Change coordinates to apply the space charge effect\u001b[39;00m\n\u001b[1;32m 588\u001b[0m xp_coordinates \u001b[38;5;241m=\u001b[39m flattened_incoming\u001b[38;5;241m.\u001b[39mto_xyz_pxpypz()\n\u001b[0;32m--> 589\u001b[0m forces \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_compute_forces\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 590\u001b[0m \u001b[43m \u001b[49m\u001b[43mflattened_incoming\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mxp_coordinates\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcell_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid_dimensions\u001b[49m\n\u001b[1;32m 591\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 592\u001b[0m xp_coordinates[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m xp_coordinates[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m+\u001b[39m forces[\n\u001b[1;32m 593\u001b[0m \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 594\u001b[0m ] \u001b[38;5;241m*\u001b[39m dt\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 595\u001b[0m xp_coordinates[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m3\u001b[39m] \u001b[38;5;241m=\u001b[39m xp_coordinates[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m3\u001b[39m] \u001b[38;5;241m+\u001b[39m forces[\n\u001b[1;32m 596\u001b[0m \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 597\u001b[0m ] \u001b[38;5;241m*\u001b[39m dt\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n", - "File \u001b[0;32m~/Documents/DESY/cheetah/cheetah/accelerator/space_charge_kick.py:513\u001b[0m, in \u001b[0;36mSpaceChargeKick._compute_forces\u001b[0;34m(self, beam, xp_coordinates, cell_size, grid_dimensions)\u001b[0m\n\u001b[1;32m 505\u001b[0m \u001b[38;5;66;03m# Keep dimensions, and set F to zero if non-valid\u001b[39;00m\n\u001b[1;32m 506\u001b[0m force_indices \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 507\u001b[0m idx_vector,\n\u001b[1;32m 508\u001b[0m torch\u001b[38;5;241m.\u001b[39mclamp(idx_x, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39mgrid_shape[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 509\u001b[0m torch\u001b[38;5;241m.\u001b[39mclamp(idx_y, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39mgrid_shape[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 510\u001b[0m torch\u001b[38;5;241m.\u001b[39mclamp(idx_tau, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39mgrid_shape[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 511\u001b[0m )\n\u001b[0;32m--> 513\u001b[0m Fx_values \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mwhere(valid_mask, \u001b[43mgrad_x\u001b[49m\u001b[43m[\u001b[49m\u001b[43mforce_indices\u001b[49m\u001b[43m]\u001b[49m, \u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 514\u001b[0m Fy_values \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mwhere(valid_mask, grad_y[force_indices], \u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 515\u001b[0m Fz_values \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mwhere(\n\u001b[1;32m 516\u001b[0m valid_mask, grad_z[force_indices], \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 517\u001b[0m ) \u001b[38;5;66;03m# (..., 8 * num_particles)\u001b[39;00m\n", - "\u001b[0;31mIndexError\u001b[0m: index -38 is out of bounds for dimension 3 with size 32" - ] + "data": { + "text/plain": [ + "ParticleBeam(n=100000, mu_x=tensor(8.8401e-07), mu_px=tensor(5.9885e-08), mu_y=tensor(-1.8451e-06), mu_py=tensor(-1.1746e-07), sigma_x=tensor(0.0002), sigma_px=tensor(3.8466e-06), sigma_y=tensor(0.0002), sigma_py=tensor(3.8615e-06), sigma_tau=tensor(8.0032e-06), sigma_p=tensor(0.0023), energy=tensor(1.0732e+08)) total_charge=tensor(5.0000e-13))" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -112,18 +98,18 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Segment(elements=ModuleList(\n", - " (0): Drift(length=tensor(1.), tracking_method='cheetah', name='unnamed_element_4')\n", - "), name='unnamed_element_5')" + " (0): Drift(length=tensor(1.), tracking_method='cheetah', name='unnamed_element_3')\n", + "), name='unnamed_element_4')" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -135,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -144,7 +130,7 @@ "tensor(1.)" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -152,6 +138,13 @@ "source": [ "other_segment.length" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 278aa6ca..8d91cb32 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -269,10 +269,7 @@ def test_space_charge_with_ares_astra_beam(): `IndexError: index -38 is out of bounds for dimension 3 with size 32`. """ segment = cheetah.Segment( - [ - cheetah.Drift(length=1.0), - cheetah.SpaceChargeKick(effect_length=1.0), - ] + [cheetah.Drift(length=1.0), cheetah.SpaceChargeKick(effect_length=1.0)] ) beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") From 7840113cf446f398af097a47d9d9457c0511d064 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 15:07:27 +0200 Subject: [PATCH 096/111] Clean up vector shape computation in `Undulator` --- cheetah/accelerator/undulator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index 85360e0c..c4c72e2c 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -47,7 +47,8 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: gamma = energy / electron_mass_eV igamma2 = torch.where(gamma != 0, 1 / gamma**2, torch.zeros_like(gamma)) - vector_shape = torch.broadcast_tensors(self.length, energy)[0].shape + vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape) + tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1)) tm[..., 0, 1] = self.length tm[..., 2, 3] = self.length From 3a446f02327b6ab623896377f8cd11e85f29b911 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 15:13:29 +0200 Subject: [PATCH 097/111] Add test that finds issue in the broadcasting when creating a `ParameterBeam` --- tests/test_vectorized.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 12cceadb..1169c33d 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -390,3 +390,18 @@ def test_drift_broadcasting_two_different_inputs_bmadx(ElementClass): assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.particle_charges.shape == (100_000,) assert outgoing.energy.shape == (2,) + + +def test_vectorized_parameter_beam_creation(): + """ + Tests that creating a parameter beam with a few vectorised parameters works as + expected. + """ + beam = cheetah.ParameterBeam.from_parameters( + mu_x=torch.tensor([2e-4, 3e-4]), sigma_x=torch.tensor([1e-5, 2e-5]) + ) + + assert beam.mu_x.shape == (2,) + assert torch.allclose(beam.mu_x, torch.tensor([2e-4, 3e-4])) + assert beam.sigma_x.shape == (2,) + assert torch.allclose(beam.sigma_x, torch.tensor([1e-5, 2e-5])) From f2dc6e3d520fd952472bbf1dd164f901e9fcf253 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 15:17:07 +0200 Subject: [PATCH 098/111] Fix issue with broadcasting when creating a `ParameterBeam` --- cheetah/particles/parameter_beam.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index f5189b15..ccebd331 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -80,15 +80,16 @@ def from_parameters( energy = energy if energy is not None else torch.tensor(1e8) total_charge = total_charge if total_charge is not None else torch.tensor(0.0) + mu_x, mu_px, mu_y, mu_py = torch.broadcast_tensors(mu_x, mu_px, mu_y, mu_py) mu = torch.stack( [ mu_x, mu_px, mu_y, mu_py, - torch.tensor(0.0), - torch.tensor(0.0), - torch.tensor(1.0), + torch.zeros_like(mu_x), + torch.zeros_like(mu_x), + torch.ones_like(mu_x), ], dim=-1, ) From cac728dd371799a0cbfa4add170ad27eda4e8203 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 15:20:35 +0200 Subject: [PATCH 099/111] Fix broadcasting issue in `ParticleBeam.transformed_to` --- cheetah/particles/particle_beam.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 5f3e0e38..afe24238 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -678,6 +678,11 @@ def transformed_to( [mu_x, mu_px, mu_y, mu_py, torch.zeros_like(mu_x), torch.zeros_like(mu_x)], dim=-1, ) + sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p = ( + torch.broadcast_tensors( + sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p + ) + ) new_sigma = torch.stack( [sigma_x, sigma_px, sigma_y, sigma_py, sigma_tau, sigma_p], dim=-1 ) From b4ee9ba56462ecc858523a38a823773571eca78c Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 15:24:41 +0200 Subject: [PATCH 100/111] Fix broadcasting issue in `base_rmatrix` --- cheetah/track_methods.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 26e17804..1f6b6ba1 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -50,8 +50,10 @@ def base_rmatrix( device = length.device dtype = length.dtype - tilt = tilt if tilt is not None else torch.zeros_like(length) - energy = energy if energy is not None else torch.zeros(1) + tilt = tilt if tilt is not None else torch.tensor(0.0, device=device, dtype=dtype) + energy = ( + energy if energy is not None else torch.tensor(0.0, device=device, dtype=dtype) + ) _, igamma2, beta = compute_relativistic_factors(energy) From d4ac307ecfaae70a7638ab0f42fba695b91ba70c Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 15:36:20 +0200 Subject: [PATCH 101/111] Incomplete cleanup of KDE tests --- tests/test_kde.py | 92 +++++++++++++++++++++++------------------------ 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/tests/test_kde.py b/tests/test_kde.py index 5ba8a61d..11fc6840 100644 --- a/tests/test_kde.py +++ b/tests/test_kde.py @@ -8,8 +8,8 @@ def test_weighted_samples_1d(): """ - Test that the 1d KDE histogram implementation correctly handles - heterogeneously weighted samples. + Test that the 1D KDE histogram implementation correctly handles heterogeneously + weighted samples. """ x_unweighted = torch.tensor([1.0, 1.0, 1.0, 2.0]) x_weighted = torch.tensor([1.0, 2.0]) @@ -30,32 +30,66 @@ def test_weighted_samples_1d(): assert not torch.allclose(hist_weighted, hist_neglect_weights) +def test_weighted_samples_2d(): + """ + Test that the 2D KDE histogram implementation correctly handles heterogeneously + weighted samples. + """ + x_unweighted = torch.tensor([[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [2.0, 1.0]]) + x_weighted = torch.tensor([[1.0, 2.0], [2.0, 1.0]]) + + bins1 = torch.linspace(0, 3, 10) + bins2 = torch.linspace(0, 3, 10) + sigma = torch.tensor(0.3) + + # Explicitly use all the samples with the same weights + hist_unweighted = kde_histogram_2d( + x_unweighted[:, 0], x_unweighted[:, 1], bins1, bins2, sigma + ) + # Use samples and taking the weights into account + hist_weighted = kde_histogram_2d( + x_weighted[:, 0], + x_weighted[:, 1], + bins1, + bins2, + sigma, + weights=torch.tensor([3.0, 1.0]), + ) + # Use samples but neglect the weights + hist_neglect_weights = kde_histogram_2d( + x_weighted[:, 0], x_weighted[:, 1], bins1, bins2, sigma + ) + + assert torch.allclose(hist_unweighted, hist_weighted) + assert not torch.allclose(hist_weighted, hist_neglect_weights) + + def test_kde_1d(): - # test basic usage + # Test basic usage data = torch.randn(100) # 5 beamline states, 100 particles in 1D - bins = torch.linspace(0, 1, 10) # a single histogram - sigma = torch.tensor(0.1) # a single bandwidth + bins = torch.linspace(0, 1, 10) # A single histogram + sigma = torch.tensor(0.1) # A single bandwidth pdf = kde_histogram_1d(data, bins, sigma) assert pdf.shape == Size([10]) # 5 histograms at 10 points - # test bad bins + # Test bad bins with pytest.raises(ValueError): _kde_marginal_pdf(data, bins, torch.rand(3) + 0.1) def test_kde_1d_vectorized(): - # test basic usage + # Test basic usage data = torch.randn((5, 100)) # 5 beamline states, 100 particles in 1D - bins = torch.linspace(0, 1, 10) # a single histogram - sigma = torch.tensor(0.1) # a single bandwidth + bins = torch.linspace(0, 1, 10) # A single histogram + sigma = torch.tensor(0.1) # A single bandwidth pdf = kde_histogram_1d(data, bins, sigma) assert pdf.shape == Size([5, 10]) # 5 histograms at 10 points - # test bad bins + # Test bad bins with pytest.raises(ValueError): _kde_marginal_pdf(data, bins, torch.rand(3) + 0.1) @@ -66,46 +100,12 @@ def test_kde_2d_vectorized(): # 3 states per diagnostic paths, # 100 particles in 6D space - # two different bins (1 per path) + # Two different bins (1 per path) n = 30 bins_x = torch.linspace(-20, 20, n) - sigma = torch.tensor(0.1) # a single bandwidth + sigma = torch.tensor(0.1) # A single bandwidth pdf = kde_histogram_2d(data[..., 0], data[..., 1], bins_x, bins_x, sigma) assert pdf.shape == Size([3, 2, n, n]) - - -def test_weighted_samples_2d(): - """ - Test that the 2d KDE histogram implementation correctly handles - heterogeneously weighted samples. - """ - x_unweighted = torch.tensor([[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [2.0, 1.0]]) - x_weighted = torch.tensor([[1.0, 2.0], [2.0, 1.0]]) - - bins1 = torch.linspace(0, 3, 10) - bins2 = torch.linspace(0, 3, 10) - sigma = torch.tensor(0.3) - - # Explicitly use all the samples with the same weights - hist_unweighted = kde_histogram_2d( - x_unweighted[:, 0], x_unweighted[:, 1], bins1, bins2, sigma - ) - # Use samples and taking the weights into account - hist_weighted = kde_histogram_2d( - x_weighted[:, 0], - x_weighted[:, 1], - bins1, - bins2, - sigma, - weights=torch.tensor([3.0, 1.0]), - ) - # Use samples but neglect the weights - hist_neglect_weights = kde_histogram_2d( - x_weighted[:, 0], x_weighted[:, 1], bins1, bins2, sigma - ) - - assert torch.allclose(hist_unweighted, hist_weighted) - assert not torch.allclose(hist_weighted, hist_neglect_weights) From 5a9f560c4c7ab385c9f68ef127b05d262e030eeb Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 15:57:25 +0200 Subject: [PATCH 102/111] Clean up KDE tests --- cheetah/utils/kde.py | 12 +++++------ tests/test_kde.py | 50 ++++++++++++++++++++------------------------ 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/cheetah/utils/kde.py b/cheetah/utils/kde.py index 3f873ef6..36634fff 100644 --- a/cheetah/utils/kde.py +++ b/cheetah/utils/kde.py @@ -12,11 +12,11 @@ def _kde_marginal_pdf( epsilon: Union[torch.Tensor, float] = 1e-10, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Calculate the 1D marginal probability distribution function of the input tensor - based on the number of histogram bins. + Compute the 1D marginal probability distribution function of the input tensor based + on the number of histogram bins. :param values: Input tensor with shape :math:`(B, N)`. `B` is the vector shape. - :param bins: Positions of the bins where KDE is calculated. + :param bins: Positions of the bins where KDE is computed. Shape :math:`(N_{bins})`. :param sigma: Gaussian smoothing factor with shape `(1,)`. :param weights: Input data weights of shape :math:`(B, N)`. Default to None. @@ -80,7 +80,7 @@ def _kde_joint_pdf_2d( epsilon: Union[torch.Tensor, float] = 1e-10, ) -> torch.Tensor: """ - Calculate the joint probability distribution function of the input tensors based on + Compute the joint probability distribution function of the input tensors based on the number of histogram bins. :param kernel_values1: shape :math:`(B, N, N_{bins})`. @@ -122,7 +122,7 @@ def kde_histogram_1d( """ Estimate the histogram using KDE of the input tensor. - The calculation uses kernel density estimation which requires a bandwidth + The computation uses kernel density estimation which requires a bandwidth (smoothing) parameter. :param x: Input tensor to compute the histogram with shape :math:`(B, D)`. @@ -163,7 +163,7 @@ def kde_histogram_2d( """ Estimate the 2D histogram of the input tensor. - The calculation uses kernel density estimation which requires a bandwidth + The computation uses kernel density estimation which requires a bandwidth (smoothing) parameter. This is a modified version of the `kornia.enhance.histogram` implementation. diff --git a/tests/test_kde.py b/tests/test_kde.py index 11fc6840..758ecaf1 100644 --- a/tests/test_kde.py +++ b/tests/test_kde.py @@ -3,7 +3,6 @@ from torch import Size from cheetah.utils import kde_histogram_1d, kde_histogram_2d -from cheetah.utils.kde import _kde_marginal_pdf def test_weighted_samples_1d(): @@ -64,8 +63,11 @@ def test_weighted_samples_2d(): assert not torch.allclose(hist_weighted, hist_neglect_weights) -def test_kde_1d(): - # Test basic usage +def test_kde_1d_basic_usage(): + """ + Test that basic usage of the 1D KDE histogram implementation works, and that the + output has the correct shape. + """ data = torch.randn(100) # 5 beamline states, 100 particles in 1D bins = torch.linspace(0, 1, 10) # A single histogram sigma = torch.tensor(0.1) # A single bandwidth @@ -74,38 +76,32 @@ def test_kde_1d(): assert pdf.shape == Size([10]) # 5 histograms at 10 points - # Test bad bins - with pytest.raises(ValueError): - _kde_marginal_pdf(data, bins, torch.rand(3) + 0.1) - -def test_kde_1d_vectorized(): - # Test basic usage - data = torch.randn((5, 100)) # 5 beamline states, 100 particles in 1D +def test_kde_1d_enforce_bins_shape(): + """ + Test that the 1D KDE histogram implementation correctly enforces the shape of the + bins tensor, i.e. throws an error if the bins tensor has the wrong shape. + """ + data = torch.randn(100) # 5 beamline states, 100 particles in 1D bins = torch.linspace(0, 1, 10) # A single histogram - sigma = torch.tensor(0.1) # A single bandwidth - pdf = kde_histogram_1d(data, bins, sigma) - - assert pdf.shape == Size([5, 10]) # 5 histograms at 10 points - - # Test bad bins with pytest.raises(ValueError): - _kde_marginal_pdf(data, bins, torch.rand(3) + 0.1) + kde_histogram_1d(data, bins, torch.rand(3) + 0.1) -def test_kde_2d_vectorized(): +def test_kde_2d_vectorized_basic_usage(): + """ + Test that the 2D KDE histogram implementation correctly handles vectorized inputs, + and that the output has the correct shape. + """ + # 2 diagnostic paths, 3 states per diagnostic path, 100 particles in 6D space data = torch.randn((3, 2, 100, 6)) - # 2 diagnostic paths, - # 3 states per diagnostic paths, - # 100 particles in 6D space - # Two different bins (1 per path) - n = 30 - bins_x = torch.linspace(-20, 20, n) - - sigma = torch.tensor(0.1) # A single bandwidth + num_bins = 30 + bins_x = torch.linspace(-20, 20, num_bins) + # A single bandwidth + sigma = torch.tensor(0.1) pdf = kde_histogram_2d(data[..., 0], data[..., 1], bins_x, bins_x, sigma) - assert pdf.shape == Size([3, 2, n, n]) + assert pdf.shape == Size([3, 2, num_bins, num_bins]) From 499e94a03075ba6a9ae1af8276837285e02b25b9 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 25 Sep 2024 16:02:27 +0200 Subject: [PATCH 103/111] Remove development notebooks --- benchmark_sum_reduce.ipynb | 142 ------------------ test_space_charge_kick_length_shape.ipynb | 171 ---------------------- 2 files changed, 313 deletions(-) delete mode 100644 benchmark_sum_reduce.ipynb delete mode 100644 test_space_charge_kick_length_shape.ipynb diff --git a/benchmark_sum_reduce.ipynb b/benchmark_sum_reduce.ipynb deleted file mode 100644 index 3c52a4fe..00000000 --- a/benchmark_sum_reduce.ipynb +++ /dev/null @@ -1,142 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from functools import reduce\n", - "\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[torch.Size([]), torch.Size([3]), torch.Size([2, 1])]" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "xs = [\n", - " torch.tensor(42.0),\n", - " torch.tensor([1.0, 2.0, 3.0]),\n", - " torch.tensor([[4.0], [5.0]]),\n", - "]\n", - "[x.shape for x in xs]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 3])]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "broadcast_xs = torch.broadcast_tensors(*xs)\n", - "[bx.shape for bx in broadcast_xs]" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "9.63 μs ± 16.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" - ] - } - ], - "source": [ - "%%timeit\n", - "\n", - "broadcast_xs = torch.broadcast_tensors(*xs)\n", - "stacked_xs = torch.stack(broadcast_xs)\n", - "torch.sum(stacked_xs, dim=0)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.91 μs ± 12 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" - ] - } - ], - "source": [ - "%%timeit\n", - "\n", - "reduce(torch.add, xs)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(42.)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "reduce(torch.add, xs[:1])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cheetah-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/test_space_charge_kick_length_shape.ipynb b/test_space_charge_kick_length_shape.ipynb deleted file mode 100644 index e78a63d8..00000000 --- a/test_space_charge_kick_length_shape.ipynb +++ /dev/null @@ -1,171 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import cheetah" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Segment(elements=ModuleList(\n", - " (0): Drift(length=tensor(1.), tracking_method='cheetah', name='unnamed_element_0')\n", - " (1): SpaceChargeKick(effect_length=tensor(1.), num_grid_points_x=32, num_grid_points_y=32, num_grid_points_tau=32, grid_extend_x=tensor(3.), grid_extend_y=tensor(3.), grid_extend_tau=tensor(3.), name='unnamed_element_1')\n", - "), name='unnamed_element_2')" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "segment = cheetah.Segment(\n", - " [cheetah.Drift(length=1.0), cheetah.SpaceChargeKick(effect_length=1.0)]\n", - ")\n", - "segment" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(1.)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "segment.length" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ParticleBeam(n=100000, mu_x=tensor(8.2413e-07), mu_px=tensor(5.9885e-08), mu_y=tensor(-1.7276e-06), mu_py=tensor(-1.1746e-07), sigma_x=tensor(0.0002), sigma_px=tensor(3.6794e-06), sigma_y=tensor(0.0002), sigma_py=tensor(3.6941e-06), sigma_tau=tensor(8.0116e-06), sigma_p=tensor(0.0023), energy=tensor(1.0732e+08)) total_charge=tensor(5.0000e-13))" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "beam = cheetah.ParticleBeam.from_astra(\"tests/resources/ACHIP_EA1_2021.1351.001\")\n", - "beam" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ParticleBeam(n=100000, mu_x=tensor(8.8401e-07), mu_px=tensor(5.9885e-08), mu_y=tensor(-1.8451e-06), mu_py=tensor(-1.1746e-07), sigma_x=tensor(0.0002), sigma_px=tensor(3.8466e-06), sigma_y=tensor(0.0002), sigma_py=tensor(3.8615e-06), sigma_tau=tensor(8.0032e-06), sigma_p=tensor(0.0023), energy=tensor(1.0732e+08)) total_charge=tensor(5.0000e-13))" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "segment.track(beam)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Segment(elements=ModuleList(\n", - " (0): Drift(length=tensor(1.), tracking_method='cheetah', name='unnamed_element_3')\n", - "), name='unnamed_element_4')" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "other_segment = cheetah.Segment([cheetah.Drift(length=1.0)])\n", - "other_segment" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(1.)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "other_segment.length" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cheetah-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From eea9487ec5bbe7858e8f3ee17c378b0a8a8bf0f8 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 10:17:15 +0200 Subject: [PATCH 104/111] Restore meaningfulness of Bmad-X quadrupole tracking test for 64-bit --- tests/test_quadrupole.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 2b1b3e57..ffda7a77 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -176,7 +176,7 @@ def test_quadrupole_bmadx_tracking(dtype): quadrupole = Quadrupole( length=torch.tensor(1.0), k1=torch.tensor(10.0), - misalignment=torch.tensor([0.01, -0.02]), + misalignment=torch.tensor([0.01, -0.02], dtype=dtype), tilt=torch.tensor(0.5), num_steps=10, tracking_method="bmadx", @@ -195,8 +195,8 @@ def test_quadrupole_bmadx_tracking(dtype): assert torch.allclose( outgoing.particles, outgoing_bmadx.to(dtype), - atol=1e-7 if dtype == torch.float64 else 0.00001, - rtol=1e-7 if dtype == torch.float64 else 1e-6, + atol=1e-14 if dtype == torch.float64 else 1e-5, + rtol=1e-14 if dtype == torch.float64 else 1e-6, ) From c874e5e4c67cc0281c5185ee3d8e67d531e47363 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 10:26:13 +0200 Subject: [PATCH 105/111] Remove too large tolerances from test --- tests/test_quadrupole.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index ffda7a77..b75cf109 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -133,9 +133,7 @@ def test_tilted_quadrupole_multiple_vector_dimensions(): outgoing = segment(incoming) - assert torch.allclose( - outgoing.particles[0, 0], outgoing.particles[0, 1], rtol=1e-1, atol=1e-5 - ) + assert torch.allclose(outgoing.particles[0, 0], outgoing.particles[0, 1]) assert outgoing.particles.shape == (2, 3, 10_000, 7) From 6162c57c8870944d979141b726f568266a735e87 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 11:35:48 +0200 Subject: [PATCH 106/111] Revectorise test that had its vectorisation mistakenly removed --- tests/test_speed_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 8d760f78..a6d96edc 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -55,7 +55,7 @@ def test_merged_transfer_maps_tracking_vectorized(): elements=[ cheetah.Drift(length=torch.tensor(0.6)), cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), - cheetah.Drift(length=torch.tensor(0.4)), + cheetah.Drift(length=torch.linspace(0.3, 0.5, 10)), cheetah.HorizontalCorrector( length=torch.tensor(0.1), angle=torch.tensor(1e-4) ), From 46e8c263e311879f553c03e12d84556f64f9ce84 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 12:09:15 +0200 Subject: [PATCH 107/111] Clean up quadrupole tests --- tests/test_quadrupole.py | 46 +++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index b75cf109..5100edfc 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -27,7 +27,6 @@ def test_quadrupole_with_misalignments_vectorized(): """ Test that a quadrupole with misalignments behaves as expected. """ - quad_with_misalignment = Quadrupole( length=torch.tensor(1.0), k1=torch.tensor(1.0), @@ -54,29 +53,29 @@ def test_quadrupole_with_misalignments_multiple_vector_dimensions(): Test that a quadrupole with misalignments that have multiple vector dimensions does not raise an error and behaves as expected. """ - - misalignments = torch.randn((4, 3, 2)) quad_with_misalignment = Quadrupole( - length=torch.tensor(1.0), k1=torch.tensor(1.0), misalignment=misalignments + length=torch.tensor(1.0), + k1=torch.tensor(1.0), + misalignment=torch.randn((4, 3, 2)) * 5e-4, ) - quad_without_misalignment = Quadrupole( length=torch.tensor(1.0), k1=torch.tensor(1.0) ) - incoming_beam = ParameterBeam.from_parameters( + + incoming = ParameterBeam.from_parameters( sigma_px=torch.tensor(2e-7), sigma_py=torch.tensor(2e-7) ) - outbeam_quad_with_misalignment = quad_with_misalignment(incoming_beam) - outbeam_quad_without_misalignment = quad_without_misalignment(incoming_beam) + + outgoing_with_misalignment = quad_with_misalignment(incoming) + outgoing_without_misalignment = quad_without_misalignment(incoming) # Check that the misalignment has an effect assert not torch.allclose( - outbeam_quad_with_misalignment.mu_x, - outbeam_quad_without_misalignment.mu_x, + outgoing_with_misalignment.mu_x, outgoing_without_misalignment.mu_x ) # Check that the output shape is correct - assert outbeam_quad_with_misalignment.mu_x.shape == misalignments.shape[:-1] + assert outgoing_with_misalignment.mu_x.shape == (4, 3) def test_tilted_quadrupole_vectorized(): @@ -91,8 +90,8 @@ def test_tilted_quadrupole_vectorized(): segment = Segment( [ Quadrupole( - length=torch.tensor([0.5, 0.5, 0.5]), - k1=torch.tensor([1.0, 1.0, 1.0]), + length=torch.tensor(0.5), + k1=torch.tensor(1.0), tilt=torch.tensor([torch.pi / 4, torch.pi / 2, torch.pi * 5 / 4]), ), Drift(length=torch.tensor(0.5)), @@ -100,10 +99,10 @@ def test_tilted_quadrupole_vectorized(): ) outgoing = segment(incoming) - # Check pi/4 and 5/4*pi rotations is the same for quadrupole + # Check that pi/4 and 5/4*pi rotations is the same for quadrupole assert torch.allclose(outgoing.particles[0], outgoing.particles[2]) - # Check pi/2 rotation is different + # Check that pi/2 rotation is different assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) @@ -112,15 +111,18 @@ def test_tilted_quadrupole_multiple_vector_dimensions(): Test that a quadrupole with tilts that have multiple vectorisation dimensions does not raise an error and behaves as expected. """ - tilts = torch.tensor( - [ - [torch.pi / 4, torch.pi / 2, torch.pi * 5 / 4], - [torch.pi * 5 / 4, torch.pi / 2, torch.pi / 4], - ] - ) segment = Segment( [ - Quadrupole(length=torch.tensor(0.5), k1=torch.tensor(1.0), tilt=tilts), + Quadrupole( + length=torch.tensor(0.5), + k1=torch.tensor(1.0), + tilt=torch.tensor( + [ + [torch.pi / 4, torch.pi / 2, torch.pi * 5 / 4], + [torch.pi * 5 / 4, torch.pi / 2, torch.pi / 4], + ] + ), + ), Drift(length=torch.tensor(0.5)), ] ) From 5a2862df1967fae0733e7c4f7e8cec892f17e342 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 12:17:30 +0200 Subject: [PATCH 108/111] Correct assertions in multiple dimensions quadrupole tilt test --- tests/test_quadrupole.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 5100edfc..11e9b330 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -135,9 +135,14 @@ def test_tilted_quadrupole_multiple_vector_dimensions(): outgoing = segment(incoming) - assert torch.allclose(outgoing.particles[0, 0], outgoing.particles[0, 1]) + # Test that shape is correct assert outgoing.particles.shape == (2, 3, 10_000, 7) + # Check that same tilts give same results + assert torch.allclose(outgoing.particles[0, 0], outgoing.particles[1, 2]) + assert torch.allclose(outgoing.particles[0, 1], outgoing.particles[1, 1]) + assert torch.allclose(outgoing.particles[0, 2], outgoing.particles[1, 0]) + def test_quadrupole_length_multiple_vector_dimensions(): """ From 68de5145e91f147bb42c4b95758eb80a257d7b73 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 12:19:27 +0200 Subject: [PATCH 109/111] Minor code readability improvement --- tests/test_quadrupole.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 11e9b330..dbdedac3 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -166,7 +166,7 @@ def test_quadrupole_length_multiple_vector_dimensions(): outgoing = segment(incoming) assert outgoing.particles.shape == (2, 3, 10_000, 7) - assert torch.allclose(outgoing.particles[0, -1], outgoing.particles[1, -2]) + assert torch.allclose(outgoing.particles[0, 2], outgoing.particles[1, 1]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) From e1e4b952301f8c1adb65327dca8705c56f4a3d7d Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 14:06:50 +0200 Subject: [PATCH 110/111] Minor cleanup in vectorisation tests --- tests/test_vectorized.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 1169c33d..87ebd03e 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -217,7 +217,7 @@ def test_cavity_with_zero_and_non_zero_voltage(BeamClass): """ cavity = cheetah.Cavity( length=torch.tensor(3.0441), - voltage=torch.tensor([0.0, 48198468.0, 0.0]), + voltage=torch.tensor([0.0, 48_198_468.0, 0.0]), phase=torch.tensor(48198468.0), frequency=torch.tensor(2.8560e09), name="my_test_cavity", @@ -299,7 +299,7 @@ def test_vectorized_solenoid(BeamClass): @pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam]) -@pytest.mark.parametrize("method", ["kde"]) +@pytest.mark.parametrize("method", ["kde"]) # Currently only KDE supports vectorisation def test_vectorized_screen_2d(BeamClass, method): """ Test that a vectorized `Screen` is able to track a particle beam and produce a From d9816db99e125b665163dbdd30cc432a7293385e Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 15:39:47 +0200 Subject: [PATCH 111/111] Fix missing vector dimension in KDE tests --- tests/test_kde.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_kde.py b/tests/test_kde.py index 758ecaf1..da08b473 100644 --- a/tests/test_kde.py +++ b/tests/test_kde.py @@ -68,13 +68,13 @@ def test_kde_1d_basic_usage(): Test that basic usage of the 1D KDE histogram implementation works, and that the output has the correct shape. """ - data = torch.randn(100) # 5 beamline states, 100 particles in 1D + data = torch.randn((5, 100)) # 5 beamline states, 100 particles in 1D bins = torch.linspace(0, 1, 10) # A single histogram sigma = torch.tensor(0.1) # A single bandwidth pdf = kde_histogram_1d(data, bins, sigma) - assert pdf.shape == Size([10]) # 5 histograms at 10 points + assert pdf.shape == Size([5, 10]) # 5 histograms at 10 points def test_kde_1d_enforce_bins_shape(): @@ -82,7 +82,7 @@ def test_kde_1d_enforce_bins_shape(): Test that the 1D KDE histogram implementation correctly enforces the shape of the bins tensor, i.e. throws an error if the bins tensor has the wrong shape. """ - data = torch.randn(100) # 5 beamline states, 100 particles in 1D + data = torch.randn((5, 100)) # 5 beamline states, 100 particles in 1D bins = torch.linspace(0, 1, 10) # A single histogram with pytest.raises(ValueError):