diff --git a/CHANGELOG.md b/CHANGELOG.md index 30213a7f..ac937eb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - Port Bmad-X tracking methods to Cheetah for `Quadrupole`, `Drift`, and `Dipole` (see #153, #240) (@jp-ga, @jank324) - Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240) (@jp-ga) - `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe) -- Implement a converter for lattice files imported from Elegant (see #222) (@hespe) +- Implement a converter for lattice files imported from Elegant (see #222, #251) (@hespe) ### 🐛 Bug fixes diff --git a/cheetah/converters/elegant.py b/cheetah/converters/elegant.py index 25d25aa9..6c9b1305 100644 --- a/cheetah/converters/elegant.py +++ b/cheetah/converters/elegant.py @@ -41,84 +41,178 @@ def convert_element( ) elif isinstance(parsed, dict) and "element_type" in parsed: if parsed["element_type"] == "sole": - validate_understood_properties(["element_type", "l"], parsed) + # The group property does not have an analoge in Cheetah, so it is neglected + validate_understood_properties(["element_type", "l", "group"], parsed) return cheetah.Solenoid( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), name=name, device=device, dtype=dtype, ) - elif parsed["element_type"] == "hkick": - validate_understood_properties(["element_type", "l", "kick"], parsed) + elif parsed["element_type"] in ["hkick", "hkic"]: + validate_understood_properties( + ["element_type", "l", "kick", "group"], parsed + ) return cheetah.HorizontalCorrector( - length=torch.tensor([parsed.get("l", 0.0)]), - angle=torch.tensor([parsed.get("kick", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), + angle=torch.tensor(parsed.get("kick", 0.0)), name=name, device=device, dtype=dtype, ) - elif parsed["element_type"] == "vkick": - validate_understood_properties(["element_type", "l", "kick"], parsed) + elif parsed["element_type"] in ["vkick", "vkic"]: + validate_understood_properties( + ["element_type", "l", "kick", "group"], parsed + ) return cheetah.VerticalCorrector( - length=torch.tensor([parsed.get("l", 0.0)]), - angle=torch.tensor([parsed.get("kick", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), + angle=torch.tensor(parsed.get("kick", 0.0)), name=name, device=device, dtype=dtype, ) elif parsed["element_type"] == "mark": - validate_understood_properties(["element_type"], parsed) + validate_understood_properties(["element_type", "group"], parsed) return cheetah.Marker(name=name) elif parsed["element_type"] == "kick": - validate_understood_properties(["element_type", "l"], parsed) + validate_understood_properties(["element_type", "l", "group"], parsed) # TODO Find proper element class return cheetah.Drift( - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), + name=name, + device=device, + dtype=dtype, + ) + elif parsed["element_type"] in ["drift", "drif"]: + validate_understood_properties(["element_type", "l", "group"], parsed) + return cheetah.Drift( + length=torch.tensor(parsed.get("l", 0.0)), name=name, device=device, dtype=dtype, ) - elif parsed["element_type"] == "drift": - validate_understood_properties(["element_type", "l"], parsed) + elif parsed["element_type"] in ["csrdrift", "csrdrif"]: + # Drift that includes effects from coherent synchrotron radiation + validate_understood_properties( + ["element_type", "l", "group", "use_stupakov", "n_kicks", "csr"], parsed + ) + return cheetah.Drift( + length=torch.tensor(parsed.get("l", 0.0)), + name=name, + device=device, + dtype=dtype, + ) + elif parsed["element_type"] in ["lscdrift", "lscdrif"]: + # Drift that includes space charge effects + validate_understood_properties( + [ + "element_type", + "l", + "group", + "interpolate", + "smoothing", + "bins", + "high_frequency_cutoff0", + "high_frequency_cutoff1", + "lsc", + ], + parsed, + ) return cheetah.Drift( - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), name=name, device=device, dtype=dtype, ) + elif parsed["element_type"] == "ecol": + validate_understood_properties( + ["element_type", "l", "x_max", "y_max"], + parsed, + ) + return cheetah.Segment( + elements=[ + cheetah.Drift( + length=torch.tensor(parsed.get("l", 0.0)), + name=name + "_drift", + device=device, + dtype=dtype, + ), + cheetah.Aperture( + x_max=torch.tensor(parsed.get("x_max", torch.inf)), + y_max=torch.tensor(parsed.get("y_max", torch.inf)), + shape="elliptical", + name=name + "_aperture", + device=device, + dtype=dtype, + ), + ], + ) + elif parsed["element_type"] == "rcol": + validate_understood_properties( + ["element_type", "l", "x_max", "y_max"], + parsed, + ) + return cheetah.Segment( + elements=[ + cheetah.Drift( + length=torch.tensor(parsed.get("l", 0.0)), + name=name + "_drift", + device=device, + dtype=dtype, + ), + cheetah.Aperture( + x_max=torch.tensor(parsed.get("x_max", torch.inf)), + y_max=torch.tensor(parsed.get("y_max", torch.inf)), + shape="rectangular", + name=name + "_aperture", + device=device, + dtype=dtype, + ), + ], + ) elif parsed["element_type"] == "quad": validate_understood_properties( - ["element_type", "l", "k1", "tilt"], + ["element_type", "l", "k1", "tilt", "group"], parsed, ) return cheetah.Quadrupole( - length=torch.tensor([parsed["l"]]), - k1=torch.tensor([parsed["k1"]]), - tilt=torch.tensor([parsed.get("tilt", 0.0)]), + length=torch.tensor(parsed["l"]), + k1=torch.tensor(parsed["k1"]), + tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, dtype=dtype, ) elif parsed["element_type"] == "sext": # validate_understood_properties( - # ["element_type", "l"], + # ["element_type", "l", "group"], # parsed, # ) # TODO Parse properly! Missing element class return cheetah.Drift( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), name=name, device=device, dtype=dtype, ) elif parsed["element_type"] == "moni": - validate_understood_properties(["element_type"], parsed) - return cheetah.Marker(name=name) + validate_understood_properties(["element_type", "group", "l"], parsed) + return cheetah.Segment( + elements=[ + cheetah.Drift( + length=torch.tensor(parsed.get("l", 0.0)), + name=name + "_drift", + device=device, + dtype=dtype, + ), + cheetah.Marker(name=name), + ] + ) elif parsed["element_type"] == "ematrix": validate_understood_properties( - ["element_type", "l", "order", "c[1-6]", "r[1-6][1-6]"], + ["element_type", "l", "order", "c[1-6]", "r[1-6][1-6]", "group"], parsed, ) @@ -135,10 +229,10 @@ def convert_element( ] ) # Add affine component (constant offset) - R[:6, 6] = torch.tensor([parsed.get(f"c{i + 1}", 0.0) for i in range(6)]) + R[:6, 6] = torch.tensor(parsed.get(f"c{i + 1}", 0.0) for i in range(6)) return cheetah.CustomTransferMap( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), transfer_map=R, device=device, dtype=dtype, @@ -155,53 +249,171 @@ def convert_element( "end1_focus", "end2_focus", "body_focus_model", + "group", + ], + parsed, + ) + + # TODO Properly handle all parameters + return cheetah.Cavity( + length=torch.tensor(parsed["l"]), + # Elegant defines 90° as the phase of maximum acceleration, + # while Cheetah uses 0°. We therefore add a phase offset to compensate. + phase=torch.tensor(parsed["phase"] - 90), + voltage=torch.tensor(parsed["volt"]), + frequency=torch.tensor(parsed["freq"]), + name=name, + device=device, + dtype=dtype, + ) + elif parsed["element_type"] == "rfcw": + validate_understood_properties( + [ + "element_type", + "l", + "phase", + "volt", + "freq", + "change_p0", + "end1_focus", + "end2_focus", + "cell_length", + "zwakefile", + "trwakefile", + "tcolumn", + "wxcolumn", + "wycolumn", + "wzcolumn", + "interpolate", + "n_kicks", + "smoothing", + "zwake", + "trwake", + "lsc", + "lsc_bins", + "lsc_high_frequency_cutoff0", + "lsc_high_frequency_cutoff1", + "group", ], parsed, ) # TODO Properly handle all parameters return cheetah.Cavity( - length=torch.tensor([parsed["l"]]), + length=torch.tensor(parsed["l"]), + # Elegant defines 90° as the phase of maximum acceleration, + # while Cheetah uses 0°. We therefore add a phase offset to compensate. + phase=torch.tensor(parsed["phase"] - 90), + voltage=torch.tensor(parsed["volt"]), + frequency=torch.tensor(parsed["freq"]), + name=name, + device=device, + dtype=dtype, + ) + elif parsed["element_type"] == "rfdf": + validate_understood_properties( + [ + "element_type", + "l", + "phase", + "voltage", + "frequency", + "group", + ], + parsed, + ) + + # TODO Properly handle all parameters + return cheetah.TransverseDeflectingCavity( + length=torch.tensor(parsed["l"]), # Elegant defines 90° as the phase of maximum acceleration, # while Cheetah uses 0°. We therefore add a phase offset to compensate. - phase=torch.tensor([parsed["phase"] - 90]), - voltage=torch.tensor([parsed["volt"]]), - frequency=torch.tensor([parsed["freq"]]), + phase=torch.tensor(parsed["phase"] - 90), + voltage=torch.tensor(parsed["voltage"]), + frequency=torch.tensor(parsed["frequency"]), name=name, device=device, dtype=dtype, ) elif parsed["element_type"] == "sben": validate_understood_properties( - ["element_type", "l", "angle", "k1", "e1", "e2", "tilt"], + ["element_type", "l", "angle", "k1", "e1", "e2", "tilt", "group"], parsed, ) return cheetah.Dipole( - length=torch.tensor([parsed["l"]]), - angle=torch.tensor([parsed.get("angle", 0.0)]), - k1=torch.tensor([parsed.get("k1", 0.0)]), - e1=torch.tensor([parsed["e1"]]), - e2=torch.tensor([parsed.get("e2", 0.0)]), - tilt=torch.tensor([parsed.get("tilt", 0.0)]), + length=torch.tensor(parsed["l"]), + angle=torch.tensor(parsed.get("angle", 0.0)), + k1=torch.tensor(parsed.get("k1", 0.0)), + e1=torch.tensor(parsed["e1"]), + e2=torch.tensor(parsed.get("e2", 0.0)), + tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, dtype=dtype, ) elif parsed["element_type"] == "rben": validate_understood_properties( - ["element_type", "l", "angle", "e1", "e2", "tilt"], + ["element_type", "l", "angle", "e1", "e2", "tilt", "group"], parsed, ) return cheetah.RBend( - length=torch.tensor([parsed["l"]]), - angle=torch.tensor([parsed.get("angle", 0.0)]), - e1=torch.tensor([parsed["e1"]]), - e2=torch.tensor([parsed.get("e2", 0.0)]), - tilt=torch.tensor([parsed.get("tilt", 0.0)]), + length=torch.tensor(parsed["l"]), + angle=torch.tensor(parsed.get("angle", 0.0)), + e1=torch.tensor(parsed["e1"]), + e2=torch.tensor(parsed.get("e2", 0.0)), + tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, dtype=dtype, ) + elif parsed["element_type"] == "csrcsben": + validate_understood_properties( + [ + "element_type", + "l", + "angle", + "e1", + "e2", + "edge1_effects", + "edge2_effects", + "tilt", + "hgap", + "fint", + "sg_halfwidth", + "sg_order", + "steady_state", + "bins", + "n_kicks", + "integration_order", + "isr", + "csr", + "group", + ], + parsed, + ) + return cheetah.Dipole( + length=torch.tensor(parsed["l"]), + angle=torch.tensor(parsed.get("angle", 0.0)), + k1=torch.tensor(parsed.get("k1", 0.0)), + e1=torch.tensor(parsed["e1"]), + e2=torch.tensor(parsed.get("e2", 0.0)), + tilt=torch.tensor(parsed.get("tilt", 0.0)), + name=name, + device=device, + dtype=dtype, + ) + elif parsed["element_type"] == "watch": + validate_understood_properties( + ["element_type", "group", "filename"], parsed + ) + return cheetah.Marker(name=name) + elif parsed["element_type"] in ["charge", "wake"]: + print( + f"WARNING: Information provided in element {name} of type" + f" {parsed['element_type']} cannot be imported automatically. Consider" + " manually providing the correct information." + ) + return cheetah.Marker(name=name) else: print( f"WARNING: Element {name} of type {parsed['element_type']} cannot" @@ -210,7 +422,7 @@ def convert_element( # TODO: Remove the length if by adding markers to Cheetah return cheetah.Drift( name=name, - length=torch.tensor([parsed.get("l", 0.0)]), + length=torch.tensor(parsed.get("l", 0.0)), device=device, dtype=dtype, ) diff --git a/cheetah/converters/utils/__init__.py b/cheetah/converters/utils/__init__.py index d33ae13b..bb6c45e8 100644 --- a/cheetah/converters/utils/__init__.py +++ b/cheetah/converters/utils/__init__.py @@ -1 +1 @@ -from . import fortran_namelist # noqa: F401 +from . import fortran_namelist, rpn # noqa: F401 diff --git a/cheetah/converters/utils/fortran_namelist.py b/cheetah/converters/utils/fortran_namelist.py index 53476ee8..9765a8cf 100644 --- a/cheetah/converters/utils/fortran_namelist.py +++ b/cheetah/converters/utils/fortran_namelist.py @@ -8,6 +8,8 @@ import scipy from scipy.constants import physical_constants +from . import rpn + def read_clean_lines(lattice_file_path: Path) -> list[str]: """ @@ -125,7 +127,7 @@ def evaluate_expression(expression: str, context: dict) -> Any: # Evaluate as a mathematical expression try: - # Surround expressions in bracks with quotes + # Surround expressions in brackets with quotes expression = re.sub(r"\[([a-z0-9_%]+)\]", r"['\1']", expression) # Replace power operator with python equivalent expression = re.sub(r"\^", r"**", expression) @@ -135,7 +137,11 @@ def evaluate_expression(expression: str, context: dict) -> Any: # behaviour. expression = re.sub(r"abs\(", r"abs_func(", expression) - return eval(expression, context) + return ( + eval(expression, context) + if not rpn.is_valid_expression(expression) + else rpn.eval_expression(expression, context) + ) except SyntaxError: if not ( len(expression.split(":")) == 3 or len(expression.split(":")) == 4 diff --git a/cheetah/converters/utils/rpn.py b/cheetah/converters/utils/rpn.py new file mode 100644 index 00000000..2ccdd2e4 --- /dev/null +++ b/cheetah/converters/utils/rpn.py @@ -0,0 +1,17 @@ +from typing import Any + + +def is_valid_expression(expression: str) -> bool: + """Checks if expression is a reverse Polish notation.""" + stripped = expression[1:-1].strip() + return stripped[-1] in "+-/*" and len(stripped.split(" ")) == 3 + + +def eval_expression(expression: str, context: dict) -> Any: + """ + Evaluates an expression in reverse Polish notation. + + NOTE: Does not support nested expressions. + """ + splits = expression[1:-1].strip().split(" ") + return eval(" ".join([splits[0], splits[2], splits[1]]), context) diff --git a/tests/resources/fodo.lte b/tests/resources/fodo.lte index 8361cae9..8e6a1816 100644 --- a/tests/resources/fodo.lte +++ b/tests/resources/fodo.lte @@ -1,7 +1,8 @@ +c: charge,total=0.25e-9 q1: quad,l=0.1,k1=1.5 q2: quad,l=0.2,k1=-3 d1: drift,l=1 d2: drift,l=2 s1: sben, l=0.3,e1=0.25 m1: mark -fodo: line=(q1,d1,m1,s1,d1,q2,d2) +fodo: line=(c,q1,d1,m1,s1,d1,q2,d2) diff --git a/tests/test_elegant_conversion.py b/tests/test_elegant_conversion.py index 18d0d83b..93b17ece 100644 --- a/tests/test_elegant_conversion.py +++ b/tests/test_elegant_conversion.py @@ -12,19 +12,18 @@ def test_fodo(): correct_lattice = cheetah.Segment( [ + cheetah.Marker(name="c"), cheetah.Quadrupole( - name="q1", length=torch.tensor([0.1]), k1=torch.tensor([1.5]) + name="q1", length=torch.tensor(0.1), k1=torch.tensor(1.5) ), - cheetah.Drift(name="d1", length=torch.tensor([1])), + cheetah.Drift(name="d1", length=torch.tensor(1)), cheetah.Marker(name="m1"), - cheetah.Dipole( - name="s1", length=torch.tensor([0.3]), e1=torch.tensor([0.25]) - ), - cheetah.Drift(name="d1", length=torch.tensor([1])), + cheetah.Dipole(name="s1", length=torch.tensor(0.3), e1=torch.tensor(0.25)), + cheetah.Drift(name="d1", length=torch.tensor(1)), cheetah.Quadrupole( - name="q2", length=torch.tensor([0.2]), k1=torch.tensor([-3]) + name="q2", length=torch.tensor(0.2), k1=torch.tensor(-3) ), - cheetah.Drift(name="d2", length=torch.tensor([2])), + cheetah.Drift(name="d2", length=torch.tensor(2)), ], name="fodo", )