Skip to content

Commit

Permalink
Merge branch 'master' into 87-issues-when-running-on-a-machine-with-c…
Browse files Browse the repository at this point in the history
…uda-gpus
  • Loading branch information
jank324 authored Nov 14, 2023
2 parents 763f4ef + dd7adb2 commit f774715
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ None
- Fix the transfer maps in `Drift` and `Dipole`; Add R56 in horizontal and vertical correctors modelling (see #90) (@cr-xu)
- Fix fringe_field_exit of `Dipole` is overwritten by `fringe_field` bug (see #99) (@cr-xu)
- Fix error caused by mismatched devices on machines with CUDA GPUs (see #97) (@jank324)
- Fix error raised when tracking a `ParameterBeam` through an active `BPM` (see #101) (@jank324)

### 🐆 Other

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def track(self, incoming: Beam) -> Beam:
if incoming is Beam.empty:
self.reading = None
elif isinstance(incoming, ParameterBeam):
self.reading = torch.stack([incoming._mu_x, incoming._mu_y])
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
elif isinstance(incoming, ParticleBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
else:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_bpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
import torch

import cheetah


@pytest.mark.parametrize("is_bpm_active", [True, False])
@pytest.mark.parametrize("beam_class", [cheetah.ParticleBeam, cheetah.ParameterBeam])
def test_no_tracking_error(is_bpm_active, beam_class):
"""Test that tracking a beam through an inactive BPM does not raise an error."""
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor(1.0)),
cheetah.BPM(name="my_bpm"),
cheetah.Drift(length=torch.tensor(1.0)),
],
)
beam = beam_class.from_astra("benchmark/cheetah/ACHIP_EA1_2021.1351.001")

segment.my_bpm.is_active = is_bpm_active

_ = segment.track(beam)

0 comments on commit f774715

Please sign in to comment.