diff --git a/cheetah/accelerator.py b/cheetah/accelerator.py index c4c15b3a..ec6915ec 100644 --- a/cheetah/accelerator.py +++ b/cheetah/accelerator.py @@ -953,22 +953,22 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam: T566 = 1.5 * self.length * igamma2 / beta0**3 T556 = 0 T555 = 0 - if incoming.energy + delta_energy > 0: + if any(incoming.energy + delta_energy > 0): k = 2 * torch.pi * self.frequency / constants.speed_of_light outgoing_energy = incoming.energy + delta_energy g1 = outgoing_energy / electron_mass_eV beta1 = torch.sqrt(1 - 1 / g1**2) if isinstance(incoming, ParameterBeam): - outgoing_mu[5] = ( - incoming._mu[5] + outgoing_mu[:, 5] = ( + incoming._mu[:, 5] + incoming.energy * beta0 / (outgoing_energy * beta1) + self.voltage * beta0 / (outgoing_energy * beta1) - * (torch.cos(incoming._mu[4] * beta0 * k + phi) - torch.cos(phi)) + * (torch.cos(incoming._mu[:, 4] * beta0 * k + phi) - torch.cos(phi)) ) - outgoing_cov[5, 5] = incoming._cov[5, 5] + outgoing_cov[:, 5, 5] = incoming._cov[:, 5, 5] # outgoing_cov[5, 5] = ( # incoming._cov[5, 5] # + incoming.energy * beta0 / (outgoing_energy * beta1) @@ -978,20 +978,20 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam: # * (torch.cos(incoming._mu[4] * beta0 * k + phi) - torch.cos(phi)) # ) else: # ParticleBeam - outgoing_particles[:, 5] = ( - incoming.particles[:, 5] + outgoing_particles[:, :, 5] = ( + incoming.particles[:, :, 5] + incoming.energy * beta0 / (outgoing_energy * beta1) + self.voltage * beta0 / (outgoing_energy * beta1) * ( - torch.cos(incoming.particles[:, 4] * beta0 * k + phi) + torch.cos(incoming.particles[:, :, 4] * beta0 * k + phi) - torch.cos(phi) ) ) dgamma = self.voltage / electron_mass_eV - if delta_energy > 0: + if any(delta_energy > 0): T566 = ( self.length * (beta0**3 * g0**3 - beta1**3 * g1**3) @@ -1030,27 +1030,27 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam: ) if isinstance(incoming, ParameterBeam): - outgoing_mu[4] = ( - T566 * incoming._mu[5] ** 2 - + T556 * incoming._mu[4] * incoming._mu[5] - + T555 * incoming._mu[4] ** 2 + outgoing_mu[:, 4] = ( + T566 * incoming._mu[:, 5] ** 2 + + T556 * incoming._mu[:, 4] * incoming._mu[:, 5] + + T555 * incoming._mu[:, 4] ** 2 ) - outgoing_cov[4, 4] = ( - T566 * incoming._cov[5, 5] ** 2 - + T556 * incoming._cov[4, 5] * incoming._cov[5, 5] - + T555 * incoming._cov[4, 4] ** 2 + outgoing_cov[:, 4, 4] = ( + T566 * incoming._cov[:, 5, 5] ** 2 + + T556 * incoming._cov[:, 4, 5] * incoming._cov[:, 5, 5] + + T555 * incoming._cov[:, 4, 4] ** 2 ) - outgoing_cov[4, 5] = ( - T566 * incoming._cov[5, 5] ** 2 - + T556 * incoming._cov[4, 5] * incoming._cov[5, 5] - + T555 * incoming._cov[4, 4] ** 2 + outgoing_cov[:, 4, 5] = ( + T566 * incoming._cov[:, 5, 5] ** 2 + + T556 * incoming._cov[:, 4, 5] * incoming._cov[:, 5, 5] + + T555 * incoming._cov[:, 4, 4] ** 2 ) - outgoing_cov[5, 4] = outgoing_cov[4, 5] + outgoing_cov[:, 5, 4] = outgoing_cov[:, 4, 5] else: # ParticleBeam - outgoing_particles[:, 4] = ( - T566 * incoming.particles[:, 5] ** 2 - + T556 * incoming.particles[:, 4] * incoming.particles[:, 5] - + T555 * incoming.particles[:, 4] ** 2 + outgoing_particles[:, :, 4] = ( + T566 * incoming.particles[:, :, 5] ** 2 + + T556 * incoming.particles[:, :, 4] * incoming.particles[:, :, 5] + + T555 * incoming.particles[:, :, 4] ** 2 ) if isinstance(incoming, ParameterBeam):