Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable automatic broadcasting #208

Merged
merged 121 commits into from
Oct 1, 2024
Merged

Enable automatic broadcasting #208

merged 121 commits into from
Oct 1, 2024

Conversation

roussel-ryan
Copy link
Contributor

@roussel-ryan roussel-ryan commented Jul 2, 2024

Description

Adds the ability to do automatic broadcasting according to PyTorch conventions without the use of broadcast function.

PyTorch broadcast semantics rely on the following rules

  • Each tensor has at least one dimension.
  • When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.

This PR allows the specification of arbitrary tensor sizes for beamline attributes and particle beam attributes as long as they abide by the rules above. For example,

quadrupole = cheetah.Quadrupole(
    length=torch.tensor([[0.2, 0.25], [0.3, 0.35], [0.4, 0.45]]), # shape = (3,2)
    k1=torch.tensor([[4.2], [4.3], [4.4]]), # shape = (3,1) 
)
incoming = cheetah.ParticleBeam.from_parameters(
    num_particles=100_000,
    sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]]), # shape = (3,2)
)

outgoing = quadrupole.track(incoming)

produces an outgoing beam with outgoing.particles.shape == (3, 2, 100_000, 7) where the beamsize and quadrupole length are changed along each batch dimension (3,2) and the quadrupole strength is changes along the first dimension.

Note: Current implementation requires the same tensor dimensions for specifying beam attributes using from_parameters.

Motivation and Context

Addresses #138

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code and checked that formatting passes (required).
  • I have have fixed all issues found by flake8 (required).
  • I have ensured that all pytest tests pass (required).
  • I have run pytest on a machine with a CUDA GPU and made sure all tests pass (required).
  • I have checked that the documentation builds (required).

Note: We are using a maximum length of 88 characters per line

@roussel-ryan
Copy link
Contributor Author

If I can get feedback on the changes made here I can go ahead and make changes for other elements

@jank324
Copy link
Member

jank324 commented Jul 9, 2024

If I can get feedback on the changes made here I can go ahead and make changes for other elements

So, I managed to break it. I really don't know what will be the best way to implement this automatic broadcasting. Ideally, we would find a module in PyTorch itself that does something similar and copy what it does.

Until then, maybe we should follow a test-driven approach, and implement tests that use Cheetah the way we would like to use it and then fix them.

@jank324 jank324 added the enhancement New feature or request label Jul 9, 2024
@jank324 jank324 mentioned this pull request Jul 10, 2024
7 tasks
@roussel-ryan
Copy link
Contributor Author

roussel-ryan commented Jul 10, 2024

This is the current form of the dipole test that passes in my version of the repo. I can correct the rest of the tests if this looks ok to you all. @jank324 does this give you a clearer idea of what we meant by batching? Also what are the thoughts on using pytest to check that errors are raised? Is there something else already being used?

import pytest
import torch

from cheetah import Dipole, Drift, ParameterBeam, ParticleBeam, Segment


def test_dipole_off():
    """
    Test that a dipole with angle=0 behaves still like a drift.
    """
    dipole = Dipole(length=torch.tensor(1.0), angle=torch.tensor(0.0))
    drift = Drift(length=torch.tensor(1.0))
    incoming_beam = ParameterBeam.from_parameters(
        sigma_xp=torch.tensor(2e-7), sigma_yp=torch.tensor(2e-7)
    )
    outbeam_dipole_off = dipole(incoming_beam)
    outbeam_drift = drift(incoming_beam)

    dipole.angle = torch.tensor(1.0, device=dipole.angle.device)
    outbeam_dipole_on = dipole(incoming_beam)

    assert dipole.name is not None
    assert torch.allclose(outbeam_dipole_off.sigma_x, outbeam_drift.sigma_x)
    assert not torch.allclose(outbeam_dipole_on.sigma_x, outbeam_drift.sigma_x)


def test_dipole_batched_execution():
    """
    Test that a dipole with batch dimensions behaves as expected.
    """
    incoming = ParticleBeam.from_parameters(
        num_particles=torch.tensor(100),
        energy=torch.tensor(1e9),
        mu_x=torch.tensor(1e-5),
    )

    # test batching to generate 3 beam lines
    segment = Segment([
        Dipole(
            length=torch.tensor([0.5, 0.5, 0.5]),
            angle=torch.tensor([0.1, 0.2, 0.1]),
        ),
        Drift(length=torch.tensor(0.5)),
    ]
    )
    outgoing = segment(incoming)

    assert outgoing.particles.shape == torch.Size([3, 100, 7])
    assert outgoing.mu_x.shape == torch.Size([3])

    # Check that dipole with same bend angle produce same output
    assert torch.allclose(outgoing.particles[0], outgoing.particles[2])

    # Check different angles do make a difference
    assert not torch.allclose(outgoing.particles[0], outgoing.particles[1])

    # test batching to generate 18 beamlines
    segment = Segment([
        Dipole(
            length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1),
            angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3),
        ),
        Drift(length=torch.tensor([0.5, 1.0]).reshape(2, 1, 1)),
    ]
    )
    outgoing = segment(incoming)
    assert outgoing.particles.shape == torch.Size([2, 3, 3, 100, 7])

    # test improper batching -- this does not obey torch broadcasting rules
    segment = Segment([
        Dipole(
            length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1),
            angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3),
        ),
        Drift(length=torch.tensor([0.5, 1.0]).reshape(2, 1)),
    ]
    )
    with pytest.raises(RuntimeError):
        segment(incoming)

