diff --git a/cheetah/converters/elegant.py b/cheetah/converters/elegant.py index f433e9c2..6c9b1305 100644 --- a/cheetah/converters/elegant.py +++ b/cheetah/converters/elegant.py @@ -44,7 +44,7 @@ def convert_element( # The group property does not have an analoge in Cheetah, so it is neglected validate_understood_properties(["element_type", "l", "group"], parsed) return cheetah.Solenoid( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), name=name, device=device, dtype=dtype, @@ -54,8 +54,8 @@ def convert_element( ["element_type", "l", "kick", "group"], parsed ) return cheetah.HorizontalCorrector( - length=torch.tensor([parsed.get("l", 0.0)]), - angle=torch.tensor([parsed.get("kick", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), + angle=torch.tensor(parsed.get("kick", 0.0)), name=name, device=device, dtype=dtype, @@ -65,8 +65,8 @@ def convert_element( ["element_type", "l", "kick", "group"], parsed ) return cheetah.VerticalCorrector( - length=torch.tensor([parsed.get("l", 0.0)]), - angle=torch.tensor([parsed.get("kick", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), + angle=torch.tensor(parsed.get("kick", 0.0)), name=name, device=device, dtype=dtype, @@ -79,7 +79,7 @@ def convert_element( # TODO Find proper element class return cheetah.Drift( - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), name=name, device=device, dtype=dtype, @@ -87,7 +87,7 @@ def convert_element( elif parsed["element_type"] in ["drift", "drif"]: validate_understood_properties(["element_type", "l", "group"], parsed) return cheetah.Drift( - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), name=name, device=device, dtype=dtype, @@ -98,7 +98,7 @@ def convert_element( ["element_type", "l", "group", "use_stupakov", "n_kicks", "csr"], parsed ) return cheetah.Drift( - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), name=name, device=device, dtype=dtype, @@ -120,7 +120,7 @@ def convert_element( parsed, ) return cheetah.Drift( - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), name=name, device=device, dtype=dtype, @@ -133,14 +133,14 @@ def convert_element( return cheetah.Segment( elements=[ cheetah.Drift( - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), name=name + "_drift", device=device, dtype=dtype, ), cheetah.Aperture( - x_max=torch.tensor([parsed.get("x_max", torch.inf)]), - y_max=torch.tensor([parsed.get("y_max", torch.inf)]), + x_max=torch.tensor(parsed.get("x_max", torch.inf)), + y_max=torch.tensor(parsed.get("y_max", torch.inf)), shape="elliptical", name=name + "_aperture", device=device, @@ -156,14 +156,14 @@ def convert_element( return cheetah.Segment( elements=[ cheetah.Drift( - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), name=name + "_drift", device=device, dtype=dtype, ), cheetah.Aperture( - x_max=torch.tensor([parsed.get("x_max", torch.inf)]), - y_max=torch.tensor([parsed.get("y_max", torch.inf)]), + x_max=torch.tensor(parsed.get("x_max", torch.inf)), + y_max=torch.tensor(parsed.get("y_max", torch.inf)), shape="rectangular", name=name + "_aperture", device=device, @@ -177,9 +177,9 @@ def convert_element( parsed, ) return cheetah.Quadrupole( - length=torch.tensor([parsed["l"]]), - k1=torch.tensor([parsed["k1"]]), - tilt=torch.tensor([parsed.get("tilt", 0.0)]), + length=torch.tensor(parsed["l"]), + k1=torch.tensor(parsed["k1"]), + tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, dtype=dtype, @@ -192,7 +192,7 @@ def convert_element( # TODO Parse properly! Missing element class return cheetah.Drift( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), name=name, device=device, dtype=dtype, @@ -202,7 +202,7 @@ def convert_element( return cheetah.Segment( elements=[ cheetah.Drift( - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), name=name + "_drift", device=device, dtype=dtype, @@ -229,10 +229,10 @@ def convert_element( ] ) # Add affine component (constant offset) - R[:6, 6] = torch.tensor([parsed.get(f"c{i + 1}", 0.0) for i in range(6)]) + R[:6, 6] = torch.tensor(parsed.get(f"c{i + 1}", 0.0) for i in range(6)) return cheetah.CustomTransferMap( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), transfer_map=R, device=device, dtype=dtype, @@ -256,12 +256,12 @@ def convert_element( # TODO Properly handle all parameters return cheetah.Cavity( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), # Elegant defines 90° as the phase of maximum acceleration, # while Cheetah uses 0°. We therefore add a phase offset to compensate. - phase=torch.tensor([parsed["phase"] - 90]), - voltage=torch.tensor([parsed["volt"]]), - frequency=torch.tensor([parsed["freq"]]), + phase=torch.tensor(parsed["phase"] - 90), + voltage=torch.tensor(parsed["volt"]), + frequency=torch.tensor(parsed["freq"]), name=name, device=device, dtype=dtype, @@ -300,12 +300,12 @@ def convert_element( # TODO Properly handle all parameters return cheetah.Cavity( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), # Elegant defines 90° as the phase of maximum acceleration, # while Cheetah uses 0°. We therefore add a phase offset to compensate. - phase=torch.tensor([parsed["phase"] - 90]), - voltage=torch.tensor([parsed["volt"]]), - frequency=torch.tensor([parsed["freq"]]), + phase=torch.tensor(parsed["phase"] - 90), + voltage=torch.tensor(parsed["volt"]), + frequency=torch.tensor(parsed["freq"]), name=name, device=device, dtype=dtype, @@ -325,12 +325,12 @@ def convert_element( # TODO Properly handle all parameters return cheetah.TransverseDeflectingCavity( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), # Elegant defines 90° as the phase of maximum acceleration, # while Cheetah uses 0°. We therefore add a phase offset to compensate. - phase=torch.tensor([parsed["phase"] - 90]), - voltage=torch.tensor([parsed["voltage"]]), - frequency=torch.tensor([parsed["frequency"]]), + phase=torch.tensor(parsed["phase"] - 90), + voltage=torch.tensor(parsed["voltage"]), + frequency=torch.tensor(parsed["frequency"]), name=name, device=device, dtype=dtype, @@ -341,12 +341,12 @@ def convert_element( parsed, ) return cheetah.Dipole( - length=torch.tensor([parsed["l"]]), - angle=torch.tensor([parsed.get("angle", 0.0)]), - k1=torch.tensor([parsed.get("k1", 0.0)]), - e1=torch.tensor([parsed["e1"]]), - e2=torch.tensor([parsed.get("e2", 0.0)]), - tilt=torch.tensor([parsed.get("tilt", 0.0)]), + length=torch.tensor(parsed["l"]), + angle=torch.tensor(parsed.get("angle", 0.0)), + k1=torch.tensor(parsed.get("k1", 0.0)), + e1=torch.tensor(parsed["e1"]), + e2=torch.tensor(parsed.get("e2", 0.0)), + tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, dtype=dtype, @@ -357,11 +357,11 @@ def convert_element( parsed, ) return cheetah.RBend( - length=torch.tensor([parsed["l"]]), - angle=torch.tensor([parsed.get("angle", 0.0)]), - e1=torch.tensor([parsed["e1"]]), - e2=torch.tensor([parsed.get("e2", 0.0)]), - tilt=torch.tensor([parsed.get("tilt", 0.0)]), + length=torch.tensor(parsed["l"]), + angle=torch.tensor(parsed.get("angle", 0.0)), + e1=torch.tensor(parsed["e1"]), + e2=torch.tensor(parsed.get("e2", 0.0)), + tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, dtype=dtype, @@ -392,12 +392,12 @@ def convert_element( parsed, ) return cheetah.Dipole( - length=torch.tensor([parsed["l"]]), - angle=torch.tensor([parsed.get("angle", 0.0)]), - k1=torch.tensor([parsed.get("k1", 0.0)]), - e1=torch.tensor([parsed["e1"]]), - e2=torch.tensor([parsed.get("e2", 0.0)]), - tilt=torch.tensor([parsed.get("tilt", 0.0)]), + length=torch.tensor(parsed["l"]), + angle=torch.tensor(parsed.get("angle", 0.0)), + k1=torch.tensor(parsed.get("k1", 0.0)), + e1=torch.tensor(parsed["e1"]), + e2=torch.tensor(parsed.get("e2", 0.0)), + tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, dtype=dtype, @@ -422,7 +422,7 @@ def convert_element( # TODO: Remove the length if by adding markers to Cheetah return cheetah.Drift( name=name, - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), device=device, dtype=dtype, ) diff --git a/tests/test_elegant_conversion.py b/tests/test_elegant_conversion.py index 877fb7fa..93b17ece 100644 --- a/tests/test_elegant_conversion.py +++ b/tests/test_elegant_conversion.py @@ -14,18 +14,16 @@ def test_fodo(): [ cheetah.Marker(name="c"), cheetah.Quadrupole( - name="q1", length=torch.tensor([0.1]), k1=torch.tensor([1.5]) + name="q1", length=torch.tensor(0.1), k1=torch.tensor(1.5) ), - cheetah.Drift(name="d1", length=torch.tensor([1])), + cheetah.Drift(name="d1", length=torch.tensor(1)), cheetah.Marker(name="m1"), - cheetah.Dipole( - name="s1", length=torch.tensor([0.3]), e1=torch.tensor([0.25]) - ), - cheetah.Drift(name="d1", length=torch.tensor([1])), + cheetah.Dipole(name="s1", length=torch.tensor(0.3), e1=torch.tensor(0.25)), + cheetah.Drift(name="d1", length=torch.tensor(1)), cheetah.Quadrupole( - name="q2", length=torch.tensor([0.2]), k1=torch.tensor([-3]) + name="q2", length=torch.tensor(0.2), k1=torch.tensor(-3) ), - cheetah.Drift(name="d2", length=torch.tensor([2])), + cheetah.Drift(name="d2", length=torch.tensor(2)), ], name="fodo", )