diff --git a/cheetah/converters/nxtables.py b/cheetah/converters/nxtables.py index 6c962616..601d1552 100644 --- a/cheetah/converters/nxtables.py +++ b/cheetah/converters/nxtables.py @@ -53,10 +53,10 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]: elif class_name == "MCXG": # TODO: Check length with Willi assert name[6] == "X" horizontal_coil = cheetah.HorizontalCorrector( - name=name[:6] + "H" + name[6 + 1 :], length=torch.tensor(5e-05) + name=name[:6] + "H" + name[6 + 1 :], length=torch.tensor([5e-05]) ) vertical_coil = cheetah.VerticalCorrector( - name=name[:6] + "V" + name[6 + 1 :], length=torch.tensor(5e-05) + name=name[:6] + "V" + name[6 + 1 :], length=torch.tensor([5e-05]) ) element = cheetah.Segment(elements=[horizontal_coil, vertical_coil], name=name) elif class_name == "BSCX": @@ -115,57 +115,57 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]: elif class_name == "SLHG": element = cheetah.Aperture( # TODO: Ask for actual size and shape name=name, - x_max=torch.tensor(float("inf")), - y_max=torch.tensor(float("inf")), + x_max=torch.tensor([float("inf")]), + y_max=torch.tensor([float("inf")]), shape="elliptical", ) elif class_name == "SLHB": element = cheetah.Aperture( # TODO: Ask for actual size and shape name=name, - x_max=torch.tensor(float("inf")), - y_max=torch.tensor(float("inf")), + x_max=torch.tensor([float("inf")]), + y_max=torch.tensor([float("inf")]), shape="rectangular", ) elif class_name == "SLHS": element = cheetah.Aperture( # TODO: Ask for actual size and shape name=name, - x_max=torch.tensor(float("inf")), - y_max=torch.tensor(float("inf")), + x_max=torch.tensor([float("inf")]), + y_max=torch.tensor([float("inf")]), shape="rectangular", ) elif class_name == "MCHM": - element = cheetah.HorizontalCorrector(name=name, length=torch.tensor(0.02)) + element = cheetah.HorizontalCorrector(name=name, length=torch.tensor([0.02])) elif class_name == "MCVM": - element = cheetah.VerticalCorrector(name=name, length=torch.tensor(0.02)) + element = cheetah.VerticalCorrector(name=name, length=torch.tensor([0.02])) elif class_name == "MBHL": - element = cheetah.Dipole(name=name, length=torch.tensor(0.322)) + element = cheetah.Dipole(name=name, length=torch.tensor([0.322])) elif class_name == "MBHB": - element = cheetah.Dipole(name=name, length=torch.tensor(0.22)) + element = cheetah.Dipole(name=name, length=torch.tensor([0.22])) elif class_name == "MBHO": element = cheetah.Dipole( name=name, - length=torch.tensor(0.43852543421396856), - angle=torch.tensor(0.8203047484373349), - e2=torch.tensor(-0.7504915783575616), + length=torch.tensor([0.43852543421396856]), + angle=torch.tensor([0.8203047484373349]), + e2=torch.tensor([-0.7504915783575616]), ) elif class_name == "MQZM": - element = cheetah.Quadrupole(name=name, length=torch.tensor(0.122)) + element = cheetah.Quadrupole(name=name, length=torch.tensor([0.122])) elif class_name == "RSBL": element = cheetah.Cavity( name=name, - length=torch.tensor(4.139), - frequency=torch.tensor(2.998e9), - voltage=torch.tensor(76e6), + length=torch.tensor([4.139]), + frequency=torch.tensor([2.998e9]), + voltage=torch.tensor([76e6]), ) elif class_name == "RXBD": element = cheetah.Cavity( # TODO: TD? and tilt? name=name, - length=torch.tensor(1.0), - frequency=torch.tensor(11.9952e9), - voltage=torch.tensor(0.0), + length=torch.tensor([1.0]), + frequency=torch.tensor([11.9952e9]), + voltage=torch.tensor([0.0]), ) elif class_name == "UNDA": # TODO: Figure out actual length - element = cheetah.Undulator(name=name, length=torch.tensor(0.25)) + element = cheetah.Undulator(name=name, length=torch.tensor([0.25])) elif class_name in [ "SOLG", "BCMG", @@ -252,7 +252,7 @@ def read_nx_tables(filepath: Path) -> "cheetah.Element": filled_with_drifts.append( cheetah.Drift( name=f"DRIFT_{previous['element'].name}_{current['element'].name}", - length=torch.as_tensor(drift_length), + length=torch.as_tensor([drift_length]), ) )