diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 5da5a4f4..56d2663d 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -1,4 +1,3 @@ -import numpy as np import pytest import torch @@ -134,11 +133,9 @@ def test_inactive_magnet_is_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6, 0.5])), - cheetah.Quadrupole( - length=torch.tensor([0.2, 0.3]), k1=torch.tensor(0.0) - ), - cheetah.Drift(length=torch.tensor([0.4, 0.1])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(0.0)), + cheetah.Drift(length=torch.tensor(0.4)), ] ) @@ -147,7 +144,7 @@ def test_inactive_magnet_is_replaced_by_drift(): assert all( isinstance(element, cheetah.Drift) for element in optimized_segment.elements ) - assert np.allclose(segment.length, optimized_segment.length) + assert torch.allclose(segment.length, optimized_segment.length) def test_active_elements_not_replaced_by_drift(): @@ -156,11 +153,9 @@ def test_active_elements_not_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6, 0.5])), - cheetah.Quadrupole( - length=torch.tensor([0.2, 0.3]), k1=torch.tensor([4.2, 0.0]) - ), - cheetah.Drift(length=torch.tensor([0.4, 0.1])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor([4.2, 0.0])), + cheetah.Drift(length=torch.tensor(0.4)), ] ) @@ -192,7 +187,8 @@ def test_inactive_magnet_drift_replacement_dtype(dtype: torch.dtype): def test_inactive_magnet_drift_except_for(): """ - Test that an inactive magnet is not replaced by a drift when it is included in the list of exceptions. + Test that an inactive magnet is not replaced by a drift when it is included in the + list of exceptions. """ segment = cheetah.Segment( elements=[ @@ -245,33 +241,24 @@ def test_without_zero_length_elements(): """Test that zero-length elements are properly recognized and removed.""" segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0, 2.0])), + cheetah.Drift(length=torch.tensor(1.0)), + cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor(0.0)), cheetah.Dipole( - length=torch.tensor(0.0), - angle=torch.tensor(0.0) - ), - cheetah.Dipole( - length=torch.tensor(0.0), - angle=torch.tensor(0.0), - name="my_dipole" - ), - cheetah.Dipole( - length=torch.tensor([0.0, 0.1]), - angle=torch.tensor(0.0) + length=torch.tensor(0.0), angle=torch.tensor(0.0), name="my_dipole" ), + cheetah.Dipole(length=torch.tensor([0.0, 0.1]), angle=torch.tensor(0.0)), cheetah.Drift(length=torch.tensor(0.0)), - cheetah.Dipole( - length=torch.tensor(0.0), - angle=torch.tensor([0.5, 0.0]) - ), + cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor([0.5, 0.0])), ] ) pruned = segment.without_inactive_zero_length_elements() - pruned_except = segment.without_inactive_zero_length_elements(except_for=["my_dipole"]) + pruned_except = segment.without_inactive_zero_length_elements( + except_for=["my_dipole"] + ) assert len(segment.elements) == 6 assert len(pruned.elements) == 3 assert len(pruned_except.elements) == 4 - assert np.allclose(segment.length, pruned.length) - assert np.allclose(segment.length, pruned_except.length) + assert torch.allclose(segment.length, pruned.length) + assert torch.allclose(segment.length, pruned_except.length)