diff --git a/CHANGELOG.md b/CHANGELOG.md index 524c10ae..581f8d7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ None - Fix the transfer maps in `Drift` and `Dipole`; Add R56 in horizontal and vertical correctors modelling (see #90) (@cr-xu) - Fix fringe_field_exit of `Dipole` is overwritten by `fringe_field` bug (see #99) (@cr-xu) - Fix error caused by mismatched devices on machines with CUDA GPUs (see #97) (@jank324) +- Fix error raised when tracking a `ParameterBeam` through an active `BPM` (see #101) (@jank324) ### 🐆 Other diff --git a/cheetah/accelerator.py b/cheetah/accelerator.py index 4f004dd4..527e917b 100644 --- a/cheetah/accelerator.py +++ b/cheetah/accelerator.py @@ -1186,7 +1186,7 @@ def track(self, incoming: Beam) -> Beam: if incoming is Beam.empty: self.reading = None elif isinstance(incoming, ParameterBeam): - self.reading = torch.stack([incoming._mu_x, incoming._mu_y]) + self.reading = torch.stack([incoming.mu_x, incoming.mu_y]) elif isinstance(incoming, ParticleBeam): self.reading = torch.stack([incoming.mu_x, incoming.mu_y]) else: diff --git a/tests/test_bpm.py b/tests/test_bpm.py new file mode 100644 index 00000000..2dc51fe4 --- /dev/null +++ b/tests/test_bpm.py @@ -0,0 +1,22 @@ +import pytest +import torch + +import cheetah + + +@pytest.mark.parametrize("is_bpm_active", [True, False]) +@pytest.mark.parametrize("beam_class", [cheetah.ParticleBeam, cheetah.ParameterBeam]) +def test_no_tracking_error(is_bpm_active, beam_class): + """Test that tracking a beam through an inactive BPM does not raise an error.""" + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(1.0)), + cheetah.BPM(name="my_bpm"), + cheetah.Drift(length=torch.tensor(1.0)), + ], + ) + beam = beam_class.from_astra("benchmark/cheetah/ACHIP_EA1_2021.1351.001") + + segment.my_bpm.is_active = is_bpm_active + + _ = segment.track(beam)