You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is actually a big item I still want to address. It's kind of related to #113.
In the end, I think everything should kind of work in a similar way to say nn.Linear in PyTorch, where you can do either one of
quad=quad.double()
quad=quad.gpu()
This requires registering stuff in PyTorch and we hadn't quite figured out yet how to do this while keeping assignments like
quad.k1=torch.tensor([4.2])
functional.
As a workaround for now, you basically have to remember to pass the dtype you want to both the beam and all elements on initialisation and then in should work.
Quadrupole (and probably other elements)
dtype
default is set totorch.float32
, which leads to errors when using beams with default tensors (which default totorch.double
). See https://github.com/desy-ml/cheetah/blob/master/cheetah/accelerator/quadrupole.py#L37This can result annoying when trying to use default kwargs
The text was updated successfully, but these errors were encountered: