Skip to content

Commit

Permalink
Adjust Elegant converter to new broadcasting convention
Browse files Browse the repository at this point in the history
  • Loading branch information
Hespe committed Oct 2, 2024
1 parent 2c4c66d commit a9d522a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 60 deletions.
104 changes: 52 additions & 52 deletions cheetah/converters/elegant.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def convert_element(
# The group property does not have an analoge in Cheetah, so it is neglected
validate_understood_properties(["element_type", "l", "group"], parsed)
return cheetah.Solenoid(
length=torch.tensor([parsed["l"]]),
length=torch.tensor(parsed["l"]),
name=name,
device=device,
dtype=dtype,
Expand All @@ -54,8 +54,8 @@ def convert_element(
["element_type", "l", "kick", "group"], parsed
)
return cheetah.HorizontalCorrector(
length=torch.tensor([parsed.get("l", 0.0)]),
angle=torch.tensor([parsed.get("kick", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
angle=torch.tensor(parsed.get("kick", 0.0)),
name=name,
device=device,
dtype=dtype,
Expand All @@ -65,8 +65,8 @@ def convert_element(
["element_type", "l", "kick", "group"], parsed
)
return cheetah.VerticalCorrector(
length=torch.tensor([parsed.get("l", 0.0)]),
angle=torch.tensor([parsed.get("kick", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
angle=torch.tensor(parsed.get("kick", 0.0)),
name=name,
device=device,
dtype=dtype,
Expand All @@ -79,15 +79,15 @@ def convert_element(

# TODO Find proper element class
return cheetah.Drift(
length=torch.tensor([parsed.get("l", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
name=name,
device=device,
dtype=dtype,
)
elif parsed["element_type"] in ["drift", "drif"]:
validate_understood_properties(["element_type", "l", "group"], parsed)
return cheetah.Drift(
length=torch.tensor([parsed.get("l", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
name=name,
device=device,
dtype=dtype,
Expand All @@ -98,7 +98,7 @@ def convert_element(
["element_type", "l", "group", "use_stupakov", "n_kicks", "csr"], parsed
)
return cheetah.Drift(
length=torch.tensor([parsed.get("l", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
name=name,
device=device,
dtype=dtype,
Expand All @@ -120,7 +120,7 @@ def convert_element(
parsed,
)
return cheetah.Drift(
length=torch.tensor([parsed.get("l", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
name=name,
device=device,
dtype=dtype,
Expand All @@ -133,14 +133,14 @@ def convert_element(
return cheetah.Segment(
elements=[
cheetah.Drift(
length=torch.tensor([parsed.get("l", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
name=name + "_drift",
device=device,
dtype=dtype,
),
cheetah.Aperture(
x_max=torch.tensor([parsed.get("x_max", torch.inf)]),
y_max=torch.tensor([parsed.get("y_max", torch.inf)]),
x_max=torch.tensor(parsed.get("x_max", torch.inf)),
y_max=torch.tensor(parsed.get("y_max", torch.inf)),
shape="elliptical",
name=name + "_aperture",
device=device,
Expand All @@ -156,14 +156,14 @@ def convert_element(
return cheetah.Segment(
elements=[
cheetah.Drift(
length=torch.tensor([parsed.get("l", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
name=name + "_drift",
device=device,
dtype=dtype,
),
cheetah.Aperture(
x_max=torch.tensor([parsed.get("x_max", torch.inf)]),
y_max=torch.tensor([parsed.get("y_max", torch.inf)]),
x_max=torch.tensor(parsed.get("x_max", torch.inf)),
y_max=torch.tensor(parsed.get("y_max", torch.inf)),
shape="rectangular",
name=name + "_aperture",
device=device,
Expand All @@ -177,9 +177,9 @@ def convert_element(
parsed,
)
return cheetah.Quadrupole(
length=torch.tensor([parsed["l"]]),
k1=torch.tensor([parsed["k1"]]),
tilt=torch.tensor([parsed.get("tilt", 0.0)]),
length=torch.tensor(parsed["l"]),
k1=torch.tensor(parsed["k1"]),
tilt=torch.tensor(parsed.get("tilt", 0.0)),
name=name,
device=device,
dtype=dtype,
Expand All @@ -192,7 +192,7 @@ def convert_element(

# TODO Parse properly! Missing element class
return cheetah.Drift(
length=torch.tensor([parsed["l"]]),
length=torch.tensor(parsed["l"]),
name=name,
device=device,
dtype=dtype,
Expand All @@ -202,7 +202,7 @@ def convert_element(
return cheetah.Segment(
elements=[
cheetah.Drift(
length=torch.tensor([parsed.get("l", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
name=name + "_drift",
device=device,
dtype=dtype,
Expand All @@ -229,10 +229,10 @@ def convert_element(
]
)
# Add affine component (constant offset)
R[:6, 6] = torch.tensor([parsed.get(f"c{i + 1}", 0.0) for i in range(6)])
R[:6, 6] = torch.tensor(parsed.get(f"c{i + 1}", 0.0) for i in range(6))

return cheetah.CustomTransferMap(
length=torch.tensor([parsed["l"]]),
length=torch.tensor(parsed["l"]),
transfer_map=R,
device=device,
dtype=dtype,
Expand All @@ -256,12 +256,12 @@ def convert_element(

# TODO Properly handle all parameters
return cheetah.Cavity(
length=torch.tensor([parsed["l"]]),
length=torch.tensor(parsed["l"]),
# Elegant defines 90° as the phase of maximum acceleration,
# while Cheetah uses 0°. We therefore add a phase offset to compensate.
phase=torch.tensor([parsed["phase"] - 90]),
voltage=torch.tensor([parsed["volt"]]),
frequency=torch.tensor([parsed["freq"]]),
phase=torch.tensor(parsed["phase"] - 90),
voltage=torch.tensor(parsed["volt"]),
frequency=torch.tensor(parsed["freq"]),
name=name,
device=device,
dtype=dtype,
Expand Down Expand Up @@ -300,12 +300,12 @@ def convert_element(

# TODO Properly handle all parameters
return cheetah.Cavity(
length=torch.tensor([parsed["l"]]),
length=torch.tensor(parsed["l"]),
# Elegant defines 90° as the phase of maximum acceleration,
# while Cheetah uses 0°. We therefore add a phase offset to compensate.
phase=torch.tensor([parsed["phase"] - 90]),
voltage=torch.tensor([parsed["volt"]]),
frequency=torch.tensor([parsed["freq"]]),
phase=torch.tensor(parsed["phase"] - 90),
voltage=torch.tensor(parsed["volt"]),
frequency=torch.tensor(parsed["freq"]),
name=name,
device=device,
dtype=dtype,
Expand All @@ -325,12 +325,12 @@ def convert_element(

# TODO Properly handle all parameters
return cheetah.TransverseDeflectingCavity(
length=torch.tensor([parsed["l"]]),
length=torch.tensor(parsed["l"]),
# Elegant defines 90° as the phase of maximum acceleration,
# while Cheetah uses 0°. We therefore add a phase offset to compensate.
phase=torch.tensor([parsed["phase"] - 90]),
voltage=torch.tensor([parsed["voltage"]]),
frequency=torch.tensor([parsed["frequency"]]),
phase=torch.tensor(parsed["phase"] - 90),
voltage=torch.tensor(parsed["voltage"]),
frequency=torch.tensor(parsed["frequency"]),
name=name,
device=device,
dtype=dtype,
Expand All @@ -341,12 +341,12 @@ def convert_element(
parsed,
)
return cheetah.Dipole(
length=torch.tensor([parsed["l"]]),
angle=torch.tensor([parsed.get("angle", 0.0)]),
k1=torch.tensor([parsed.get("k1", 0.0)]),
e1=torch.tensor([parsed["e1"]]),
e2=torch.tensor([parsed.get("e2", 0.0)]),
tilt=torch.tensor([parsed.get("tilt", 0.0)]),
length=torch.tensor(parsed["l"]),
angle=torch.tensor(parsed.get("angle", 0.0)),
k1=torch.tensor(parsed.get("k1", 0.0)),
e1=torch.tensor(parsed["e1"]),
e2=torch.tensor(parsed.get("e2", 0.0)),
tilt=torch.tensor(parsed.get("tilt", 0.0)),
name=name,
device=device,
dtype=dtype,
Expand All @@ -357,11 +357,11 @@ def convert_element(
parsed,
)
return cheetah.RBend(
length=torch.tensor([parsed["l"]]),
angle=torch.tensor([parsed.get("angle", 0.0)]),
e1=torch.tensor([parsed["e1"]]),
e2=torch.tensor([parsed.get("e2", 0.0)]),
tilt=torch.tensor([parsed.get("tilt", 0.0)]),
length=torch.tensor(parsed["l"]),
angle=torch.tensor(parsed.get("angle", 0.0)),
e1=torch.tensor(parsed["e1"]),
e2=torch.tensor(parsed.get("e2", 0.0)),
tilt=torch.tensor(parsed.get("tilt", 0.0)),
name=name,
device=device,
dtype=dtype,
Expand Down Expand Up @@ -392,12 +392,12 @@ def convert_element(
parsed,
)
return cheetah.Dipole(
length=torch.tensor([parsed["l"]]),
angle=torch.tensor([parsed.get("angle", 0.0)]),
k1=torch.tensor([parsed.get("k1", 0.0)]),
e1=torch.tensor([parsed["e1"]]),
e2=torch.tensor([parsed.get("e2", 0.0)]),
tilt=torch.tensor([parsed.get("tilt", 0.0)]),
length=torch.tensor(parsed["l"]),
angle=torch.tensor(parsed.get("angle", 0.0)),
k1=torch.tensor(parsed.get("k1", 0.0)),
e1=torch.tensor(parsed["e1"]),
e2=torch.tensor(parsed.get("e2", 0.0)),
tilt=torch.tensor(parsed.get("tilt", 0.0)),
name=name,
device=device,
dtype=dtype,
Expand All @@ -422,7 +422,7 @@ def convert_element(
# TODO: Remove the length if by adding markers to Cheetah
return cheetah.Drift(
name=name,
length=torch.tensor([parsed.get("l", 0.0)]),
length=torch.tensor(parsed.get("l", 0.0)),
device=device,
dtype=dtype,
)
Expand Down
14 changes: 6 additions & 8 deletions tests/test_elegant_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@ def test_fodo():
[
cheetah.Marker(name="c"),
cheetah.Quadrupole(
name="q1", length=torch.tensor([0.1]), k1=torch.tensor([1.5])
name="q1", length=torch.tensor(0.1), k1=torch.tensor(1.5)
),
cheetah.Drift(name="d1", length=torch.tensor([1])),
cheetah.Drift(name="d1", length=torch.tensor(1)),
cheetah.Marker(name="m1"),
cheetah.Dipole(
name="s1", length=torch.tensor([0.3]), e1=torch.tensor([0.25])
),
cheetah.Drift(name="d1", length=torch.tensor([1])),
cheetah.Dipole(name="s1", length=torch.tensor(0.3), e1=torch.tensor(0.25)),
cheetah.Drift(name="d1", length=torch.tensor(1)),
cheetah.Quadrupole(
name="q2", length=torch.tensor([0.2]), k1=torch.tensor([-3])
name="q2", length=torch.tensor(0.2), k1=torch.tensor(-3)
),
cheetah.Drift(name="d2", length=torch.tensor([2])),
cheetah.Drift(name="d2", length=torch.tensor(2)),
],
name="fodo",
)
Expand Down

0 comments on commit a9d522a

Please sign in to comment.