Skip to content

Commit

Permalink
Fix failing NX tables test
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Jan 6, 2024
1 parent 94064cb commit 4b7c15f
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions cheetah/converters/nxtables.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]:
elif class_name == "MCXG": # TODO: Check length with Willi
assert name[6] == "X"
horizontal_coil = cheetah.HorizontalCorrector(
name=name[:6] + "H" + name[6 + 1 :], length=torch.tensor(5e-05)
name=name[:6] + "H" + name[6 + 1 :], length=torch.tensor([5e-05])
)
vertical_coil = cheetah.VerticalCorrector(
name=name[:6] + "V" + name[6 + 1 :], length=torch.tensor(5e-05)
name=name[:6] + "V" + name[6 + 1 :], length=torch.tensor([5e-05])
)
element = cheetah.Segment(elements=[horizontal_coil, vertical_coil], name=name)
elif class_name == "BSCX":
Expand Down Expand Up @@ -115,57 +115,57 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]:
elif class_name == "SLHG":
element = cheetah.Aperture( # TODO: Ask for actual size and shape
name=name,
x_max=torch.tensor(float("inf")),
y_max=torch.tensor(float("inf")),
x_max=torch.tensor([float("inf")]),
y_max=torch.tensor([float("inf")]),
shape="elliptical",
)
elif class_name == "SLHB":
element = cheetah.Aperture( # TODO: Ask for actual size and shape
name=name,
x_max=torch.tensor(float("inf")),
y_max=torch.tensor(float("inf")),
x_max=torch.tensor([float("inf")]),
y_max=torch.tensor([float("inf")]),
shape="rectangular",
)
elif class_name == "SLHS":
element = cheetah.Aperture( # TODO: Ask for actual size and shape
name=name,
x_max=torch.tensor(float("inf")),
y_max=torch.tensor(float("inf")),
x_max=torch.tensor([float("inf")]),
y_max=torch.tensor([float("inf")]),
shape="rectangular",
)
elif class_name == "MCHM":
element = cheetah.HorizontalCorrector(name=name, length=torch.tensor(0.02))
element = cheetah.HorizontalCorrector(name=name, length=torch.tensor([0.02]))
elif class_name == "MCVM":
element = cheetah.VerticalCorrector(name=name, length=torch.tensor(0.02))
element = cheetah.VerticalCorrector(name=name, length=torch.tensor([0.02]))
elif class_name == "MBHL":
element = cheetah.Dipole(name=name, length=torch.tensor(0.322))
element = cheetah.Dipole(name=name, length=torch.tensor([0.322]))
elif class_name == "MBHB":
element = cheetah.Dipole(name=name, length=torch.tensor(0.22))
element = cheetah.Dipole(name=name, length=torch.tensor([0.22]))
elif class_name == "MBHO":
element = cheetah.Dipole(
name=name,
length=torch.tensor(0.43852543421396856),
angle=torch.tensor(0.8203047484373349),
e2=torch.tensor(-0.7504915783575616),
length=torch.tensor([0.43852543421396856]),
angle=torch.tensor([0.8203047484373349]),
e2=torch.tensor([-0.7504915783575616]),
)
elif class_name == "MQZM":
element = cheetah.Quadrupole(name=name, length=torch.tensor(0.122))
element = cheetah.Quadrupole(name=name, length=torch.tensor([0.122]))
elif class_name == "RSBL":
element = cheetah.Cavity(
name=name,
length=torch.tensor(4.139),
frequency=torch.tensor(2.998e9),
voltage=torch.tensor(76e6),
length=torch.tensor([4.139]),
frequency=torch.tensor([2.998e9]),
voltage=torch.tensor([76e6]),
)
elif class_name == "RXBD":
element = cheetah.Cavity( # TODO: TD? and tilt?
name=name,
length=torch.tensor(1.0),
frequency=torch.tensor(11.9952e9),
voltage=torch.tensor(0.0),
length=torch.tensor([1.0]),
frequency=torch.tensor([11.9952e9]),
voltage=torch.tensor([0.0]),
)
elif class_name == "UNDA": # TODO: Figure out actual length
element = cheetah.Undulator(name=name, length=torch.tensor(0.25))
element = cheetah.Undulator(name=name, length=torch.tensor([0.25]))
elif class_name in [
"SOLG",
"BCMG",
Expand Down Expand Up @@ -252,7 +252,7 @@ def read_nx_tables(filepath: Path) -> "cheetah.Element":
filled_with_drifts.append(
cheetah.Drift(
name=f"DRIFT_{previous['element'].name}_{current['element'].name}",
length=torch.as_tensor(drift_length),
length=torch.as_tensor([drift_length]),
)
)

Expand Down

0 comments on commit 4b7c15f

Please sign in to comment.