Skip to content

Commit

Permalink
Merge branch 'master' into fix_twiss_plot
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 authored Oct 3, 2024
2 parents 8871f10 + dcb6318 commit 5e9a0ef
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des

### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)

Expand Down
5 changes: 3 additions & 2 deletions cheetah/accelerator/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def without_inactive_zero_length_elements(
elements=[
element
for element in self.elements
if element.length > 0.0
if torch.any(element.length > 0.0)
or (hasattr(element, "is_active") and element.is_active)
or element.name in except_for
],
Expand Down Expand Up @@ -199,10 +199,11 @@ def inactive_elements_as_drifts(
(
element
if (hasattr(element, "is_active") and element.is_active)
or element.length == 0.0
or torch.all(element.length == 0.0)
or element.name in except_for
else Drift(
element.length,
name=element.name,
device=element.length.device,
dtype=element.length.dtype,
)
Expand Down
50 changes: 49 additions & 1 deletion tests/test_speed_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def test_inactive_magnet_is_replaced_by_drift():
assert all(
isinstance(element, cheetah.Drift) for element in optimized_segment.elements
)
assert torch.allclose(segment.length, optimized_segment.length)


def test_active_elements_not_replaced_by_drift():
Expand All @@ -153,7 +154,7 @@ def test_active_elements_not_replaced_by_drift():
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor(0.6)),
cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(4.2)),
cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor([4.2, 0.0])),
cheetah.Drift(length=torch.tensor(0.4)),
]
)
Expand Down Expand Up @@ -184,6 +185,26 @@ def test_inactive_magnet_drift_replacement_dtype(dtype: torch.dtype):
assert all(element.length.dtype == dtype for element in optimized_segment.elements)


def test_inactive_magnet_drift_except_for():
"""
Test that an inactive magnet is not replaced by a drift when it is included in the
list of exceptions.
"""
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor(0.6)),
cheetah.Quadrupole(
length=torch.tensor(0.2), k1=torch.tensor(0.0), name="my_quad"
),
cheetah.Drift(length=torch.tensor(0.4)),
]
)

optimized_segment = segment.inactive_elements_as_drifts(except_for=["my_quad"])

assert isinstance(optimized_segment.elements[1], cheetah.Quadrupole)


def test_skippable_elements_reset():
"""
@cr-xu pointed out that the skippable elements are not always reset appropriately
Expand Down Expand Up @@ -214,3 +235,30 @@ def test_skippable_elements_reset():
merged_tm = merged_segment.elements[2].transfer_map(energy=incoming_beam.energy)

assert torch.allclose(original_tm, merged_tm)


def test_without_zero_length_elements():
"""Test that zero-length elements are properly recognized and removed."""
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor(1.0)),
cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor(0.0)),
cheetah.Dipole(
length=torch.tensor(0.0), angle=torch.tensor(0.0), name="my_dipole"
),
cheetah.Dipole(length=torch.tensor([0.0, 0.1]), angle=torch.tensor(0.0)),
cheetah.Drift(length=torch.tensor(0.0)),
cheetah.Dipole(length=torch.tensor(0.0), angle=torch.tensor([0.5, 0.0])),
]
)

pruned = segment.without_inactive_zero_length_elements()
pruned_except = segment.without_inactive_zero_length_elements(
except_for=["my_dipole"]
)

assert len(segment.elements) == 6
assert len(pruned.elements) == 3
assert len(pruned_except.elements) == 4
assert torch.allclose(segment.length, pruned.length)
assert torch.allclose(segment.length, pruned_except.length)

0 comments on commit 5e9a0ef

Please sign in to comment.