Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HermiteSplineProfile to improve accuracy of coordinate mapping #1199

Merged
merged 12 commits into from
Aug 20, 2024
2 changes: 1 addition & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def get_data_deps(keys, obj, has_axis=False, basis="rpz", data=None):

Returns
-------
deps : list of str
deps : list[str]
Names of quantities needed to compute key.

"""
Expand Down
17 changes: 11 additions & 6 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,19 @@

# do surface average to get iota once
if "iota" in profiles and profiles["iota"] is None:
profiles["iota"] = eq.get_profile("iota", params=params)
profiles["iota"] = eq.get_profile(["iota", "iota_r"], params=params)
params["i_l"] = profiles["iota"].params

@functools.partial(jit, static_argnums=1)
def compute(y, basis):
grid = Grid(y, sort=False, jitable=True)
data = {}
if "iota" in deps:
data["iota"] = profiles["iota"](grid, params=params["i_l"])
data["iota"] = profiles["iota"].compute(grid, params=params["i_l"])
unalmis marked this conversation as resolved.
Show resolved Hide resolved
if "iota_r" in deps:
data["iota_r"] = profiles["iota"](grid, dr=1, params=params["i_l"])
data["iota_r"] = profiles["iota"].compute(grid, dr=1, params=params["i_l"])
if "iota_rr" in deps:
data["iota_rr"] = profiles["iota"](grid, dr=2, params=params["i_l"])
data["iota_rr"] = profiles["iota"].compute(grid, dr=2, params=params["i_l"])

Check warning on line 153 in desc/equilibrium/coords.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/coords.py#L153

Added line #L153 was not covered by tests
transforms = get_transforms(basis, eq, grid, jitable=True)
data = compute_fun(eq, basis, params, transforms, profiles, data)
x = jnp.array([data[k] for k in basis]).T
Expand Down Expand Up @@ -243,7 +243,10 @@
theta = coords[:, inbasis.index(poloidal)]
elif poloidal == "alpha":
alpha = coords[:, inbasis.index("alpha")]
iota = profiles["iota"](rho)
rho = jnp.atleast_1d(rho)
zero = jnp.zeros_like(rho)
grid = Grid(nodes=jnp.column_stack([rho, zero, zero]), sort=False, jitable=True)
iota = profiles["iota"].compute(grid)
theta = (alpha + iota * zeta) % (2 * jnp.pi)

yk = jnp.column_stack([rho, theta, zeta])
Expand Down Expand Up @@ -677,7 +680,7 @@
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for each quantity in inbasis.
Use np.inf to denote no periodicity.
Use ``np.inf`` to denote no periodicity.
jitable : bool, optional
If false the returned grid has additional attributes.
Required to be false to retain nodes at magnetic axis.
Expand All @@ -691,6 +694,8 @@
grid = Grid.create_meshgrid(
[radial, poloidal, toroidal], coordinates=coordinates, period=period
)
if "iota" in kwargs:
kwargs["iota"] = grid.expand(kwargs["iota"])

Check warning on line 698 in desc/equilibrium/coords.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/coords.py#L697-L698

Added lines #L697 - L698 were not covered by tests
inbasis = {
"r": "rho",
"t": "theta",
Expand Down
21 changes: 12 additions & 9 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from desc.optimizable import Optimizable, optimizable_parameter
from desc.optimize import Optimizer
from desc.perturbations import perturb
from desc.profiles import PowerSeriesProfile, SplineProfile
from desc.profiles import HermiteSplineProfile, PowerSeriesProfile, SplineProfile
from desc.transform import Transform
from desc.utils import (
ResolutionWarning,
Expand Down Expand Up @@ -732,6 +732,8 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs):
----------
name : str
Name of the quantity to compute.
If list is given, then two names are expected: the quantity to spline
and its radial derivative.
grid : Grid, optional
Grid of coordinates to evaluate at. Defaults to the quadrature grid.
Note profile will only be a function of the radial coordinate.
Expand All @@ -748,14 +750,17 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs):
if grid is None:
grid = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP)
data = self.compute(name, grid=grid, **kwargs)
f = data[name]
f = grid.compress(f, surface_label="rho")
x = grid.nodes[grid.unique_rho_idx, 0]
p = SplineProfile(f, x, name=name)
knots = grid.compress(grid.nodes[:, 0])
if isinstance(name, str):
f = grid.compress(data[name])
p = SplineProfile(f, knots, name=name)
else:
f, df = map(grid.compress, (data[name[0]], data[name[1]]))
p = HermiteSplineProfile(f, df, knots, name=name)
if kind == "power_series":
p = p.to_powerseries(order=min(self.L, len(x)), xs=x, sym=True)
p = p.to_powerseries(order=min(self.L, grid.num_rho), xs=knots, sym=True)
if kind == "fourier_zernike":
p = p.to_fourierzernike(L=min(self.L, len(x)), xs=x)
p = p.to_fourierzernike(L=min(self.L, grid.num_rho), xs=knots)
return p

def get_axis(self):
Expand Down Expand Up @@ -1161,8 +1166,6 @@ def map_coordinates(

Parameters
----------
eq : Equilibrium
Equilibrium to use.
coords : ndarray
Shape (k, 3).
2D array of input coordinates. Each row is a different point in space.
Expand Down
159 changes: 129 additions & 30 deletions desc/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
copy_coeffs,
errorif,
multinomial_coefficients,
setdefault,
warnif,
)

Expand Down Expand Up @@ -613,7 +614,7 @@

def set_params(self, l, a=None):
"""Set specific power series coefficients."""
l, a = np.atleast_1d(l), np.atleast_1d(a)
l, a = np.atleast_1d(l, a)
a = np.broadcast_to(a, l.shape)
for ll, aa in zip(l, a):
idx = self.basis.get_idx(ll, 0, 0)
Expand Down Expand Up @@ -793,24 +794,25 @@


class SplineProfile(_Profile):
"""Profile represented by a piecewise cubic spline.
"""Radial profile represented by a piecewise cubic spline.

Parameters
----------
values: array-like
Values of the function at knot locations.
knots : int or ndarray
x locations to use for spline. If an integer, uses that many points linearly
spaced between 0,1
1-D array containing values of the dependent variable.
knots : array-like
1-D array containing values of the independent variable.
Must be real, finite, and in strictly increasing order in [0, 1].
If ``None``, assumes ``values`` is given on knots uniformly spaced in [0, 1].
method : str
method of interpolation
Method of interpolation. Default is cubic2.
- `'nearest'`: nearest neighbor interpolation
- `'linear'`: linear interpolation
- `'cubic'`: C1 cubic splines (aka local splines)
- `'cubic2'`: C2 cubic splines (aka natural splines)
- `'catmull-rom'`: C1 cubic centripetal "tension" splines
name : str
name of the profile
Optional name of the profile.

"""

