Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quadrupole default dtype #191

Open
jp-ga opened this issue Jun 20, 2024 · 1 comment
Open

quadrupole default dtype #191

jp-ga opened this issue Jun 20, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@jp-ga
Copy link
Collaborator

jp-ga commented Jun 20, 2024

Quadrupole (and probably other elements) dtype default is set to torch.float32, which leads to errors when using beams with default tensors (which default to torch.double). See https://github.com/desy-ml/cheetah/blob/master/cheetah/accelerator/quadrupole.py#L37

This can result annoying when trying to use default kwargs

@jp-ga jp-ga added the enhancement New feature or request label Jun 20, 2024
@jp-ga jp-ga changed the title quadrupole dtype quadrupole default dtype Jun 20, 2024
@jank324
Copy link
Member

jank324 commented Jun 20, 2024

This is actually a big item I still want to address. It's kind of related to #113.

In the end, I think everything should kind of work in a similar way to say nn.Linear in PyTorch, where you can do either one of

quad = quad.double()
quad = quad.gpu()

This requires registering stuff in PyTorch and we hadn't quite figured out yet how to do this while keeping assignments like

quad.k1 = torch.tensor([4.2])

functional.

As a workaround for now, you basically have to remember to pass the dtype you want to both the beam and all elements on initialisation and then in should work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants