diff --git a/CHANGELOG.md b/CHANGELOG.md index ac937eb6..dc2bd029 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des ### 🚨 Breaking Changes -- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe, @roussel-ryan) +- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #215, #218, #229, #233, #258) (@jank324, @cr-xu, @hespe, @roussel-ryan) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index c2aaf65d..d9a19389 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -170,7 +170,7 @@ def without_inactive_zero_length_elements( elements=[ element for element in self.elements - if element.length > 0.0 + if torch.any(element.length > 0.0) or (hasattr(element, "is_active") and element.is_active) or element.name in except_for ], @@ -199,10 +199,11 @@ def inactive_elements_as_drifts( ( element if (hasattr(element, "is_active") and element.is_active) - or element.length == 0.0 + or torch.all(element.length == 0.0) or element.name in except_for else Drift( element.length, + name=element.name, device=element.length.device, dtype=element.length.dtype, ) diff --git a/tests/test_speed_optimizations.py b/tests/test_speed_optimizations.py index a6d96edc..56d2663d 100644 --- a/tests/test_speed_optimizations.py +++ b/tests/test_speed_optimizations.py @@ -144,6 +144,7 @@ def test_inactive_magnet_is_replaced_by_drift(): assert all( isinstance(element, cheetah.Drift) for element in optimized_segment.elements ) + assert torch.allclose(segment.length, optimized_segment.length) def test_active_elements_not_replaced_by_drift(): @@ -153,7 +154,7 @@ def test_active_elements_not_replaced_by_drift(): segment = cheetah.Segment( elements=[ cheetah.Drift(length=torch.tensor(0.6)), - cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)), + cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor([4.2, 0.0])), cheetah.Drift(length=torch.tensor(0.4)), ] ) @@ -184,6 +185,26 @@ def test_inactive_magnet_drift_replacement_dtype(dtype: torch.dtype): assert all(element.length.dtype == dtype for element in optimized_segment.elements) +def test_inactive_magnet_drift_except_for(): + """ + Test that an inactive magnet is not replaced by a drift when it is included in the + list of exceptions. + """ + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(0.6)), + cheetah.Quadrupole( + length=torch.tensor(0.2), k1=torch.tensor(0.0), name="my_quad" + ), + cheetah.Drift(length=torch.tensor(0.4)), + ] + ) + + optimized_segment = segment.inactive_elements_as_drifts(except_for=["my_quad"]) + + assert isinstance(optimized_segment.elements[1], cheetah.Quadrupole) + + def test_skippable_elements_reset(): """ @cr-xu pointed out that the skippable elements are not always reset appropriately @@ -214,3 +235,30 @@ def test_skippable_elements_reset(): merged_tm = merged_segment.elements[2].transfer_map(energy=incoming_beam.energy) assert torch.allclose(original_tm, merged_tm) + + +def test_without_zero_length_elements(): + """Test that zero-length elements are properly recognized and removed.""" + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(1.0)), + cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor(0.0)), + cheetah.Dipole( + length=torch.tensor(0.0), angle=torch.tensor(0.0), name="my_dipole" + ), + cheetah.Dipole(length=torch.tensor([0.0, 0.1]), angle=torch.tensor(0.0)), + cheetah.Drift(length=torch.tensor(0.0)), + cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor([0.5, 0.0])), + ] + ) + + pruned = segment.without_inactive_zero_length_elements() + pruned_except = segment.without_inactive_zero_length_elements( + except_for=["my_dipole"] + ) + + assert len(segment.elements) == 6 + assert len(pruned.elements) == 3 + assert len(pruned_except.elements) == 4 + assert torch.allclose(segment.length, pruned.length) + assert torch.allclose(segment.length, pruned_except.length)