Skip to content

Commit

Permalink
Cleanup tests comparing to Ocelot
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Sep 3, 2023
1 parent 430c42b commit c1553ec
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 95 deletions.
4 changes: 2 additions & 2 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def __init__(
self,
xmax: float = np.inf,
ymax: float = np.inf,
type: str = "rect",
type: str = "rect", # TODO: Better strings ellipciatl and rectangular
name: Optional[str] = None,
**kwargs,
) -> None:
Expand All @@ -942,7 +942,7 @@ def __init__(
super().__init__(name, **kwargs)

@property
def is_skippable(self) -> bool:
def is_skippable(self) -> bool: # TODO: Aperatures should always be active
return not self.is_active

def transfer_map(self, energy: float) -> torch.Tensor:
Expand Down
225 changes: 132 additions & 93 deletions test/test_compare_ocelot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,115 +19,154 @@
)


def test_benchmark_ocelot_dipole():
length = 0.1
angle = 0.1
cheetah_bend = cheetah.Dipole(length=length, angle=angle)
ocelot_bend = ocelot.Bend(l=length, angle=angle)
p_array = deepcopy(PARRAY_OCELOT)
p_in_cheetah = deepcopy(PARTICLEBEAM_CHEETAH)
pb_out_cheetah = cheetah_bend(p_in_cheetah)

lat = ocelot.MagneticLattice([ocelot_bend], stop=None)
navi = ocelot.Navigator(lat)
_, p_array = ocelot.track(lat, p_array, navi)
def test_dipole():
"""
Test that the tracking results through a Cheeath `Dipole` element match those
through an Oclet `Bend` element.
"""
# Cheetah
incoming_beam = cheetah.ParticleBeam.from_astra(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
cheetah_dipole = cheetah.Dipole(length=0.1, angle=0.1)
outgoing_beam = cheetah_dipole(incoming_beam)

# Ocelot
incoming_p_array = ocelot.astraBeam2particleArray(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
ocelot_bend = ocelot.Bend(l=0.1, angle=0.1)
lattice = ocelot.MagneticLattice([ocelot_bend])
navigator = ocelot.Navigator(lattice)
_, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator)

assert np.allclose(
p_array.rparticles,
pb_out_cheetah.particles[:, :6].t().numpy(),
rtol=1e-4,
atol=1e-10,
equal_nan=False,
outgoing_beam.particles[:, :6], outgoing_p_array.rparticles.transpose()
)


def test_benchmark_ocelot_dipole_with_fringe_field():
length = 0.1
angle = 0.1
fint = 0.1
gap = 0.2
cheetah_bend = cheetah.Dipole(length=length, angle=angle, fint=fint, gap=gap)
ocelot_bend = ocelot.Bend(l=length, angle=angle, fint=fint, gap=gap)
p_array = deepcopy(PARRAY_OCELOT)
p_in_cheetah = deepcopy(PARTICLEBEAM_CHEETAH)
pb_out_cheetah = cheetah_bend(p_in_cheetah)
def test_dipole_with_fringe_field():
"""
Test that the tracking results through a Cheeath `Dipole` element match those
through an Oclet `Bend` element when there are fringe fields.
"""
# Cheetah
incoming_beam = cheetah.ParticleBeam.from_astra(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
cheetah_dipole = cheetah.Dipole(length=0.1, angle=0.1, fint=0.1, gap=0.2)
outgoing_beam = cheetah_dipole(incoming_beam)

lat = ocelot.MagneticLattice([ocelot_bend], stop=None)
navi = ocelot.Navigator(lat)
_, p_array = ocelot.track(lat, p_array, navi)
# Ocelot
incoming_p_array = ocelot.astraBeam2particleArray(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
ocelot_bend = ocelot.Bend(l=0.1, angle=0.1, fint=0.1, gap=0.2)
lattice = ocelot.MagneticLattice([ocelot_bend])
navigator = ocelot.Navigator(lattice)
_, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator)

assert np.allclose(
p_array.rparticles,
pb_out_cheetah.particles[:, :6].t().numpy(),
rtol=1e-4,
atol=1e-10,
equal_nan=False,
outgoing_beam.particles[:, :6], outgoing_p_array.rparticles.transpose()
)


def test_benchmark_ocelot_aperture():
xmax = 2e-4
ymax = 2e-4
drift_length = 0.1 # so that ocelot starts tracking
ocelot_aperture = ocelot.Aperture(xmax=xmax, ymax=xmax)
cheetah_aperture = cheetah.Aperture(xmax=xmax, ymax=ymax)
cheetah_aperture.is_active = True
p_array = deepcopy(PARRAY_OCELOT)
p_in_cheetah = deepcopy(PARTICLEBEAM_CHEETAH)
# Cheetah Tracking
segment = cheetah.Segment([cheetah_aperture, cheetah.Drift(length=drift_length)])
p_out_cheetah = segment(p_in_cheetah)
# Ocelot Tracking
lat = ocelot.MagneticLattice(
[ocelot_aperture, ocelot.Drift(drift_length)], stop=None
def test_aperture():
"""
Test that the tracking results through a Cheeath `Aperture` element match those
through an Oclet `Aperture` element.
"""
# Cheetah
incoming_beam = cheetah.ParticleBeam.from_astra(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
navi = ocelot.Navigator(lat)
navi.activate_apertures()
_, p_array = ocelot.track(lat, p_array, navi)

assert p_out_cheetah.n == p_array.rparticles.shape[1]


def test_benchmark_ocelot_aperture_elliptical():
xmax = 2e-4
ymax = 2e-4
drift_length = 0.1 # so that ocelot starts tracking
ocelot_aperture = ocelot.Aperture(xmax=xmax, ymax=xmax, type="ellipt")
cheetah_aperture = cheetah.Aperture(xmax=xmax, ymax=ymax, type="ellipt")
cheetah_aperture.is_active = True
p_array = deepcopy(PARRAY_OCELOT)
p_in_cheetah = deepcopy(PARTICLEBEAM_CHEETAH)
# Cheetah Tracking
segment = cheetah.Segment([cheetah_aperture, cheetah.Drift(length=drift_length)])
p_out_cheetah = segment(p_in_cheetah)
# Ocelot Tracking
lat = ocelot.MagneticLattice(
[ocelot_aperture, ocelot.Drift(drift_length)], stop=None
cheetah_segment = cheetah.Segment(
[
cheetah.Aperture(
xmax=2e-4,
ymax=2e-4,
type="rect",
name="aperture", # TODO: Don't use type keyword
), # TODO: is_active on init
cheetah.Drift(length=0.1),
]
)
navi = ocelot.Navigator(lat)
navi.activate_apertures()
_, p_array = ocelot.track(lat, p_array, navi)

assert p_out_cheetah.n == p_array.rparticles.shape[1]
cheetah_segment.aperture.is_active = True
outgoing_beam = cheetah_segment(incoming_beam)

# Ocelot
incoming_p_array = ocelot.astraBeam2particleArray(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
ocelot_cell = [ocelot.Aperture(xmax=2e-4, ymax=2e-4), ocelot.Drift(l=0.1)]
lattice = ocelot.MagneticLattice([ocelot_cell])
navigator = ocelot.Navigator(lattice)
navigator.activate_apertures()
_, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator)

assert outgoing_beam.n == outgoing_p_array.rparticles.shape[1]


def test_aperture_elliptical():
"""
Test that the tracking results through an elliptical Cheeath `Aperture` element
match those through an elliptical Oclet `Aperture` element.
"""
# Cheetah
incoming_beam = cheetah.ParticleBeam.from_astra(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
cheetah_segment = cheetah.Segment(
[
cheetah.Aperture(
xmax=2e-4,
ymax=2e-4,
type="ellipt",
name="aperture", # TODO: Don't use type keyword
), # TODO: is_active on init
cheetah.Drift(length=0.1),
]
)
cheetah_segment.aperture.is_active = True
outgoing_beam = cheetah_segment(incoming_beam)

def test_benchmark_ocelot_solenoid():
length = 0.5
k = 5
cheetah_bend = cheetah.Solenoid(length=length, k=k)
ocelot_bend = ocelot.Solenoid(l=length, k=k)
p_array = deepcopy(PARRAY_OCELOT)
p_in_cheetah = deepcopy(PARTICLEBEAM_CHEETAH)
pb_out_cheetah = cheetah_bend(p_in_cheetah)
# Ocelot
incoming_p_array = ocelot.astraBeam2particleArray(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
ocelot_cell = [
ocelot.Aperture(xmax=2e-4, ymax=2e-4, type="ellipt"),
ocelot.Drift(l=0.1),
]
lattice = ocelot.MagneticLattice([ocelot_cell])
navigator = ocelot.Navigator(lattice)
navigator.activate_apertures()
_, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator)

assert outgoing_beam.n == outgoing_p_array.rparticles.shape[1]


def test_solenoid():
"""
Test that the tracking results through a Cheeath `Solenoid` element match those
through an Oclet `Solenoid` element.
"""
# Cheetah
incoming_beam = cheetah.ParticleBeam.from_astra(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
cheetah_solenoid = cheetah.Solenoid(length=0.5, k=5)
outgoing_beam = cheetah_solenoid(incoming_beam)

lat = ocelot.MagneticLattice([ocelot_bend], stop=None)
navi = ocelot.Navigator(lat)
_, p_array = ocelot.track(lat, p_array, navi)
# Ocelot
incoming_p_array = ocelot.astraBeam2particleArray(
"benchmark/cheetah/ACHIP_EA1_2021.1351.001"
)
ocelot_solenoid = ocelot.Solenoid(l=0.5, k=5)
lattice = ocelot.MagneticLattice([ocelot_solenoid])
navigator = ocelot.Navigator(lattice)
_, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator)

assert np.allclose(
p_array.rparticles,
pb_out_cheetah.particles[:, :6].t().numpy(),
rtol=1e-4,
atol=1e-10,
equal_nan=False,
outgoing_beam.particles[:, :6], outgoing_p_array.rparticles.transpose()
)

0 comments on commit c1553ec

Please sign in to comment.