Skip to content

Commit

Permalink
Add test to detect Aperture vectorisation issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Oct 3, 2024
1 parent dcb6318 commit dada294
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/test_vectorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,28 @@ def test_vectorized_parameter_beam_creation():
assert torch.allclose(beam.mu_x, torch.tensor([2e-4, 3e-4]))
assert beam.sigma_x.shape == (2,)
assert torch.allclose(beam.sigma_x, torch.tensor([1e-5, 2e-5]))


def test_vectorized_aperture_broadcasting():
"""
Test that apertures work in a vectorised setting and that broadcasting rules are
applied correctly.
"""
incoming = cheetah.ParticleBeam.from_parameters(
num_particles=100_000, energy=torch.tensor([154e6, 14e9, 5e9])
)
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor(0.5)),
cheetah.Aperture(
x_max=torch.tensor([[1e-3], [2e-3], [3e-3]]), y_max=torch.tensor(1e-3)
),
cheetah.Drift(length=torch.tensor(0.5)),
]
)

outgoing = segment.track(incoming)

assert outgoing.particles.shape == (3, 2, 100_000, 7)
assert outgoing.particle_charges.shape == (100_000,)
assert outgoing.energy.shape == (2,)

0 comments on commit dada294

Please sign in to comment.