Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot elongation #807

Merged
merged 12 commits into from
Jan 17, 2024
1 change: 0 additions & 1 deletion desc/compute/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ def _R0_over_a(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="z",
data=["rho", "sqrt(g)", "g_tt", "R"],
grid_type="quad",
)
def _a_major_over_a_minor(params, transforms, profiles, data, **kwargs):
max_rho = jnp.max(data["rho"])
Expand Down
77 changes: 55 additions & 22 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,30 +816,32 @@
dep0d = [
dep
for dep in deps
if (
(data_index[p][dep]["coordinates"] == "")
or (data_index[p][dep]["grid_type"] == "quad")
)
and (dep not in data)
if (data_index[p][dep]["coordinates"] == "") and (dep not in data)
]
dep1d = [
dep1dr = [
dep
for dep in deps
if (data_index[p][dep]["coordinates"] == "r") and (dep not in data)
]
dep1dz = [
dep
for dep in deps
if (data_index[p][dep]["coordinates"] == "z") and (dep not in data)
]

# whether we need to calculate 0d or 1d quantities on a special grid
calc0d = bool(len(dep0d))
calc1d = bool(len(dep1d))
calc1dr = bool(len(dep1dr))
calc1dz = bool(len(dep1dz))
if ( # see if the grid we're already using will work for desired qtys
(grid.L >= self.L_grid)
and (grid.M >= self.M_grid)
and (grid.N >= self.N_grid)
):
if isinstance(grid, QuadratureGrid):
calc0d = calc1d = False
calc0d = calc1dr = calc1dz = False
if isinstance(grid, LinearGrid):
calc1d = False
calc1dr = calc1dz = False

if calc0d and override_grid:
grid0d = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP)
Expand All @@ -856,32 +858,63 @@
data0d = {key: val for key, val in data0d.items() if key in dep0d}
data.update(data0d)

