-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable automatic broadcasting #208
Conversation
If I can get feedback on the changes made here I can go ahead and make changes for other elements |
So, I managed to break it. I really don't know what will be the best way to implement this automatic broadcasting. Ideally, we would find a module in PyTorch itself that does something similar and copy what it does. Until then, maybe we should follow a test-driven approach, and implement tests that use Cheetah the way we would like to use it and then fix them. |
This is the current form of the dipole test that passes in my version of the repo. I can correct the rest of the tests if this looks ok to you all. @jank324 does this give you a clearer idea of what we meant by batching? Also what are the thoughts on using pytest to check that errors are raised? Is there something else already being used? import pytest
import torch
from cheetah import Dipole, Drift, ParameterBeam, ParticleBeam, Segment
def test_dipole_off():
"""
Test that a dipole with angle=0 behaves still like a drift.
"""
dipole = Dipole(length=torch.tensor(1.0), angle=torch.tensor(0.0))
drift = Drift(length=torch.tensor(1.0))
incoming_beam = ParameterBeam.from_parameters(
sigma_xp=torch.tensor(2e-7), sigma_yp=torch.tensor(2e-7)
)
outbeam_dipole_off = dipole(incoming_beam)
outbeam_drift = drift(incoming_beam)
dipole.angle = torch.tensor(1.0, device=dipole.angle.device)
outbeam_dipole_on = dipole(incoming_beam)
assert dipole.name is not None
assert torch.allclose(outbeam_dipole_off.sigma_x, outbeam_drift.sigma_x)
assert not torch.allclose(outbeam_dipole_on.sigma_x, outbeam_drift.sigma_x)
def test_dipole_batched_execution():
"""
Test that a dipole with batch dimensions behaves as expected.
"""
incoming = ParticleBeam.from_parameters(
num_particles=torch.tensor(100),
energy=torch.tensor(1e9),
mu_x=torch.tensor(1e-5),
)
# test batching to generate 3 beam lines
segment = Segment([
Dipole(
length=torch.tensor([0.5, 0.5, 0.5]),
angle=torch.tensor([0.1, 0.2, 0.1]),
),
Drift(length=torch.tensor(0.5)),
]
)
outgoing = segment(incoming)
assert outgoing.particles.shape == torch.Size([3, 100, 7])
assert outgoing.mu_x.shape == torch.Size([3])
# Check that dipole with same bend angle produce same output
assert torch.allclose(outgoing.particles[0], outgoing.particles[2])
# Check different angles do make a difference
assert not torch.allclose(outgoing.particles[0], outgoing.particles[1])
# test batching to generate 18 beamlines
segment = Segment([
Dipole(
length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1),
angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3),
),
Drift(length=torch.tensor([0.5, 1.0]).reshape(2, 1, 1)),
]
)
outgoing = segment(incoming)
assert outgoing.particles.shape == torch.Size([2, 3, 3, 100, 7])
# test improper batching -- this does not obey torch broadcasting rules
segment = Segment([
Dipole(
length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1),
angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3),
),
Drift(length=torch.tensor([0.5, 1.0]).reshape(2, 1)),
]
)
with pytest.raises(RuntimeError):
segment(incoming) |
I would say, that's how you are supposed to do it? Or is there a reason you wouldn't? I think something very similar is already done somewhere in the Cheetah tests, but I'm not sure where at this point. |
Ok, that's how I would do it. I wasn't sure if pytest was being used or not and I didn't want to add a dependency if you had something else you wanted to use |
So for my understanding incoming = ParticleBeam.from_parameters(
num_particles=torch.tensor(100),
energy=torch.tensor(1e9),
mu_x=torch.tensor(1e-5),
)
# Gives me the "product" of all settings
Dipole(
length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1),
angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3),
)
...
assert outgoing.particles.shape == (3, 3, num_particles, 7) # passes
# Matches the settings one-to-one
Dipole(
length=torch.tensor([0.5, 0.5, 0.5]),
angle=torch.tensor([0.1, 0.2, 0.1]),
)
...
assert outgoing.particles.shape == (3, num_particles, 7) # passes Correct? |
Yes, that is correct |
Seems like a good way of doing it to me then. Are there any places this could cause problems? And is there a straight-forward way of implementing this relatively easily for every element? ... such that we can give as a general guideline when we are implementing new elements, you should do it like ... and then you can be sure this will follow those rules? |
The main requirement is that you need to use pytorch operations which are broadcastable. For example, most of the changes that I needed to make (so far) is replacing the boolean indexing that you have been using for avoiding div0 errors with gamma = energy / electron_mass_eV.to(device=device, dtype=dtype)
igamma2 = torch.where(gamma == 0.0, 0.0, 1 / gamma**2) I'll update this list as I go through the other elements |
Note that I added a utility function to calculate the relativistic factors that are used throughout different elements |
…k is the result of vectorisation
I fixed all the issues I could just see looking at the code changes and did a bit of cleanup on the way. It turned out to be quite a bit more than I expected. Either way, I think as it is now, we can merge this PR. |
px = px + voltage * torch.sin(phase) | ||
# TODO: Assigning px to px is really bad practice and should be separated into | ||
# two separate variables | ||
px = px + voltage.unsqueeze(-1) * torch.sin(phase) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need .unsqueeze(-1)
here (I saw this also in the quadrupole bmad tracking)? I'm probably just confused.
W.r.t. the variable naming, it can probably be px_new
or px_out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because px
has a shape(vector, num_particles)
and voltage
has (vector, )
... for PyTorch broadcasting the vector dimension of voltage
basically has to be pushed to the left ... and that's what the unsqueeze
does.
All the Bmad-X code needs refactoring, including that variable name ... I've had to keep myself from doing that in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! Then we should merge this first and do the refactoring later.
Description
Adds the ability to do automatic broadcasting according to PyTorch conventions without the use of broadcast function.
PyTorch broadcast semantics rely on the following rules
This PR allows the specification of arbitrary tensor sizes for beamline attributes and particle beam attributes as long as they abide by the rules above. For example,
produces an outgoing beam with
outgoing.particles.shape == (3, 2, 100_000, 7)
where the beamsize and quadrupole length are changed along each batch dimension (3,2) and the quadrupole strength is changes along the first dimension.Note: Current implementation requires the same tensor dimensions for specifying beam attributes using
from_parameters
.Motivation and Context
Addresses #138
Types of changes
Checklist
flake8
(required).pytest
tests pass (required).pytest
on a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line