Skip to content

Commit

Permalink
Merge pull request #265 from desy-ml/264-broadcasting-error-in-correc…
Browse files Browse the repository at this point in the history
…tor-elements

Fix corrector broadcasting bug
  • Loading branch information
jank324 authored Oct 4, 2024
2 parents c96f3c4 + 2ffd048 commit 73f53e6
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des

### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)

Expand Down
4 changes: 3 additions & 1 deletion cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:

_, igamma2, beta = compute_relativistic_factors(energy)

vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape)
vector_shape = torch.broadcast_shapes(
self.length.shape, igamma2.shape, self.angle.shape
)

tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1))
tm[..., 0, 1] = self.length
Expand Down
4 changes: 3 additions & 1 deletion cheetah/accelerator/vertical_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:

_, igamma2, beta = compute_relativistic_factors(energy)

vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape)
vector_shape = torch.broadcast_shapes(
self.length.shape, igamma2.shape, self.angle.shape
)

tm = torch.eye(7, device=device, dtype=dtype).repeat((*vector_shape, 1, 1))
tm[..., 0, 1] = self.length
Expand Down
44 changes: 44 additions & 0 deletions tests/test_vectorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,47 @@ def test_vectorized_parameter_beam_creation():
assert torch.allclose(beam.mu_x, torch.tensor([2e-4, 3e-4]))
assert beam.sigma_x.shape == (2,)
assert torch.allclose(beam.sigma_x, torch.tensor([1e-5, 2e-5]))


@pytest.mark.parametrize(
"ElementClass", [cheetah.HorizontalCorrector, cheetah.VerticalCorrector]
)
def test_broadcasting_corrector_angles(ElementClass):
"""Test that broadcasting rules are correctly applied to with corrector angles."""
incoming = cheetah.ParticleBeam.from_parameters(
num_particles=100_000, energy=torch.tensor([154e6, 14e9])
)
element = ElementClass(
length=torch.tensor(0.15), angle=torch.tensor([[1e-5], [2e-5], [3e-5]])
)

outgoing = element.track(incoming)

assert outgoing.particles.shape == (3, 2, 100_000, 7)
assert outgoing.particle_charges.shape == (100_000,)
assert outgoing.energy.shape == (2,)


def test_broadcasting_solenoid_misalignment():
"""
Test that broadcasting rules are correctly applied to the misalignment in solenoids.
"""
incoming = cheetah.ParticleBeam.from_parameters(
num_particles=100_000, energy=torch.tensor([154e6, 14e9])
)
element = cheetah.Solenoid(
length=torch.tensor(0.15),
misalignment=torch.tensor(
[
[[1e-5, 2e-5], [2e-5, 3e-5]],
[[3e-5, 4e-5], [4e-5, 5e-5]],
[[5e-5, 6e-5], [6e-5, 7e-5]],
]
),
)

outgoing = element.track(incoming)

assert outgoing.particles.shape == (3, 2, 100_000, 7)
assert outgoing.particle_charges.shape == (100_000,)
assert outgoing.energy.shape == (2,)

0 comments on commit 73f53e6

Please sign in to comment.