if calc1d and override_grid:
grid1d = LinearGrid(
if calc1dr and override_grid:
grid1dr = LinearGrid(
rho=grid.nodes[grid.unique_rho_idx, 0],
M=self.M_grid,
N=self.N_grid,
NFP=self.NFP,
sym=self.sym,
)
# Todo: Pass in data0d as a seed once there are 1d quantities that
# depend on 0d quantities in data_index.
data1d = compute_fun(
# TODO: Pass in data0d as a seed once there are 1d quantities that
# depend on 0d quantities in data_index.
data1dr = compute_fun(
self,
dep1d,
dep1dr,
params=params,
transforms=get_transforms(dep1d, obj=self, grid=grid1d, **kwargs),
profiles=get_profiles(dep1d, obj=self, grid=grid1d),
transforms=get_transforms(dep1dr, obj=self, grid=grid1dr, **kwargs),
profiles=get_profiles(dep1dr, obj=self, grid=grid1dr),
data=None,
**kwargs,
)
# need to make this data broadcast with the data on the original grid
data1d = {
key: grid.expand(grid1d.compress(val))
for key, val in data1d.items()
if key in dep1d
data1dr = {
key: grid.expand(
grid1dr.compress(val, surface_label="rho"), surface_label="rho"
)
for key, val in data1dr.items()
if key in dep1dr
}
data.update(data1dr)

if calc1dz and override_grid:
grid1dz = LinearGrid(

Check warning on line 891 in desc/equilibrium/equilibrium.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/equilibrium.py#L891

Added line #L891 was not covered by tests
zeta=grid.nodes[grid.unique_zeta_idx, 2],
L=self.L_grid,
M=self.M_grid,
NFP=grid.NFP, # ex: self.NFP>1 but grid.NFP=1 for plot_3d
sym=self.sym,
)
# TODO: Pass in data0d as a seed once there are 1d quantities that
# depend on 0d quantities in data_index.
data1dz = compute_fun(

Check warning on line 900 in desc/equilibrium/equilibrium.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/equilibrium.py#L900

Added line #L900 was not covered by tests
self,
dep1dz,
params=params,
transforms=get_transforms(dep1dz, obj=self, grid=grid1dz, **kwargs),
profiles=get_profiles(dep1dz, obj=self, grid=grid1dz),
data=None,
**kwargs,
)
# need to make this data broadcast with the data on the original grid
data1dz = {

Check warning on line 910 in desc/equilibrium/equilibrium.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/equilibrium.py#L910

Added line #L910 was not covered by tests
key: grid.expand(
grid1dz.compress(val, surface_label="zeta"), surface_label="zeta"
)
for key, val in data1dz.items()
if key in dep1dz
}
data.update(data1d)
data.update(data1dz)

Check warning on line 917 in desc/equilibrium/equilibrium.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/equilibrium.py#L917

Added line #L917 was not covered by tests

# TODO: we can probably reduce the number of deps computed here if some are only
# needed as inputs for 0d and 1d qtys, unless the user asks for them
Expand Down
Binary file modified tests/baseline/test_1d_dpdr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_1d_elongation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_1d_iota.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_1d_iota_radial.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_1d_logpsi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_1d_p.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/inputs/master_compute_data.pkl
Binary file not shown.
3 changes: 1 addition & 2 deletions tests/test_compute_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ def test_elongation():
eq1 = Equilibrium() # elongation = 1
eq2 = Equilibrium(surface=surf2) # elongation = 2
eq3 = Equilibrium(surface=surf3) # elongation = 3
rho = np.linspace(0, 1, 128)
grid = LinearGrid(M=eq3.M_grid, N=eq3.N_grid, NFP=eq3.NFP, sym=eq3.sym, rho=rho)
grid = LinearGrid(L=5, M=2 * eq3.M_grid, N=eq3.N_grid, NFP=eq3.NFP, sym=eq3.sym)
data1 = eq1.compute(["a_major/a_minor"], grid=grid)
data2 = eq2.compute(["a_major/a_minor"], grid=grid)
data3 = eq3.compute(["a_major/a_minor"], grid=grid)
Expand Down
37 changes: 23 additions & 14 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,28 @@ def test_kwarg_future_warning(DummyStellarator):


@pytest.mark.unit
@pytest.mark.solve
@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d)
def test_1d_p(SOLOVEV):
def test_1d_p():
"""Test plotting 1d pressure profile."""
eq = load(load_from=str(SOLOVEV["desc_h5_path"]))[-1]
eq = get("SOLOVEV")
fig, ax, data = plot_1d(eq, "p", figsize=(4, 4), return_data=True)
assert "p" in data.keys()
return fig


@pytest.mark.unit
@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d)
def test_1d_elongation():
f0uriest marked this conversation as resolved.
Show resolved Hide resolved
"""Test plotting 1d elongation as a function of toroidal angle."""
eq = get("precise_QA")
grid = LinearGrid(N=20, NFP=eq.NFP)
fig, ax, data = plot_1d(
eq, "a_major/a_minor", grid=grid, figsize=(4, 4), return_data=True
)
assert "a_major/a_minor" in data.keys()
return fig


@pytest.mark.unit
def test_1d_fsa_consistency():
"""Test that plot_1d uses 2d grid to compute quantities with surface averages."""
Expand All @@ -101,11 +113,10 @@ def test(name, with_sqrt_g=True, grid=None):


@pytest.mark.unit
@pytest.mark.solve
@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d)
def test_1d_dpdr(DSHAPE_current):
def test_1d_dpdr():
"""Test plotting 1d pressure derivative."""
eq = load(load_from=str(DSHAPE_current["desc_h5_path"]))[-1]
eq = get("DSHAPE_current")
fig, ax, data = plot_1d(eq, "p_r", figsize=(4, 4), return_data=True)
assert "p_r" in data.keys()
return fig
Expand All @@ -114,9 +125,9 @@ def test_1d_dpdr(DSHAPE_current):
@pytest.mark.unit
@pytest.mark.solve
@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d)
def test_1d_iota(DSHAPE_current):
def test_1d_iota():
"""Test plotting 1d rotational transform."""
eq = load(load_from=str(DSHAPE_current["desc_h5_path"]))[-1]
eq = get("DSHAPE_current")
grid = LinearGrid(rho=0.5, theta=100, zeta=0.0)
fig, ax, data = plot_1d(eq, "iota", grid=grid, figsize=(4, 4), return_data=True)
assert "theta" in data.keys()
Expand All @@ -125,23 +136,21 @@ def test_1d_iota(DSHAPE_current):


@pytest.mark.unit
@pytest.mark.solve
@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d)
def test_1d_iota_radial(DSHAPE_current):
def test_1d_iota_radial():
"""Test plotting 1d rotational transform."""
eq = load(load_from=str(DSHAPE_current["desc_h5_path"]))[-1]
eq = get("DSHAPE_current")
fig, ax, data = plot_1d(eq, "iota", figsize=(4, 4), return_data=True)
assert "rho" in data.keys()
assert "iota" in data.keys()
return fig


@pytest.mark.unit
@pytest.mark.solve
@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d)
def test_1d_logpsi(DSHAPE_current):
def test_1d_logpsi():
"""Test plotting 1d flux function with log scale."""
eq = load(load_from=str(DSHAPE_current["desc_h5_path"]))[-1]
eq = get("DSHAPE_current")
fig, ax, data = plot_1d(eq, "psi", log=True, figsize=(4, 4), return_data=True)
ax.set_ylim([1e-5, 1e0])
assert "rho" in data.keys()
Expand Down