@jank324
Copy link
Member

jank324 commented Jul 10, 2024

Also what are the thoughts on using pytest to check that errors are raised? Is there something else already being used?

I would say, that's how you are supposed to do it? Or is there a reason you wouldn't?

I think something very similar is already done somewhere in the Cheetah tests, but I'm not sure where at this point.

@roussel-ryan
Copy link
Contributor Author

Ok, that's how I would do it. I wasn't sure if pytest was being used or not and I didn't want to add a dependency if you had something else you wanted to use

@jank324
Copy link
Member

jank324 commented Jul 10, 2024

So for my understanding

incoming = ParticleBeam.from_parameters(
    num_particles=torch.tensor(100),
    energy=torch.tensor(1e9),
    mu_x=torch.tensor(1e-5),
)

# Gives me the "product" of all settings
Dipole(
    length=torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1),
    angle=torch.tensor([0.1, 0.2, 0.1]).reshape(1, 3),
)
...
assert outgoing.particles.shape == (3, 3, num_particles, 7)  # passes

# Matches the settings one-to-one
Dipole(
    length=torch.tensor([0.5, 0.5, 0.5]),
    angle=torch.tensor([0.1, 0.2, 0.1]),
)
...
assert outgoing.particles.shape == (3, num_particles, 7)  # passes

Correct?

@roussel-ryan
Copy link
Contributor Author

Yes, that is correct

@jank324
Copy link
Member

jank324 commented Jul 10, 2024

Yes, that is correct

Seems like a good way of doing it to me then. Are there any places this could cause problems? And is there a straight-forward way of implementing this relatively easily for every element? ... such that we can give as a general guideline when we are implementing new elements, you should do it like ... and then you can be sure this will follow those rules?

@roussel-ryan
Copy link
Contributor Author

roussel-ryan commented Jul 10, 2024

The main requirement is that you need to use pytorch operations which are broadcastable. For example, most of the changes that I needed to make (so far) is replacing the boolean indexing that you have been using for avoiding div0 errors with torch.where operations.
For example,

gamma = energy / electron_mass_eV.to(device=device, dtype=dtype)
igamma2 = torch.where(gamma == 0.0, 0.0, 1 / gamma**2)

I'll update this list as I go through the other elements

@roussel-ryan
Copy link
Contributor Author

Note that I added a utility function to calculate the relativistic factors that are used throughout different elements

@roussel-ryan roussel-ryan marked this pull request as ready for review July 15, 2024 14:44
@jank324 jank324 self-requested a review July 15, 2024 19:12
@jank324 jank324 marked this pull request as ready for review October 1, 2024 12:13
@jank324
Copy link
Member

jank324 commented Oct 1, 2024

I fixed all the issues I could just see looking at the code changes and did a bit of cleanup on the way. It turned out to be quite a bit more than I expected.

Either way, I think as it is now, we can merge this PR.

px = px + voltage * torch.sin(phase)
# TODO: Assigning px to px is really bad practice and should be separated into
# two separate variables
px = px + voltage.unsqueeze(-1) * torch.sin(phase)
Copy link
Member

@cr-xu cr-xu Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need .unsqueeze(-1) here (I saw this also in the quadrupole bmad tracking)? I'm probably just confused.

W.r.t. the variable naming, it can probably be px_new or px_out

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because px has a shape(vector, num_particles) and voltage has (vector, ) ... for PyTorch broadcasting the vector dimension of voltage basically has to be pushed to the left ... and that's what the unsqueeze does.

All the Bmad-X code needs refactoring, including that variable name ... I've had to keep myself from doing that in this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! Then we should merge this first and do the refactoring later.

tests/test_kde.py Outdated Show resolved Hide resolved
@jank324 jank324 merged commit 8a38b63 into desy-ml:master Oct 1, 2024
8 checks passed
@cr-xu cr-xu mentioned this pull request Oct 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants