Skip to content

Commit

Permalink
Add quick little first test to check for grad_fn in results
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Sep 12, 2023
1 parent 96d0bb1 commit d44eca5
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions test/test_differentiable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch

import cheetah


def test_simple_quadrupole():
"""
Simple test on a [D, Q, D] lattice with the qudrupole's k1 requiring grad, checking
if PyTorch tracked a grad_fn into the outgoing beam.
"""
segment = cheetah.Segment(
[
cheetah.Drift(torch.tensor(1.0)),
cheetah.Quadrupole(torch.tensor(0.2), k1=torch.tensor(1.0), name="my_quad"),
cheetah.Drift(1.0),
]
)
incoming_beam = cheetah.ParticleBeam.from_astra(
"benchmark/astra/ACHIP_EA1_2021.1351.001"
)

segment.my_quad.k1.requires_grad = True

outgoing_beam = segment.track(incoming_beam)

assert hasattr(outgoing_beam, "grad_fn")

0 comments on commit d44eca5

Please sign in to comment.