Expand All @@ -821,11 +823,12 @@

if values is None:
values = [0, 0, 0]
values = np.atleast_1d(values)
values = jnp.atleast_1d(values)
if knots is None:
knots = np.linspace(0, 1, values.size)
else:
knots = np.atleast_1d(knots)
knots = jnp.linspace(0, 1, values.size)
knots = jnp.atleast_1d(knots)
errorif(values.shape[-1] != knots.shape[-1])
errorif(not (values.ndim == knots.ndim == 1), NotImplementedError)
self._knots = knots
self._params = values
self._method = method
Expand All @@ -834,7 +837,7 @@
"""Get the string form of the object."""
s = super().__repr__()
s = s[:-1]
s += ", method={}, num_knots={})".format(self._method, len(self._knots))
s += ", method={}, num_knots={})".format(self._method, self._knots.size)
return s

@property
Expand All @@ -849,24 +852,23 @@

@params.setter
def params(self, new):
if len(new) == len(self._knots):
self._params = jnp.asarray(new)
else:
raise ValueError(
"params should have the same size as the knots, "
+ f"got {len(new)} values for {len(self._knots)} knots"
)
errorif(
len(new) != self._knots.size,
msg="params should have the same size as the knots, "
+ f"got {len(new)} values for {self._knots.size} knots",
)
self._params = jnp.asarray(new)

def compute(self, grid, params=None, dr=0, dt=0, dz=0):
"""Compute values of profile at specified nodes.

Parameters
----------
grid : Grid
locations to compute values at.
Locations to compute values at.
params : array-like
spline values to use. If not given, uses the
values given by the params attribute
Values of the function at ``self.knots``.
If not given, uses ``self.params``.
dr, dt, dz : int
derivative order in rho, theta, zeta

Expand All @@ -876,15 +878,112 @@
values of the profile or its derivative at the points specified

"""
if params is None:
params = self.params
if dt != 0 or dz != 0:
return jnp.zeros_like(grid.nodes[:, 0])
x = self.knots
f = params
xq = grid.nodes[:, 0]
fq = interp1d(xq, x, f, method=self._method, derivative=dr, extrap=True)
return fq
params = setdefault(params, self._params)
return interp1d(
xq=grid.nodes[:, 0],
x=self._knots,
f=params,
method=self._method,
derivative=dr,
extrap=True,
)


