Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable automatic broadcasting #208

Merged
merged 121 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 105 commits
Commits
Show all changes
121 commits
Select commit Hold shift + click to select a range
404456a
update track methods to enable automatic broadcasting for drifts/quads
roussel-ryan Jul 2, 2024
849421c
update vectorize calcs and tests
roussel-ryan Jul 2, 2024
078ebb4
Add test that breaks current automatic broadcasting idea
jank324 Jul 9, 2024
7c98657
Merge branch 'master' into master
jank324 Jul 9, 2024
1766118
remove drift broadcast method, add utility function to calculate inv …
roussel-ryan Jul 11, 2024
44e589f
update relativistic factor calc util, cavity
roussel-ryan Jul 11, 2024
7edad77
Update cavity.py
roussel-ryan Jul 11, 2024
7150248
update quadrupole and segment for batching
roussel-ryan Jul 11, 2024
a0aef4f
Update quadrupole.py
roussel-ryan Jul 11, 2024
b37b3cf
fix test errors and bugs
roussel-ryan Jul 11, 2024
a17dd65
implement batch calculation utility and fix vectorize tests
roussel-ryan Jul 11, 2024
35ebe23
ufmt formatting
roussel-ryan Jul 11, 2024
42fcd83
remove broadcasting methods, fix vectorized test
roussel-ryan Jul 11, 2024
8ed54b7
remove broadcast methods
roussel-ryan Jul 11, 2024
5211369
remove broadcast methods
roussel-ryan Jul 11, 2024
bfddf8d
remove utility function, add element property
roussel-ryan Jul 11, 2024
8749e44
Merge branch 'master' into broadcasting
jank324 Jul 15, 2024
230a0cf
Fix vectorisation tests (not code that causes one to fail)
jank324 Jul 15, 2024
dfcd73d
Remove vectorisation tests that no longer make sense
jank324 Jul 15, 2024
c54d04f
Fix formating
jank324 Jul 15, 2024
a967010
Remove `broadcast` method from all beams
jank324 Jul 15, 2024
a67310f
Fix remaining `xp`s and `yp`s
jank324 Jul 15, 2024
d39d008
Remove unnecesary dimensions from elements (not yet beams) in tests
jank324 Jul 15, 2024
247919c
Cleanup imports
jank324 Jul 15, 2024
da1bd2b
fix some tests + solenoid fixes
roussel-ryan Jul 16, 2024
374b55b
Update solenoid.py
roussel-ryan Jul 16, 2024
21cc854
Merge branch 'master' into master
cr-xu Jul 17, 2024
9945b50
Fix non-batched error in space charge and particle beam creation
cr-xu Jul 17, 2024
5d8be5f
Procastinate by removing some more brackets than are no longer needed
jank324 Jul 17, 2024
9b1db53
Fix expected values for `energy` and `total_charge` shapes
jank324 Jul 24, 2024
039cd16
Fix `total_charge` and `energy` broacast issue
jank324 Jul 24, 2024
a97fcf2
Fix remaining failing tests
jank324 Jul 24, 2024
3aaae6f
Fix `any`s to `torch.any`
jank324 Jul 25, 2024
604fa81
Remove no longer needed dimensions
jank324 Jul 25, 2024
75a10b3
Merge branch 'master' into broadcasting
jank324 Jul 25, 2024
e1aa83c
test fixes/improvements
roussel-ryan Aug 6, 2024
12cfbb4
fix test to require matching batch sizes for particle creation
roussel-ryan Aug 8, 2024
df3d925
require that transform_to method takes scalar arguments
roussel-ryan Aug 8, 2024
d482dd6
Update parameter_beam.py
roussel-ryan Aug 8, 2024
de2679e
Update test_particle_beam.py
roussel-ryan Aug 8, 2024
9e298d4
Merge branch 'master' into master
roussel-ryan Aug 8, 2024
029187e
fix tests
roussel-ryan Aug 13, 2024
4761e93
apply formatting
roussel-ryan Aug 13, 2024
0b86926
Merge branch 'master' into master
roussel-ryan Aug 26, 2024
77cbae4
Merge branch 'master' into broadcasting
jank324 Sep 2, 2024
9be43e1
Clean up tests
jank324 Sep 2, 2024
4606ccf
Minor fixes
jank324 Sep 2, 2024
511fe7f
Fix flake8 warning
jank324 Sep 2, 2024
0703c13
Fix messed up formatting
jank324 Sep 2, 2024
862087e
Another reimaing formatting fix
jank324 Sep 2, 2024
e284503
Fix all but screen and space charge tests
jank324 Sep 2, 2024
8edf7e4
Fix format
jank324 Sep 2, 2024
35060cd
Add more tests for Screen
jank324 Sep 2, 2024
914d9b1
A few initial fixes to the vectorisation of Screens
jank324 Sep 2, 2024
785e2ae
Merge branch 'master' into broadcasting
jank324 Sep 4, 2024
24cc278
fixes to functionality and tests for screen batching
roussel-ryan Sep 6, 2024
b320971
fix misalignment broadcasting issue for particle beam
roussel-ryan Sep 9, 2024
78d81bf
fix misalignment broadcasting issue with parameter beam
roussel-ryan Sep 9, 2024
06435a9
updated error message
roussel-ryan Sep 9, 2024
3152832
remove reading from test due to the additional batch dimension form o…
roussel-ryan Sep 9, 2024
ad4ad1c
Some cleanup
jank324 Sep 20, 2024
149ca84
Merge branch 'master' into master
jank324 Sep 20, 2024
a09c2bd
Fix format
jank324 Sep 20, 2024
2b350a3
Another format fix
jank324 Sep 20, 2024
d080311
Fix bug that needlessly added batch dimension during tracking
jank324 Sep 21, 2024
dc92a36
Proper fix for remaining failing tests
jank324 Sep 21, 2024
cd71a05
Clean up conditions for catching unsupported vectorisation with Scree…
jank324 Sep 21, 2024
7a4c965
Add please report bugs note to changelog
jank324 Sep 21, 2024
ba36076
Some function name and docstring cleanup
jank324 Sep 23, 2024
1ddbbb6
Clean up needless conversion of constants to tensors
jank324 Sep 23, 2024
7daa7b1
Clean up relativistic factors function signature
jank324 Sep 23, 2024
f252c8f
Comment cleanup
jank324 Sep 23, 2024
a87227a
Remove gradient removal
jank324 Sep 23, 2024
e91b179
Return `misalginment` and `binning` order
jank324 Sep 23, 2024
f7b9c01
Update CHANGELOG.md
jank324 Sep 23, 2024
d878440
Fix vector vs. batch terminology
jank324 Sep 23, 2024
f15e8bf
Merge branch 'master' of https://github.com/roussel-ryan/cheetah into…
jank324 Sep 23, 2024
0d15a7d
Remove `batch_shape` property again
jank324 Sep 23, 2024
e8bd8c2
Clean up unnecessary `from torch import Tensor`
jank324 Sep 23, 2024
82df3af
Cleanup docstring
jank324 Sep 23, 2024
ca23f5c
Add test that breaks with `repeat` used for example in `Drift.transfe…
jank324 Sep 23, 2024
5d7c363
Fix test to ask for correct specification
jank324 Sep 23, 2024
3036b0a
Clean up `Drift.transfer_map` broadcasting
jank324 Sep 23, 2024
c15980e
Fix one of the failing tests; remove old broadcast methods
jank324 Sep 23, 2024
3aef2f4
Clean up some more of the automatic broadcasting
jank324 Sep 23, 2024
2d72fcf
Fix broken Ocelot comparison test
jank324 Sep 23, 2024
ded28e3
Complete expected vectorisation test results
jank324 Sep 24, 2024
24e7d2a
Correct expected test results for vectorisation with different input …
jank324 Sep 24, 2024
2569502
Add special test case for `Cavity` that affects outgoing beam energy
jank324 Sep 24, 2024
1a6df80
Remove test for cavity affecting energy because it turns out it makes…
jank324 Sep 24, 2024
852aef0
Fix broadcasting issue in TDC code
jank324 Sep 24, 2024
040942a
Add test for different dimension inputs to run over all bmadx trackin…
jank324 Sep 24, 2024
0d3eb38
Fix flake8 warning
jank324 Sep 24, 2024
15ea754
Fix broadcasting issues in elements with `"bmax"` tracking methods
jank324 Sep 24, 2024
c5df11a
A little cleanup
jank324 Sep 24, 2024
6f71b77
Fix issues with previously existing `Screen` tests
jank324 Sep 25, 2024
db4c891
Remove vectorised screen test
jank324 Sep 25, 2024
091ab6a
Some cleanup in `Screen` code
jank324 Sep 25, 2024
b07a619
Fix tests by reinstating some changes to `Screen` with `ParameterBeam`
jank324 Sep 25, 2024
4af6630
Add changelog entry for `Screen.is_blocking`
jank324 Sep 25, 2024
7f875cf
Benchmark timing of broadcasting and sum against reduce and add
jank324 Sep 25, 2024
cff1416
Remove not needed special case for segment of length one
jank324 Sep 25, 2024
eefc022
Fix vectorisation issue in length computation with zero-length elements
jank324 Sep 25, 2024
90ffed0
Add tests for new error with `SpaceChargeKick` I ran into that I thin…
jank324 Sep 25, 2024
d78750c
Fix test that failed on space charge code
jank324 Sep 25, 2024
7840113
Clean up vector shape computation in `Undulator`
jank324 Sep 25, 2024
3a446f0
Add test that finds issue in the broadcasting when creating a `Parame…
jank324 Sep 25, 2024
f2dc6e3
Fix issue with broadcasting when creating a `ParameterBeam`
jank324 Sep 25, 2024
cac728d
Fix broadcasting issue in `ParticleBeam.transformed_to`
jank324 Sep 25, 2024
b4ee9ba
Fix broadcasting issue in `base_rmatrix`
jank324 Sep 25, 2024
d4ac307
Incomplete cleanup of KDE tests
jank324 Sep 25, 2024
5a9f560
Clean up KDE tests
jank324 Sep 25, 2024
499e94a
Remove development notebooks
jank324 Sep 25, 2024
eea9487
Restore meaningfulness of Bmad-X quadrupole tracking test for 64-bit
jank324 Oct 1, 2024
c874e5e
Remove too large tolerances from test
jank324 Oct 1, 2024
6162c57
Revectorise test that had its vectorisation mistakenly removed
jank324 Oct 1, 2024
46e8c26
Clean up quadrupole tests
jank324 Oct 1, 2024
5a2862d
Correct assertions in multiple dimensions quadrupole tilt test
jank324 Oct 1, 2024
68de514
Minor code readability improvement
jank324 Oct 1, 2024
e1e4b95
Minor cleanup in vectorisation tests
jank324 Oct 1, 2024
d9816db
Fix missing vector dimension in KDE tests
jank324 Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cheetah.egg-info
.vscode
dist
.coverage
.idea

