Skip to content

Commit

Permalink
Fix cavity tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Jan 6, 2024
1 parent dab3f13 commit a077aa2
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,22 +953,22 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
T566 = 1.5 * self.length * igamma2 / beta0**3
T556 = 0
T555 = 0
if incoming.energy + delta_energy > 0:
if any(incoming.energy + delta_energy > 0):
k = 2 * torch.pi * self.frequency / constants.speed_of_light
outgoing_energy = incoming.energy + delta_energy
g1 = outgoing_energy / electron_mass_eV
beta1 = torch.sqrt(1 - 1 / g1**2)

if isinstance(incoming, ParameterBeam):
outgoing_mu[5] = (
incoming._mu[5]
outgoing_mu[:, 5] = (
incoming._mu[:, 5]
+ incoming.energy * beta0 / (outgoing_energy * beta1)
+ self.voltage
* beta0
/ (outgoing_energy * beta1)
* (torch.cos(incoming._mu[4] * beta0 * k + phi) - torch.cos(phi))
* (torch.cos(incoming._mu[:, 4] * beta0 * k + phi) - torch.cos(phi))
)
outgoing_cov[5, 5] = incoming._cov[5, 5]
outgoing_cov[:, 5, 5] = incoming._cov[:, 5, 5]
# outgoing_cov[5, 5] = (
# incoming._cov[5, 5]
# + incoming.energy * beta0 / (outgoing_energy * beta1)
Expand All @@ -978,20 +978,20 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
# * (torch.cos(incoming._mu[4] * beta0 * k + phi) - torch.cos(phi))
# )
else: # ParticleBeam
outgoing_particles[:, 5] = (
incoming.particles[:, 5]
outgoing_particles[:, :, 5] = (
incoming.particles[:, :, 5]
+ incoming.energy * beta0 / (outgoing_energy * beta1)
+ self.voltage
* beta0
/ (outgoing_energy * beta1)
* (
torch.cos(incoming.particles[:, 4] * beta0 * k + phi)
torch.cos(incoming.particles[:, :, 4] * beta0 * k + phi)
- torch.cos(phi)
)
)

dgamma = self.voltage / electron_mass_eV
if delta_energy > 0:
if any(delta_energy > 0):
T566 = (
self.length
* (beta0**3 * g0**3 - beta1**3 * g1**3)
Expand Down Expand Up @@ -1030,27 +1030,27 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
)

if isinstance(incoming, ParameterBeam):
outgoing_mu[4] = (
T566 * incoming._mu[5] ** 2
+ T556 * incoming._mu[4] * incoming._mu[5]
+ T555 * incoming._mu[4] ** 2
outgoing_mu[:, 4] = (
T566 * incoming._mu[:, 5] ** 2
+ T556 * incoming._mu[:, 4] * incoming._mu[:, 5]
+ T555 * incoming._mu[:, 4] ** 2
)
outgoing_cov[4, 4] = (
T566 * incoming._cov[5, 5] ** 2
+ T556 * incoming._cov[4, 5] * incoming._cov[5, 5]
+ T555 * incoming._cov[4, 4] ** 2
outgoing_cov[:, 4, 4] = (
T566 * incoming._cov[:, 5, 5] ** 2
+ T556 * incoming._cov[:, 4, 5] * incoming._cov[:, 5, 5]
+ T555 * incoming._cov[:, 4, 4] ** 2
)
outgoing_cov[4, 5] = (
T566 * incoming._cov[5, 5] ** 2
+ T556 * incoming._cov[4, 5] * incoming._cov[5, 5]
+ T555 * incoming._cov[4, 4] ** 2
outgoing_cov[:, 4, 5] = (
T566 * incoming._cov[:, 5, 5] ** 2
+ T556 * incoming._cov[:, 4, 5] * incoming._cov[:, 5, 5]
+ T555 * incoming._cov[:, 4, 4] ** 2
)
outgoing_cov[5, 4] = outgoing_cov[4, 5]
outgoing_cov[:, 5, 4] = outgoing_cov[:, 4, 5]
else: # ParticleBeam
outgoing_particles[:, 4] = (
T566 * incoming.particles[:, 5] ** 2
+ T556 * incoming.particles[:, 4] * incoming.particles[:, 5]
+ T555 * incoming.particles[:, 4] ** 2
outgoing_particles[:, :, 4] = (
T566 * incoming.particles[:, :, 5] ** 2
+ T556 * incoming.particles[:, :, 4] * incoming.particles[:, :, 5]
+ T555 * incoming.particles[:, :, 4] ** 2
)

if isinstance(incoming, ParameterBeam):
Expand Down

0 comments on commit a077aa2

Please sign in to comment.