diff --git a/CHANGELOG.md b/CHANGELOG.md index 3385a850..7996495a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### 🚨 Breaking Changes -None +- The handling of `device` and `dtype` was overhauled. They might not behave as expected. `Element`s also no longer have a `device` attribute. (see #115) (@jank324) ### 🚀 Features @@ -15,7 +15,7 @@ None - Fix the transfer maps in `Drift` and `Dipole`; Add R56 in horizontal and vertical correctors modelling (see #90) (@cr-xu) - Fix fringe_field_exit of `Dipole` is overwritten by `fringe_field` bug (see #99) (@cr-xu) -- Fix error caused by mismatched devices on machines with CUDA GPUs (see #97) (@jank324) +- Fix error caused by mismatched devices on machines with CUDA GPUs (see #97 and #115) (@jank324) - Fix error raised when tracking a `ParameterBeam` through an active `BPM` (see #101) (@jank324) - Fix error in ASTRA beam import where the energy was set to `float64` instead of `float32` (see #111) (@jank324) - Fix missing passing of `total_charge` in `ParameterBeam.transformed_to` (see #112) (@jank324) diff --git a/cheetah/accelerator.py b/cheetah/accelerator.py index 527e917b..8defa667 100644 --- a/cheetah/accelerator.py +++ b/cheetah/accelerator.py @@ -15,7 +15,6 @@ from cheetah.converters.dontbmad import convert_bmad_lattice from cheetah.converters.nxtables import read_nx_tables -from cheetah.error import DeviceError from cheetah.latticejson import load_cheetah_model, save_cheetah_model from cheetah.particles import Beam, ParameterBeam, ParticleBeam from cheetah.track_methods import base_rmatrix, misalignment_matrix, rotation_matrix @@ -38,19 +37,13 @@ class Element(ABC, nn.Module): Base class for elements of particle accelerators. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ - def __init__(self, name: Optional[str] = None, device: str = "auto") -> None: + def __init__(self, name: Optional[str] = None) -> None: super().__init__() self.name = name if name is not None else generate_unique_name() - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: """ Generates the element's transfer map that describes how the beam and its @@ -82,9 +75,6 @@ def track(self, incoming: Beam) -> Beam: if incoming is Beam.empty: return incoming elif isinstance(incoming, ParameterBeam): - if self.device != incoming.device: - raise DeviceError - tm = self.transfer_map(incoming.energy) mu = torch.matmul(tm, incoming._mu) cov = torch.matmul(tm, torch.matmul(incoming._cov, tm.t())) @@ -92,20 +82,19 @@ def track(self, incoming: Beam) -> Beam: mu, cov, incoming.energy, - device=incoming.device, total_charge=incoming.total_charge, + device=mu.device, + dtype=mu.dtype, ) elif isinstance(incoming, ParticleBeam): - if self.device != incoming.device: - raise DeviceError - tm = self.transfer_map(incoming.energy) new_particles = torch.matmul(incoming.particles, tm.t()) return ParticleBeam( new_particles, incoming.energy, - device=incoming.device, particle_charges=incoming.particle_charges, + device=new_particles.device, + dtype=new_particles.dtype, ) else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") @@ -159,10 +148,7 @@ def plot(self, ax: matplotlib.axes.Axes, s: float) -> None: raise NotImplementedError def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(name={repr(self.name)}," - f" device={repr(self.device)})" - ) + return f"{self.__class__.__name__}(name={repr(self.name)})" class CustomTransferMap(Element): @@ -175,18 +161,20 @@ def __init__( transfer_map: Union[torch.Tensor, nn.Parameter], length: Optional[torch.Tensor] = None, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) assert isinstance(transfer_map, torch.Tensor) assert transfer_map.shape == (7, 7) - self._transfer_map = transfer_map.to(self.device) + self._transfer_map = torch.as_tensor(transfer_map, **factory_kwargs) self.length = ( - length.to(self.device) + torch.as_tensor(length, **factory_kwargs) if length is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) @classmethod @@ -210,7 +198,10 @@ def from_merging_elements( " incorrect tracking results." ) - tm = torch.eye(7, device=incoming_beam.device) + device = elements[0].transfer_map(incoming_beam.energy).device + dtype = elements[0].transfer_map(incoming_beam.energy).dtype + + tm = torch.eye(7, device=device, dtype=dtype) for element in elements: tm = torch.matmul(element.transfer_map(incoming_beam.energy), tm) incoming_beam = element.track(incoming_beam) @@ -219,7 +210,7 @@ def from_merging_elements( element.length for element in elements if hasattr(element, "length") ) - return cls(tm, length=combined_length, device=elements[0].device) + return cls(tm, length=combined_length, device=device, dtype=dtype) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return self._transfer_map @@ -248,28 +239,33 @@ class Drift(Element): :param length: Length in meters. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( self, length: Union[torch.Tensor, nn.Parameter], name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) - self.length = length.to(self.device) + self.length = torch.as_tensor(length, **factory_kwargs) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - gamma = energy / rest_energy.to(self.device) + device = self.length.device + dtype = self.length.dtype + + gamma = energy / rest_energy.to(device=device, dtype=dtype) igamma2 = ( - 1 / gamma**2 if gamma != 0 else torch.tensor(0.0, device=self.device) + 1 / gamma**2 + if gamma != 0 + else torch.tensor(0.0, device=device, dtype=dtype) ) beta = torch.sqrt(1 - igamma2) - tm = torch.eye(7, device=self.device) + tm = torch.eye(7, device=device, dtype=dtype) tm[0, 1] = self.length tm[2, 3] = self.length tm[4, 5] = -self.length / beta**2 * igamma2 @@ -284,7 +280,7 @@ def split(self, resolution: torch.Tensor) -> list[Element]: split_elements = [] remaining = self.length while remaining > 0: - element = Drift(torch.min(resolution, remaining), device=self.device) + element = Drift(torch.min(resolution, remaining)) split_elements.append(element) remaining -= resolution return split_elements @@ -297,10 +293,7 @@ def defining_features(self) -> list[str]: return super().defining_features + ["length"] def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(length={repr(self.length)}," - f" name={repr(self.name)}, device={repr(self.device)})" - ) + return f"{self.__class__.__name__}(length={repr(self.length)})" class Quadrupole(Element): @@ -313,8 +306,6 @@ class Quadrupole(Element): :param tilt: Tilt angle of the quadrupole in x-y plane [rad]. pi/4 for skew-quadrupole. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -324,41 +315,45 @@ def __init__( misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) - self.length = length.to(self.device) + self.length = torch.as_tensor(length, **factory_kwargs) self.k1 = ( - k1.to(self.device) + torch.as_tensor(k1, **factory_kwargs) if k1 is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.misalignment = ( - misalignment.to(self.device) + torch.as_tensor(misalignment, **factory_kwargs) if misalignment is not None - else torch.tensor([0.0, 0.0], device=self.device) + else torch.tensor([0.0, 0.0], **factory_kwargs) ) self.tilt = ( - tilt.to(self.device) + torch.as_tensor(tilt, **factory_kwargs) if tilt is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: + device = self.length.device + dtype = self.length.dtype + R = base_rmatrix( length=self.length, k1=self.k1, - hx=torch.tensor(0.0, device=self.device), + hx=torch.tensor(0.0, device=device, dtype=dtype), tilt=self.tilt, energy=energy, - device=self.device, ) if self.misalignment[0] == 0 and self.misalignment[1] == 0: return R else: - R_exit, R_entry = misalignment_matrix(self.misalignment, self.device) + R_exit, R_entry = misalignment_matrix(self.misalignment) R = torch.matmul(R_exit, torch.matmul(R, R_entry)) return R @@ -378,7 +373,6 @@ def split(self, resolution: torch.Tensor) -> list[Element]: torch.min(resolution, remaining), self.k1, misalignment=self.misalignment, - device=self.device, ) split_elements.append(element) remaining -= resolution @@ -402,8 +396,7 @@ def __repr__(self) -> None: + f"k1={repr(self.k1)}, " + f"misalignment={repr(self.misalignment)}, " + f"tilt={repr(self.tilt)}, " - + f"name={repr(self.name)}, " - + f'device="{repr(self.device)}")' + + f"name={repr(self.name)})" ) @@ -421,8 +414,6 @@ class Dipole(Element): integral of the exit face. :param gap: The magnet gap [m], NOTE in MAD and ELEGANT: HGAP = gap/2 :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -436,53 +427,55 @@ def __init__( fringe_integral_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None, gap: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ): - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) - self.length = length.to(self.device) + self.length = torch.as_tensor(length, **factory_kwargs) self.angle = ( - angle.to(self.device) + torch.as_tensor(angle, **factory_kwargs) if angle is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.gap = ( - gap.to(self.device) + torch.as_tensor(gap, **factory_kwargs) if gap is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.tilt = ( - tilt.to(self.device) + torch.as_tensor(tilt, **factory_kwargs) if tilt is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.name = name self.fringe_integral = ( - fringe_integral.to(self.device) + torch.as_tensor(fringe_integral, **factory_kwargs) if fringe_integral is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.fringe_integral_exit = ( self.fringe_integral if fringe_integral_exit is None - else fringe_integral_exit.to(self.device) + else torch.as_tensor(fringe_integral_exit, **factory_kwargs) ) # Rectangular bend self.e1 = ( - e1.to(self.device) + torch.as_tensor(e1, **factory_kwargs) if e1 is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.e2 = ( - e2.to(self.device) + torch.as_tensor(e2, **factory_kwargs) if e2 is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) @property def hx(self) -> torch.Tensor: if self.length == 0.0: - return torch.tensor(0.0, device=self.device) + return torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype) else: return self.angle / self.length @@ -495,20 +488,22 @@ def is_active(self): return self.angle != 0 def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: + device = self.length.device + dtype = self.length.dtype + R_enter = self._transfer_map_enter() R_exit = self._transfer_map_exit() if self.length != 0.0: # Bending magnet with finite length R = base_rmatrix( length=self.length, - k1=torch.tensor(0.0), + k1=torch.tensor(0.0, device=device, dtype=dtype), hx=self.hx, - tilt=torch.tensor(0.0), + tilt=torch.tensor(0.0, device=device, dtype=dtype), energy=energy, - device=self.device, ) else: # Reduce to Thin-Corrector - R = torch.eye(7, device=self.device) + R = torch.eye(7, device=device, dtype=dtype) R[0, 1] = self.length R[2, 6] = self.angle R[2, 3] = self.length @@ -524,7 +519,10 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: def _transfer_map_enter(self) -> torch.Tensor: """Linear transfer map for the entrance face of the dipole magnet.""" - sec_e = torch.tensor(1.0, device=self.device) / torch.cos(self.e1) + device = self.length.device + dtype = self.length.dtype + + sec_e = torch.tensor(1.0, device=device, dtype=dtype) / torch.cos(self.e1) phi = ( self.fringe_integral * self.hx @@ -533,7 +531,7 @@ def _transfer_map_enter(self) -> torch.Tensor: * (1 + torch.sin(self.e1) ** 2) ) - tm = torch.eye(7, device=self.device) + tm = torch.eye(7, device=device, dtype=dtype) tm[1, 0] = self.hx * torch.tan(self.e1) tm[3, 2] = -self.hx * torch.tan(self.e1 - phi) @@ -541,6 +539,9 @@ def _transfer_map_enter(self) -> torch.Tensor: def _transfer_map_exit(self) -> torch.Tensor: """Linear transfer map for the exit face of the dipole magnet.""" + device = self.length.device + dtype = self.length.dtype + sec_e = 1.0 / torch.cos(self.e2) phi = ( self.fringe_integral_exit @@ -550,7 +551,7 @@ def _transfer_map_exit(self) -> torch.Tensor: * (1 + torch.sin(self.e2) ** 2) ) - tm = torch.eye(7, device=self.device) + tm = torch.eye(7, device=device, dtype=dtype) tm[1, 0] = self.hx * torch.tan(self.e2) tm[3, 2] = -self.hx * torch.tan(self.e2 - phi) @@ -571,8 +572,7 @@ def __repr__(self): + f"fringe_integral={repr(self.fringe_integral)}," + f"fringe_integral_exit={repr(self.fringe_integral_exit)}," + f"gap={repr(self.gap)}," - + f"name={repr(self.name)}, " - + f"device={repr(self.device)})" + + f"name={repr(self.name)})" ) @property @@ -612,8 +612,6 @@ class RBend(Dipole): integral of the exit face. :param gap: The magnet gap [m], NOTE in MAD and ELEGANT: HGAP = gap/2 :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -627,7 +625,8 @@ def __init__( fringe_integral_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None, gap: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ): angle = angle if angle is not None else torch.tensor(0.0) e1 = e1 if e1 is not None else torch.tensor(0.0) @@ -653,6 +652,7 @@ def __init__( gap=gap, name=name, device=device, + dtype=dtype, ) @@ -665,8 +665,6 @@ class HorizontalCorrector(Element): :param length: Length in meters. :param angle: Particle deflection angle in the horizontal plane in rad. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -674,25 +672,32 @@ def __init__( length: Union[torch.Tensor, nn.Parameter], angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) - self.length = length.to(self.device) + self.length = torch.as_tensor(length, **factory_kwargs) self.angle = ( - angle.to(self.device) + torch.as_tensor(angle, **factory_kwargs) if angle is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - gamma = energy / rest_energy.to(self.device) + device = self.length.device + dtype = self.length.dtype + + gamma = energy / rest_energy.to(device=device, dtype=dtype) igamma2 = ( - 1 / gamma**2 if gamma != 0 else torch.tensor(0.0, device=self.device) + 1 / gamma**2 + if gamma != 0 + else torch.tensor(0.0, device=device, dtype=dtype) ) beta = torch.sqrt(1 - igamma2) - tm = torch.eye(7, device=self.device) + tm = torch.eye(7, device=device, dtype=dtype) tm[0, 1] = self.length tm[1, 6] = self.angle tm[2, 3] = self.length @@ -713,9 +718,7 @@ def split(self, resolution: torch.Tensor) -> list[Element]: remaining = self.length while remaining > 0: length = torch.min(resolution, remaining) - element = HorizontalCorrector( - length, self.angle * length / self.length, device=self.device - ) + element = HorizontalCorrector(length, self.angle * length / self.length) split_elements.append(element) remaining -= resolution return split_elements @@ -737,8 +740,7 @@ def __repr__(self) -> str: return ( f"{self.__class__.__name__}(length={repr(self.length)}, " + f"angle={repr(self.angle)}, " - + f"name={repr(self.name)}, " - + f"device={repr(self.device)})" + + f"name={repr(self.name)})" ) @@ -751,8 +753,6 @@ class VerticalCorrector(Element): :param length: Length in meters. :param angle: Particle deflection angle in the vertical plane in rad. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -760,25 +760,32 @@ def __init__( length: Union[torch.Tensor, nn.Parameter], angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) - self.length = length.to(self.device) + self.length = torch.as_tensor(length, **factory_kwargs) self.angle = ( - angle.to(self.device) + torch.as_tensor(angle, **factory_kwargs) if angle is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - gamma = energy / rest_energy.to(self.device) + device = self.length.device + dtype = self.length.dtype + + gamma = energy / rest_energy.to(device=device, dtype=dtype) igamma2 = ( - 1 / gamma**2 if gamma != 0 else torch.tensor(0.0, device=self.device) + 1 / gamma**2 + if gamma != 0 + else torch.tensor(0.0, device=device, dtype=dtype) ) beta = torch.sqrt(1 - igamma2) - tm = torch.eye(7, device=self.device) + tm = torch.eye(7, device=device, dtype=dtype) tm[0, 1] = self.length tm[2, 3] = self.length tm[3, 6] = self.angle @@ -798,9 +805,7 @@ def split(self, resolution: torch.Tensor) -> list[Element]: remaining = self.length while remaining > 0: length = torch.min(resolution, remaining) - element = VerticalCorrector( - length, self.angle * length / self.length, device=self.device - ) + element = VerticalCorrector(length, self.angle * length / self.length) split_elements.append(element) remaining -= resolution return split_elements @@ -822,8 +827,7 @@ def __repr__(self) -> str: return ( f"{self.__class__.__name__}(length={repr(self.length)}, " + f"angle={repr(self.angle)}, " - + f"name={repr(self.name)}, " - + f"device={repr(self.device)})" + + f"name={repr(self.name)})" ) @@ -836,8 +840,6 @@ class Cavity(Element): :param phase: Phase of the cavity in degrees. :param frequency: Frequency of the cavity in Hz. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -847,25 +849,27 @@ def __init__( phase: Optional[Union[torch.Tensor, nn.Parameter]] = None, frequency: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) - self.length = length.to(self.device) + self.length = torch.as_tensor(length, **factory_kwargs) self.voltage = ( - voltage.to(self.device) + torch.as_tensor(voltage, **factory_kwargs) if voltage is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.phase = ( - phase.to(self.device) + torch.as_tensor(phase, **factory_kwargs) if phase is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.frequency = ( - frequency.to(self.device) + torch.as_tensor(frequency, **factory_kwargs) if frequency is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) @property @@ -877,16 +881,18 @@ def is_skippable(self) -> bool: return not self.is_active def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: + device = self.length.device + dtype = self.length.dtype + if self.voltage > 0: return self._cavity_rmatrix(energy) else: return base_rmatrix( length=self.length, - k1=torch.tensor(0.0, device=self.device), - hx=torch.tensor(0.0, device=self.device), - tilt=torch.tensor(0.0, device=self.device), + k1=torch.tensor(0.0, device=device, dtype=dtype), + hx=torch.tensor(0.0, device=device, dtype=dtype), + tilt=torch.tensor(0.0, device=device, dtype=dtype), energy=energy, - device=self.device, ) def track(self, incoming: Beam) -> Beam: @@ -906,11 +912,14 @@ def track(self, incoming: Beam) -> Beam: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam: - beta0 = torch.tensor(1.0, device=self.device) - igamma2 = torch.tensor(0.0, device=self.device) - g0 = torch.tensor(1e10, device=self.device) + device = self.length.device + dtype = self.length.dtype + + beta0 = torch.tensor(1.0, device=device, dtype=dtype) + igamma2 = torch.tensor(0.0, device=device, dtype=dtype) + g0 = torch.tensor(1e10, device=device, dtype=dtype) if incoming.energy != 0: - g0 = incoming.energy / electron_mass_eV.to(self.device) + g0 = incoming.energy / electron_mass_eV.to(device=device, dtype=dtype) igamma2 = 1 / g0**2 beta0 = torch.sqrt(1 - igamma2) @@ -1033,7 +1042,8 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam: outgoing_cov, outgoing_energy, total_charge=incoming.total_charge, - device=incoming.device, + device=outgoing_mu.device, + dtype=outgoing_mu.dtype, ) return outgoing else: # ParticleBeam @@ -1041,16 +1051,20 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam: outgoing_particles, outgoing_energy, particle_charges=incoming.particle_charges, - device=incoming.device, + device=outgoing_particles.device, + dtype=outgoing_particles.dtype, ) return outgoing def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: """Produces an R-matrix for a cavity when it is on, i.e. voltage > 0.0.""" + device = self.length.device + dtype = self.length.dtype + phi = torch.deg2rad(self.phase) delta_energy = self.voltage * torch.cos(phi) # Comment from Ocelot: Pure pi-standing-wave case - eta = torch.tensor(1.0, device=self.device) + eta = torch.tensor(1.0, device=device, dtype=dtype) Ei = energy / electron_mass_eV Ef = (energy + delta_energy) / electron_mass_eV Ep = (Ef - Ei) / self.length # Derivative of the energy @@ -1111,7 +1125,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: r66 = Ei / Ef * beta0 / beta1 r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV) - R = torch.eye(7, device=self.device) + R = torch.eye(7, device=device, dtype=dtype) R[0, 0] = r11 R[0, 1] = r12 R[1, 0] = r21 @@ -1151,8 +1165,7 @@ def __repr__(self) -> str: + f"voltage={repr(self.voltage)}, " + f"phase={repr(self.voltage)}, " + f"frequency={repr(self.frequency)}, " - + f"name={repr(self.name)}, " - + f"device={repr(self.device)})" + + f"name={repr(self.name)})" ) @@ -1163,14 +1176,10 @@ class BPM(Element): :param is_active: If `True` the BPM is active and will record the beam's position. If `False` the BPM is inactive and will not record the beam's position. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ - def __init__( - self, is_active: bool = False, name: Optional[str] = None, device: str = "auto" - ) -> None: - super().__init__(name=name, device=device) + def __init__(self, is_active: bool = False, name: Optional[str] = None) -> None: + super().__init__(name=name) self.is_active = is_active self.reading = None @@ -1180,7 +1189,7 @@ def is_skippable(self) -> bool: return not self.is_active def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - return torch.eye(7, device=self.device) + return torch.eye(7, device=energy.device, dtype=energy.dtype) def track(self, incoming: Beam) -> Beam: if incoming is Beam.empty: @@ -1209,10 +1218,7 @@ def defining_features(self) -> list[str]: return super().defining_features def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(name={repr(self.name)}," - f" device={repr(self.device)})" - ) + return f"{self.__class__.__name__}(name={repr(self.name)})" class Marker(Element): @@ -1220,15 +1226,13 @@ class Marker(Element): General Marker / Monitor element :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ - def __init__(self, name: Optional[str] = None, device: str = "auto") -> None: - super().__init__(name=name, device=device) + def __init__(self, name: Optional[str] = None) -> None: + super().__init__(name=name) def transfer_map(self, energy): - return torch.eye(7, device=self.device) + return torch.eye(7, device=energy.device, dtype=energy.dtype) def track(self, incoming): # TODO: At some point Markers should be able to be active or inactive. Active @@ -1252,10 +1256,7 @@ def defining_features(self) -> list[str]: return super().defining_features def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(name={repr(self.name)}," - f" device={repr(self.device)})" - ) + return f"{self.__class__.__name__}(name={repr(self.name)})" class Screen(Element): @@ -1273,8 +1274,6 @@ class Screen(Element): distribution. If `False` the screen is inactive and will not record the beam's distribution. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -1285,29 +1284,31 @@ def __init__( misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, is_active: bool = False, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) self.resolution = ( - resolution.to(self.device) + torch.as_tensor(resolution, device=device) if resolution is not None - else torch.tensor((1024, 1024), device=self.device) + else torch.tensor((1024, 1024), device=device) ) self.pixel_size = ( - pixel_size.to(self.device) + torch.as_tensor(pixel_size, **factory_kwargs) if pixel_size is not None - else torch.tensor((1e-3, 1e-3), device=self.device) + else torch.tensor((1e-3, 1e-3), **factory_kwargs) ) self.binning = ( - binning.to(self.device) + torch.as_tensor(binning, device=device) if binning is not None - else torch.tensor(1, device=self.device) + else torch.tensor(1, device=device) ) self.misalignment = ( - misalignment.to(self.device) + torch.as_tensor(misalignment, **factory_kwargs) if misalignment is not None - else torch.tensor((0.0, 0.0), device=self.device) + else torch.tensor((0.0, 0.0), **factory_kwargs) ) self.is_active = is_active @@ -1327,12 +1328,14 @@ def effective_pixel_size(self) -> torch.Tensor: return self.pixel_size * self.binning @property - def extent(self) -> tuple[float, float, float, float]: - return ( - -self.resolution[0] * self.pixel_size[0] / 2, - self.resolution[0] * self.pixel_size[0] / 2, - -self.resolution[1] * self.pixel_size[1] / 2, - self.resolution[1] * self.pixel_size[1] / 2, + def extent(self) -> torch.Tensor: + return torch.stack( + [ + -self.resolution[0] * self.pixel_size[0] / 2, + self.resolution[0] * self.pixel_size[0] / 2, + -self.resolution[1] * self.pixel_size[1] / 2, + self.resolution[1] * self.pixel_size[1] / 2, + ] ) @property @@ -1351,7 +1354,10 @@ def pixel_bin_edges(self) -> tuple[torch.Tensor, torch.Tensor]: ) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - return torch.eye(7, device=self.device) + device = self.misalignment.device + dtype = self.misalignment.dtype + + return torch.eye(7, device=device, dtype=dtype) def track(self, incoming: Beam) -> Beam: if self.is_active: @@ -1459,8 +1465,7 @@ def __repr__(self) -> str: + f"binning={repr(self.binning)}, " + f"misalignment={repr(self.misalignment)}, " + f"is_active={repr(self.is_active)}, " - + f"name={repr(self.name)}, " - + f"device={repr(self.device)})" + + f"name={repr(self.name)})" ) @@ -1473,8 +1478,6 @@ class Aperture(Element): :param shape: Shape of the aperture. Can be "rectangular" or "elliptical". :param is_active: If the aperture actually blocks particles. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -1484,19 +1487,21 @@ def __init__( shape: Literal["rectangular", "elliptical"] = "rectangular", is_active: bool = True, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) self.x_max = ( - x_max.to(self.device) + torch.as_tensor(x_max, **factory_kwargs) if x_max is not None - else torch.tensor(float("inf"), device=self.device) + else torch.tensor(float("inf"), **factory_kwargs) ) self.y_max = ( - y_max.to(self.device) + torch.as_tensor(y_max, **factory_kwargs) if y_max is not None - else torch.tensor(float("inf"), device=self.device) + else torch.tensor(float("inf"), **factory_kwargs) ) self.shape = shape self.is_active = is_active @@ -1508,7 +1513,10 @@ def is_skippable(self) -> bool: return not self.is_active def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - return torch.eye(7, device=self.device) + device = self.x_max.device + dtype = self.x_max.dtype + + return torch.eye(7, device=device, dtype=dtype) def track(self, incoming: Beam) -> Beam: # Only apply aperture to particle beams and if the element is active @@ -1545,7 +1553,8 @@ def track(self, incoming: Beam) -> Beam: outgoing_particles, incoming.energy, particle_charges=outgoing_particle_charges, - device=incoming.device, + device=outgoing_particles.device, + dtype=outgoing_particles.dtype, ) if outgoing_particles.shape[0] > 0 else ParticleBeam.empty @@ -1581,8 +1590,7 @@ def __repr__(self) -> str: + f"y_max={repr(self.y_max)}, " + f"shape={repr(self.shape)}, " + f"is_active={repr(self.is_active)}, " - + f"name={repr(self.name)}, " - + f"device={repr(self.device)})" + + f"name={repr(self.name)})" ) @@ -1596,8 +1604,6 @@ class Undulator(Element): :param is_active: Indicates if the undulator is active or not. Currently has no effect. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -1605,20 +1611,27 @@ def __init__( length: Union[torch.Tensor, nn.Parameter], is_active: bool = False, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) - self.length = length.to(self.device) + self.length = torch.as_tensor(length, **factory_kwargs) self.is_active = is_active def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - gamma = energy / rest_energy + device = self.length.device + dtype = self.length.dtype + + gamma = energy / rest_energy.to(device=device, dtype=dtype) igamma2 = ( - 1 / gamma**2 if gamma != 0 else torch.tensor(0.0, device=self.device) + 1 / gamma**2 + if gamma != 0 + else torch.tensor(0.0, device=device, dtype=dtype) ) - tm = torch.eye(7, device=self.device) + tm = torch.eye(7, device=device, dtype=dtype) tm[0, 1] = self.length tm[2, 3] = self.length tm[4, 5] = self.length * igamma2 @@ -1650,8 +1663,7 @@ def __repr__(self) -> str: return ( f"{self.__class__.__name__}(length={repr(self.length)}, " + f"is_active={repr(self.is_active)}, " - + f"name={repr(self.name)}, " - + f"device={repr(self.device)})" + + f"name={repr(self.name)})" ) @@ -1667,8 +1679,6 @@ class Solenoid(Element): :param misalignment: Misalignment vector of the solenoid magnet in x- and y-directions. :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ def __init__( @@ -1677,41 +1687,46 @@ def __init__( k: Optional[Union[torch.Tensor, nn.Parameter]] = None, misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: - super().__init__(name=name, device=device) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) self.length = ( - length.to(self.device) + torch.as_tensor(length, **factory_kwargs) if length is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.k = ( - k.to(self.device) + torch.as_tensor(k, **factory_kwargs) if k is not None - else torch.tensor(0.0, device=self.device) + else torch.tensor(0.0, **factory_kwargs) ) self.misalignment = ( - misalignment.to(self.device) + torch.as_tensor(misalignment, **factory_kwargs) if misalignment is not None - else torch.tensor((0.0, 0.0), device=self.device) + else torch.tensor((0.0, 0.0), **factory_kwargs) ) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - gamma = energy / rest_energy + device = self.length.device + dtype = self.length.dtype + + gamma = energy / rest_energy.to(device=device, dtype=dtype) c = torch.cos(self.length * self.k) s = torch.sin(self.length * self.k) if self.k == 0: s_k = self.length else: s_k = s / self.k - r56 = torch.tensor(0.0, device=self.device) + r56 = torch.tensor(0.0, device=device, dtype=dtype) if gamma != 0: gamma2 = gamma * gamma beta = torch.sqrt(1.0 - 1.0 / gamma2) r56 -= self.length / (beta * beta * gamma2) - R = torch.eye(7, device=self.device) + R = torch.eye(7, device=device, dtype=dtype) R[0, 0] = c**2 R[0, 1] = c * s_k R[0, 2] = s * c @@ -1735,7 +1750,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: if self.misalignment[0] == 0 and self.misalignment[1] == 0: return R else: - R_exit, R_entry = misalignment_matrix(self.misalignment, self.device) + R_exit, R_entry = misalignment_matrix(self.misalignment) R = torch.matmul(R_exit, torch.matmul(R, R_entry)) return R @@ -1768,8 +1783,7 @@ def __repr__(self) -> str: f"{self.__class__.__name__}(length={repr(self.length)}, " + f"k={repr(self.k)}, " + f"misalignment={repr(self.misalignment)}, " - + f"name={repr(self.name)}, " - + f"device={repr(self.device)})" + + f"name={repr(self.name)})" ) @@ -1779,20 +1793,14 @@ class Segment(Element): :param cell: List of Cheetah elements that describe an accelerator (section). :param name: Unique identifier of the element. - :param device: Device to move the beam's particle array to. If set to `"auto"` a - CUDA GPU is selected if available. The CPU is used otherwise. """ - def __init__( - self, elements: list[Element], name: str = "unnamed", device: str = "auto" - ) -> None: - super().__init__(name=name, device=device) + def __init__(self, elements: list[Element], name: str = "unnamed") -> None: + super().__init__(name=name) self.elements = nn.ModuleList(elements) for element in self.elements: - element.device = self.device - # Make elements accessible via .name attribute. If multiple elements have # the same name, they are accessible via a list. if element.name in self.__dict__: @@ -1803,7 +1811,7 @@ def __init__( else: self.__dict__[element.name] = element - def subcell(self, start: str, end: str, **kwargs) -> "Segment": + def subcell(self, start: str, end: str) -> "Segment": """Extract a subcell `[start, end]` from an this segment.""" subcell = [] is_in_subcell = False @@ -1815,7 +1823,7 @@ def subcell(self, start: str, end: str, **kwargs) -> "Segment": if element.name == end: break - return self.__class__(subcell, device=self.device, **kwargs) + return self.__class__(subcell) def flattened(self) -> "Segment": """ @@ -1829,7 +1837,7 @@ def flattened(self) -> "Segment": else: flattened_elements.append(element) - return Segment(elements=flattened_elements, name=self.name, device=self.device) + return Segment(elements=flattened_elements, name=self.name) def transfer_maps_merged( self, incoming_beam: Beam, except_for: Optional[list[str]] = None @@ -1879,7 +1887,7 @@ def transfer_maps_merged( ) ) - return Segment(elements=merged_elements, name=self.name, device=self.device) + return Segment(elements=merged_elements, name=self.name) def without_inactive_markers( self, except_for: Optional[list[str]] = None @@ -1906,7 +1914,6 @@ def without_inactive_markers( if not isinstance(element, Marker) or element.name in except_for ], name=self.name, - device=self.device, ) def without_inactive_zero_length_elements( @@ -1935,7 +1942,6 @@ def without_inactive_zero_length_elements( or element.name in except_for ], name=self.name, - device=self.device, ) def inactive_elements_as_drifts( @@ -1967,7 +1973,6 @@ def inactive_elements_as_drifts( for element in self.elements ], name=self.name, - device=self.device, ) @classmethod @@ -2000,7 +2005,13 @@ def to_lattice_json( @classmethod def from_ocelot( - cls, cell, name: Optional[str] = None, warnings: bool = True, **kwargs + cls, + cell, + name: Optional[str] = None, + warnings: bool = True, + device=None, + dtype=torch.float32, + **kwargs, ) -> "Segment": """ Translate an Ocelot cell to a Cheetah `Segment`. @@ -2019,7 +2030,10 @@ def from_ocelot( """ from cheetah.converters.nocelot import ocelot2cheetah - converted = [ocelot2cheetah(element, warnings=warnings) for element in cell] + converted = [ + ocelot2cheetah(element, warnings=warnings, device=device, dtype=dtype) + for element in cell + ] return cls(converted, name=name, **kwargs) @classmethod @@ -2063,7 +2077,7 @@ def is_skippable(self) -> bool: return all(element.is_skippable for element in self.elements) @property - def length(self) -> float: + def length(self) -> torch.Tensor: lengths = torch.stack( [element.length for element in self.elements if hasattr(element, "length")] ) @@ -2071,7 +2085,7 @@ def length(self) -> float: def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: if self.is_skippable: - tm = torch.eye(7, dtype=torch.float32, device=self.device) + tm = torch.eye(7, device=energy.device, dtype=energy.dtype) for element in self.elements: tm = torch.matmul(element.transfer_map(energy), tm) return tm @@ -2087,7 +2101,7 @@ def track(self, incoming: Beam) -> Beam: if not element.is_skippable: todos.append(element) elif not todos or not todos[-1].is_skippable: - todos.append(Segment([element], device=self.device)) + todos.append(Segment([element])) else: todos[-1].elements.append(element) @@ -2275,6 +2289,5 @@ def plot_twiss_over_lattice(self, beam: Beam, figsize=(8, 4)) -> None: def __repr__(self) -> str: return ( f"{self.__class__.__name__}(elements={repr(self.elements)}, " - + f"name={repr(self.name)}, " - + f"device={repr(self.device)})" + + f"name={repr(self.name)})" ) diff --git a/cheetah/converters/nocelot.py b/cheetah/converters/nocelot.py index f47180d5..6b5f9385 100644 --- a/cheetah/converters/nocelot.py +++ b/cheetah/converters/nocelot.py @@ -3,7 +3,9 @@ import cheetah -def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": +def ocelot2cheetah( + element, warnings: bool = True, device=None, dtype=torch.float32 +) -> "cheetah.Element": """ Translate an Ocelot element to a Cheetah element. @@ -29,31 +31,42 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": if isinstance(element, ocelot.Drift): return cheetah.Drift( - length=torch.tensor(element.l, dtype=torch.float32), name=element.id + length=torch.tensor(element.l, dtype=torch.float32), + name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.Quadrupole): return cheetah.Quadrupole( length=torch.tensor(element.l, dtype=torch.float32), k1=torch.tensor(element.k1, dtype=torch.float32), name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.Solenoid): return cheetah.Solenoid( length=torch.tensor(element.l, dtype=torch.float32), k=torch.tensor(element.k, dtype=torch.float32), name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.Hcor): return cheetah.HorizontalCorrector( length=torch.tensor(element.l, dtype=torch.float32), angle=torch.tensor(element.angle, dtype=torch.float32), name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.Vcor): return cheetah.VerticalCorrector( length=torch.tensor(element.l, dtype=torch.float32), angle=torch.tensor(element.angle, dtype=torch.float32), name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.Bend): return cheetah.Dipole( @@ -66,6 +79,8 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), gap=torch.tensor(element.gap, dtype=torch.float32), name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.SBend): return cheetah.Dipole( @@ -78,6 +93,8 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), gap=torch.tensor(element.gap, dtype=torch.float32), name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.RBend): return cheetah.RBend( @@ -90,6 +107,8 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), gap=torch.tensor(element.gap, dtype=torch.float32), name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.Cavity): return cheetah.Cavity( @@ -97,6 +116,9 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9, frequency=torch.tensor(element.freq, dtype=torch.float32), phase=torch.tensor(element.phi, dtype=torch.float32), + name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.TDCavity): # TODO: Better replacement at some point? @@ -105,6 +127,9 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9, frequency=torch.tensor(element.freq, dtype=torch.float32), phase=torch.tensor(element.phi, dtype=torch.float32), + name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.Monitor) and ("BSC" in element.id): # NOTE This pattern is very specific to ARES and will need a more complex @@ -118,6 +143,8 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": resolution=torch.tensor([2448, 2040]), pixel_size=torch.tensor([3.5488e-6, 2.5003e-6]), name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.Monitor) and "BPM" in element.id: return cheetah.BPM(name=element.id) @@ -127,7 +154,10 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": return cheetah.Marker(name=element.id) elif isinstance(element, ocelot.Undulator): return cheetah.Undulator( - torch.tensor(element.l, dtype=torch.float32), name=element.id + torch.tensor(element.l, dtype=torch.float32), + name=element.id, + device=device, + dtype=dtype, ) elif isinstance(element, ocelot.Aperture): shape_translation = {"rect": "rectangular", "elip": "elliptical"} @@ -137,6 +167,8 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": shape=shape_translation[element.type], is_active=True, name=element.id, + device=device, + dtype=dtype, ) else: if warnings: @@ -145,7 +177,10 @@ def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": " replacing with drift section." ) return cheetah.Drift( - length=torch.tensor(element.l, dtype=torch.float32), name=element.id + length=torch.tensor(element.l, dtype=torch.float32), + name=element.id, + device=device, + dtype=dtype, ) diff --git a/cheetah/error.py b/cheetah/error.py deleted file mode 100644 index bdd8d6e3..00000000 --- a/cheetah/error.py +++ /dev/null @@ -1,11 +0,0 @@ -class DeviceError(Exception): - """ - Used to create an exception, in case the device used for the beam - and the elements are different. - """ - - def __init__(self): - super().__init__( - "Warning! The device used for calculating the elements is not the same, " - "as the device used to calculate the Beam." - ) diff --git a/cheetah/particles.py b/cheetah/particles.py index bcb58b0f..e31e31de 100644 --- a/cheetah/particles.py +++ b/cheetah/particles.py @@ -303,19 +303,21 @@ def __init__( cov: torch.Tensor, energy: torch.Tensor, total_charge: Optional[torch.Tensor] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - - self._mu = mu.to(device) - self._cov = cov.to(device) - total_charge = total_charge if total_charge is not None else torch.tensor(0.0) - self.total_charge = total_charge.to(device) - self.energy = energy + self._mu = torch.as_tensor(mu, **factory_kwargs) + self._cov = torch.as_tensor(cov, **factory_kwargs) + total_charge = ( + total_charge + if total_charge is not None + else torch.tensor(0.0, **factory_kwargs) + ) + self.total_charge = torch.as_tensor(total_charge, **factory_kwargs) + self.energy = torch.as_tensor(energy, **factory_kwargs) @classmethod def from_parameters( @@ -335,7 +337,8 @@ def from_parameters( cor_s: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> "ParameterBeam": # Set default values without function call in function signature mu_x = mu_x if mu_x is not None else torch.tensor(0.0) @@ -398,7 +401,8 @@ def from_twiss( cor_s: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> "ParameterBeam": # Set default values without function call in function signature beta_x = beta_x if beta_x is not None else torch.tensor(0.0) @@ -435,7 +439,7 @@ def from_twiss( ) @classmethod - def from_ocelot(cls, parray, device: str = "auto") -> "ParameterBeam": + def from_ocelot(cls, parray, device=None, dtype=torch.float32) -> "ParameterBeam": """Load an Ocelot ParticleArray `parray` as a Cheetah Beam.""" mu = torch.ones(7) mu[:6] = torch.tensor(parray.rparticles.mean(axis=1), dtype=torch.float32) @@ -447,11 +451,16 @@ def from_ocelot(cls, parray, device: str = "auto") -> "ParameterBeam": total_charge = torch.tensor(np.sum(parray.q_array), dtype=torch.float32) return cls( - mu=mu, cov=cov, energy=energy, total_charge=total_charge, device=device + mu=mu, + cov=cov, + energy=energy, + total_charge=total_charge, + device=device, + dtype=dtype, ) @classmethod - def from_astra(cls, path: str, **kwargs) -> "ParameterBeam": + def from_astra(cls, path: str, device=None, dtype=torch.float32) -> "ParameterBeam": """Load an Astra particle distribution as a Cheetah Beam.""" from cheetah.converters.astralavista import from_astrabeam @@ -469,7 +478,8 @@ def from_astra(cls, path: str, **kwargs) -> "ParameterBeam": cov=cov, energy=torch.tensor(energy, dtype=torch.float32), total_charge=total_charge, - **kwargs, + device=device, + dtype=dtype, ) def transformed_to( @@ -486,6 +496,8 @@ def transformed_to( sigma_p: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, + device=None, + dtype=torch.float32, ) -> "ParameterBeam": """ Create version of this beam that is transformed to new beam parameters. @@ -504,6 +516,9 @@ def transformed_to( :param energy: Energy of the beam in eV. :param total_charge: Total charge of the beam in C. """ + device = device if device is not None else self.mu_x.device + dtype = dtype if dtype is not None else self.mu_x.dtype + mu_x = mu_x if mu_x is not None else self.mu_x mu_xp = mu_xp if mu_xp is not None else self.mu_xp mu_y = mu_y if mu_y is not None else self.mu_y @@ -530,7 +545,8 @@ def transformed_to( sigma_p=sigma_p, energy=energy, total_charge=total_charge, - device=self.device, + device=device, + dtype=dtype, ) @property @@ -617,26 +633,24 @@ def __init__( particles: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> None: super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} assert ( len(particles) > 0 and particles.shape[1] == 7 ), "Particle vectors must be 7-dimensional." - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - - self.particles = particles.to(self.device) + self.particles = particles.to(**factory_kwargs) num_particles = len(self.particles) self.particle_charges = ( - particle_charges.to(self.device) + particle_charges.to(**factory_kwargs) if particle_charges is not None - else torch.zeros(num_particles, dtype=torch.float32, device=self.device) + else torch.zeros(num_particles, **factory_kwargs) ) - self.energy = energy.to(self.device) + self.energy = energy.to(**factory_kwargs) @classmethod def from_parameters( @@ -657,7 +671,8 @@ def from_parameters( cor_s: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> "ParticleBeam": """ Generate Cheetah Beam of random particles. @@ -701,7 +716,7 @@ def from_parameters( energy = energy if energy is not None else torch.tensor(1e8) total_charge = total_charge if total_charge is not None else torch.tensor(0.0) particle_charges = ( - torch.ones(num_particles, dtype=torch.float32) + torch.ones(num_particles, device=device, dtype=dtype) * total_charge / num_particles ) @@ -724,11 +739,17 @@ def from_parameters( cov[5, 4] = cor_s cov[5, 5] = sigma_p**2 - particles = torch.ones((num_particles, 7), dtype=torch.float32) + particles = torch.ones((num_particles, 7)) distribution = MultivariateNormal(mean, covariance_matrix=cov) particles[:, :6] = distribution.sample((num_particles,)) - return cls(particles, energy, particle_charges=particle_charges, device=device) + return cls( + particles, + energy, + particle_charges=particle_charges, + device=device, + dtype=dtype, + ) @classmethod def from_twiss( @@ -745,7 +766,8 @@ def from_twiss( sigma_p: Optional[torch.Tensor] = None, cor_s: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> "ParticleBeam": # Set default values without function call in function signature num_particles = ( @@ -788,6 +810,7 @@ def from_twiss( cor_y=cor_y, total_charge=total_charge, device=device, + dtype=dtype, ) @classmethod @@ -806,7 +829,8 @@ def make_linspaced( sigma_p: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, - device: str = "auto", + device=None, + dtype=torch.float32, ) -> "ParticleBeam": """ Generate Cheetah Beam of *n* linspaced particles. @@ -848,55 +872,47 @@ def make_linspaced( / num_particles ) - particles = torch.ones((num_particles, 7), dtype=torch.float32) + particles = torch.ones((num_particles, 7)) - particles[:, 0] = torch.linspace( - mu_x - sigma_x, mu_x + sigma_x, num_particles, dtype=torch.float32 - ) + particles[:, 0] = torch.linspace(mu_x - sigma_x, mu_x + sigma_x, num_particles) particles[:, 1] = torch.linspace( - mu_xp - sigma_xp, mu_xp + sigma_xp, num_particles, dtype=torch.float32 - ) - particles[:, 2] = torch.linspace( - mu_y - sigma_y, mu_y + sigma_y, num_particles, dtype=torch.float32 + mu_xp - sigma_xp, mu_xp + sigma_xp, num_particles ) + particles[:, 2] = torch.linspace(mu_y - sigma_y, mu_y + sigma_y, num_particles) particles[:, 3] = torch.linspace( - mu_yp - sigma_yp, mu_yp + sigma_yp, num_particles, dtype=torch.float32 - ) - particles[:, 4] = torch.linspace( - -sigma_s, sigma_s, num_particles, dtype=torch.float32 - ) - particles[:, 5] = torch.linspace( - -sigma_p, sigma_p, num_particles, dtype=torch.float32 + mu_yp - sigma_yp, mu_yp + sigma_yp, num_particles ) + particles[:, 4] = torch.linspace(-sigma_s, sigma_s, num_particles) + particles[:, 5] = torch.linspace(-sigma_p, sigma_p, num_particles) return cls( particles=particles, energy=energy, particle_charges=particle_charges, device=device, + dtype=dtype, ) @classmethod - def from_ocelot(cls, parray, device: str = "auto") -> "ParticleBeam": + def from_ocelot(cls, parray, device=None, dtype=torch.float32) -> "ParticleBeam": """ Convert an Ocelot ParticleArray `parray` to a Cheetah Beam. """ num_particles = parray.rparticles.shape[1] particles = torch.ones((num_particles, 7)) - particles[:, :6] = torch.tensor( - parray.rparticles.transpose(), dtype=torch.float32 - ) - particle_charges = torch.tensor(parray.q_array, dtype=torch.float32) + particles[:, :6] = torch.tensor(parray.rparticles.transpose()) + particle_charges = torch.tensor(parray.q_array) return cls( particles=particles, energy=torch.tensor(1e9 * parray.E), particle_charges=particle_charges, device=device, + dtype=dtype, ) @classmethod - def from_astra(cls, path: str, **kwargs) -> "ParticleBeam": + def from_astra(cls, path: str, device=None, dtype=torch.float32) -> "ParticleBeam": """Load an Astra particle distribution as a Cheetah Beam.""" from cheetah.converters.astralavista import from_astrabeam @@ -906,9 +922,10 @@ def from_astra(cls, path: str, **kwargs) -> "ParticleBeam": particle_charges = torch.from_numpy(particle_charges) return cls( particles=particles_7d, - energy=torch.tensor(energy, dtype=torch.float32), + energy=torch.tensor(energy), particle_charges=particle_charges, - **kwargs, + device=device, + dtype=dtype, ) def transformed_to( @@ -925,6 +942,8 @@ def transformed_to( sigma_p: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, + device=None, + dtype=torch.float32, ) -> "ParticleBeam": """ Create version of this beam that is transformed to new beam parameters. @@ -945,39 +964,40 @@ def transformed_to( :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. """ - mu_x = mu_x.to(self.device) if mu_x is not None else self.mu_x - mu_y = mu_y.to(self.device) if mu_y is not None else self.mu_y - mu_xp = mu_xp.to(self.device) if mu_xp is not None else self.mu_xp - mu_yp = mu_yp.to(self.device) if mu_yp is not None else self.mu_yp - sigma_x = sigma_x.to(self.device) if sigma_x is not None else self.sigma_x - sigma_y = sigma_y.to(self.device) if sigma_y is not None else self.sigma_y - sigma_xp = sigma_xp.to(self.device) if sigma_xp is not None else self.sigma_xp - sigma_yp = sigma_yp.to(self.device) if sigma_yp is not None else self.sigma_yp - sigma_s = sigma_s.to(self.device) if sigma_s is not None else self.sigma_s - sigma_p = sigma_p.to(self.device) if sigma_p is not None else self.sigma_p - energy = energy.to(self.device) if energy is not None else self.energy + device = device if device is not None else self.mu_x.device + dtype = dtype if dtype is not None else self.mu_x.dtype + + mu_x = mu_x if mu_x is not None else self.mu_x + mu_y = mu_y if mu_y is not None else self.mu_y + mu_xp = mu_xp if mu_xp is not None else self.mu_xp + mu_yp = mu_yp if mu_yp is not None else self.mu_yp + sigma_x = sigma_x if sigma_x is not None else self.sigma_x + sigma_y = sigma_y if sigma_y is not None else self.sigma_y + sigma_xp = sigma_xp if sigma_xp is not None else self.sigma_xp + sigma_yp = sigma_yp if sigma_yp is not None else self.sigma_yp + sigma_s = sigma_s if sigma_s is not None else self.sigma_s + sigma_p = sigma_p if sigma_p is not None else self.sigma_p + energy = energy if energy is not None else self.energy if total_charge is None: particle_charges = self.particle_charges elif self.total_charge is None: # Scale to the new charge - particle_charges = ( - self.particle_charges * total_charge.to(self.device) / self.total_charge + total_charge = total_charge.to( + device=self.particle_charges.device, dtype=self.particle_charges.dtype ) + particle_charges = self.particle_charges * total_charge / self.total_charge else: particle_charges = ( - torch.ones(len(self.particles), device=self.device) - * total_charge.to(self.device) + torch.ones( + len(self.particles), + device=total_charge.device, + dtype=total_charge.dtype, + ) + * total_charge / len(self.particles) ) new_mu = torch.stack( - [ - mu_x, - mu_xp, - mu_y, - mu_yp, - torch.tensor(0.0, device=self.device), - torch.tensor(0.0, device=self.device), - ] + [mu_x, mu_xp, mu_y, mu_yp, torch.tensor(0.0), torch.tensor(0.0)] ) new_sigma = torch.stack( [sigma_x, sigma_xp, sigma_y, sigma_yp, sigma_s, sigma_p] @@ -989,8 +1009,8 @@ def transformed_to( self.mu_xp, self.mu_y, self.mu_yp, - torch.tensor(0.0, device=self.device), - torch.tensor(0.0, device=self.device), + torch.tensor(0.0), + torch.tensor(0.0), ] ) old_sigma = torch.stack( @@ -1007,16 +1027,15 @@ def transformed_to( phase_space = self.particles[:, :6] phase_space = (phase_space - old_mu) / old_sigma * new_sigma + new_mu - particles = torch.ones_like( - self.particles, dtype=torch.float32, device=self.device - ) + particles = torch.ones_like(self.particles) particles[:, :6] = phase_space return self.__class__( particles=particles, energy=energy, particle_charges=particle_charges, - device=self.device, + device=device, + dtype=dtype, ) def __len__(self) -> int: diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index ee719149..1926f9a5 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -1,6 +1,6 @@ """Utility functions for creating transfer maps for the elements.""" -from typing import Optional, Union +from typing import Optional import torch from scipy import constants @@ -9,25 +9,20 @@ constants.electron_mass * constants.speed_of_light**2 / constants.elementary_charge -) # electron mass +) # Electron mass -def rotation_matrix( - angle: torch.Tensor, device: Union[str, torch.device] = "auto" -) -> torch.Tensor: +def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: """Rotate the transfer map in x-y plane :param angle: Rotation angle in rad, for example `angle = np.pi/2` for vertical = dipole. - :param device: Device used for tracking, by default "auto". :return: Rotation matrix to be multiplied to the element's transfer matrix. """ - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" cs = torch.cos(angle) sn = torch.sin(angle) - tm = torch.eye(7, dtype=torch.float32, device=device) + tm = torch.eye(7, dtype=angle.dtype, device=angle.device) tm[0, 0] = cs tm[0, 2] = sn tm[1, 1] = cs @@ -46,7 +41,6 @@ def base_rmatrix( hx: torch.Tensor, tilt: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, - device: Union[str, torch.device] = "auto", ) -> torch.Tensor: """ Create a universal transfer matrix for a beamline element. @@ -56,27 +50,31 @@ def base_rmatrix( :param hx: Curvature (1/radius) of the element in 1/m**2. :param tilt: Roation of the element relative to the longitudinal axis in rad. :param energy: Beam energy in eV. - :param device: Device where the transfer matrix is created. If "auto", the device - is selected automatically. :return: Transfer matrix for the element. """ - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" + device = length.device + dtype = length.dtype - tilt = tilt if tilt is not None else torch.tensor(0.0, device=device) - energy = energy if energy is not None else torch.tensor(0.0, device=device) + tilt = tilt if tilt is not None else torch.tensor(0.0, device=device, dtype=dtype) + energy = ( + energy if energy is not None else torch.tensor(0.0, device=device, dtype=dtype) + ) - gamma = energy / REST_ENERGY - igamma2 = 1 / gamma**2 if gamma != 0 else torch.tensor(0.0, device=device) + gamma = energy / REST_ENERGY.to(device=device, dtype=dtype) + igamma2 = ( + 1 / gamma**2 if gamma != 0 else torch.tensor(0.0, device=device, dtype=dtype) + ) beta = torch.sqrt(1 - igamma2) if k1 == 0: - k1 = k1 + torch.tensor(1e-12, device=device) # Avoid division by zero + k1 = k1 + torch.tensor( + 1e-12, device=device, dtype=dtype + ) # Avoid division by zero kx2 = k1 + hx**2 ky2 = -k1 - kx = torch.sqrt(torch.complex(kx2, torch.tensor(0.0, device=device))) - ky = torch.sqrt(torch.complex(ky2, torch.tensor(0.0, device=device))) + kx = torch.sqrt(torch.complex(kx2, torch.tensor(0.0, device=device, dtype=dtype))) + ky = torch.sqrt(torch.complex(ky2, torch.tensor(0.0, device=device, dtype=dtype))) cx = torch.cos(kx * length).real cy = torch.cos(ky * length).real sy = (torch.sin(ky * length) / ky).real if ky != 0 else length @@ -87,7 +85,7 @@ def base_rmatrix( r56 = r56 - length / beta**2 * igamma2 - R = torch.eye(7, dtype=torch.float32, device=device) + R = torch.eye(7, dtype=dtype, device=device) R[0, 0] = cx R[0, 1] = sx R[0, 5] = dx / beta @@ -109,14 +107,17 @@ def base_rmatrix( def misalignment_matrix( - misalignment: torch.Tensor, device: torch.device + misalignment: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Shift the beam for tracking beam through misaligned elements""" - R_exit = torch.eye(7, dtype=torch.float32, device=device) + device = misalignment.device + dtype = misalignment.dtype + + R_exit = torch.eye(7, device=device, dtype=dtype) R_exit[0, 6] = misalignment[0] R_exit[2, 6] = misalignment[1] - R_entry = torch.eye(7, dtype=torch.float32, device=device) + R_entry = torch.eye(7, device=device, dtype=dtype) R_entry[0, 6] = -misalignment[0] R_entry[2, 6] = -misalignment[1] diff --git a/test_requirements.txt b/test_requirements.txt index bfc6e418..be468496 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,3 +1,3 @@ git+https://github.com/ocelot-collab/ocelot@v22.12.0 # Ocelot pytest -pytest-cov \ No newline at end of file +pytest-cov diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 1f2900aa..035f3359 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -36,6 +36,37 @@ def test_dipole(): ) +def test_dipole_with_float64(): + """ + Test that the tracking results through a Cheeath `Dipole` element match those + through an Oclet `Bend` element using float64 precision. + """ + # Cheetah + incoming_beam = cheetah.ParticleBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001", dtype=torch.float64 + ) + cheetah_dipole = cheetah.Dipole( + length=torch.tensor(0.1), + angle=torch.tensor(0.1), + dtype=torch.float64, + ) + outgoing_beam = cheetah_dipole.track(incoming_beam) + + # Ocelot + incoming_p_array = ocelot.astraBeam2particleArray( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + ocelot_bend = ocelot.Bend(l=0.1, angle=0.1) + lattice = ocelot.MagneticLattice([ocelot_bend]) + navigator = ocelot.Navigator(lattice) + _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) + + assert np.allclose( + outgoing_beam.particles[:, :6].cpu().numpy(), + outgoing_p_array.rparticles.transpose(), + ) + + def test_dipole_with_fringe_field(): """ Test that the tracking results through a Cheeath `Dipole` element match those @@ -203,18 +234,14 @@ def test_ares_ea(): assert np.isclose(outgoing_beam.mu_xp.cpu().numpy(), outgoing_p_array.px().mean()) assert np.isclose(outgoing_beam.mu_y.cpu().numpy(), outgoing_p_array.y().mean()) assert np.isclose(outgoing_beam.mu_yp.cpu().numpy(), outgoing_p_array.py().mean()) - assert np.isclose( - outgoing_beam.mu_s.cpu().numpy(), outgoing_p_array.tau().mean(), atol=1e-7 - ) + assert np.isclose(outgoing_beam.mu_s.cpu().numpy(), outgoing_p_array.tau().mean()) assert np.isclose(outgoing_beam.mu_p.cpu().numpy(), outgoing_p_array.p().mean()) assert np.allclose(outgoing_beam.xs.cpu().numpy(), outgoing_p_array.x()) assert np.allclose(outgoing_beam.xps.cpu().numpy(), outgoing_p_array.px()) assert np.allclose(outgoing_beam.ys.cpu().numpy(), outgoing_p_array.y()) assert np.allclose(outgoing_beam.yps.cpu().numpy(), outgoing_p_array.py()) - assert np.allclose( - outgoing_beam.ss.cpu().numpy(), outgoing_p_array.tau(), atol=1e-7, rtol=1e-1 - ) # TODO: Why do we need such large tolerances? + assert np.allclose(outgoing_beam.ss.cpu().numpy(), outgoing_p_array.tau()) assert np.allclose(outgoing_beam.ps.cpu().numpy(), outgoing_p_array.p()) @@ -544,9 +571,12 @@ def test_asymmetric_bend(): outgoing_beam = cheetah_segment.track(incoming_beam) assert np.allclose( - outgoing_beam.particles[:, :6], outgoing_p_array.rparticles.transpose() + outgoing_beam.particles[:, :6].cpu().numpy(), + outgoing_p_array.rparticles.transpose(), + ) + assert np.allclose( + outgoing_beam.particle_charges.cpu().numpy(), outgoing_p_array.q_array ) - assert np.allclose(outgoing_beam.particle_charges, outgoing_p_array.q_array) def test_cavity(): @@ -577,24 +607,6 @@ def test_cavity(): - alpha_x = -1.0160687592932345 - alpha_y = -1.0160687593664295 """ - # Cheetah - incoming_beam = cheetah.ParticleBeam.from_twiss( - beta_x=torch.tensor(5.91253677), - alpha_x=torch.tensor(3.55631308), - beta_y=torch.tensor(5.91253677), - alpha_y=torch.tensor(3.55631308), - emittance_x=torch.tensor(3.494768647122823e-09), - emittance_y=torch.tensor(3.497810737006068e-09), - energy=torch.tensor(6e6), - total_charge=5e-9, - ) - cheetah_cavity = cheetah.Cavity( - length=torch.tensor(1.0377), - voltage=torch.tensor(0.01815975e9), - frequency=torch.tensor(1.3e9), - phase=torch.tensor(0.0), - ) - outgoing_beam = cheetah_cavity.track(incoming_beam) # Ocelot tws = ocelot.Twiss() @@ -617,18 +629,27 @@ def test_cavity(): _, outgoing_parray = ocelot.track(lattice, deepcopy(p_array), navigator) derived_twiss = ocelot.cpbd.beam.get_envelope(outgoing_parray) - # Compare - assert np.isclose( - outgoing_beam.beta_x.cpu().numpy(), derived_twiss.beta_x, rtol=1e-2 + # Cheetah + incoming_beam = cheetah.ParticleBeam.from_ocelot( + parray=p_array, dtype=torch.float64 ) - assert np.isclose( - outgoing_beam.alpha_x.cpu().numpy(), derived_twiss.alpha_x, rtol=1e-2 + cheetah_cavity = cheetah.Cavity( + length=1.0377, + voltage=0.01815975e9, + frequency=1.3e9, + phase=0.0, + dtype=torch.float64, ) + outgoing_beam = cheetah_cavity.track(incoming_beam) + + # Compare + assert np.isclose(outgoing_beam.beta_x.cpu().numpy(), derived_twiss.beta_x) assert np.isclose( - outgoing_beam.beta_y.cpu().numpy(), derived_twiss.beta_y, rtol=1e-2 + outgoing_beam.alpha_x.cpu().numpy(), derived_twiss.alpha_x, rtol=1e-4 ) + assert np.isclose(outgoing_beam.beta_y.cpu().numpy(), derived_twiss.beta_y) assert np.isclose( - outgoing_beam.alpha_y.cpu().numpy(), derived_twiss.alpha_y, rtol=1e-2 + outgoing_beam.alpha_y.cpu().numpy(), derived_twiss.alpha_y, rtol=1e-4 ) assert np.isclose( outgoing_beam.total_charge.cpu().numpy(), np.sum(outgoing_parray.q_array) diff --git a/tests/test_drift.py b/tests/test_drift.py index e20c5aa6..983a8de9 100644 --- a/tests/test_drift.py +++ b/tests/test_drift.py @@ -1,6 +1,7 @@ +import pytest import torch -from cheetah import Drift, ParameterBeam, ParticleBeam +import cheetah def test_diverging_parameter_beam(): @@ -8,11 +9,11 @@ def test_diverging_parameter_beam(): Test that that a parameter beam with sigma_xp > 0 and sigma_yp > 0 increases in size in both dimensions when travelling through a drift section. """ - drift = Drift(length=torch.tensor(1.0)) - incoming_beam = ParameterBeam.from_parameters( + drift = cheetah.Drift(length=torch.tensor(1.0)) + incoming_beam = cheetah.ParameterBeam.from_parameters( sigma_xp=torch.tensor(2e-7), sigma_yp=torch.tensor(2e-7) ) - outgoing_beam = drift(incoming_beam) + outgoing_beam = drift.track(incoming_beam) assert outgoing_beam.sigma_x > incoming_beam.sigma_x assert outgoing_beam.sigma_y > incoming_beam.sigma_y @@ -24,16 +25,37 @@ def test_diverging_particle_beam(): Test that that a particle beam with sigma_xp > 0 and sigma_yp > 0 increases in size in both dimensions when travelling through a drift section. """ - drift = Drift(length=torch.tensor(1.0)) - incoming_beam = ParticleBeam.from_parameters( + drift = cheetah.Drift(length=torch.tensor(1.0)) + incoming_beam = cheetah.ParticleBeam.from_parameters( num_particles=torch.tensor(1000), sigma_xp=torch.tensor(2e-7), sigma_yp=torch.tensor(2e-7), ) - outgoing_beam = drift(incoming_beam) + outgoing_beam = drift.track(incoming_beam) assert outgoing_beam.sigma_x > incoming_beam.sigma_x assert outgoing_beam.sigma_y > incoming_beam.sigma_y assert torch.allclose( outgoing_beam.particle_charges, incoming_beam.particle_charges ) + + +@pytest.mark.skip( + reason="Requires rewriting Element and Beam member variables to be buffers." +) +def test_device_like_torch_module(): + """ + Test that when changing the device, Drift reacts like a `torch.nn.Module`. + """ + # There is no point in running this test, if there aren't two different devices to + # move between + if not torch.cuda.is_available(): + return + + element = cheetah.Drift(length=torch.tensor(0.2), device="cuda") + + assert element.length.device.type == "cuda" + + element = element.cpu() + + assert element.length.device.type == "cpu" diff --git a/tests/test_ocelot_import.py b/tests/test_ocelot_import.py index 7e890960..0f841c65 100644 --- a/tests/test_ocelot_import.py +++ b/tests/test_ocelot_import.py @@ -2,7 +2,7 @@ import ocelot import pytest -from cheetah import ParameterBeam, ParticleBeam, Screen, Segment +import cheetah from .resources import ARESlatticeStage3v1_9 as ares @@ -29,15 +29,15 @@ def test_screen_conversion(name: str): """ Test on the example of the ARES lattice that all screens are correctly converted to `cheetah.Screen`. - ˚""" - segment = Segment.from_ocelot(ares.cell) + """ + segment = cheetah.Segment.from_ocelot(ares.cell) screen = getattr(segment, name) - assert isinstance(screen, Screen) + assert isinstance(screen, cheetah.Screen) def test_ocelot_to_parameterbeam(): parray = ocelot.astraBeam2particleArray("tests/resources/ACHIP_EA1_2021.1351.001") - beam = ParameterBeam.from_ocelot(parray) + beam = cheetah.ParameterBeam.from_ocelot(parray) assert np.allclose(beam.mu_x.cpu().numpy(), np.mean(parray.x())) assert np.allclose(beam.mu_xp.cpu().numpy(), np.mean(parray.px())) @@ -55,7 +55,7 @@ def test_ocelot_to_parameterbeam(): def test_ocelot_to_particlebeam(): parray = ocelot.astraBeam2particleArray("tests/resources/ACHIP_EA1_2021.1351.001") - beam = ParticleBeam.from_ocelot(parray) + beam = cheetah.ParticleBeam.from_ocelot(parray) assert np.allclose(beam.particles[:, 0].cpu().numpy(), parray.x()) assert np.allclose(beam.particles[:, 1].cpu().numpy(), parray.px()) @@ -65,3 +65,21 @@ def test_ocelot_to_particlebeam(): assert np.allclose(beam.particles[:, 5].cpu().numpy(), parray.p()) assert np.allclose(beam.energy.cpu().numpy(), parray.E * 1e9) assert np.allclose(beam.particle_charges.cpu().numpy(), parray.q_array) + + +def test_ocelot_lattice_import(): + """ + Tests if a lattice is importet correctly (and to the device requested). + """ + cell = [ocelot.Drift(l=0.3), ocelot.Quadrupole(l=0.2), ocelot.Drift(l=1.0)] + segment = cheetah.Segment.from_ocelot(cell=cell) + + assert isinstance(segment.elements[0], cheetah.Drift) + assert isinstance(segment.elements[1], cheetah.Quadrupole) + assert isinstance(segment.elements[2], cheetah.Drift) + + assert segment.elements[0].length.device.type == "cpu" + assert segment.elements[1].length.device.type == "cpu" + assert segment.elements[1].k1.device.type == "cpu" + assert segment.elements[1].misalignment.device.type == "cpu" + assert segment.elements[2].length.device.type == "cpu" diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index e0b1bff8..297d0311 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -15,7 +15,7 @@ def test_quadrupole_off(): outbeam_quad = quadrupole(incoming_beam) outbeam_drift = drift(incoming_beam) - quadrupole.k1 = torch.tensor(1.0, device=quadrupole.device) + quadrupole.k1 = torch.tensor(1.0, device=quadrupole.k1.device) outbeam_quad_on = quadrupole(incoming_beam) assert torch.allclose(outbeam_quad.sigma_x, outbeam_drift.sigma_x) diff --git a/tests/test_screen.py b/tests/test_screen.py index 2576def7..0902e949 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -74,11 +74,15 @@ def test_reading_shows_beam_ares(): ) beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") - segment.AREABSCR1.resolution = torch.tensor((2448, 2040), device=segment.device) + segment.AREABSCR1.resolution = torch.tensor( + (2448, 2040), device=segment.AREABSCR1.resolution.device + ) segment.AREABSCR1.pixel_size = torch.tensor( - (3.3198e-6, 2.4469e-6), device=segment.device + (3.3198e-6, 2.4469e-6), + device=segment.AREABSCR1.pixel_size.device, + dtype=segment.AREABSCR1.pixel_size.dtype, ) - segment.AREABSCR1.binning = torch.tensor(1, device=segment.device) + segment.AREABSCR1.binning = torch.tensor(1, device=segment.AREABSCR1.binning.device) segment.AREABSCR1.is_active = True assert isinstance(segment.AREABSCR1.reading, torch.Tensor)