Skip to content

Commit

Permalink
Split utils.py into more appropriate modules
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Sep 12, 2023
1 parent c6fbc9b commit d60af18
Show file tree
Hide file tree
Showing 15 changed files with 199 additions and 170 deletions.
7 changes: 5 additions & 2 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from scipy.stats import multivariate_normal

from cheetah.dontbmad import convert_bmad_lattice
from cheetah.error import DeviceError
from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.track_methods import base_rmatrix, misalignment_matrix, rotation_matrix
from cheetah.utils import DeviceError

REST_ENERGY = torch.tensor(
constants.electron_mass
Expand All @@ -40,6 +40,9 @@ class Element(ABC, pydantic.BaseModel):
name: str = "unnamed"
device: torch.device

class Config:
arbitrary_types_allowed = True

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -1533,7 +1536,7 @@ def from_ocelot(
Cheetah or converted with potentially unexpected behavior.
:return: Cheetah segment closely resembling the Ocelot cell.
"""
from cheetah.utils import ocelot2cheetah
from cheetah.nocelot import ocelot2cheetah

converted = [ocelot2cheetah(element, warnings=warnings) for element in cell]
return cls(converted, name=name, **kwargs)
Expand Down
58 changes: 58 additions & 0 deletions cheetah/astralavista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
from scipy.constants import physical_constants

# Electron mass in eV
electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6


def from_astrabeam(path: str) -> tuple[np.ndarray, float]:
"""
Read from a ASTRA beam distribution, and prepare for conversion to a Cheetah
ParticleBeam or ParameterBeam.
Adapted from the implementation in ocelot:
https://github.com/ocelot-collab/ocelot/blob/master/ocelot/adaptors/astra2ocelot.py
:param path: Path to the ASTRA beam distribution file.
:return: Particle 6D phase space information and mean energy of the particle beam.
"""
P0 = np.loadtxt(path)

# remove lost particles
inds = np.argwhere(P0[:, 9] > 0)
inds = inds.reshape(inds.shape[0])

P0 = P0[inds, :]
n_particles = P0.shape[0]

# s_ref = P0[0, 2]
Pref = P0[0, 5]

xp = P0[:, :6]
xp[0, 2] = 0.0
xp[0, 5] = 0.0

gamref = np.sqrt((Pref / electron_mass_eV) ** 2 + 1)
# energy in eV: E = gamma * m_e
energy = gamref * electron_mass_eV

n_particles = xp.shape[0]
particles = np.zeros((n_particles, 6))

u = np.c_[xp[:, 3], xp[:, 4], xp[:, 5] + Pref]
gamma = np.sqrt(1 + np.sum(u * u, 1) / electron_mass_eV**2)
beta = np.sqrt(1 - gamma**-2)
betaref = np.sqrt(1 - gamref**-2)

p0 = np.linalg.norm(u, 2, 1).reshape((n_particles, 1))

u = u / p0
cdt = -xp[:, 2] / (beta * u[:, 2])
particles[:, 0] = xp[:, 0] + beta * u[:, 0] * cdt
particles[:, 2] = xp[:, 1] + beta * u[:, 1] * cdt
particles[:, 4] = cdt
particles[:, 1] = xp[:, 3] / Pref
particles[:, 3] = xp[:, 4] / Pref
particles[:, 5] = (gamma / gamref - 1) / betaref

return particles, energy
11 changes: 11 additions & 0 deletions cheetah/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class DeviceError(Exception):
"""
Used to create an exception, in case the device used for the beam
and the elements are different.
"""

def __init__(self):
super().__init__(
"Warning! The device used for calculating the elements is not the same, "
"as the device used to calculate the Beam."
)
154 changes: 0 additions & 154 deletions cheetah/utils.py → cheetah/latticejson.py
Original file line number Diff line number Diff line change
@@ -1,149 +1,8 @@
import json
from typing import Optional

import numpy as np
import torch
from scipy.constants import physical_constants

import cheetah

# Electron mass in eV
electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6


def from_astrabeam(path: str) -> tuple[np.ndarray, float]:
"""
Read from a ASTRA beam distribution, and prepare for conversion to a Cheetah
ParticleBeam or ParameterBeam.
Adapted from the implementation in ocelot:
https://github.com/ocelot-collab/ocelot/blob/master/ocelot/adaptors/astra2ocelot.py
:param path: Path to the ASTRA beam distribution file.
:return: Particle 6D phase space information and mean energy of the particle beam.
"""
P0 = np.loadtxt(path)

# remove lost particles
inds = np.argwhere(P0[:, 9] > 0)
inds = inds.reshape(inds.shape[0])

P0 = P0[inds, :]
n_particles = P0.shape[0]

# s_ref = P0[0, 2]
Pref = P0[0, 5]

xp = P0[:, :6]
xp[0, 2] = 0.0
xp[0, 5] = 0.0

gamref = np.sqrt((Pref / electron_mass_eV) ** 2 + 1)
# energy in eV: E = gamma * m_e
energy = gamref * electron_mass_eV

n_particles = xp.shape[0]
particles = np.zeros((n_particles, 6))

u = np.c_[xp[:, 3], xp[:, 4], xp[:, 5] + Pref]
gamma = np.sqrt(1 + np.sum(u * u, 1) / electron_mass_eV**2)
beta = np.sqrt(1 - gamma**-2)
betaref = np.sqrt(1 - gamref**-2)

p0 = np.linalg.norm(u, 2, 1).reshape((n_particles, 1))

u = u / p0
cdt = -xp[:, 2] / (beta * u[:, 2])
particles[:, 0] = xp[:, 0] + beta * u[:, 0] * cdt
particles[:, 2] = xp[:, 1] + beta * u[:, 1] * cdt
particles[:, 4] = cdt
particles[:, 1] = xp[:, 3] / Pref
particles[:, 3] = xp[:, 4] / Pref
particles[:, 5] = (gamma / gamref - 1) / betaref

return particles, energy


def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element":
"""
Translate an Ocelot element to a Cheetah element.
NOTE Object not supported by Cheetah are translated to drift sections. Screen
objects are created only from `ocelot.Monitor` objects when the string "BSC" is
contained in their `id` attribute. Their screen properties are always set to default
values and most likely need adjusting afterwards. BPM objects are only created from
`ocelot.Monitor` objects when their id has a substring "BPM".
:param element: Ocelot element object representing an element of particle
accelerator.
:param warnings: Whether to print warnings when elements might not be converted as
expected.
:return: Cheetah element object representing an element of particle accelerator.
"""
try:
import ocelot
except ImportError:
raise ImportError(
"""To use the ocelot2cheetah lattice converter, Ocelot must be first
installed, see https://github.com/ocelot-collab/ocelot """
)

if isinstance(element, ocelot.Drift):
return cheetah.Drift(torch.tensor(element.l), name=element.id)
elif isinstance(element, ocelot.Quadrupole):
return cheetah.Quadrupole(
torch.tensor(element.l), torch.tensor(element.k1), name=element.id
)
elif isinstance(element, ocelot.Hcor):
return cheetah.HorizontalCorrector(
torch.tensor(element.l), torch.tensor(element.angle), name=element.id
)
elif isinstance(element, ocelot.Vcor):
return cheetah.VerticalCorrector(
torch.tensor(element.l), torch.tensor(element.angle), name=element.id
)
elif isinstance(element, ocelot.Cavity):
return cheetah.Cavity(torch.tensor(element.l), name=element.id)
elif isinstance(element, ocelot.Monitor) and ("BSC" in element.id):
# NOTE This pattern is very specific to ARES and will need a more complex
# solution for other accelerators
if warnings:
print(
"WARNING: Diagnostic screen was converted with default screen"
" properties."
)
return cheetah.Screen(
torch.tensor([2448, 2040]),
torch.tensor([3.5488e-6, 2.5003e-6]),
name=element.id,
)
elif isinstance(element, ocelot.Monitor) and "BPM" in element.id:
return cheetah.BPM(name=element.id)
elif isinstance(element, ocelot.Undulator):
return cheetah.Undulator(torch.tensor(element.l), name=element.id)
else:
if warnings:
print(
f"WARNING: Unknown element {element.id} of type {type(element)},"
" replacing with drift section."
)
return cheetah.Drift(torch.tensor(element.l), name=element.id)


def subcell_of_ocelot(cell: list, start: str, end: str) -> list:
"""Extract a subcell `[start, end]` from an Ocelot cell."""
subcell = []
is_in_subcell = False
for el in cell:
if el.id == start:
is_in_subcell = True
if is_in_subcell:
subcell.append(el)
if el.id == end:
break

return subcell


# Saving Cheetah to JSON
def parse_cheetah_element(element: cheetah.Element):
Expand Down Expand Up @@ -297,16 +156,3 @@ def load_cheetah_model(fname: str, name: Optional[str] = None) -> cheetah.Segmen
def str_to_class(classname: str):
# get class from string
return getattr(cheetah, classname)


class DeviceError(Exception):
"""
Used to create an exception, in case the device used for the beam
and the elements are different.
"""

def __init__(self):
super().__init__(
"Warning! The device used for calculating the elements is not the same, "
"as the device used to calculate the Beam."
)
84 changes: 84 additions & 0 deletions cheetah/nocelot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch

import cheetah


def ocelot2cheetah(element, warnings: bool = True) -> "cheetah.Element":
"""
Translate an Ocelot element to a Cheetah element.
NOTE Object not supported by Cheetah are translated to drift sections. Screen
objects are created only from `ocelot.Monitor` objects when the string "BSC" is
contained in their `id` attribute. Their screen properties are always set to default
values and most likely need adjusting afterwards. BPM objects are only created from
`ocelot.Monitor` objects when their id has a substring "BPM".
:param element: Ocelot element object representing an element of particle
accelerator.
:param warnings: Whether to print warnings when elements might not be converted as
expected.
:return: Cheetah element object representing an element of particle accelerator.
"""
try:
import ocelot
except ImportError:
raise ImportError(
"""To use the ocelot2cheetah lattice converter, Ocelot must be first
installed, see https://github.com/ocelot-collab/ocelot """
)

if isinstance(element, ocelot.Drift):
return cheetah.Drift(torch.tensor(element.l), name=element.id)
elif isinstance(element, ocelot.Quadrupole):
return cheetah.Quadrupole(
torch.tensor(element.l), torch.tensor(element.k1), name=element.id
)
elif isinstance(element, ocelot.Hcor):
return cheetah.HorizontalCorrector(
torch.tensor(element.l), torch.tensor(element.angle), name=element.id
)
elif isinstance(element, ocelot.Vcor):
return cheetah.VerticalCorrector(
torch.tensor(element.l), torch.tensor(element.angle), name=element.id
)
elif isinstance(element, ocelot.Cavity):
return cheetah.Cavity(torch.tensor(element.l), name=element.id)
elif isinstance(element, ocelot.Monitor) and ("BSC" in element.id):
# NOTE This pattern is very specific to ARES and will need a more complex
# solution for other accelerators
if warnings:
print(
"WARNING: Diagnostic screen was converted with default screen"
" properties."
)
return cheetah.Screen(
torch.tensor([2448, 2040]),
torch.tensor([3.5488e-6, 2.5003e-6]),
name=element.id,
)
elif isinstance(element, ocelot.Monitor) and "BPM" in element.id:
return cheetah.BPM(name=element.id)
elif isinstance(element, ocelot.Undulator):
return cheetah.Undulator(torch.tensor(element.l), name=element.id)
else:
if warnings:
print(
f"WARNING: Unknown element {element.id} of type {type(element)},"
" replacing with drift section."
)
return cheetah.Drift(torch.tensor(element.l), name=element.id)


def subcell_of_ocelot(cell: list, start: str, end: str) -> list:
"""Extract a subcell `[start, end]` from an Ocelot cell."""
subcell = []
is_in_subcell = False
for el in cell:
if el.id == start:
is_in_subcell = True
if is_in_subcell:
subcell.append(el)
if el.id == end:
break

return subcell
4 changes: 2 additions & 2 deletions cheetah/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def from_ocelot(cls, parray, device: str = "auto") -> "ParameterBeam":
@classmethod
def from_astra(cls, path: str, **kwargs) -> "ParameterBeam":
"""Load an Astra particle distribution as a Cheetah Beam."""
from cheetah.utils import from_astrabeam
from cheetah.astralavista import from_astrabeam

particles, energy = from_astrabeam(path)
mu = torch.ones(7)
Expand Down Expand Up @@ -727,7 +727,7 @@ def from_ocelot(cls, parray, device: str = "auto") -> "ParticleBeam":
@classmethod
def from_astra(cls, path: str, **kwargs) -> "ParticleBeam":
"""Load an Astra particle distribution as a Cheetah Beam."""
from cheetah.utils import from_astrabeam
from cheetah.astralavista import from_astrabeam

particles, energy = from_astrabeam(path)
particles_7d = torch.ones((particles.shape[0], 7))
Expand Down
8 changes: 8 additions & 0 deletions docs/astralavista.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. Documents astralavista.py
Astralavista
============

.. automodule:: astralavista
:members:
:undoc-members:
8 changes: 8 additions & 0 deletions docs/error.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. Documents error.py
Error
=====

.. automodule:: error
:members:
:undoc-members:
Loading

0 comments on commit d60af18

Please sign in to comment.