-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
114 tests are failing on gpu machine #115
Conversation
If I read the changes correctly, now we can safely switch between [float16, float32, float64] for elements, right? |
I haven't tested it (other than quickly trying to track through a quadrupole with MPS, see #61 ), but in theory, we should now be able to safely switch between all dtypes and devices, yes. |
- The handling of `device` and `dtype` was overhauled. They might not behave as expected. `Element`s also no longer have a `device` attribute. (see #115) (@jank324) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth mentioning explicitly in the documentation that it actually changes the behaviour of element creation, i.e. one can now create elements conveniently as before v.6.0
, without the requirement to wrap every parameter as tensor first
dipole = cheetah.Dipole(length=0.1, angle=0.1, dtype=torch.float32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right, but I also think this should be tested before making that claim.
Description
This PR overhauls the way devices and types are handles by Elements and Beams. The idea is to make it more in line with how original PyTorch Modules like
nn.Linear
do it (https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear). This also addresses the fact that according to (https://stackoverflow.com/questions/58926054/how-to-get-the-device-type-of-a-pytorch-module-conveniently) Modules shouldn't have overarching type and device properties because their parameters and buffers may be on different devices and of different types.On a more practical note, this makes the entire device system much more robust and basically completes MPS support from Cheetah's side (#61 / not all operations we use are implemented in PyTorch for MPS yet).
A minor downside is that we no longer do automatic device selection. This does, however, make that part of Cheetah more predictable (I previously had a lot of code using Cheetah fix the device because weird things would happen).
This PR also prepares an eventual fix for #113.
Motivation and Context
Types of changes
Checklist
flake8
(required).pytest
tests pass (required).pytest
on a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line