Skip to content

Commit

Permalink
Further cleanup of the vectorized speed optimization tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Hespe committed Oct 2, 2024
1 parent 4a7c90c commit 777346b
Showing 1 changed file with 19 additions and 32 deletions.
51 changes: 19 additions & 32 deletions tests/test_speed_optimizations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -134,11 +133,9 @@ def test_inactive_magnet_is_replaced_by_drift():
"""
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor([0.6, 0.5])),
cheetah.Quadrupole(
length=torch.tensor([0.2, 0.3]), k1=torch.tensor(0.0)
),
cheetah.Drift(length=torch.tensor([0.4, 0.1])),
cheetah.Drift(length=torch.tensor(0.6)),
cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(0.0)),
cheetah.Drift(length=torch.tensor(0.4)),
]
)

Expand All @@ -147,7 +144,7 @@ def test_inactive_magnet_is_replaced_by_drift():
assert all(
isinstance(element, cheetah.Drift) for element in optimized_segment.elements
)
assert np.allclose(segment.length, optimized_segment.length)
assert torch.allclose(segment.length, optimized_segment.length)


def test_active_elements_not_replaced_by_drift():
Expand All @@ -156,11 +153,9 @@ def test_active_elements_not_replaced_by_drift():
"""
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor([0.6, 0.5])),
cheetah.Quadrupole(
length=torch.tensor([0.2, 0.3]), k1=torch.tensor([4.2, 0.0])
),
cheetah.Drift(length=torch.tensor([0.4, 0.1])),
cheetah.Drift(length=torch.tensor(0.6)),
cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor([4.2, 0.0])),
cheetah.Drift(length=torch.tensor(0.4)),
]
)

Expand Down Expand Up @@ -192,7 +187,8 @@ def test_inactive_magnet_drift_replacement_dtype(dtype: torch.dtype):

def test_inactive_magnet_drift_except_for():
"""
Test that an inactive magnet is not replaced by a drift when it is included in the list of exceptions.
Test that an inactive magnet is not replaced by a drift when it is included in the
list of exceptions.
"""
segment = cheetah.Segment(
elements=[
Expand Down Expand Up @@ -245,33 +241,24 @@ def test_without_zero_length_elements():
"""Test that zero-length elements are properly recognized and removed."""
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor([1.0, 2.0])),
cheetah.Drift(length=torch.tensor(1.0)),
cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor(0.0)),
cheetah.Dipole(
length=torch.tensor(0.0),
angle=torch.tensor(0.0)
),
cheetah.Dipole(
length=torch.tensor(0.0),
angle=torch.tensor(0.0),
name="my_dipole"
),
cheetah.Dipole(
length=torch.tensor([0.0, 0.1]),
angle=torch.tensor(0.0)
length=torch.tensor(0.0), angle=torch.tensor(0.0), name="my_dipole"
),
cheetah.Dipole(length=torch.tensor([0.0, 0.1]), angle=torch.tensor(0.0)),
cheetah.Drift(length=torch.tensor(0.0)),
cheetah.Dipole(
length=torch.tensor(0.0),
angle=torch.tensor([0.5, 0.0])
),
cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor([0.5, 0.0])),
]
)

pruned = segment.without_inactive_zero_length_elements()
pruned_except = segment.without_inactive_zero_length_elements(except_for=["my_dipole"])
pruned_except = segment.without_inactive_zero_length_elements(
except_for=["my_dipole"]
)

assert len(segment.elements) == 6
assert len(pruned.elements) == 3
assert len(pruned_except.elements) == 4
assert np.allclose(segment.length, pruned.length)
assert np.allclose(segment.length, pruned_except.length)
assert torch.allclose(segment.length, pruned.length)
assert torch.allclose(segment.length, pruned_except.length)

0 comments on commit 777346b

Please sign in to comment.