From f9cacef24c884b223d384672e2d2fd467420371b Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Tue, 24 Sep 2024 16:39:14 +0200 Subject: [PATCH 01/12] Add tests for lattice pruning --- tests/test_segment_pruning.py | 98 +++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 tests/test_segment_pruning.py diff --git a/tests/test_segment_pruning.py b/tests/test_segment_pruning.py new file mode 100644 index 00000000..baa5fa5e --- /dev/null +++ b/tests/test_segment_pruning.py @@ -0,0 +1,98 @@ +import numpy as np +import torch + +import cheetah + + +def test_inactive_elements_as_drifts(): + """Test that the conversion into drifts works properly.""" + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor([1.0, 2.0])), + cheetah.Dipole( + length=torch.tensor([0.1, 0.0]), + angle=torch.tensor([0.0, 0.0]), + name="Dipole1", + ), + cheetah.Dipole( + length=torch.tensor([0.0, 0.1]), + angle=torch.tensor([0.0, 0.0]), + name="Dipole2", + ), + cheetah.Dipole( + length=torch.tensor([0.2, 0.1]), + angle=torch.tensor([0.5, 0.0]), + name="Dipole3", + ), + cheetah.Dipole( + length=torch.tensor([0.0, 0.0]), + angle=torch.tensor([0.0, 0.0]), + name="Dipole4", + ), + cheetah.Drift(length=torch.tensor([0.0, 2.0])), + cheetah.BPM(is_active=torch.tensor([False, False]), name="Bpm"), + ] + ) + + pruned = segment.inactive_elements_as_drifts() + pruned_except = segment.inactive_elements_as_drifts(except_for=["Dipole2"]) + + assert len(segment.elements) == len(pruned.elements) + assert len(segment.elements) == len(pruned_except.elements) + assert np.allclose(segment.length, pruned.length) + assert np.allclose(segment.length, pruned_except.length) + + assert isinstance(pruned.Dipole1, cheetah.Drift) + assert isinstance(pruned.Dipole2, cheetah.Drift) + assert isinstance(pruned.Dipole3, cheetah.Dipole) + assert isinstance(pruned.Dipole4, cheetah.Dipole) + assert isinstance(pruned.Bpm, cheetah.BPM) + assert isinstance(pruned_except.Dipole1, cheetah.Drift) + assert isinstance(pruned_except.Dipole2, cheetah.Dipole) + assert isinstance(pruned_except.Dipole3, cheetah.Dipole) + assert isinstance(pruned_except.Dipole4, cheetah.Dipole) + assert isinstance(pruned_except.Bpm, cheetah.BPM) + + +def test_without_zerolength_elements(): + """Test that zerolength elements are properly recognized and removed.""" + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor([1.0, 2.0])), + cheetah.Dipole( + length=torch.tensor([0.0, 0.0]), + angle=torch.tensor([0.0, 0.0]), + name="Dipole1", + ), + cheetah.Drift(length=torch.tensor([1.0, 0.0])), + cheetah.Dipole( + length=torch.tensor([0.0, 0.0]), + angle=torch.tensor([0.0, 0.0]), + name="Dipole2", + ), + cheetah.Drift(length=torch.tensor([0.0, 2.0])), + cheetah.Dipole( + length=torch.tensor([0.0, 0.1]), + angle=torch.tensor([0.0, 0.0]), + name="Dipole3", + ), + cheetah.Drift(length=torch.tensor([0.0, 0.0])), + cheetah.Dipole( + length=torch.tensor([0.0, 0.0]), + angle=torch.tensor([0.5, 0.0]), + name="Dipole4", + ), + ] + ) + + pruned = segment.without_inactive_zero_length_elements() + pruned_except = segment.without_inactive_zero_length_elements( + except_for=["Dipole2"] + ) + + assert len(segment.elements) == 8 + assert len(pruned.elements) == 5 + assert len(pruned_except.elements) == 6 + assert np.allclose(segment.length, pruned.length) + assert np.allclose(segment.length, pruned_except.length) + assert not torch.all(pruned_except.Dipole2.is_active) From 892cdd0a0384954cf3d8296c7d14f61f7356fb8e Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Tue, 24 Sep 2024 16:47:55 +0200 Subject: [PATCH 02/12] Fix lattice pruning methods --- cheetah/accelerator/segment.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index d362a4b1..06df0602 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -170,8 +170,8 @@ def without_inactive_zero_length_elements( elements=[ element for element in self.elements - if element.length > 0.0 - or (hasattr(element, "is_active") and element.is_active) + if torch.any(element.length > 0.0) + or (hasattr(element, "is_active") and torch.any(element.is_active)) or element.name in except_for ], name=self.name, @@ -198,8 +198,8 @@ def inactive_elements_as_drifts( elements=[ ( element - if (hasattr(element, "is_active") and element.is_active) - or element.length == 0.0 + if (hasattr(element, "is_active") and torch.any(element.is_active)) + or torch.all(element.length == 0.0) or element.name in except_for else Drift( element.length, From 7e37974f835b519ada3a8d992ec9ee4880754b00 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Wed, 25 Sep 2024 08:50:43 +0200 Subject: [PATCH 03/12] Work around for BPM vectorization issue --- cheetah/accelerator/segment.py | 1 + tests/test_segment_pruning.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index 06df0602..e046a432 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -203,6 +203,7 @@ def inactive_elements_as_drifts( or element.name in except_for else Drift( element.length, + name=element.name, device=element.length.device, dtype=element.length.dtype, ) diff --git a/tests/test_segment_pruning.py b/tests/test_segment_pruning.py index baa5fa5e..d45d4d5c 100644 --- a/tests/test_segment_pruning.py +++ b/tests/test_segment_pruning.py @@ -30,7 +30,9 @@ def test_inactive_elements_as_drifts(): name="Dipole4", ), cheetah.Drift(length=torch.tensor([0.0, 2.0])), - cheetah.BPM(is_active=torch.tensor([False, False]), name="Bpm"), + cheetah.BPM(is_active=torch.tensor([False, False]), name="Bpm").broadcast( + (2,) + ), ] ) From 0ed4d6f223ee2413ca296ce6da6c2481c9b50c95 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Wed, 25 Sep 2024 09:22:23 +0200 Subject: [PATCH 04/12] Move tests to proper file --- tests/test_segment_pruning.py | 100 ------------------------------ tests/test_speed_optimizations.py | 73 ++++++++++++++++++++-- 2 files changed, 67 insertions(+), 106 deletions(-) delete mode 100644 tests/test_segment_pruning.py diff --git a/tests/test_segment_pruning.py b/tests/test_segment_pruning.py deleted file mode 100644 index d45d4d5c..00000000 --- a/tests/test_segment_pruning.py +++ /dev/null @@ -1,100 +0,0 @@ -import numpy as np -import torch - -import cheetah - - -def test_inactive_elements_as_drifts(): - """Test that the conversion into drifts works properly.""" - segment = cheetah.Segment( - elements=[ - cheetah.Drift(length=torch.tensor([1.0, 2.0])), - cheetah.Dipole( - length=torch.tensor([0.1, 0.0]), - angle=torch.tensor([0.0, 0.0]), - name="Dipole1", - ), - cheetah.Dipole( - length=torch.tensor([0.0, 0.1]), - angle=torch.tensor([0.0, 0.0]), - name="Dipole2", - ), - cheetah.Dipole( - length=torch.tensor([0.2, 0.1]), - angle=torch.tensor([0.5, 0.0]), - name="Dipole3", - ), - cheetah.Dipole( - length=torch.tensor([0.0, 0.0]), - angle=torch.tensor([0.0, 0.0]), - name="Dipole4", - ), - cheetah.Drift(length=torch.tensor([0.0, 2.0])), - cheetah.BPM(is_active=torch.tensor([False, False]), name="Bpm").broadcast( - (2,) - ), - ] - ) - - pruned = segment.inactive_elements_as_drifts() - pruned_except = segment.inactive_elements_as_drifts(except_for=["Dipole2"]) - - assert len(segment.elements) == len(pruned.elements) - assert len(segment.elements) == len(pruned_except.elements) - assert np.allclose(segment.length, pruned.length) - assert np.allclose(segment.length, pruned_except.length) - - assert isinstance(pruned.Dipole1, cheetah.Drift) - assert isinstance(pruned.Dipole2, cheetah.Drift) - assert isinstance(pruned.Dipole3, cheetah.Dipole) - assert isinstance(pruned.Dipole4, cheetah.Dipole) - assert isinstance(pruned.Bpm, cheetah.BPM) - assert isinstance(pruned_except.Dipole1, cheetah.Drift) - assert isinstance(pruned_except.Dipole2, cheetah.Dipole) - assert isinstance(pruned_except.Dipole3, cheetah.Dipole) - assert isinstance(pruned_except.Dipole4, cheetah.Dipole) - assert isinstance(pruned_except.Bpm, cheetah.BPM) - - -def test_without_zerolength_elements(): - """Test that zerolength elements are properly recognized and removed.""" - segment = cheetah.Segment( - elements=[ - cheetah.Drift(length=torch.tensor([1.0, 2.0])), - cheetah.Dipole( - length=torch.tensor([0.0, 0.0]), - angle=torch.tensor([0.0, 0.0]), - name="Dipole1", - ), - cheetah.Drift(length=torch.tensor([1.0, 0.0])), - cheetah.Dipole( - length=torch.tensor([0.0, 0.0]), - angle=torch.tensor([0.0, 0.0]), - name="Dipole2", - ), - cheetah.Drift(length=torch.tensor([0.0, 2.0])), - cheetah.Dipole( - length=torch.tensor([0.0, 0.1]), - angle=torch.tensor([0.0, 0.0]), - name="Dipole3", - ), - cheetah.Drift(length=torch.tensor([0.0, 0.0])), - cheetah.Dipole( - length=torch.tensor([0.0, 0.0]), - angle=torch.tensor([0.5, 0.0]), - name="Dipole4", - ), - ] - ) - - pruned = segment.without_inactive_zero_length_elements() - pruned_except = segment.without_inactive_zero_length_elements( - except_for=["Dipole2"] - ) - - assert len(segment.elements) == 8 - assert len(pruned.elements) == 5 - assert len(pruned_except.elements) == 6 - assert np.allclose(segment.length, pruned.length) - assert np.allclose(segment.length, pruned_except.length) - assert not torch.all(pruned_except.Dipole2.is_active) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 25b1c0b4..04709f7a 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -133,9 +134,11 @@ def test_inactive_magnet_is_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([0.0])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor([0.6, 0.5])), + cheetah.Quadrupole( + length=torch.tensor([0.2, 0.3]), k1=torch.tensor([0.0, 0.0]) + ), + cheetah.Drift(length=torch.tensor([0.4, 0.1])), ] ) @@ -144,6 +147,7 @@ def test_inactive_magnet_is_replaced_by_drift(): assert all( isinstance(element, cheetah.Drift) for element in optimized_segment.elements ) + assert np.allclose(segment.length, optimized_segment.length) def test_active_elements_not_replaced_by_drift(): @@ -152,9 +156,11 @@ def test_active_elements_not_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), - cheetah.Quadrupole(length=torch.tensor([0.2]), k1=torch.tensor([4.2])), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor([0.6, 0.5])), + cheetah.Quadrupole( + length=torch.tensor([0.2, 0.3]), k1=torch.tensor([4.2, 0.0]) + ), + cheetah.Drift(length=torch.tensor([0.4, 0.1])), ] ) @@ -184,6 +190,25 @@ def test_inactive_magnet_drift_replacement_dtype(dtype: torch.dtype): assert all(element.length.dtype == dtype for element in optimized_segment.elements) +def test_inactive_magnet_drift_except_for(): + """ + Test that an inactive magnet is not replaced by a drift if except is used + """ + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor([0.6])), + cheetah.Quadrupole( + length=torch.tensor([0.2]), k1=torch.tensor([4.2]), name="quad" + ), + cheetah.Drift(length=torch.tensor([0.4])), + ] + ) + + optimized_segment = segment.inactive_elements_as_drifts(except_for=["quad"]) + + assert isinstance(optimized_segment.elements[1], cheetah.Quadrupole) + + def test_skippable_elements_reset(): """ @cr-xu pointed out that the skippable elements are not always reset appropriately @@ -214,3 +239,39 @@ def test_skippable_elements_reset(): merged_tm = merged_segment.elements[2].transfer_map(energy=incoming_beam.energy) assert torch.allclose(original_tm, merged_tm) + + +def test_without_zerolength_elements(): + """Test that zerolength elements are properly recognized and removed.""" + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor([1.0, 2.0])), + cheetah.Dipole( + length=torch.tensor([0.0, 0.0]), + angle=torch.tensor([0.0, 0.0]), + ), + cheetah.Dipole( + length=torch.tensor([0.0, 0.0]), + angle=torch.tensor([0.0, 0.0]), + name="dipole", + ), + cheetah.Dipole( + length=torch.tensor([0.0, 0.1]), + angle=torch.tensor([0.0, 0.0]), + ), + cheetah.Drift(length=torch.tensor([0.0, 0.0])), + cheetah.Dipole( + length=torch.tensor([0.0, 0.0]), + angle=torch.tensor([0.5, 0.0]), + ), + ] + ) + + pruned = segment.without_inactive_zero_length_elements() + pruned_except = segment.without_inactive_zero_length_elements(except_for=["dipole"]) + + assert len(segment.elements) == 6 + assert len(pruned.elements) == 3 + assert len(pruned_except.elements) == 4 + assert np.allclose(segment.length, pruned.length) + assert np.allclose(segment.length, pruned_except.length) From b9e46410187ab1d4952f4c6f94bb9703a9d2be6a Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Wed, 25 Sep 2024 09:23:17 +0200 Subject: [PATCH 05/12] Fix type error in lattice optimization methods --- cheetah/accelerator/segment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index e046a432..f4077a6d 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -171,7 +171,7 @@ def without_inactive_zero_length_elements( element for element in self.elements if torch.any(element.length > 0.0) - or (hasattr(element, "is_active") and torch.any(element.is_active)) + or (hasattr(element, "is_active") and element.is_active) or element.name in except_for ], name=self.name, @@ -198,7 +198,7 @@ def inactive_elements_as_drifts( elements=[ ( element - if (hasattr(element, "is_active") and torch.any(element.is_active)) + if (hasattr(element, "is_active") and element.is_active) or torch.all(element.length == 0.0) or element.name in except_for else Drift( From b4477dd7f833631edfbcc01b952a17f4315157d7 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Wed, 25 Sep 2024 09:24:19 +0200 Subject: [PATCH 06/12] Update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96765227..113d1712 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### 🚨 Breaking Changes -- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe) +- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #215, #218, #229, #233, #258) (@jank324, @cr-xu, @hespe) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu) ### 🚀 Features From c5ce46f5c61a6be253a9ff7920bbbae2fdd40317 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Wed, 25 Sep 2024 10:00:42 +0200 Subject: [PATCH 07/12] Make tested Quad actually inactive --- tests/test_speed_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 04709f7a..cd3d8dfe 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -198,7 +198,7 @@ def test_inactive_magnet_drift_except_for(): elements=[ cheetah.Drift(length=torch.tensor([0.6])), cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([4.2]), name="quad" + length=torch.tensor([0.2]), k1=torch.tensor([0.0]), name="quad" ), cheetah.Drift(length=torch.tensor([0.4])), ] From 56d528e1f8179b577affaa4de392ee214e92d777 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 18:16:28 +0200 Subject: [PATCH 08/12] Update tests/test_speed_optimizations.py --- tests/test_speed_optimizations.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 3c9411c5..45d053a9 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -247,22 +247,22 @@ def test_without_zerolength_elements(): elements=[ cheetah.Drift(length=torch.tensor([1.0, 2.0])), cheetah.Dipole( - length=torch.tensor([0.0, 0.0]), - angle=torch.tensor([0.0, 0.0]), + length=torch.tensor(0.0), + angle=torch.tensor(0.0) ), cheetah.Dipole( - length=torch.tensor([0.0, 0.0]), - angle=torch.tensor([0.0, 0.0]), - name="dipole", + length=torch.tensor(0.0), + angle=torch.tensor(0.0), + name="dipole" ), cheetah.Dipole( length=torch.tensor([0.0, 0.1]), - angle=torch.tensor([0.0, 0.0]), + angle=torch.tensor(0.0) ), - cheetah.Drift(length=torch.tensor([0.0, 0.0])), + cheetah.Drift(length=torch.tensor(0.0)), cheetah.Dipole( - length=torch.tensor([0.0, 0.0]), - angle=torch.tensor([0.5, 0.0]), + length=torch.tensor(0.0), + angle=torch.tensor([0.5, 0.0]) ), ] ) From 4f3d2d42986adf0c733235d33b61a352c84e289b Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 18:17:02 +0200 Subject: [PATCH 09/12] Apply suggestions from code review --- tests/test_speed_optimizations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 45d053a9..826b0915 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -192,15 +192,15 @@ def test_inactive_magnet_drift_replacement_dtype(dtype: torch.dtype): def test_inactive_magnet_drift_except_for(): """ - Test that an inactive magnet is not replaced by a drift if except is used + Test that an inactive magnet is not replaced by a drift when it is included in the list of exceptions. """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6])), + cheetah.Drift(length=torch.tensor(0.6)), cheetah.Quadrupole( - length=torch.tensor([0.2]), k1=torch.tensor([0.0]), name="quad" + length=torch.tensor(0.2), k1=torch.tensor(0.0), name="quad" ), - cheetah.Drift(length=torch.tensor([0.4])), + cheetah.Drift(length=torch.tensor(0.4)), ] ) From 2da552d38f24031a9407a5f0f38ff700d01fe2d9 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 18:18:31 +0200 Subject: [PATCH 10/12] Apply suggestions from code review --- tests/test_speed_optimizations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 826b0915..28f3fb8d 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -241,8 +241,8 @@ def test_skippable_elements_reset(): assert torch.allclose(original_tm, merged_tm) -def test_without_zerolength_elements(): - """Test that zerolength elements are properly recognized and removed.""" +def test_without_zero_length_elements(): + """Test that zero-length elements are properly recognized and removed.""" segment = cheetah.Segment( elements=[ cheetah.Drift(length=torch.tensor([1.0, 2.0])), From ea2f6d293ba705a3c3925db639a50b7fb411df89 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 1 Oct 2024 18:20:10 +0200 Subject: [PATCH 11/12] Apply suggestions from code review --- tests/test_speed_optimizations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 28f3fb8d..5da5a4f4 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -198,13 +198,13 @@ def test_inactive_magnet_drift_except_for(): elements=[ cheetah.Drift(length=torch.tensor(0.6)), cheetah.Quadrupole( - length=torch.tensor(0.2), k1=torch.tensor(0.0), name="quad" + length=torch.tensor(0.2), k1=torch.tensor(0.0), name="my_quad" ), cheetah.Drift(length=torch.tensor(0.4)), ] ) - optimized_segment = segment.inactive_elements_as_drifts(except_for=["quad"]) + optimized_segment = segment.inactive_elements_as_drifts(except_for=["my_quad"]) assert isinstance(optimized_segment.elements[1], cheetah.Quadrupole) @@ -253,7 +253,7 @@ def test_without_zero_length_elements(): cheetah.Dipole( length=torch.tensor(0.0), angle=torch.tensor(0.0), - name="dipole" + name="my_dipole" ), cheetah.Dipole( length=torch.tensor([0.0, 0.1]), @@ -268,7 +268,7 @@ def test_without_zero_length_elements(): ) pruned = segment.without_inactive_zero_length_elements() - pruned_except = segment.without_inactive_zero_length_elements(except_for=["dipole"]) + pruned_except = segment.without_inactive_zero_length_elements(except_for=["my_dipole"]) assert len(segment.elements) == 6 assert len(pruned.elements) == 3 From 777346bb0ae4bc92a119e53327c743e1aebd20ff Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Wed, 2 Oct 2024 08:22:28 +0200 Subject: [PATCH 12/12] Further cleanup of the vectorized speed optimization tests --- tests/test_speed_optimizations.py | 51 ++++++++++++------------------- 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index 5da5a4f4..56d2663d 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -1,4 +1,3 @@ -import numpy as np import pytest import torch @@ -134,11 +133,9 @@ def test_inactive_magnet_is_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6, 0.5])), - cheetah.Quadrupole( - length=torch.tensor([0.2, 0.3]), k1=torch.tensor(0.0) - ), - cheetah.Drift(length=torch.tensor([0.4, 0.1])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(0.0)), + cheetah.Drift(length=torch.tensor(0.4)), ] ) @@ -147,7 +144,7 @@ def test_inactive_magnet_is_replaced_by_drift(): assert all( isinstance(element, cheetah.Drift) for element in optimized_segment.elements ) - assert np.allclose(segment.length, optimized_segment.length) + assert torch.allclose(segment.length, optimized_segment.length) def test_active_elements_not_replaced_by_drift(): @@ -156,11 +153,9 @@ def test_active_elements_not_replaced_by_drift(): """ segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([0.6, 0.5])), - cheetah.Quadrupole( - length=torch.tensor([0.2, 0.3]), k1=torch.tensor([4.2, 0.0]) - ), - cheetah.Drift(length=torch.tensor([0.4, 0.1])), + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor([4.2, 0.0])), + cheetah.Drift(length=torch.tensor(0.4)), ] ) @@ -192,7 +187,8 @@ def test_inactive_magnet_drift_replacement_dtype(dtype: torch.dtype): def test_inactive_magnet_drift_except_for(): """ - Test that an inactive magnet is not replaced by a drift when it is included in the list of exceptions. + Test that an inactive magnet is not replaced by a drift when it is included in the + list of exceptions. """ segment = cheetah.Segment( elements=[ @@ -245,33 +241,24 @@ def test_without_zero_length_elements(): """Test that zero-length elements are properly recognized and removed.""" segment = cheetah.Segment( elements=[ - cheetah.Drift(length=torch.tensor([1.0, 2.0])), + cheetah.Drift(length=torch.tensor(1.0)), + cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor(0.0)), cheetah.Dipole( - length=torch.tensor(0.0), - angle=torch.tensor(0.0) - ), - cheetah.Dipole( - length=torch.tensor(0.0), - angle=torch.tensor(0.0), - name="my_dipole" - ), - cheetah.Dipole( - length=torch.tensor([0.0, 0.1]), - angle=torch.tensor(0.0) + length=torch.tensor(0.0), angle=torch.tensor(0.0), name="my_dipole" ), + cheetah.Dipole(length=torch.tensor([0.0, 0.1]), angle=torch.tensor(0.0)), cheetah.Drift(length=torch.tensor(0.0)), - cheetah.Dipole( - length=torch.tensor(0.0), - angle=torch.tensor([0.5, 0.0]) - ), + cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor([0.5, 0.0])), ] ) pruned = segment.without_inactive_zero_length_elements() - pruned_except = segment.without_inactive_zero_length_elements(except_for=["my_dipole"]) + pruned_except = segment.without_inactive_zero_length_elements( + except_for=["my_dipole"] + ) assert len(segment.elements) == 6 assert len(pruned.elements) == 3 assert len(pruned_except.elements) == 4 - assert np.allclose(segment.length, pruned.length) - assert np.allclose(segment.length, pruned_except.length) + assert torch.allclose(segment.length, pruned.length) + assert torch.allclose(segment.length, pruned_except.length)