*.egg-info

Expand All @@ -14,4 +15,5 @@ build
distributions

docs/_build
dev*

dev*
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

## v0.7.0 [🚧 Work in Progress]

This is a major release with significant upgrades under the hood of Cheetah. Despite extensive testing, you might still encounter a few bugs. Please report them by opening an issue, so we can fix them as soon as possible and improve the experience for everyone.

### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe)
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)

### 🚀 Features

Expand Down
142 changes: 142 additions & 0 deletions benchmark_sum_reduce.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from functools import reduce\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[torch.Size([]), torch.Size([3]), torch.Size([2, 1])]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xs = [\n",
" torch.tensor(42.0),\n",
" torch.tensor([1.0, 2.0, 3.0]),\n",
" torch.tensor([[4.0], [5.0]]),\n",
"]\n",
"[x.shape for x in xs]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 3])]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"broadcast_xs = torch.broadcast_tensors(*xs)\n",
"[bx.shape for bx in broadcast_xs]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9.63 μs ± 16.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"\n",
"broadcast_xs = torch.broadcast_tensors(*xs)\n",
"stacked_xs = torch.stack(broadcast_xs)\n",
"torch.sum(stacked_xs, dim=0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.91 μs ± 12 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"\n",
"reduce(torch.add, xs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(42.)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reduce(torch.add, xs[:1])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cheetah-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 3 additions & 3 deletions cheetah/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import cheetah.converters # noqa: F401
from cheetah.accelerator import ( # noqa: F401
from . import converters # noqa: F401
from .accelerator import ( # noqa: F401
BPM,
Aperture,
Cavity,
Expand All @@ -18,4 +18,4 @@
Undulator,
VerticalCorrector,
)
from cheetah.particles import ParameterBeam, ParticleBeam # noqa: F401
from .particles import ParameterBeam, ParticleBeam # noqa: F401
20 changes: 3 additions & 17 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import Size, nn

from cheetah.particles import Beam, ParticleBeam
from cheetah.utils import UniqueNameGenerator
from torch import nn

from ..particles import Beam, ParticleBeam
from ..utils import UniqueNameGenerator
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -110,19 +109,6 @@ def track(self, incoming: Beam) -> Beam:
else ParticleBeam.empty
)

def broadcast(self, shape: Size) -> Element:
new_aperture = self.__class__(
x_max=self.x_max.repeat(shape),
y_max=self.y_max.repeat(shape),
shape=self.shape,
is_active=self.is_active,
name=self.name,
device=self.x_max.device,
dtype=self.x_max.dtype,
)
new_aperture.length = self.length.repeat(shape)
return new_aperture

def split(self, resolution: torch.Tensor) -> list[Element]:
# TODO: Implement splitting for aperture properly, for now just return self
return [self]
Expand Down
11 changes: 2 additions & 9 deletions cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import Size

from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.utils import UniqueNameGenerator

from ..particles import Beam, ParameterBeam, ParticleBeam
from ..utils import UniqueNameGenerator
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -50,11 +48,6 @@ def track(self, incoming: Beam) -> Beam:

return deepcopy(incoming)

def broadcast(self, shape: Size) -> Element:
new_bpm = self.__class__(is_active=self.is_active, name=self.name)
new_bpm.length = self.length.repeat(shape)
return new_bpm

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]