class HermiteSplineProfile(_Profile):
"""Radial profile represented by a piecewise cubic Hermite spline.

Parameters
----------
f: array-like
1-D array containing values of the dependent variable.
df: array-like
1-D array containing derivatives of the dependent variable.
knots : array-like
1-D array containing values of the independent variable.
Must be real, finite, and in strictly increasing order in [0, 1].
If ``None``, assumes ``f`` and ``df`` are given on knots uniformly
spaced in [0, 1].
name : str
Optional name of the profile.

"""

_io_attrs_ = _Profile._io_attrs_ + ["_knots", "_params"]

def __init__(self, f, df, knots=None, name=""):
super().__init__(name)

f, df = jnp.atleast_1d(f, df)
if knots is None:
knots = jnp.linspace(0, 1, f.size)

Check warning on line 920 in desc/profiles.py

View check run for this annotation

Codecov / codecov/patch

desc/profiles.py#L920

Added line #L920 was not covered by tests
knots = jnp.atleast_1d(knots)
errorif(not (f.shape[-1] == df.shape[-1] == knots.shape[-1]))
errorif(not (f.ndim == df.ndim == knots.ndim == 1), NotImplementedError)
self._knots = knots
self._params = jnp.concatenate([f, df])

def __repr__(self):
"""Get the string form of the object."""
s = super().__repr__()
s = s[:-1]
s += ", num_knots={})".format(self._knots.size)
return s

Check warning on line 932 in desc/profiles.py

View check run for this annotation

Codecov / codecov/patch

desc/profiles.py#L929-L932

Added lines #L929 - L932 were not covered by tests

@property
def knots(self):
"""ndarray: Knot locations."""
return self._knots

Check warning on line 937 in desc/profiles.py

View check run for this annotation

Codecov / codecov/patch

desc/profiles.py#L937

Added line #L937 was not covered by tests

@property
def params(self):
"""ndarray: Parameters for computation.

First (second) half stores function (derivative) values at ``knots``.
"""
return self._params

@params.setter
def params(self, new):
new = jnp.asarray(new)
errorif(
new.ndim != 1 or new.size != 2 * self._knots.size,
msg="Params should be 1D with size twice number of knots. "
f"Got {new.shape} params for {self._knots.size} knots.",
)
self._params = new

def compute(self, grid, params=None, dr=0, dt=0, dz=0):
"""Compute values of profile at specified nodes.

Parameters
----------
grid : Grid
Locations to compute values at.
params : array-like
First (second) half stores function (derivative) values at ``knots``.
If not given, uses ``self.params``.
dr, dt, dz : int
derivative order in rho, theta, zeta

Returns
-------
f : ndarray
Array containing values of the dependent variable at the points specified.

"""
if dt != 0 or dz != 0:
return jnp.zeros_like(grid.nodes[:, 0])

Check warning on line 977 in desc/profiles.py

View check run for this annotation

Codecov / codecov/patch

desc/profiles.py#L977

Added line #L977 was not covered by tests
params = setdefault(params, self._params)
return interp1d(
xq=grid.nodes[:, 0],
x=self._knots,
f=params[: self._knots.size],
fx=params[self._knots.size :],
derivative=dr,
extrap=True,
)


class MTanhProfile(_Profile):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from scipy.interpolate import interp1d

from desc.equilibrium import Equilibrium
from desc.examples import get
from desc.grid import LinearGrid
from desc.io import InputReader
from desc.objectives import (
Expand All @@ -15,6 +16,7 @@
)
from desc.profiles import (
FourierZernikeProfile,
HermiteSplineProfile,
MTanhProfile,
PowerSeriesProfile,
SplineProfile,
Expand Down Expand Up @@ -507,3 +509,14 @@ def test_kinetic_pressure(self):
assert np.all(data2["Te_r"] == data2["Ti_r"])
np.testing.assert_allclose(data1["p"], data2["p"])
np.testing.assert_allclose(data1["p_r"], data2["p_r"])

@pytest.mark.unit
def test_hermite_spline_solve(self):
"""Test that spline with double number of parameters is optimized."""
eq = get("DSHAPE")
rho = np.linspace(0, 1.0, 20, endpoint=True)
eq.pressure = HermiteSplineProfile(
eq.pressure(rho), eq.pressure(rho, dr=1), rho
)
eq.solve()
assert eq.is_nested()
Loading