diff --git a/cheetah/accelerator.py b/cheetah/accelerator.py index b67f2fc0..3ce54d11 100644 --- a/cheetah/accelerator.py +++ b/cheetah/accelerator.py @@ -14,9 +14,9 @@ from scipy.stats import multivariate_normal from cheetah.dontbmad import convert_bmad_lattice +from cheetah.error import DeviceError from cheetah.particles import Beam, ParameterBeam, ParticleBeam from cheetah.track_methods import base_rmatrix, misalignment_matrix, rotation_matrix -from cheetah.utils import DeviceError REST_ENERGY = torch.tensor( constants.electron_mass @@ -40,6 +40,9 @@ class Element(ABC, pydantic.BaseModel): name: str = "unnamed" device: torch.device + class Config: + arbitrary_types_allowed = True + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -1533,7 +1536,7 @@ def from_ocelot( Cheetah or converted with potentially unexpected behavior. :return: Cheetah segment closely resembling the Ocelot cell. """ - from cheetah.utils import ocelot2cheetah + from cheetah.nocelot import ocelot2cheetah converted = [ocelot2cheetah(element, warnings=warnings) for element in cell] return cls(converted, name=name, **kwargs) diff --git a/cheetah/astralavista.py b/cheetah/astralavista.py new file mode 100644 index 00000000..68fcdef2 --- /dev/null +++ b/cheetah/astralavista.py @@ -0,0 +1,58 @@ +import numpy as np +from scipy.constants import physical_constants + +# Electron mass in eV +electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 + + +def from_astrabeam(path: str) -> tuple[np.ndarray, float]: + """ + Read from a ASTRA beam distribution, and prepare for conversion to a Cheetah + ParticleBeam or ParameterBeam. + + Adapted from the implementation in ocelot: + https://github.com/ocelot-collab/ocelot/blob/master/ocelot/adaptors/astra2ocelot.py + + :param path: Path to the ASTRA beam distribution file. + :return: Particle 6D phase space information and mean energy of the particle beam. + """ + P0 = np.loadtxt(path) + + # remove lost particles + inds = np.argwhere(P0[:, 9] > 0) + inds = inds.reshape(inds.shape[0]) + + P0 = P0[inds, :] + n_particles = P0.shape[0] + + # s_ref = P0[0, 2] + Pref = P0[0, 5] + + xp = P0[:, :6] + xp[0, 2] = 0.0 + xp[0, 5] = 0.0 + + gamref = np.sqrt((Pref / electron_mass_eV) ** 2 + 1) + # energy in eV: E = gamma * m_e + energy = gamref * electron_mass_eV + + n_particles = xp.shape[0] + particles = np.zeros((n_particles, 6)) + + u = np.c_[xp[:, 3], xp[:, 4], xp[:, 5] + Pref] + gamma = np.sqrt(1 + np.sum(u * u, 1) / electron_mass_eV**2) + beta = np.sqrt(1 - gamma**-2) + betaref = np.sqrt(1 - gamref**-2) + + p0 = np.linalg.norm(u, 2, 1).reshape((n_particles, 1)) + + u = u / p0 + cdt = -xp[:, 2] / (beta * u[:, 2]) + particles[:, 0] = xp[:, 0] + beta * u[:, 0] * cdt + particles[:, 2] = xp[:, 1] + beta * u[:, 1] * cdt + particles[:, 4] = cdt + particles[:, 1] = xp[:, 3] / Pref + particles[:, 3] = xp[:, 4] / Pref + particles[:, 5] = (gamma / gamref - 1) / betaref + + return particles, energy diff --git a/cheetah/error.py b/cheetah/error.py new file mode 100644 index 00000000..bdd8d6e3 --- /dev/null +++ b/cheetah/error.py @@ -0,0 +1,11 @@ +class DeviceError(Exception): + """ + Used to create an exception, in case the device used for the beam + and the elements are different. + """ + + def __init__(self): + super().__init__( + "Warning! The device used for calculating the elements is not the same, " + "as the device used to calculate the Beam." + ) diff --git a/cheetah/utils.py b/cheetah/latticejson.py similarity index 50% rename from cheetah/utils.py rename to cheetah/latticejson.py index bca07c46..f7f2789e 100644 --- a/cheetah/utils.py +++ b/cheetah/latticejson.py @@ -1,149 +1,8 @@ import json from typing import Optional -import numpy as np -import torch -from scipy.constants import physical_constants - import cheetah -# Electron mass in eV -electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 - - -def from_astrabeam(path: str) -> tuple[np.ndarray, float]: - """ - Read from a ASTRA beam distribution, and prepare for conversion to a Cheetah - ParticleBeam or ParameterBeam. - - Adapted from the implementation in ocelot: - https://github.com/ocelot-collab/ocelot/blob/master/ocelot/adaptors/astra2ocelot.py - - :param path: Path to the ASTRA beam distribution file. - :return: Particle 6D phase space information and mean energy of the particle beam. - """ - P0 = np.loadtxt(path) - - # remove lost particles - inds = np.argwhere(P0[:, 9] > 0) - inds = inds.reshape(inds.shape[0]) - - P0 = P0[inds, :] - n_particles = P0.shape[0] - - # s_ref = P0[0, 2] - Pref = P0[0, 5] - - xp = P0[:, :6] - xp[0, 2] = 0.0 - xp[0, 5] = 0.0 - - gamref = np.sqrt((Pref / electron_mass_eV) ** 2 + 1) - # energy in eV: E = gamma * m_e - energy = gamref * electron_mass_eV - - n_particles = xp.shape[0] - particles = np.zeros((n_particles, 6)) - - u = np.c_[xp[:, 3], xp[:, 4], xp[:, 5] + Pref] - gamma = np.sqrt(1 + np.sum(u * u, 1) / electron_mass_eV**2) - beta = np.sqrt(1 - gamma**-2) - betaref = np.sqrt(1 - gamref**-2) - - p0 = np.linalg.norm(u, 2, 1).reshape((n_particles, 1)) - - u = u / p0 - cdt = -xp[:, 2] / (beta * u[:, 2]) - particles[:, 0] = xp[:, 0] + beta * u[:, 0] * cdt - particles[:, 2] = xp[:, 1] + beta * u[:, 1] * cdt - particles[:, 4] = cdt - particles[:, 1] = xp[:, 3] / Pref - particles[:, 3] = xp[:, 4] / Pref - particles[:, 5] = (gamma / gamref - 1) / betaref - - return particles, energy - - -def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": - """ - Translate an Ocelot element to a Cheetah element. - - NOTE Object not supported by Cheetah are translated to drift sections. Screen - objects are created only from `ocelot.Monitor` objects when the string "BSC" is - contained in their `id` attribute. Their screen properties are always set to default - values and most likely need adjusting afterwards. BPM objects are only created from - `ocelot.Monitor` objects when their id has a substring "BPM". - - :param element: Ocelot element object representing an element of particle - accelerator. - :param warnings: Whether to print warnings when elements might not be converted as - expected. - :return: Cheetah element object representing an element of particle accelerator. - """ - try: - import ocelot - except ImportError: - raise ImportError( - """To use the ocelot2cheetah lattice converter, Ocelot must be first - installed, see https://github.com/ocelot-collab/ocelot """ - ) - - if isinstance(element, ocelot.Drift): - return cheetah.Drift(torch.tensor(element.l), name=element.id) - elif isinstance(element, ocelot.Quadrupole): - return cheetah.Quadrupole( - torch.tensor(element.l), torch.tensor(element.k1), name=element.id - ) - elif isinstance(element, ocelot.Hcor): - return cheetah.HorizontalCorrector( - torch.tensor(element.l), torch.tensor(element.angle), name=element.id - ) - elif isinstance(element, ocelot.Vcor): - return cheetah.VerticalCorrector( - torch.tensor(element.l), torch.tensor(element.angle), name=element.id - ) - elif isinstance(element, ocelot.Cavity): - return cheetah.Cavity(torch.tensor(element.l), name=element.id) - elif isinstance(element, ocelot.Monitor) and ("BSC" in element.id): - # NOTE This pattern is very specific to ARES and will need a more complex - # solution for other accelerators - if warnings: - print( - "WARNING: Diagnostic screen was converted with default screen" - " properties." - ) - return cheetah.Screen( - torch.tensor([2448, 2040]), - torch.tensor([3.5488e-6, 2.5003e-6]), - name=element.id, - ) - elif isinstance(element, ocelot.Monitor) and "BPM" in element.id: - return cheetah.BPM(name=element.id) - elif isinstance(element, ocelot.Undulator): - return cheetah.Undulator(torch.tensor(element.l), name=element.id) - else: - if warnings: - print( - f"WARNING: Unknown element {element.id} of type {type(element)}," - " replacing with drift section." - ) - return cheetah.Drift(torch.tensor(element.l), name=element.id) - - -def subcell_of_ocelot(cell: list, start: str, end: str) -> list: - """Extract a subcell `[start, end]` from an Ocelot cell.""" - subcell = [] - is_in_subcell = False - for el in cell: - if el.id == start: - is_in_subcell = True - if is_in_subcell: - subcell.append(el) - if el.id == end: - break - - return subcell - # Saving Cheetah to JSON def parse_cheetah_element(element: cheetah.Element): @@ -297,16 +156,3 @@ def load_cheetah_model(fname: str, name: Optional[str] = None) -> cheetah.Segmen def str_to_class(classname: str): # get class from string return getattr(cheetah, classname) - - -class DeviceError(Exception): - """ - Used to create an exception, in case the device used for the beam - and the elements are different. - """ - - def __init__(self): - super().__init__( - "Warning! The device used for calculating the elements is not the same, " - "as the device used to calculate the Beam." - ) diff --git a/cheetah/nocelot.py b/cheetah/nocelot.py new file mode 100644 index 00000000..a9179839 --- /dev/null +++ b/cheetah/nocelot.py @@ -0,0 +1,84 @@ +import torch + +import cheetah + + +def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element": + """ + Translate an Ocelot element to a Cheetah element. + + NOTE Object not supported by Cheetah are translated to drift sections. Screen + objects are created only from `ocelot.Monitor` objects when the string "BSC" is + contained in their `id` attribute. Their screen properties are always set to default + values and most likely need adjusting afterwards. BPM objects are only created from + `ocelot.Monitor` objects when their id has a substring "BPM". + + :param element: Ocelot element object representing an element of particle + accelerator. + :param warnings: Whether to print warnings when elements might not be converted as + expected. + :return: Cheetah element object representing an element of particle accelerator. + """ + try: + import ocelot + except ImportError: + raise ImportError( + """To use the ocelot2cheetah lattice converter, Ocelot must be first + installed, see https://github.com/ocelot-collab/ocelot """ + ) + + if isinstance(element, ocelot.Drift): + return cheetah.Drift(torch.tensor(element.l), name=element.id) + elif isinstance(element, ocelot.Quadrupole): + return cheetah.Quadrupole( + torch.tensor(element.l), torch.tensor(element.k1), name=element.id + ) + elif isinstance(element, ocelot.Hcor): + return cheetah.HorizontalCorrector( + torch.tensor(element.l), torch.tensor(element.angle), name=element.id + ) + elif isinstance(element, ocelot.Vcor): + return cheetah.VerticalCorrector( + torch.tensor(element.l), torch.tensor(element.angle), name=element.id + ) + elif isinstance(element, ocelot.Cavity): + return cheetah.Cavity(torch.tensor(element.l), name=element.id) + elif isinstance(element, ocelot.Monitor) and ("BSC" in element.id): + # NOTE This pattern is very specific to ARES and will need a more complex + # solution for other accelerators + if warnings: + print( + "WARNING: Diagnostic screen was converted with default screen" + " properties." + ) + return cheetah.Screen( + torch.tensor([2448, 2040]), + torch.tensor([3.5488e-6, 2.5003e-6]), + name=element.id, + ) + elif isinstance(element, ocelot.Monitor) and "BPM" in element.id: + return cheetah.BPM(name=element.id) + elif isinstance(element, ocelot.Undulator): + return cheetah.Undulator(torch.tensor(element.l), name=element.id) + else: + if warnings: + print( + f"WARNING: Unknown element {element.id} of type {type(element)}," + " replacing with drift section." + ) + return cheetah.Drift(torch.tensor(element.l), name=element.id) + + +def subcell_of_ocelot(cell: list, start: str, end: str) -> list: + """Extract a subcell `[start, end]` from an Ocelot cell.""" + subcell = [] + is_in_subcell = False + for el in cell: + if el.id == start: + is_in_subcell = True + if is_in_subcell: + subcell.append(el) + if el.id == end: + break + + return subcell diff --git a/cheetah/particles.py b/cheetah/particles.py index 8a96e757..37c689b9 100644 --- a/cheetah/particles.py +++ b/cheetah/particles.py @@ -388,7 +388,7 @@ def from_ocelot(cls, parray, device: str = "auto") -> "ParameterBeam": @classmethod def from_astra(cls, path: str, **kwargs) -> "ParameterBeam": """Load an Astra particle distribution as a Cheetah Beam.""" - from cheetah.utils import from_astrabeam + from cheetah.astralavista import from_astrabeam particles, energy = from_astrabeam(path) mu = torch.ones(7) @@ -727,7 +727,7 @@ def from_ocelot(cls, parray, device: str = "auto") -> "ParticleBeam": @classmethod def from_astra(cls, path: str, **kwargs) -> "ParticleBeam": """Load an Astra particle distribution as a Cheetah Beam.""" - from cheetah.utils import from_astrabeam + from cheetah.astralavista import from_astrabeam particles, energy = from_astrabeam(path) particles_7d = torch.ones((particles.shape[0], 7)) diff --git a/docs/astralavista.rst b/docs/astralavista.rst new file mode 100644 index 00000000..9f25de0e --- /dev/null +++ b/docs/astralavista.rst @@ -0,0 +1,8 @@ +.. Documents astralavista.py + +Astralavista +============ + +.. automodule:: astralavista + :members: + :undoc-members: diff --git a/docs/error.rst b/docs/error.rst new file mode 100644 index 00000000..91ef45d8 --- /dev/null +++ b/docs/error.rst @@ -0,0 +1,8 @@ +.. Documents error.py + +Error +===== + +.. automodule:: error + :members: + :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index 78e7f338..161dcdf2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,10 +11,13 @@ Welcome to Cheetah's documentation! :caption: Contents: accelerator + astralavista dontbmad + error + latticejson + nocelot particles track_methods - utils `Cheetah `_ is a particle tracking accelerator we built specifically to speed up the training of reinforcement learning models. diff --git a/docs/latticejson.rst b/docs/latticejson.rst new file mode 100644 index 00000000..1df0c20e --- /dev/null +++ b/docs/latticejson.rst @@ -0,0 +1,8 @@ +.. Documents latticejson.py + +LatticeJSON +=========== + +.. automodule:: latticejson + :members: + :undoc-members: diff --git a/docs/nocelot.rst b/docs/nocelot.rst new file mode 100644 index 00000000..83f75766 --- /dev/null +++ b/docs/nocelot.rst @@ -0,0 +1,8 @@ +.. Documents nocelot.py + +NOcelot +======= + +.. automodule:: nocelot + :members: + :undoc-members: diff --git a/docs/utils.rst b/docs/utils.rst deleted file mode 100644 index a7197ce9..00000000 --- a/docs/utils.rst +++ /dev/null @@ -1,8 +0,0 @@ -.. Documents utils.py - -Utilities -========= - -.. automodule:: utils - :members: - :undoc-members: diff --git a/test/test_compare_ocelot.py b/test/test_compare_ocelot.py index db87b063..08332c4f 100644 --- a/test/test_compare_ocelot.py +++ b/test/test_compare_ocelot.py @@ -165,7 +165,7 @@ def test_ares_ea(): Test that the tracking results through a Experimental Area (EA) lattice of the ARES accelerator at DESY match those using Ocelot. """ - cell = cheetah.utils.subcell_of_ocelot(ares.cell, "AREASOLA1", "AREABSCR1") + cell = cheetah.nocelot.subcell_of_ocelot(ares.cell, "AREASOLA1", "AREABSCR1") ares.areamqzm1.k1 = 5.0 ares.areamqzm2.k1 = -5.0 ares.areamcvm1.k1 = 1e-3 diff --git a/test/test_lattice_json.py b/test/test_lattice_json.py index 8555c9e1..45a87d9b 100644 --- a/test/test_lattice_json.py +++ b/test/test_lattice_json.py @@ -2,7 +2,7 @@ import test.ARESlatticeStage3v1_9 as ares from cheetah.accelerator import Segment -from cheetah.utils import load_cheetah_model, save_cheetah_model +from cheetah.latticejson import load_cheetah_model, save_cheetah_model cheetah_segment = Segment.from_ocelot(ares.cell, name="ARES_Segment") diff --git a/test/test_speed.py b/test/test_speed.py index 7c6cd476..4bf17bda 100644 --- a/test/test_speed.py +++ b/test/test_speed.py @@ -7,7 +7,7 @@ # TODO: Test that Cheeath tracks faster than Ocelot def test_tracking_speed(): """Really only tests that Cheetah isn't super slow.""" - cell = cheetah.utils.subcell_of_ocelot(ares.cell, "AREASOLA1", "AREABSCR1") + cell = cheetah.nocelot.subcell_of_ocelot(ares.cell, "AREASOLA1", "AREABSCR1") segment = cheetah.Segment.from_ocelot(cell) segment.AREABSCR1.is_active = True # Turn screen on and off