Expand Down
69 changes: 27 additions & 42 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from matplotlib.patches import Rectangle
from scipy import constants
from scipy.constants import physical_constants
from torch import Size, nn

from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.track_methods import base_rmatrix
from cheetah.utils import UniqueNameGenerator
from torch import nn

from ..particles import Beam, ParameterBeam, ParticleBeam
from ..track_methods import base_rmatrix
from ..utils import UniqueNameGenerator, compute_relativistic_factors
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
Expand Down Expand Up @@ -110,14 +109,7 @@ def _track_beam(self, incoming: Beam) -> Beam:
Track particles through the cavity. The input can be a `ParameterBeam` or a
`ParticleBeam`.
"""
beta0 = torch.full_like(self.length, 1.0)
igamma2 = torch.full_like(self.length, 0.0)
g0 = torch.full_like(self.length, 1e10)

mask = incoming.energy != 0
g0[mask] = incoming.energy[mask] / electron_mass_eV
igamma2[mask] = 1 / g0[mask] ** 2
beta0[mask] = torch.sqrt(1 - igamma2[mask])
gamma0, igamma2, beta0 = compute_relativistic_factors(incoming.energy)

phi = torch.deg2rad(self.phase)

Expand All @@ -138,8 +130,7 @@ def _track_beam(self, incoming: Beam) -> Beam:
if torch.any(incoming.energy + delta_energy > 0):
k = 2 * torch.pi * self.frequency / constants.speed_of_light
outgoing_energy = incoming.energy + delta_energy
g1 = outgoing_energy / electron_mass_eV
beta1 = torch.sqrt(1 - 1 / g1**2)
gamma1, _, beta1 = compute_relativistic_factors(outgoing_energy)

if isinstance(incoming, ParameterBeam):
outgoing_mu[..., 5] = incoming._mu[..., 5] * incoming.energy * beta0 / (
Expand Down Expand Up @@ -174,18 +165,18 @@ def _track_beam(self, incoming: Beam) -> Beam:
if torch.any(delta_energy > 0):
T566 = (
self.length
* (beta0**3 * g0**3 - beta1**3 * g1**3)
/ (2 * beta0 * beta1**3 * g0 * (g0 - g1) * g1**3)
* (beta0**3 * gamma0**3 - beta1**3 * gamma1**3)
/ (2 * beta0 * beta1**3 * gamma0 * (gamma0 - gamma1) * gamma1**3)
)
T556 = (
beta0
* k
* self.length
* dgamma
* g0
* (beta1**3 * g1**3 + beta0 * (g0 - g1**3))
* gamma0
* (beta1**3 * gamma1**3 + beta0 * (gamma0 - gamma1**3))
* torch.sin(phi)
/ (beta1**3 * g1**3 * (g0 - g1) ** 2)
/ (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 2)
)
T555 = (
beta0**2
Expand All @@ -196,15 +187,15 @@ def _track_beam(self, incoming: Beam) -> Beam:
* (
dgamma
* (
2 * g0 * g1**3 * (beta0 * beta1**3 - 1)
+ g0**2
+ 3 * g1**2
2 * gamma0 * gamma1**3 * (beta0 * beta1**3 - 1)
+ gamma0**2
+ 3 * gamma1**2
- 2
)
/ (beta1**3 * g1**3 * (g0 - g1) ** 3)
/ (beta1**3 * gamma1**3 * (gamma0 - gamma1) ** 3)
* torch.sin(phi) ** 2
- (g1 * g0 * (beta1 * beta0 - 1) + 1)
/ (beta1 * g1 * (g0 - g1) ** 2)
- (gamma1 * gamma0 * (beta1 * beta0 - 1) + 1)
/ (beta1 * gamma1 * (gamma0 - gamma1) ** 2)
* torch.cos(phi)
)
)
Expand Down Expand Up @@ -237,9 +228,9 @@ def _track_beam(self, incoming: Beam) -> Beam:

if isinstance(incoming, ParameterBeam):
outgoing = ParameterBeam(
outgoing_mu,
outgoing_cov,
outgoing_energy,
mu=outgoing_mu,
cov=outgoing_cov,
energy=outgoing_energy,
total_charge=incoming.total_charge,
device=outgoing_mu.device,
dtype=outgoing_mu.dtype,
Expand Down Expand Up @@ -302,7 +293,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
beta1 = torch.tensor(1.0)

k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light)
r55_cor = 0.0
r55_cor = torch.tensor(0.0)
if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if?
beta0 = torch.sqrt(1 - 1 / Ei**2)
beta1 = torch.sqrt(1 - 1 / Ef**2)
Expand All @@ -324,7 +315,12 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
r66 = Ei / Ef * beta0 / beta1
r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV)

R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
# Make sure that all matrix elements have the same shape
r11, r12, r21, r22, r55_cor, r56, r65, r66 = torch.broadcast_tensors(
r11, r12, r21, r22, r55_cor, r56, r65, r66
)

R = torch.eye(7, device=device, dtype=dtype).repeat((*r11.shape, 1, 1))
R[..., 0, 0] = r11
R[..., 0, 1] = r12
R[..., 1, 0] = r21
Expand All @@ -340,17 +336,6 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:

return R

def broadcast(self, shape: Size) -> Element:
return self.__class__(
length=self.length.repeat(shape),
voltage=self.voltage.repeat(shape),
phase=self.phase.repeat(shape),
frequency=self.frequency.repeat(shape),
name=self.name,
device=self.length.device,
dtype=self.length.dtype,
)

def split(self, resolution: torch.Tensor) -> list[Element]:
# TODO: Implement splitting for cavity properly, for now just returns the
# element itself
Expand Down
Loading