Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JAX error with FourierCurrentPotentialField in Flux objectives #1002

Merged
merged 30 commits into from
Jul 5, 2024

Conversation

dpanici
Copy link
Collaborator

@dpanici dpanici commented Apr 19, 2024

On master, an optimization with a FourierCurrentPotentialField and the QuadraticFlux objective would fail. This happens because in QuadraticFlux.compute, field.compute_magnetic_field is called. If the field needs transforms to evaluate, then these transforms will be created on the fly if they are not provided, resulting in an error.

This PR fixes that by adding jitable=True to the .compute call inside .compute_magnetic_field

It also adds transform as an argument to MagneticField.compute_magnetic_field methods, in preparation for when #1079 is resolved and objectives can pre-compute and pass in transform objects for all magnetic fields easily.

This PR kicks the can down the road of making get_transforms work with MagneticField.compute_magnetic_field objects (to allow us to pre-compute the transforms used for magnetic field computation #1079 ) and instead opts for the simple fix.

Copy link
Contributor

github-actions bot commented Apr 19, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +7.06 +/- 8.20     | +3.58e-02 +/- 4.16e-02 |  5.43e-01 +/- 4.1e-02  |  5.07e-01 +/- 6.4e-03  |
 test_build_transform_fft_midres         |     +6.67 +/- 3.29     | +3.96e-02 +/- 1.95e-02 |  6.33e-01 +/- 1.7e-02  |  5.93e-01 +/- 9.9e-03  |
 test_build_transform_fft_highres        |     +3.78 +/- 3.51     | +3.73e-02 +/- 3.46e-02 |  1.02e+00 +/- 1.5e-02  |  9.87e-01 +/- 3.1e-02  |
 test_equilibrium_init_lowres            |     +1.75 +/- 2.66     | +6.43e-02 +/- 9.76e-02 |  3.73e+00 +/- 9.7e-02  |  3.67e+00 +/- 9.8e-03  |
 test_equilibrium_init_medres            |     +1.82 +/- 2.17     | +7.53e-02 +/- 8.98e-02 |  4.22e+00 +/- 8.9e-02  |  4.14e+00 +/- 1.0e-02  |
 test_equilibrium_init_highres           |     +1.21 +/- 1.88     | +6.71e-02 +/- 1.04e-01 |  5.60e+00 +/- 8.5e-02  |  5.54e+00 +/- 6.0e-02  |
 test_objective_compile_dshape_current   |     +0.64 +/- 5.96     | +2.43e-02 +/- 2.26e-01 |  3.81e+00 +/- 2.2e-01  |  3.79e+00 +/- 1.9e-02  |
 test_objective_compile_atf              |     +1.09 +/- 3.08     | +8.93e-02 +/- 2.54e-01 |  8.32e+00 +/- 1.5e-01  |  8.23e+00 +/- 2.0e-01  |
 test_objective_compute_dshape_current   |     -1.81 +/- 4.61     | -2.31e-05 +/- 5.87e-05 |  1.25e-03 +/- 2.9e-05  |  1.27e-03 +/- 5.1e-05  |
 test_objective_compute_atf              |     +0.42 +/- 6.38     | +1.76e-05 +/- 2.70e-04 |  4.25e-03 +/- 2.3e-04  |  4.23e-03 +/- 1.5e-04  |
 test_objective_jac_dshape_current       |     -1.67 +/- 11.51    | -6.06e-04 +/- 4.18e-03 |  3.57e-02 +/- 2.3e-03  |  3.64e-02 +/- 3.5e-03  |
 test_objective_jac_atf                  |     +3.58 +/- 2.49     | +6.67e-02 +/- 4.63e-02 |  1.93e+00 +/- 3.4e-02  |  1.86e+00 +/- 3.2e-02  |
 test_perturb_1                          |     +0.51 +/- 0.68     | +6.71e-02 +/- 8.89e-02 |  1.31e+01 +/- 7.4e-02  |  1.30e+01 +/- 4.9e-02  |
 test_perturb_2                          |     +0.74 +/- 1.23     | +1.33e-01 +/- 2.22e-01 |  1.81e+01 +/- 1.3e-01  |  1.80e+01 +/- 1.8e-01  |
 test_proximal_jac_atf                   |     -0.61 +/- 1.44     | -4.48e-02 +/- 1.05e-01 |  7.28e+00 +/- 6.0e-02  |  7.32e+00 +/- 8.7e-02  |
 test_proximal_freeb_compute             |     -1.67 +/- 0.74     | -3.00e-03 +/- 1.32e-03 |  1.76e-01 +/- 1.0e-03  |  1.79e-01 +/- 8.2e-04  |
 test_proximal_freeb_jac                 |     -0.12 +/- 1.09     | -8.66e-03 +/- 7.98e-02 |  7.31e+00 +/- 5.9e-02  |  7.32e+00 +/- 5.4e-02  |
 test_solve_fixed_iter                   |     -0.38 +/- 9.08     | -5.59e-02 +/- 1.34e+00 |  1.47e+01 +/- 9.9e-01  |  1.48e+01 +/- 9.0e-01  |

Copy link

codecov bot commented Apr 19, 2024

Codecov Report

Attention: Patch coverage is 78.94737% with 4 lines in your changes missing coverage. Please review.

Project coverage is 94.97%. Comparing base (0a6b995) to head (63e1cd9).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1002      +/-   ##
==========================================
- Coverage   94.98%   94.97%   -0.01%     
==========================================
  Files          87       87              
  Lines       21749    21762      +13     
==========================================
+ Hits        20658    20669      +11     
- Misses       1091     1093       +2     
Files Coverage Δ
desc/geometry/core.py 95.91% <100.00%> (ø)
desc/objectives/_coils.py 99.12% <ø> (ø)
desc/coils.py 96.88% <85.71%> (+0.02%) ⬆️
desc/magnetic_fields/_current_potential.py 98.83% <75.00%> (-0.58%) ⬇️
desc/magnetic_fields/_core.py 96.40% <71.42%> (-0.27%) ⬇️

... and 1 file with indirect coverage changes

Copy link
Member

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason coils worked while the current potential doesnt is that by default the coil classes create jitable transforms in their compute method, since they're only 1d its not really worth it to try to do the fft.

I'd recommend adding the new classes to the existing get_transforms logic, I think most of it should "just work" assuming attributes are named correctly. (might need some logic somewhere for tree-like coilsets etc, but do-able.

@@ -753,13 +753,28 @@ def build(self, use_jit=True, verbose=1):
Bplasma = compute_B_plasma(
eq, eval_grid, self._source_grid, normal_only=True
)
field = self._field
if hasattr(field, "Phi_mn"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to get this all to work with the existing desc.compute.utils.get_transforms helper? That would work for coils etc as well.

@dpanici dpanici marked this pull request as draft April 25, 2024 15:12
@dpanici
Copy link
Collaborator Author

dpanici commented Apr 30, 2024

I think the reason coils worked while the current potential doesnt is that by default the coil classes create jitable transforms in their compute method, since they're only 1d its not really worth it to try to do the fft.

I'd recommend adding the new classes to the existing get_transforms logic, I think most of it should "just work" assuming attributes are named correctly. (might need some logic somewhere for tree-like coilsets etc, but do-able.

The problem is less the classes and more that not every magnetic field class needs the same keys to calculate the magnetic field. Coils need "I" and "x", the current potential fields need "K", toroidal fields dont need any keys as they just have their own compute magnetic field function.

@dpanici
Copy link
Collaborator Author

dpanici commented May 1, 2024

add case for coils and 1D transforms (by calling get_transforms in build always asking for "x" (if is a curve class or CoilSet) and maybe also "K" if it is a current potential field.

Also add in the current potential field compute, a jitable=True (as if there is a current potential field in a SumMagneticField it might not work correctly...)

tests/test_optimizer.py Outdated Show resolved Hide resolved
tests/test_optimizer.py Outdated Show resolved Hide resolved
@dpanici
Copy link
Collaborator Author

dpanici commented Jun 10, 2024

@f0uriest I don't exactly remember the issue with the current method/how putting the logic in get_transforms would improve it.

We still would need to call get_transforms with the correct grid based off of what the underlying object is (which would still require a check in the build method to see if it is a coil or a field), and to ask for the correct keys ("x" and maybe "K" as well if it is a current potential field).

I think I also found in this PR a fix to something causing a JAX error in a few places when current potential fields are being optimized, so if possible I'd like to get this one in. Unless there is a better case for putting more logic into get_transforms, which I think should wait until we make "B" a data index quantity for MagneticField objects and then can go through that route to get all the proper transforms etc based off of the specific MagneticField parameterization

@dpanici dpanici marked this pull request as ready for review June 10, 2024 17:48
desc/coils.py Show resolved Hide resolved
desc/objectives/_coils.py Outdated Show resolved Hide resolved
desc/objectives/_coils.py Outdated Show resolved Hide resolved


@pytest.mark.unit
def test_tor_flux_with_surface_current_field():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these tests be one test with an inner test() function or is it better to keep them separate?

desc/coils.py Show resolved Hide resolved
@@ -644,7 +647,23 @@ def _compute_magnetic_field_from_CurrentPotentialField(
# compute surface current, and store grid quantities
# needed for integration in class
# TODO: does this have to be xyz, or can it be computed in rpz as well?
data = field.compute(["K", "x"], grid=source_grid, basis="xyz", params=params)
if not params and not transforms:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this conditional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid calling field.compute inside a compute function of an objective, as there is some logic there that jit does not like

@dpanici
Copy link
Collaborator Author

dpanici commented Jun 19, 2024

Equinox for runtime error checking in jitted functions

@dpanici
Copy link
Collaborator Author

dpanici commented Jun 19, 2024

@f0uriest check this again for placing the code in a util function

@dpanici dpanici requested review from f0uriest, ddudt and kianorr and removed request for f0uriest, ddudt and YigitElma July 2, 2024 14:36
@dpanici dpanici changed the title Add transform pre-computation for FourierCurrentPotentialField in Flux objectives Fix JAX error with FourierCurrentPotentialField in Flux objectives Jul 2, 2024
rahulgaur104
rahulgaur104 previously approved these changes Jul 2, 2024
@rahulgaur104
Copy link
Collaborator

Merge after increasing the coverage.

kianorr
kianorr previously approved these changes Jul 2, 2024
@dpanici dpanici dismissed stale reviews from kianorr and rahulgaur104 via 07d110b July 3, 2024 00:02
tests/test_optimizer.py Show resolved Hide resolved


@pytest.mark.unit
def test_quad_flux_with_surface_current_field():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to remove/combine this test with #1025 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, as that one will add coil tests while this one is using surface current fields. If one test on that branch involves a sum magnetic field with coils and a surface current field then we can remove this in favor of that one

@dpanici dpanici requested a review from daniel-dudt July 4, 2024 15:33
@@ -1272,7 +1272,10 @@ def compute(self, field_params, constants=None):

# B_ext is not pre-computed because field is not fixed
B_ext = constants["field"].compute_magnetic_field(
x, source_grid=constants["field_grid"], basis="rpz", params=field_params
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No real change to this file, right? You just like to make the code longer to bother me?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I blame my IDE

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but yea no change here just formatting

kianorr added a commit that referenced this pull request Jul 4, 2024
- doesn't pass due to jax bug, maybe PR #1002 will fix it?
@kianorr kianorr mentioned this pull request Jul 4, 2024
@dpanici dpanici merged commit f021ce2 into master Jul 5, 2024
17 of 18 checks passed
@dpanici dpanici deleted the dp/quadratic-flux-transform branch July 5, 2024 13:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants