diff --git a/desc/backend.py b/desc/backend.py index c26213b045..5538c79a8c 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -66,19 +66,23 @@ ) if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign? - jit = jax.jit - fori_loop = jax.lax.fori_loop - cond = jax.lax.cond - switch = jax.lax.switch - while_loop = jax.lax.while_loop - vmap = jax.vmap - bincount = jnp.bincount - repeat = jnp.repeat - take = jnp.take - scan = jax.lax.scan - from jax import custom_jvp + from jax import custom_jvp, jit, vmap + + imap = jax.lax.map from jax.experimental.ode import odeint - from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular + from jax.lax import cond, fori_loop, scan, switch, while_loop + from jax.nn import softmax as softargmax + from jax.numpy import bincount, flatnonzero, repeat, take + from jax.numpy.fft import irfft, rfft, rfft2 + from jax.scipy.fft import dct, idct + from jax.scipy.linalg import ( + block_diag, + cho_factor, + cho_solve, + eigh_tridiagonal, + qr, + solve_triangular, + ) from jax.scipy.special import gammaln, logsumexp from jax.tree_util import ( register_pytree_node, @@ -90,6 +94,10 @@ treedef_is_leaf, ) + trapezoid = ( + jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid + ) + def put(arr, inds, vals): """Functional interface for array "fancy indexing". @@ -328,6 +336,8 @@ def root( This routine may be used on over or under-determined systems, in which case it will solve it in a least squares / least norm sense. """ + from desc.compute.utils import safenorm + if fixup is None: fixup = lambda x, *args: x if jac is None: @@ -392,7 +402,7 @@ def tangent_solve(g, y): x, (res, niter) = jax.lax.custom_root( res, x0, solve, tangent_solve, has_aux=True ) - return x, (jnp.linalg.norm(res), niter) + return x, (safenorm(res), niter) # we can't really test the numpy backend stuff in automated testing, so we ignore it @@ -401,15 +411,54 @@ def tangent_solve(g, y): jit = lambda func, *args, **kwargs: func execute_on_cpu = lambda func: func import scipy.optimize + from numpy.fft import irfft, rfft, rfft2 # noqa: F401 + from scipy.fft import dct, idct # noqa: F401 from scipy.integrate import odeint # noqa: F401 from scipy.linalg import ( # noqa: F401 block_diag, cho_factor, cho_solve, + eigh_tridiagonal, qr, solve_triangular, ) from scipy.special import gammaln, logsumexp # noqa: F401 + from scipy.special import softmax as softargmax # noqa: F401 + + trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz + + def imap(f, xs, batch_size=None, in_axes=0, out_axes=0): + """Generalizes jax.lax.map; uses numpy.""" + if not isinstance(xs, np.ndarray): + raise NotImplementedError( + "Require numpy array input, or install jax to support pytrees." + ) + xs = np.moveaxis(xs, source=in_axes, destination=0) + return np.stack([f(x) for x in xs], axis=out_axes) + + def vmap(fun, in_axes=0, out_axes=0): + """A numpy implementation of jax.lax.map whose API is a subset of jax.vmap. + + Like Python's builtin map, + except inputs and outputs are in the form of stacked arrays, + and the returned object is a vectorized version of the input function. + + Parameters + ---------- + fun: callable + Function (A -> B) + in_axes: int + Axis to map over. + out_axes: int + An integer indicating where the mapped axis should appear in the output. + + Returns + ------- + fun_vmap: callable + Vectorized version of fun. + + """ + return lambda xs: imap(fun, xs, in_axes=in_axes, out_axes=out_axes) def tree_stack(*args, **kwargs): """Stack pytree for numpy backend.""" @@ -592,32 +641,6 @@ def while_loop(cond_fun, body_fun, init_val): val = body_fun(val) return val - def vmap(fun, out_axes=0): - """A numpy implementation of jax.lax.map whose API is a subset of jax.vmap. - - Like Python's builtin map, - except inputs and outputs are in the form of stacked arrays, - and the returned object is a vectorized version of the input function. - - Parameters - ---------- - fun: callable - Function (A -> B) - out_axes: int - An integer indicating where the mapped axis should appear in the output. - - Returns - ------- - fun_vmap: callable - Vectorized version of fun. - - """ - - def fun_vmap(fun_inputs): - return np.stack([fun(fun_input) for fun_input in fun_inputs], axis=out_axes) - - return fun_vmap - def scan(f, init, xs, length=None, reverse=False, unroll=1): """Scan a function over leading array axes while carrying along state. @@ -657,9 +680,14 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1): ys.append(y) return carry, np.stack(ys) - def bincount(x, weights=None, minlength=None, length=None): - """Same as np.bincount but with a dummy parameter to match jnp.bincount API.""" - return np.bincount(x, weights, minlength) + def bincount(x, weights=None, minlength=0, length=None): + """A numpy implementation of jnp.bincount.""" + x = np.clip(x, 0, None) + if length is None: + length = max(minlength, x.max() + 1) + else: + minlength = max(minlength, length) + return np.bincount(x, weights, minlength)[:length] def repeat(a, repeats, axis=None, total_repeat_length=None): """A numpy implementation of jnp.repeat.""" @@ -778,6 +806,13 @@ def root( out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol) return out.x, out + def flatnonzero(a, size=None, fill_value=0): + """A numpy implementation of jnp.flatnonzero.""" + nz = np.flatnonzero(a) + if size is not None: + nz = np.pad(nz, (0, max(size - nz.size, 0)), constant_values=fill_value) + return nz + def take( a, indices, diff --git a/desc/compute/_bootstrap.py b/desc/compute/_bootstrap.py index 48af83b4e5..2329682c06 100644 --- a/desc/compute/_bootstrap.py +++ b/desc/compute/_bootstrap.py @@ -13,7 +13,7 @@ from scipy.special import roots_legendre from ..backend import fori_loop, jnp -from ..integrals import surface_averages_map +from ..integrals.surface_integral import surface_averages_map from .data_index import register_compute_fun diff --git a/desc/compute/_equil.py b/desc/compute/_equil.py index 93b2c5232b..4e8a10413d 100644 --- a/desc/compute/_equil.py +++ b/desc/compute/_equil.py @@ -14,7 +14,7 @@ from desc.backend import jnp -from ..integrals import surface_averages +from ..integrals.surface_integral import surface_averages from .data_index import register_compute_fun from .utils import cross, dot, safediv, safenorm diff --git a/desc/compute/_field.py b/desc/compute/_field.py index 37732b1cdf..a5728d17ef 100644 --- a/desc/compute/_field.py +++ b/desc/compute/_field.py @@ -13,7 +13,7 @@ from desc.backend import jnp -from ..integrals import ( +from ..integrals.surface_integral import ( surface_averages, surface_integrals_map, surface_max, diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index 536bd05bb7..ceb6703386 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -13,7 +13,7 @@ from desc.backend import jnp -from ..integrals import surface_averages +from ..integrals.surface_integral import surface_averages from .data_index import register_compute_fun from .utils import cross, dot, safediv, safenorm diff --git a/desc/compute/_profiles.py b/desc/compute/_profiles.py index 940a463951..81604c9868 100644 --- a/desc/compute/_profiles.py +++ b/desc/compute/_profiles.py @@ -13,7 +13,7 @@ from desc.backend import cond, jnp -from ..integrals import surface_averages, surface_integrals +from ..integrals.surface_integral import surface_averages, surface_integrals from .data_index import register_compute_fun from .utils import cumtrapz, dot, safediv diff --git a/desc/compute/_stability.py b/desc/compute/_stability.py index 4a985a4dc5..3b820f83b0 100644 --- a/desc/compute/_stability.py +++ b/desc/compute/_stability.py @@ -13,7 +13,7 @@ from desc.backend import jnp -from ..integrals import surface_integrals_map +from ..integrals.surface_integral import surface_integrals_map from .data_index import register_compute_fun from .utils import dot diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index bb9b5b8be9..c7b51b24ab 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -685,11 +685,14 @@ def get_rtz_grid( rvp : rho, theta_PEST, phi rtz : rho, theta, zeta period : tuple of float - Assumed periodicity for each quantity in inbasis. + Assumed periodicity for functions of the given coordinates. Use ``np.inf`` to denote no periodicity. jitable : bool, optional If false the returned grid has additional attributes. Required to be false to retain nodes at magnetic axis. + kwargs + Additional parameters to supply to the coordinate mapping function. + See ``desc.equilibrium.coords.map_coordinates``. Returns ------- @@ -701,7 +704,7 @@ def get_rtz_grid( [radial, poloidal, toroidal], coordinates=coordinates, period=period ) if "iota" in kwargs: - kwargs["iota"] = grid.expand(kwargs["iota"]) + kwargs["iota"] = grid.expand(jnp.atleast_1d(kwargs["iota"])) inbasis = { "r": "rho", "t": "theta", diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 8d09d5f64b..a13164dbe6 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -1255,7 +1255,11 @@ def compute_theta_coords( point. Only returned if ``full_output`` is True. """ - warnif(True, DeprecationWarning, msg="Use map_coordinates instead.") + warnif( + True, + DeprecationWarning, + "Use map_coordinates instead of compute_theta_coords.", + ) return map_coordinates( self, flux_coords, diff --git a/desc/grid.py b/desc/grid.py index 4f318afcaf..6a8ab78fe3 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -619,6 +619,7 @@ def meshgrid_reshape(self, x, order): ------- x : ndarray Data reshaped to align with grid nodes. + """ errorif( not self.is_meshgrid, @@ -637,7 +638,8 @@ def meshgrid_reshape(self, x, order): vec = True shape += (-1,) x = x.reshape(shape, order="F") - x = jnp.moveaxis(x, 1, 0) # now shape rtz/raz etc + # swap to change shape from trz/arz to rtz/raz etc. + x = jnp.swapaxes(x, 1, 0) newax = tuple(self.coordinates.index(c) for c in order) if vec: newax += (3,) @@ -788,10 +790,11 @@ def create_meshgrid( rtz : rho, theta, zeta period : tuple of float Assumed periodicity for each coordinate. - Use np.inf to denote no periodicity. + Use ``np.inf`` to denote no periodicity. NFP : int Number of field periods (Default = 1). - Only makes sense to change from 1 if ``period[2]==2π``. + Only makes sense to change from 1 if last coordinate is periodic + with some constant divided by ``NFP``. Returns ------- @@ -1885,8 +1888,13 @@ def _periodic_spacing(x, period=2 * jnp.pi, sort=False, jnp=jnp): x = jnp.sort(x, axis=0) # choose dx to be half the distance between its neighbors if x.size > 1: - dx_0 = x[1] + (period - x[-1]) % period - dx_1 = x[0] + (period - x[-2]) % period + if np.isfinite(period): + dx_0 = x[1] + (period - x[-1]) % period + dx_1 = x[0] + (period - x[-2]) % period + else: + # just set to 0 to stop nan gradient, even though above gives expected value + dx_0 = 0 + dx_1 = 0 if x.size == 2: # then dx[0] == period and dx[-1] == 0, so fix this dx_1 = dx_0 diff --git a/desc/integrals/__init__.py b/desc/integrals/__init__.py index f223e39606..88cc3001ca 100644 --- a/desc/integrals/__init__.py +++ b/desc/integrals/__init__.py @@ -1,5 +1,6 @@ """Classes for function integration.""" +from .bounce_integral import Bounce1D from .singularities import ( DFTInterpolator, FFTInterpolator, diff --git a/desc/integrals/basis.py b/desc/integrals/basis.py new file mode 100644 index 0000000000..91a31edf60 --- /dev/null +++ b/desc/integrals/basis.py @@ -0,0 +1,109 @@ +"""Fast transformable basis.""" + +from functools import partial + +from desc.backend import flatnonzero, jnp, put +from desc.utils import setdefault + + +@partial(jnp.vectorize, signature="(m),(m)->(m)") +def _in_epigraph_and(is_intersect, df_dy_sign, /): + """Set and epigraph of function f with the given set of points. + + Used to return only intersects where the straight line path between + adjacent intersects resides in the epigraph of a continuous map ``f``. + + Parameters + ---------- + is_intersect : jnp.ndarray + Boolean array indicating whether index corresponds to an intersect. + df_dy_sign : jnp.ndarray + Shape ``is_intersect.shape``. + Sign of ∂f/∂y (yᵢ) for f(yᵢ) = 0. + + Returns + ------- + is_intersect : jnp.ndarray + Boolean array indicating whether element is an intersect + and satisfies the stated condition. + + Examples + -------- + See ``desc/integrals/bounce_utils.py::bounce_points``. + This is used there to ensure the domains of integration are magnetic wells. + + """ + # The pairs ``y1`` and ``y2`` are boundaries of an integral only if ``y1 <= y2``. + # For the integrals to be over wells, it is required that the first intersect + # has a non-positive derivative. Now, by continuity, + # ``df_dy_sign[...,k]<=0`` implies ``df_dy_sign[...,k+1]>=0``, + # so there can be at most one inversion, and if it exists, the inversion + # must be at the first pair. To correct the inversion, it suffices to disqualify the + # first intersect as a right boundary, except under an edge case of a series of + # inflection points. + idx = flatnonzero(is_intersect, size=2, fill_value=-1) + edge_case = ( + (df_dy_sign[idx[0]] == 0) + & (df_dy_sign[idx[1]] < 0) + & is_intersect[idx[0]] + & is_intersect[idx[1]] + # In theory, we need to keep propagating this edge case, e.g. + # (df_dy_sign[..., 1] < 0) | ( + # (df_dy_sign[..., 1] == 0) & (df_dy_sign[..., 2] < 0)... + # ). + # At each step, the likelihood that an intersection has already been lost + # due to floating point errors grows, so the real solution is to pick a less + # degenerate pitch value - one that does not ride the global extrema of f. + ) + return put(is_intersect, idx[0], edge_case) + + +def _add2legend(legend, lines): + """Add lines to legend if it's not already in it.""" + for line in setdefault(lines, [lines], hasattr(lines, "__iter__")): + label = line.get_label() + if label not in legend: + legend[label] = line + + +def _plot_intersect(ax, legend, z1, z2, k, k_transparency, klabel): + """Plot intersects on ``ax``.""" + if k is None: + return + + k = jnp.atleast_1d(jnp.squeeze(k)) + assert k.ndim == 1 + z1, z2 = jnp.atleast_2d(z1, z2) + assert z1.ndim == z2.ndim >= 2 + assert k.shape[0] == z1.shape[0] == z2.shape[0] + for p in k: + _add2legend( + legend, + ax.axhline(p, color="tab:purple", alpha=k_transparency, label=klabel), + ) + for i in range(k.size): + _z1, _z2 = z1[i], z2[i] + if _z1.size == _z2.size: + mask = (_z1 - _z2) != 0.0 + _z1 = _z1[mask] + _z2 = _z2[mask] + _add2legend( + legend, + ax.scatter( + _z1, + jnp.full_like(_z1, k[i]), + marker="v", + color="tab:red", + label=r"$z_1$", + ), + ) + _add2legend( + legend, + ax.scatter( + _z2, + jnp.full_like(_z2, k[i]), + marker="^", + color="tab:green", + label=r"$z_2$", + ), + ) diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py new file mode 100644 index 0000000000..dff4db396c --- /dev/null +++ b/desc/integrals/bounce_integral.py @@ -0,0 +1,428 @@ +"""Methods for computing bounce integrals (singular or otherwise).""" + +from interpax import CubicHermiteSpline, PPoly +from orthax.legendre import leggauss + +from desc.backend import jnp +from desc.integrals.bounce_utils import ( + _bounce_quadrature, + _check_bounce_points, + _set_default_plot_kwargs, + bounce_points, + get_pitch_inv, + interp_to_argmin, + plot_ppoly, +) +from desc.integrals.interp_utils import polyder_vec +from desc.integrals.quad_utils import ( + automorphism_sin, + get_quadrature, + grad_automorphism_sin, +) +from desc.io import IOAble +from desc.utils import errorif, setdefault, warnif + + +class Bounce1D(IOAble): + """Computes bounce integrals using one-dimensional local spline methods. + + The bounce integral is defined as ∫ f(λ, ℓ) dℓ, where + dℓ parameterizes the distance along the field line in meters, + f(λ, ℓ) is the quantity to integrate along the field line, + and the boundaries of the integral are bounce points ℓ₁, ℓ₂ s.t. λ|B|(ℓᵢ) = 1, + where λ is a constant defining the integral proportional to the magnetic moment + over energy and |B| is the norm of the magnetic field. + + For a particle with fixed λ, bounce points are defined to be the location on the + field line such that the particle's velocity parallel to the magnetic field is zero. + The bounce integral is defined up to a sign. We choose the sign that corresponds to + the particle's guiding center trajectory traveling in the direction of increasing + field-line-following coordinate ζ. + + Notes + ----- + Brief description of algorithm for developers. + + For applications which reduce to computing a nonlinear function of distance + along field lines between bounce points, it is required to identify these + points with field-line-following coordinates. (In the special case of a linear + function summing integrals between bounce points over a flux surface, arbitrary + coordinate systems may be used as this operation reduces to a surface integral, + which is invariant to the order of summation). + + The DESC coordinate system is related to field-line-following coordinate + systems by a relation whose solution is best found with Newton iteration. + There is a unique real solution to this equation, so Newton iteration is a + globally convergent root-finding algorithm here. For the task of finding + bounce points, even if the inverse map: θ(α, ζ) was known, Newton iteration + is not a globally convergent algorithm to find the real roots of + f : ζ ↦ |B|(ζ) − 1/λ where ζ is a field-line-following coordinate. + For this, function approximation of |B| is necessary. + + The function approximation in ``Bounce1D`` is ignorant that the objects to + approximate are defined on a bounded subset of ℝ². Instead, the domain is + projected to ℝ, where information sampled about the function at infinity + cannot support reconstruction of the function near the origin. As the + functions of interest do not vanish at infinity, pseudo-spectral techniques + are not used. Instead, function approximation is done with local splines. + This is useful if one can efficiently obtain data along field lines and + most efficient if the number of toroidal transits to follow a field line is + not too large. + + After computing the bounce points, the supplied quadrature is performed. + By default, this is a Gauss quadrature after removing the singularity. + Local splines interpolate functions in the integrand to the quadrature nodes. + + See Also + -------- + Bounce2D : Uses two-dimensional pseudo-spectral techniques for the same task. + + Examples + -------- + See ``tests/test_integrals.py::TestBounce1D::test_bounce1d_checks``. + + Attributes + ---------- + required_names : list + Names in ``data_index`` required to compute bounce integrals. + B : jnp.ndarray + Shape (M, L, N - 1, B.shape[-1]). + Polynomial coefficients of the spline of |B| in local power basis. + Last axis enumerates the coefficients of power series. For a polynomial + given by ∑ᵢⁿ cᵢ xⁱ, coefficient cᵢ is stored at ``B[...,n-i]``. + Third axis enumerates the polynomials that compose a particular spline. + Second axis enumerates flux surfaces. + First axis enumerates field lines of a particular flux surface. + + """ + + required_names = ["B^zeta", "B^zeta_z|r,a", "|B|", "|B|_z|r,a"] + get_pitch_inv = staticmethod(get_pitch_inv) + + def __init__( + self, + grid, + data, + quad=leggauss(32), + automorphism=(automorphism_sin, grad_automorphism_sin), + Bref=1.0, + Lref=1.0, + *, + is_reshaped=False, + check=False, + **kwargs, + ): + """Returns an object to compute bounce integrals. + + Parameters + ---------- + grid : Grid + Clebsch coordinate (ρ, α, ζ) tensor-product grid. + The ζ coordinates (the unique values prior to taking the tensor-product) + must be strictly increasing and preferably uniformly spaced. These are used + as knots to construct splines. A reference knot density is 100 knots per + toroidal transit. Note that below shape notation defines + L = ``grid.num_rho``, M = ``grid.num_alpha``, and N = ``grid.num_zeta``. + data : dict[str, jnp.ndarray] + Data evaluated on ``grid``. + Must include names in ``Bounce1D.required_names``. + quad : (jnp.ndarray, jnp.ndarray) + Quadrature points xₖ and weights wₖ for the approximate evaluation of an + integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). Default is 32 points. + automorphism : (Callable, Callable) or None + The first callable should be an automorphism of the real interval [-1, 1]. + The second callable should be the derivative of the first. This map defines + a change of variable for the bounce integral. The choice made for the + automorphism will affect the performance of the quadrature method. + Bref : float + Optional. Reference magnetic field strength for normalization. + Lref : float + Optional. Reference length scale for normalization. + is_reshaped : bool + Whether the arrays in ``data`` are already reshaped to the expected form of + shape (..., N) or (..., L, N) or (M, L, N). This option can be used to + iteratively compute bounce integrals one field line or one flux surface + at a time, respectively, potentially reducing memory usage. To do so, + set to true and provide only those axes of the reshaped data. + Default is false. + check : bool + Flag for debugging. Must be false for JAX transformations. + + """ + # Strictly increasing zeta knots enforces dζ > 0. + # To retain dℓ = (|B|/B^ζ) dζ > 0 after fixing dζ > 0, we require + # B^ζ = B⋅∇ζ > 0. This is equivalent to changing the sign of ∇ζ or [∂ℓ/∂ζ]|ρ,a. + # Recall dζ = ∇ζ⋅dR, implying 1 = ∇ζ⋅(e_ζ|ρ,a). Hence, a sign change in ∇ζ + # requires the same sign change in e_ζ|ρ,a to retain the metric identity. + warnif( + check and kwargs.pop("warn", True) and jnp.any(data["B^zeta"] <= 0), + msg="(∂ℓ/∂ζ)|ρ,a > 0 is required. Enforcing positive B^ζ.", + ) + data = { + "B^zeta": jnp.abs(data["B^zeta"]) * Lref / Bref, + "B^zeta_z|r,a": data["B^zeta_z|r,a"] + * jnp.sign(data["B^zeta"]) + * Lref + / Bref, + "|B|": data["|B|"] / Bref, + "|B|_z|r,a": data["|B|_z|r,a"] / Bref, # This is already the correct sign. + } + self._data = ( + data + if is_reshaped + else dict(zip(data.keys(), Bounce1D.reshape_data(grid, *data.values()))) + ) + self._x, self._w = get_quadrature(quad, automorphism) + + # Compute local splines. + self._zeta = grid.compress(grid.nodes[:, 2], surface_label="zeta") + self.B = jnp.moveaxis( + CubicHermiteSpline( + x=self._zeta, + y=self._data["|B|"], + dydx=self._data["|B|_z|r,a"], + axis=-1, + check=check, + ).c, + source=(0, 1), + destination=(-1, -2), + ) + self._dB_dz = polyder_vec(self.B) + + # Add axis here instead of in ``_bounce_quadrature``. + for name in self._data: + self._data[name] = self._data[name][..., jnp.newaxis, :] + + @staticmethod + def reshape_data(grid, *arys): + """Reshape arrays for acceptable input to ``integrate``. + + Parameters + ---------- + grid : Grid + Clebsch coordinate (ρ, α, ζ) tensor-product grid. + arys : jnp.ndarray + Data evaluated on grid. + + Returns + ------- + f : jnp.ndarray + Shape (M, L, N). + Reshaped data which may be given to ``integrate``. + + """ + f = [grid.meshgrid_reshape(d, "arz") for d in arys] + return f if len(f) > 1 else f[0] + + def points(self, pitch_inv, *, num_well=None): + """Compute bounce points. + + Parameters + ---------- + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is + specified by ``pitch_inv[α,ρ]`` where in the latter the labels + are interpreted as the indices that correspond to that field line. + num_well : int or None + Specify to return the first ``num_well`` pairs of bounce points for each + pitch along each field line. This is useful if ``num_well`` tightly + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``check_points`` method. + + If not specified, then all bounce points are returned. If there were fewer + wells detected along a field line than the size of the last axis of the + returned arrays, then that axis is padded with zero. + + Returns + ------- + z1, z2 : (jnp.ndarray, jnp.ndarray) + Shape (M, L, P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such + that the straight line path between ``z1`` and ``z2`` resides in the + epigraph of |B|. + + If there were less than ``num_well`` wells detected along a field line, + then the last axis, which enumerates bounce points for a particular field + line and pitch, is padded with zero. + + """ + return bounce_points(pitch_inv, self._zeta, self.B, self._dB_dz, num_well) + + def check_points(self, z1, z2, pitch_inv, *, plot=True, **kwargs): + """Check that bounce points are computed correctly. + + Parameters + ---------- + z1, z2 : (jnp.ndarray, jnp.ndarray) + Shape (M, L, P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such + that the straight line path between ``z1`` and ``z2`` resides in the + epigraph of |B|. + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is + specified by ``pitch_inv[α,ρ]`` where in the latter the labels + are interpreted as the indices that correspond to that field line. + plot : bool + Whether to plot the field lines and bounce points of the given pitch angles. + kwargs + Keyword arguments into ``desc/integrals/bounce_utils.py::plot_ppoly``. + + Returns + ------- + plots : list + Matplotlib (fig, ax) tuples for the 1D plot of each field line. + + """ + return _check_bounce_points( + z1=z1, + z2=z2, + pitch_inv=pitch_inv, + knots=self._zeta, + B=self.B, + plot=plot, + **kwargs, + ) + + def integrate( + self, + integrand, + pitch_inv, + f=None, + weight=None, + *, + num_well=None, + method="cubic", + batch=True, + check=False, + plot=False, + ): + """Bounce integrate ∫ f(λ, ℓ) dℓ. + + Computes the bounce integral ∫ f(λ, ℓ) dℓ for every field line and pitch. + + Parameters + ---------- + integrand : callable + The composition operator on the set of functions in ``f`` that maps the + functions in ``f`` to the integrand f(λ, ℓ) in ∫ f(λ, ℓ) dℓ. It should + accept the arrays in ``f`` as arguments as well as the additional keyword + arguments: ``B`` and ``pitch``. A quadrature will be performed to + approximate the bounce integral of ``integrand(*f,B=B,pitch=pitch)``. + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by + ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. + f : list[jnp.ndarray] or jnp.ndarray + Shape (M, L, N). + Real scalar-valued functions evaluated on the ``grid`` supplied to + construct this object. These functions should be arguments to the callable + ``integrand``. Use the method ``self.reshape_data`` to reshape the data + into the expected shape. + weight : jnp.ndarray + Shape (M, L, N). + If supplied, the bounce integral labeled by well j is weighted such that + the returned value is w(j) ∫ f(λ, ℓ) dℓ, where w(j) is ``weight`` + interpolated to the deepest point in that magnetic well. Use the method + ``self.reshape_data`` to reshape the data into the expected shape. + num_well : int or None + Specify to return the first ``num_well`` pairs of bounce points for each + pitch along each field line. This is useful if ``num_well`` tightly + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``check_points`` method. + + If not specified, then all bounce points are returned. If there were fewer + wells detected along a field line than the size of the last axis of the + returned arrays, then that axis is padded with zero. + method : str + Method of interpolation. + See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. + Default is cubic C1 local spline. + batch : bool + Whether to perform computation in a batched manner. Default is true. + check : bool + Flag for debugging. Must be false for JAX transformations. + plot : bool + Whether to plot the quantities in the integrand interpolated to the + quadrature points of each integral. Ignored if ``check`` is false. + + Returns + ------- + result : jnp.ndarray + Shape (M, L, P, num_well). + Last axis enumerates the bounce integrals for a given field line, + flux surface, and pitch value. + + """ + z1, z2 = self.points(pitch_inv, num_well=num_well) + result = _bounce_quadrature( + x=self._x, + w=self._w, + z1=z1, + z2=z2, + integrand=integrand, + pitch_inv=pitch_inv, + f=setdefault(f, []), + data=self._data, + knots=self._zeta, + method=method, + batch=batch, + check=check, + plot=plot, + ) + if weight is not None: + result *= interp_to_argmin( + weight, + z1, + z2, + self._zeta, + self.B, + self._dB_dz, + method, + ) + assert result.shape == z1.shape + return result + + def plot(self, m, l, pitch_inv=None, /, **kwargs): + """Plot the field line and bounce points of the given pitch angles. + + Parameters + ---------- + m, l : int, int + Indices into the nodes of the grid supplied to make this object. + ``alpha,rho=grid.meshgrid_reshape(grid.nodes[:,:2],"arz")[m,l,0]``. + pitch_inv : jnp.ndarray + Shape (P, ). + Optional, 1/λ values whose corresponding bounce points on the field line + specified by Clebsch coordinate α(m), ρ(l) will be plotted. + kwargs + Keyword arguments into ``desc/integrals/bounce_utils.py::plot_ppoly``. + + Returns + ------- + fig, ax + Matplotlib (fig, ax) tuple. + + """ + B, dB_dz = self.B, self._dB_dz + if B.ndim == 4: + B = B[m] + dB_dz = dB_dz[m] + if B.ndim == 3: + B = B[l] + dB_dz = dB_dz[l] + if pitch_inv is not None: + errorif( + pitch_inv.ndim > 1, + msg=f"Got pitch_inv.ndim={pitch_inv.ndim}, but expected 1.", + ) + z1, z2 = bounce_points(pitch_inv, self._zeta, B, dB_dz) + kwargs["z1"] = z1 + kwargs["z2"] = z2 + kwargs["k"] = pitch_inv + fig, ax = plot_ppoly(PPoly(B.T, self._zeta), **_set_default_plot_kwargs(kwargs)) + return fig, ax diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py new file mode 100644 index 0000000000..c63477c0cc --- /dev/null +++ b/desc/integrals/bounce_utils.py @@ -0,0 +1,809 @@ +"""Utilities and functional programming interface for bounce integrals.""" + +import numpy as np +from interpax import PPoly +from matplotlib import pyplot as plt + +from desc.backend import imap, jnp, softargmax +from desc.integrals.basis import _add2legend, _in_epigraph_and, _plot_intersect +from desc.integrals.interp_utils import ( + interp1d_Hermite_vec, + interp1d_vec, + polyroot_vec, + polyval_vec, +) +from desc.integrals.quad_utils import ( + bijection_from_disc, + composite_linspace, + grad_bijection_from_disc, +) +from desc.utils import ( + atleast_nd, + errorif, + flatten_matrix, + is_broadcastable, + setdefault, + take_mask, +) + + +def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6): + """Return 1/λ values for quadrature between ``min_B`` and ``max_B``. + + Parameters + ---------- + min_B : jnp.ndarray + Minimum |B| value. + max_B : jnp.ndarray + Maximum |B| value. + num : int + Number of values, not including endpoints. + relative_shift : float + Relative amount to shift maxima down and minima up to avoid floating point + errors in downstream routines. + + Returns + ------- + pitch_inv : jnp.ndarray + Shape (*min_B.shape, num + 2). + 1/λ values. + + """ + # Floating point error impedes consistent detection of bounce points riding + # extrema. Shift values slightly to resolve this issue. + min_B = (1 + relative_shift) * min_B + max_B = (1 - relative_shift) * max_B + # Samples should be uniformly spaced in |B| and not λ (GitHub issue #1228). + pitch_inv = jnp.moveaxis(composite_linspace(jnp.stack([min_B, max_B]), num), 0, -1) + assert pitch_inv.shape == (*min_B.shape, num + 2) + return pitch_inv + + +def _check_spline_shape(knots, g, dg_dz, pitch_inv=None): + """Ensure inputs have compatible shape. + + Parameters + ---------- + knots : jnp.ndarray + Shape (N, ). + ζ coordinates of spline knots. Must be strictly increasing. + g : jnp.ndarray + Shape (..., N - 1, g.shape[-1]). + Polynomial coefficients of the spline of g in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dg_dz : jnp.ndarray + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂ζ in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values. 1/λ(α,ρ) is specified by ``pitch_inv[α,ρ]`` where in + the latter the labels are interpreted as the indices that correspond + to that field line. + + """ + errorif(knots.ndim != 1, msg=f"knots should be 1d; got shape {knots.shape}.") + errorif( + g.shape[-2] != (knots.size - 1), + msg=( + "Second to last axis does not enumerate polynomials of spline. " + f"Spline shape {g.shape}. Knots shape {knots.shape}." + ), + ) + errorif( + not (g.ndim == dg_dz.ndim < 5) + or g.shape != (*dg_dz.shape[:-1], dg_dz.shape[-1] + 1), + msg=f"Invalid shape {g.shape} for spline and derivative {dg_dz.shape}.", + ) + g, dg_dz = jnp.atleast_2d(g, dg_dz) + if pitch_inv is not None: + pitch_inv = jnp.atleast_1d(pitch_inv) + errorif( + pitch_inv.ndim > 3 + or not is_broadcastable(pitch_inv.shape[:-1], g.shape[:-2]), + msg=f"Invalid shape {pitch_inv.shape} for pitch angles.", + ) + return g, dg_dz, pitch_inv + + +def bounce_points( + pitch_inv, knots, B, dB_dz, num_well=None, check=False, plot=True, **kwargs +): + """Compute the bounce points given spline of |B| and pitch λ. + + Parameters + ---------- + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values to compute the bounce points. + knots : jnp.ndarray + Shape (N, ). + ζ coordinates of spline knots. Must be strictly increasing. + B : jnp.ndarray + Shape (..., N - 1, B.shape[-1]). + Polynomial coefficients of the spline of |B| in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dB_dz : jnp.ndarray + Shape (..., N - 1, B.shape[-1] - 1). + Polynomial coefficients of the spline of (∂|B|/∂ζ)|(ρ,α) in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + num_well : int or None + Specify to return the first ``num_well`` pairs of bounce points for each + pitch along each field line. This is useful if ``num_well`` tightly + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``_check_bounce_points`` method. + + If not specified, then all bounce points are returned. If there were fewer + wells detected along a field line than the size of the last axis of the + returned arrays, then that axis is padded with zero. + check : bool + Flag for debugging. Must be false for JAX transformations. + plot : bool + Whether to plot some things if check is true. Default is true. + kwargs + Keyword arguments into ``plot_ppoly``. + + Returns + ------- + z1, z2 : (jnp.ndarray, jnp.ndarray) + Shape (..., P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such + that the straight line path between ``z1`` and ``z2`` resides in the + epigraph of |B|. + + If there were less than ``num_well`` wells detected along a field line, + then the last axis, which enumerates bounce points for a particular field + line and pitch, is padded with zero. + + """ + B, dB_dz, pitch_inv = _check_spline_shape(knots, B, dB_dz, pitch_inv) + intersect = polyroot_vec( + c=B[..., jnp.newaxis, :, :], # Add P axis + k=pitch_inv[..., jnp.newaxis], # Add N axis + a_min=jnp.array([0.0]), + a_max=jnp.diff(knots), + sort=True, + sentinel=-1.0, + distinct=True, + ) + assert intersect.shape[-3:] == ( + pitch_inv.shape[-1], + knots.size - 1, + B.shape[-1] - 1, + ) + + # Reshape so that last axis enumerates intersects of a pitch along a field line. + dB_sign = flatten_matrix( + jnp.sign(polyval_vec(x=intersect, c=dB_dz[..., jnp.newaxis, :, jnp.newaxis, :])) + ) + # Only consider intersect if it is within knots that bound that polynomial. + is_intersect = flatten_matrix(intersect) >= 0 + # Following discussion on page 3 and 5 of https://doi.org/10.1063/1.873749, + # we ignore the bounce points of particles only assigned to a class that are + # trapped outside this snapshot of the field line. + is_z1 = (dB_sign <= 0) & is_intersect + is_z2 = (dB_sign >= 0) & _in_epigraph_and(is_intersect, dB_sign) + + # Transform out of local power basis expansion. + intersect = flatten_matrix(intersect + knots[:-1, jnp.newaxis]) + # New versions of JAX only like static sentinels. + sentinel = -10000000.0 # instead of knots[0] - 1 + z1 = take_mask(intersect, is_z1, size=num_well, fill_value=sentinel) + z2 = take_mask(intersect, is_z2, size=num_well, fill_value=sentinel) + + mask = (z1 > sentinel) & (z2 > sentinel) + # Set outside mask to same value so integration is over set of measure zero. + z1 = jnp.where(mask, z1, 0.0) + z2 = jnp.where(mask, z2, 0.0) + + if check: + _check_bounce_points(z1, z2, pitch_inv, knots, B, plot, **kwargs) + + return z1, z2 + + +def _set_default_plot_kwargs(kwargs): + kwargs.setdefault( + "title", + r"Intersects $\zeta$ in epigraph($\vert B \vert$) s.t. " + r"$\vert B \vert(\zeta) = 1/\lambda$", + ) + kwargs.setdefault("klabel", r"$1/\lambda$") + kwargs.setdefault("hlabel", r"$\zeta$") + kwargs.setdefault("vlabel", r"$\vert B \vert$") + return kwargs + + +def _check_bounce_points(z1, z2, pitch_inv, knots, B, plot=True, **kwargs): + """Check that bounce points are computed correctly.""" + z1 = atleast_nd(4, z1) + z2 = atleast_nd(4, z2) + pitch_inv = atleast_nd(3, pitch_inv) + B = atleast_nd(4, B) + + kwargs = _set_default_plot_kwargs(kwargs) + plots = [] + + assert z1.shape == z2.shape + mask = (z1 - z2) != 0.0 + z1 = jnp.where(mask, z1, jnp.nan) + z2 = jnp.where(mask, z2, jnp.nan) + + err_1 = jnp.any(z1 > z2, axis=-1) + err_2 = jnp.any(z1[..., 1:] < z2[..., :-1], axis=-1) + + eps = kwargs.pop("eps", jnp.finfo(jnp.array(1.0).dtype).eps * 10) + for ml in np.ndindex(B.shape[:-2]): + ppoly = PPoly(B[ml].T, knots) + for p in range(pitch_inv.shape[-1]): + idx = (*ml, p) + B_midpoint = ppoly((z1[idx] + z2[idx]) / 2) + err_3 = jnp.any(B_midpoint > pitch_inv[idx] + eps) + if not (err_1[idx] or err_2[idx] or err_3): + continue + _z1 = z1[idx][mask[idx]] + _z2 = z2[idx][mask[idx]] + if plot: + plot_ppoly( + ppoly=ppoly, + z1=_z1, + z2=_z2, + k=pitch_inv[idx], + title=kwargs.pop("title") + f", (m,l,p)={idx}", + **kwargs, + ) + + print(" z1 | z2") + print(jnp.column_stack([_z1, _z2])) + assert not err_1[idx], "Intersects have an inversion.\n" + assert not err_2[idx], "Detected discontinuity.\n" + assert not err_3, ( + f"Detected |B| = {B_midpoint[mask[idx]]} > {pitch_inv[idx] + eps} " + "= 1/λ in well, implying the straight line path between " + "bounce points is in hypograph(|B|). Use more knots.\n" + ) + if plot: + plots.append( + plot_ppoly( + ppoly=ppoly, + z1=z1[ml], + z2=z2[ml], + k=pitch_inv[ml], + **kwargs, + ) + ) + return plots + + +def _bounce_quadrature( + x, + w, + z1, + z2, + integrand, + pitch_inv, + f, + data, + knots, + method="cubic", + batch=True, + check=False, + plot=False, +): + """Bounce integrate ∫ f(λ, ℓ) dℓ. + + Parameters + ---------- + x : jnp.ndarray + Shape (w.size, ). + Quadrature points in [-1, 1]. + w : jnp.ndarray + Shape (w.size, ). + Quadrature weights. + z1, z2 : jnp.ndarray + Shape (..., P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such + that the straight line path between ``z1`` and ``z2`` resides in the + epigraph of |B|. + integrand : callable + The composition operator on the set of functions in ``f`` that maps the + functions in ``f`` to the integrand f(λ, ℓ) in ∫ f(λ, ℓ) dℓ. It should + accept the arrays in ``f`` as arguments as well as the additional keyword + arguments: ``B`` and ``pitch``. A quadrature will be performed to + approximate the bounce integral of ``integrand(*f,B=B,pitch=pitch)``. + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values to compute the bounce integrals. + f : list[jnp.ndarray] + Shape (..., N). + Real scalar-valued functions evaluated on the ``knots``. + These functions should be arguments to the callable ``integrand``. + data : dict[str, jnp.ndarray] + Shape (..., 1, N). + Required data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. + Must include names in ``Bounce1D.required_names``. + knots : jnp.ndarray + Shape (N, ). + Unique ζ coordinates where the arrays in ``data`` and ``f`` were evaluated. + method : str + Method of interpolation. + See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. + Default is cubic C1 local spline. + batch : bool + Whether to perform computation in a batched manner. Default is true. + check : bool + Flag for debugging. Must be false for JAX transformations. + Ignored if ``batch`` is false. + plot : bool + Whether to plot the quantities in the integrand interpolated to the + quadrature points of each integral. Ignored if ``check`` is false. + + Returns + ------- + result : jnp.ndarray + Shape (..., P, num_well). + Last axis enumerates the bounce integrals for a field line, + flux surface, and pitch. + + """ + errorif(x.ndim != 1 or x.shape != w.shape) + errorif(z1.ndim < 2 or z1.shape != z2.shape) + pitch_inv = jnp.atleast_1d(pitch_inv) + if not isinstance(f, (list, tuple)): + f = [f] if isinstance(f, (jnp.ndarray, np.ndarray)) else list(f) + + # Integrate and complete the change of variable. + if batch: + result = _interpolate_and_integrate( + w=w, + Q=bijection_from_disc(x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]), + pitch_inv=pitch_inv, + integrand=integrand, + f=f, + data=data, + knots=knots, + method=method, + check=check, + plot=plot, + ) + else: + # TODO: Use batched vmap. + def loop(z): # over num well axis + z1, z2 = z + # Need to return tuple because input was tuple; artifact of JAX map. + return None, _interpolate_and_integrate( + w=w, + Q=bijection_from_disc(x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]), + pitch_inv=pitch_inv, + integrand=integrand, + f=f, + data=data, + knots=knots, + method=method, + check=False, + plot=False, + batch=True, + ) + + result = jnp.moveaxis( + imap(loop, (jnp.moveaxis(z1, -1, 0), jnp.moveaxis(z2, -1, 0)))[1], + source=0, + destination=-1, + ) + + return result * grad_bijection_from_disc(z1, z2) + + +def _interpolate_and_integrate( + w, + Q, + pitch_inv, + integrand, + f, + data, + knots, + method, + check, + plot, + batch=False, +): + """Interpolate given functions to points ``Q`` and perform quadrature. + + Parameters + ---------- + w : jnp.ndarray + Shape (w.size, ). + Quadrature weights. + Q : jnp.ndarray + Shape (..., P, Q.shape[-2], w.size). + Quadrature points in ζ coordinates. + + Returns + ------- + result : jnp.ndarray + Shape Q.shape[:-1]. + Quadrature result. + + """ + assert w.ndim == 1 and Q.shape[-1] == w.size + assert Q.shape[-3 + batch] == pitch_inv.shape[-1] + assert data["|B|"].shape[-1] == knots.size + + shape = Q.shape + if not batch: + Q = flatten_matrix(Q) + b_sup_z = interp1d_Hermite_vec( + Q, + knots, + data["B^zeta"] / data["|B|"], + data["B^zeta_z|r,a"] / data["|B|"] + - data["B^zeta"] * data["|B|_z|r,a"] / data["|B|"] ** 2, + ) + B = interp1d_Hermite_vec(Q, knots, data["|B|"], data["|B|_z|r,a"]) + # Spline each function separately so that operations in the integrand + # that do not preserve smoothness can be captured. + f = [interp1d_vec(Q, knots, f_i[..., jnp.newaxis, :], method=method) for f_i in f] + result = ( + (integrand(*f, B=B, pitch=1 / pitch_inv[..., jnp.newaxis]) / b_sup_z) + .reshape(shape) + .dot(w) + ) + if check: + _check_interp(shape, Q, f, b_sup_z, B, result, plot) + + return result + + +def _check_interp(shape, Q, f, b_sup_z, B, result, plot): + """Check for interpolation failures and floating point issues. + + Parameters + ---------- + shape : tuple + (..., P, Q.shape[-2], w.size). + Q : jnp.ndarray + Quadrature points in ζ coordinates. + f : list[jnp.ndarray] + Arguments to the integrand, interpolated to Q. + b_sup_z : jnp.ndarray + Contravariant toroidal component of magnetic field, interpolated to Q. + B : jnp.ndarray + Norm of magnetic field, interpolated to Q. + result : jnp.ndarray + Output of ``_interpolate_and_integrate``. + plot : bool + Whether to plot stuff. + + """ + assert jnp.isfinite(Q).all(), "NaN interpolation point." + assert not ( + jnp.isclose(B, 0).any() or jnp.isclose(b_sup_z, 0).any() + ), "|B| has vanished, violating the hairy ball theorem." + + # Integrals that we should be computing. + marked = jnp.any(Q.reshape(shape) != 0.0, axis=-1) + goal = marked.sum() + + assert goal == (marked & jnp.isfinite(b_sup_z).reshape(shape).all(axis=-1)).sum() + assert goal == (marked & jnp.isfinite(B).reshape(shape).all(axis=-1)).sum() + for f_i in f: + assert goal == (marked & jnp.isfinite(f_i).reshape(shape).all(axis=-1)).sum() + + # Number of those integrals that were computed. + actual = (marked & jnp.isfinite(result)).sum() + assert goal == actual, ( + f"Lost {goal - actual} integrals from NaN generation in the integrand. This " + "is caused by floating point error, usually due to a poor quadrature choice." + ) + if plot: + Q = Q.reshape(shape) + _plot_check_interp(Q, B.reshape(shape), name=r"$\vert B \vert$") + _plot_check_interp( + Q, b_sup_z.reshape(shape), name=r"$ (B / \vert B \vert) \cdot e^{\zeta}$" + ) + + +def _plot_check_interp(Q, V, name=""): + """Plot V[..., λ, (ζ₁, ζ₂)](Q).""" + for idx in np.ndindex(Q.shape[:3]): + marked = jnp.nonzero(jnp.any(Q[idx] != 0.0, axis=-1))[0] + if marked.size == 0: + continue + fig, ax = plt.subplots() + ax.set_xlabel(r"$\zeta$") + ax.set_ylabel(name) + ax.set_title(f"Interpolation of {name} to quadrature points, (m,l,p)={idx}") + for i in marked: + ax.plot(Q[(*idx, i)], V[(*idx, i)], marker="o") + fig.text(0.01, 0.01, "Each color specifies a particular integral.") + plt.tight_layout() + plt.show() + + +def _get_extrema(knots, g, dg_dz, sentinel=jnp.nan): + """Return extrema (z*, g(z*)). + + Parameters + ---------- + knots : jnp.ndarray + Shape (N, ). + ζ coordinates of spline knots. Must be strictly increasing. + g : jnp.ndarray + Shape (..., N - 1, g.shape[-1]). + Polynomial coefficients of the spline of g in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dg_dz : jnp.ndarray + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂z in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + sentinel : float + Value with which to pad array to return fixed shape. + + Returns + ------- + ext, g_ext : jnp.ndarray + Shape (..., (N - 1) * (g.shape[-1] - 2)). + First array enumerates z*. Second array enumerates g(z*) + Sorting order of extrema is arbitrary. + + """ + g, dg_dz, _ = _check_spline_shape(knots, g, dg_dz) + ext = polyroot_vec( + c=dg_dz, a_min=jnp.array([0.0]), a_max=jnp.diff(knots), sentinel=sentinel + ) + g_ext = flatten_matrix(polyval_vec(x=ext, c=g[..., jnp.newaxis, :])) + # Transform out of local power basis expansion. + ext = flatten_matrix(ext + knots[:-1, jnp.newaxis]) + assert ext.shape == g_ext.shape and ext.shape[-1] == g.shape[-2] * (g.shape[-1] - 2) + return ext, g_ext + + +def _where_for_argmin(z1, z2, ext, g_ext, upper_sentinel): + return jnp.where( + (z1[..., jnp.newaxis] < ext[..., jnp.newaxis, jnp.newaxis, :]) + & (ext[..., jnp.newaxis, jnp.newaxis, :] < z2[..., jnp.newaxis]), + g_ext[..., jnp.newaxis, jnp.newaxis, :], + upper_sentinel, + ) + + +def interp_to_argmin( + h, z1, z2, knots, g, dg_dz, method="cubic", beta=-100, upper_sentinel=1e2 +): + """Interpolate ``h`` to the deepest point of ``g`` between ``z1`` and ``z2``. + + Let E = {ζ ∣ ζ₁ < ζ < ζ₂} and A = argmin_E g(ζ). Returns mean_A h(ζ). + + Parameters + ---------- + h : jnp.ndarray + Shape (..., N). + Values evaluated on ``knots`` to interpolate. + z1, z2 : jnp.ndarray + Shape (..., P, W). + Boundaries to detect argmin between. + knots : jnp.ndarray + Shape (N, ). + z coordinates of spline knots. Must be strictly increasing. + g : jnp.ndarray + Shape (..., N - 1, g.shape[-1]). + Polynomial coefficients of the spline of g in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dg_dz : jnp.ndarray + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂z in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + method : str + Method of interpolation. + See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. + Default is cubic C1 local spline. + beta : float + More negative gives exponentially better approximation at the + expense of noisier gradients - noisier in the physics sense (unrelated + to the automatic differentiation). + upper_sentinel : float + Something larger than g. Choose value such that + exp(max(g)) << exp(``upper_sentinel``). Don't make too large or numerical + resolution is lost. + + Warnings + -------- + Recall that if g is small then the effect of β is reduced. + If the intention is to use this function as argmax, be sure to supply + a lower sentinel for ``upper_sentinel``. + + Returns + ------- + h : jnp.ndarray + Shape (..., P, W). + + """ + assert z1.ndim == z2.ndim >= 2 and z1.shape == z2.shape + ext, g_ext = _get_extrema(knots, g, dg_dz, sentinel=0) + # Our softargmax(x) does the proper shift to compute softargmax(x - max(x)), + # but it's still not a good idea to compute over a large length scale, so we + # warn in docstring to choose upper sentinel properly. + argmin = softargmax( + beta * _where_for_argmin(z1, z2, ext, g_ext, upper_sentinel), + axis=-1, + ) + h = jnp.linalg.vecdot( + argmin, + interp1d_vec(ext, knots, h, method=method)[..., jnp.newaxis, jnp.newaxis, :], + ) + assert h.shape == z1.shape + return h + + +def interp_to_argmin_hard(h, z1, z2, knots, g, dg_dz, method="cubic"): + """Interpolate ``h`` to the deepest point of ``g`` between ``z1`` and ``z2``. + + Let E = {ζ ∣ ζ₁ < ζ < ζ₂} and A ∈ argmin_E g(ζ). Returns h(A). + + See Also + -------- + interp_to_argmin + Accomplishes the same task, but handles the case of non-unique global minima + more correctly. It is also more efficient if P >> 1. + + Parameters + ---------- + h : jnp.ndarray + Shape (..., N). + Values evaluated on ``knots`` to interpolate. + z1, z2 : jnp.ndarray + Shape (..., P, W). + Boundaries to detect argmin between. + knots : jnp.ndarray + Shape (N, ). + z coordinates of spline knots. Must be strictly increasing. + g : jnp.ndarray + Shape (..., N - 1, g.shape[-1]). + Polynomial coefficients of the spline of g in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dg_dz : jnp.ndarray + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂z in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + method : str + Method of interpolation. + See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. + Default is cubic C1 local spline. + + Returns + ------- + h : jnp.ndarray + Shape (..., P, W). + + """ + assert z1.ndim == z2.ndim >= 2 and z1.shape == z2.shape + ext, g_ext = _get_extrema(knots, g, dg_dz, sentinel=0) + # We can use the non-differentiable max because we actually want the gradients + # to accumulate through only the minimum since we are differentiating how our + # physics objective changes wrt equilibrium perturbations not wrt which of the + # extrema get interpolated to. + argmin = jnp.argmin( + _where_for_argmin(z1, z2, ext, g_ext, jnp.max(g_ext) + 1), + axis=-1, + ) + h = interp1d_vec( + jnp.take_along_axis(ext[jnp.newaxis], argmin, axis=-1), + knots, + h[..., jnp.newaxis, :], + method=method, + ) + assert h.shape == z1.shape, h.shape + return h + + +def plot_ppoly( + ppoly, + num=1000, + z1=None, + z2=None, + k=None, + k_transparency=0.5, + klabel=r"$k$", + title=r"Intersects $z$ in epigraph($f$) s.t. $f(z) = k$", + hlabel=r"$z$", + vlabel=r"$f$", + show=True, + start=None, + stop=None, + include_knots=False, + knot_transparency=0.2, + include_legend=True, +): + """Plot the piecewise polynomial ``ppoly``. + + Parameters + ---------- + ppoly : PPoly + Piecewise polynomial f. + num : int + Number of points to evaluate for plot. + z1 : jnp.ndarray + Shape (k.shape[0], W). + Optional, intersects with ∂f/∂z <= 0. + z2 : jnp.ndarray + Shape (k.shape[0], W). + Optional, intersects with ∂f/∂z >= 0. + k : jnp.ndarray + Shape (k.shape[0], ). + Optional, k such that f(z) = k. + k_transparency : float + Transparency of intersect lines. + klabel : str + Label of intersect lines. + title : str + Plot title. + hlabel : str + Horizontal axis label. + vlabel : str + Vertical axis label. + show : bool + Whether to show the plot. Default is true. + start : float + Minimum z on plot. + stop : float + Maximum z on plot. + include_knots : bool + Whether to plot vertical lines at the knots. + knot_transparency : float + Transparency of knot lines. + include_legend : bool + Whether to include the legend in the plot. Default is true. + + Returns + ------- + fig, ax + Matplotlib (fig, ax) tuple. + + """ + fig, ax = plt.subplots() + legend = {} + if include_knots: + for knot in ppoly.x: + _add2legend( + legend, + ax.axvline( + x=knot, color="tab:blue", alpha=knot_transparency, label="knot" + ), + ) + + z = jnp.linspace( + start=setdefault(start, ppoly.x[0]), + stop=setdefault(stop, ppoly.x[-1]), + num=num, + ) + _add2legend(legend, ax.plot(z, ppoly(z), label=vlabel)) + _plot_intersect( + ax=ax, + legend=legend, + z1=z1, + z2=z2, + k=k, + k_transparency=k_transparency, + klabel=klabel, + ) + ax.set_xlabel(hlabel) + ax.set_ylabel(vlabel) + if include_legend: + ax.legend(legend.values(), legend.keys(), loc="lower right") + ax.set_title(title) + plt.tight_layout() + if show: + plt.show() + plt.close() + return fig, ax diff --git a/desc/integrals/interp_utils.py b/desc/integrals/interp_utils.py new file mode 100644 index 0000000000..4943be509c --- /dev/null +++ b/desc/integrals/interp_utils.py @@ -0,0 +1,292 @@ +"""Fast interpolation utilities. + +Notes +----- +These polynomial utilities are chosen for performance on gpu among +methods that have the best (asymptotic) algorithmic complexity. +For example, we prefer to not use Horner's method. +""" + +from functools import partial + +from interpax import interp1d + +from desc.backend import jnp +from desc.compute.utils import safediv + +# Warning: method must be specified as keyword argument. +interp1d_vec = jnp.vectorize( + interp1d, signature="(m),(n),(n)->(m)", excluded={"method"} +) + + +@partial(jnp.vectorize, signature="(m),(n),(n),(n)->(m)") +def interp1d_Hermite_vec(xq, x, f, fx, /): + """Vectorized cubic Hermite spline.""" + return interp1d(xq, x, f, method="cubic", fx=fx) + + +def polyder_vec(c): + """Coefficients for the derivatives of the given set of polynomials. + + Parameters + ---------- + c : jnp.ndarray + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. + + Returns + ------- + poly : jnp.ndarray + Coefficients of polynomial derivative, ignoring the arbitrary constant. That is, + ``poly[...,i]`` stores the coefficient of the monomial xⁿ⁻ⁱ⁻¹, where n is + ``c.shape[-1]-1``. + + """ + return c[..., :-1] * jnp.arange(c.shape[-1] - 1, 0, -1) + + +def polyval_vec(*, x, c): + """Evaluate the set of polynomials ``c`` at the points ``x``. + + Parameters + ---------- + x : jnp.ndarray + Coordinates at which to evaluate the set of polynomials. + c : jnp.ndarray + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. + + Returns + ------- + val : jnp.ndarray + Polynomial with given coefficients evaluated at given points. + + Examples + -------- + .. code-block:: python + + np.testing.assert_allclose( + polyval_vec(x=x, c=c), + np.sum(polyvander(x, c.shape[-1] - 1) * c[..., ::-1], axis=-1), + ) + + """ + # Better than Horner's method as we expect to evaluate low order polynomials. + # No need to use fast multipoint evaluation techniques for the same reason. + return jnp.sum( + c * x[..., jnp.newaxis] ** jnp.arange(c.shape[-1] - 1, -1, -1), + axis=-1, + ) + + +# TODO: Eventually do a PR to move this stuff into interpax. + + +def _subtract_last(c, k): + """Subtract ``k`` from last index of last axis of ``c``. + + Semantically same as ``return c.copy().at[...,-1].add(-k)``, + but allows dimension to increase. + """ + c_1 = c[..., -1] - k + c = jnp.concatenate( + [ + jnp.broadcast_to(c[..., :-1], (*c_1.shape, c.shape[-1] - 1)), + c_1[..., jnp.newaxis], + ], + axis=-1, + ) + return c + + +def _filter_distinct(r, sentinel, eps): + """Set all but one of matching adjacent elements in ``r`` to ``sentinel``.""" + # eps needs to be low enough that close distinct roots do not get removed. + # Otherwise, algorithms relying on continuity will fail. + mask = jnp.isclose(jnp.diff(r, axis=-1, prepend=sentinel), 0, atol=eps) + r = jnp.where(mask, sentinel, r) + return r + + +_roots = jnp.vectorize(partial(jnp.roots, strip_zeros=False), signature="(m)->(n)") + + +def polyroot_vec( + c, + k=0, + a_min=None, + a_max=None, + sort=False, + sentinel=jnp.nan, + eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12), + distinct=False, +): + """Roots of polynomial with given coefficients. + + Parameters + ---------- + c : jnp.ndarray + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. + k : jnp.ndarray + Shape (..., *c.shape[:-1]). + Specify to find solutions to ∑ᵢⁿ cᵢ xⁱ = ``k``. + a_min : jnp.ndarray + Shape (..., *c.shape[:-1]). + Minimum ``a_min`` and maximum ``a_max`` value to return roots between. + If specified only real roots are returned, otherwise returns all complex roots. + a_max : jnp.ndarray + Shape (..., *c.shape[:-1]). + Minimum ``a_min`` and maximum ``a_max`` value to return roots between. + If specified only real roots are returned, otherwise returns all complex roots. + sort : bool + Whether to sort the roots. + sentinel : float + Value with which to pad array in place of filtered elements. + Anything less than ``a_min`` or greater than ``a_max`` plus some floating point + error buffer will work just like nan while avoiding ``nan`` gradient. + eps : float + Absolute tolerance with which to consider value as zero. + distinct : bool + Whether to only return the distinct roots. If true, when the multiplicity is + greater than one, the repeated roots are set to ``sentinel``. + + Returns + ------- + r : jnp.ndarray + Shape (..., *c.shape[:-1], c.shape[-1] - 1). + The roots of the polynomial, iterated over the last axis. + + """ + get_only_real_roots = not (a_min is None and a_max is None) + num_coef = c.shape[-1] + c = _subtract_last(c, k) + func = {2: _root_linear, 3: _root_quadratic, 4: _root_cubic} + + if ( + num_coef in func + and get_only_real_roots + and not (jnp.iscomplexobj(c) or jnp.iscomplexobj(k)) + ): + # Compute from analytic formula to avoid the issue of complex roots with small + # imaginary parts and to avoid nan in gradient. + r = func[num_coef](C=c, sentinel=sentinel, eps=eps, distinct=distinct) + # We already filtered distinct roots for quadratics. + distinct = distinct and num_coef > 3 + else: + # Compute from eigenvalues of polynomial companion matrix. + r = _roots(c) + + if get_only_real_roots: + a_min = -jnp.inf if a_min is None else a_min[..., jnp.newaxis] + a_max = +jnp.inf if a_max is None else a_max[..., jnp.newaxis] + r = jnp.where( + (jnp.abs(r.imag) <= eps) & (a_min <= r.real) & (r.real <= a_max), + r.real, + sentinel, + ) + + if sort or distinct: + r = jnp.sort(r, axis=-1) + r = _filter_distinct(r, sentinel, eps) if distinct else r + assert r.shape[-1] == num_coef - 1 + return r + + +def _root_cubic(C, sentinel, eps, distinct): + """Return real cubic root assuming real coefficients.""" + # numerical.recipes/book.html, page 228 + + def irreducible(Q, R, b, mask): + # Three irrational real roots. + theta = jnp.arccos(R / jnp.sqrt(jnp.where(mask, Q**3, R**2 + 1))) + return jnp.moveaxis( + -2 + * jnp.sqrt(Q) + * jnp.stack( + [ + jnp.cos(theta / 3), + jnp.cos((theta + 2 * jnp.pi) / 3), + jnp.cos((theta - 2 * jnp.pi) / 3), + ] + ) + - b / 3, + source=0, + destination=-1, + ) + + def reducible(Q, R, b): + # One real and two complex roots. + A = -jnp.sign(R) * (jnp.abs(R) + jnp.sqrt(jnp.abs(R**2 - Q**3))) ** (1 / 3) + B = safediv(Q, A) + r1 = (A + B) - b / 3 + return _concat_sentinel(r1[..., jnp.newaxis], sentinel, num=2) + + def root(b, c, d): + b = safediv(b, a) + c = safediv(c, a) + d = safediv(d, a) + Q = (b**2 - 3 * c) / 9 + R = (2 * b**3 - 9 * b * c + 27 * d) / 54 + mask = R**2 < Q**3 + return jnp.where( + mask[..., jnp.newaxis], + irreducible(jnp.abs(Q), R, b, mask), + reducible(Q, R, b), + ) + + a = C[..., 0] + b = C[..., 1] + c = C[..., 2] + d = C[..., 3] + return jnp.where( + # Tests catch failure here if eps < 1e-12 for 64 bit precision. + jnp.expand_dims(jnp.abs(a) <= eps, axis=-1), + _concat_sentinel( + _root_quadratic( + C=C[..., 1:], sentinel=sentinel, eps=eps, distinct=distinct + ), + sentinel, + ), + root(b, c, d), + ) + + +def _root_quadratic(C, sentinel, eps, distinct): + """Return real quadratic root assuming real coefficients.""" + # numerical.recipes/book.html, page 227 + a = C[..., 0] + b = C[..., 1] + c = C[..., 2] + + discriminant = b**2 - 4 * a * c + q = -0.5 * (b + jnp.sign(b) * jnp.sqrt(jnp.abs(discriminant))) + r1 = jnp.where( + discriminant < 0, + sentinel, + safediv(q, a, _root_linear(C=C[..., 1:], sentinel=sentinel, eps=eps)), + ) + r2 = jnp.where( + # more robust to remove repeated roots with discriminant + (discriminant < 0) | (distinct & (discriminant <= eps)), + sentinel, + safediv(c, q, sentinel), + ) + return jnp.stack([r1, r2], axis=-1) + + +def _root_linear(C, sentinel, eps, distinct=False): + """Return real linear root assuming real coefficients.""" + a = C[..., 0] + b = C[..., 1] + return safediv(-b, a, jnp.where(jnp.abs(b) <= eps, 0, sentinel)) + + +def _concat_sentinel(r, sentinel, num=1): + """Append ``sentinel`` ``num`` times to ``r`` on last axis.""" + sent = jnp.broadcast_to(sentinel, (*r.shape[:-1], num)) + return jnp.append(r, sent, axis=-1) diff --git a/desc/integrals/quad_utils.py b/desc/integrals/quad_utils.py new file mode 100644 index 0000000000..692149e84e --- /dev/null +++ b/desc/integrals/quad_utils.py @@ -0,0 +1,246 @@ +"""Utilities for quadratures.""" + +from orthax.legendre import legder, legval + +from desc.backend import eigh_tridiagonal, jnp, put +from desc.utils import errorif + + +def bijection_to_disc(x, a, b): + """[a, b] ∋ x ↦ y ∈ [−1, 1].""" + y = 2.0 * (x - a) / (b - a) - 1.0 + return y + + +def bijection_from_disc(x, a, b): + """[−1, 1] ∋ x ↦ y ∈ [a, b].""" + y = 0.5 * (b - a) * (x + 1.0) + a + return y + + +def grad_bijection_from_disc(a, b): + """Gradient wrt ``x`` of ``bijection_from_disc``.""" + dy_dx = 0.5 * (b - a) + return dy_dx + + +def automorphism_arcsin(x): + """[-1, 1] ∋ x ↦ y ∈ [−1, 1]. + + The arcsin transformation introduces a singularity that augments the singularity + in the bounce integral, so the quadrature scheme used to evaluate the integral must + work well on functions with large derivative near the boundary. + + Parameters + ---------- + x : jnp.ndarray + Points to transform. + + Returns + ------- + y : jnp.ndarray + Transformed points. + + """ + y = 2.0 * jnp.arcsin(x) / jnp.pi + return y + + +def grad_automorphism_arcsin(x): + """Gradient of arcsin automorphism.""" + dy_dx = 2.0 / (jnp.sqrt(1.0 - x**2) * jnp.pi) + return dy_dx + + +grad_automorphism_arcsin.__doc__ += "\n" + automorphism_arcsin.__doc__ + + +def automorphism_sin(x, s=0, m=10): + """[-1, 1] ∋ x ↦ y ∈ [−1, 1]. + + When used as the change of variable map for the bounce integral, the Lipschitzness + of the sin transformation prevents generation of new singularities. Furthermore, + its derivative vanishes to zero slowly near the boundary, which will suppress the + large derivatives near the boundary of singular integrals. + + In effect, this map pulls the mass of the integral away from the singularities, + which should improve convergence if the quadrature performs better on less singular + integrands. Pairs well with Gauss-Legendre quadrature. + + Parameters + ---------- + x : jnp.ndarray + Points to transform. + s : float + Strength of derivative suppression, s ∈ [0, 1]. + m : float + Number of machine epsilons used for floating point error buffer. + + Returns + ------- + y : jnp.ndarray + Transformed points. + + """ + errorif(not (0 <= s <= 1)) + # s = 0 -> derivative vanishes like cosine. + # s = 1 -> derivative vanishes like cosine^k. + y0 = jnp.sin(0.5 * jnp.pi * x) + y1 = x + jnp.sin(jnp.pi * x) / jnp.pi # k = 2 + y = (1 - s) * y0 + s * y1 + # y is an expansion, so y(x) > x near x ∈ {−1, 1} and there is a tendency + # for floating point error to overshoot the true value. + eps = m * jnp.finfo(jnp.array(1.0).dtype).eps + return jnp.clip(y, -1 + eps, 1 - eps) + + +def grad_automorphism_sin(x, s=0): + """Gradient of sin automorphism.""" + dy0_dx = 0.5 * jnp.pi * jnp.cos(0.5 * jnp.pi * x) + dy1_dx = 1.0 + jnp.cos(jnp.pi * x) + dy_dx = (1 - s) * dy0_dx + s * dy1_dx + return dy_dx + + +grad_automorphism_sin.__doc__ += "\n" + automorphism_sin.__doc__ + + +def tanh_sinh(deg, m=10): + """Tanh-Sinh quadrature. + + Returns quadrature points xₖ and weights wₖ for the approximate evaluation of the + integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ). + + Parameters + ---------- + deg : int + Number of quadrature points. + m : float + Number of machine epsilons used for floating point error buffer. Larger implies + less floating point error, but increases the minimum achievable error. + + Returns + ------- + x, w : (jnp.ndarray, jnp.ndarray) + Shape (deg, ). + Quadrature points and weights. + + """ + # buffer to avoid numerical instability + x_max = jnp.array(1.0) + x_max = x_max - m * jnp.finfo(x_max.dtype).eps + t_max = jnp.arcsinh(2 * jnp.arctanh(x_max) / jnp.pi) + # maximal-spacing scheme, doi.org/10.48550/arXiv.2007.15057 + t = jnp.linspace(-t_max, t_max, deg) + dt = 2 * t_max / (deg - 1) + arg = 0.5 * jnp.pi * jnp.sinh(t) + x = jnp.tanh(arg) # x = g(t) + w = 0.5 * jnp.pi * jnp.cosh(t) / jnp.cosh(arg) ** 2 * dt # w = (dg/dt) dt + return x, w + + +def leggauss_lob(deg, interior_only=False): + """Lobatto-Gauss-Legendre quadrature. + + Returns quadrature points xₖ and weights wₖ for the approximate evaluation of the + integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ). + + Parameters + ---------- + deg : int + Number of quadrature points. + interior_only : bool + Whether to exclude the points and weights at -1 and 1; + useful if f(-1) = f(1) = 0. If true, then ``deg`` points are still + returned; these are the interior points for lobatto quadrature of ``deg+2``. + + Returns + ------- + x, w : (jnp.ndarray, jnp.ndarray) + Shape (deg, ). + Quadrature points and weights. + + """ + N = deg + 2 * bool(interior_only) + errorif(N < 2) + + # Golub-Welsh algorithm + n = jnp.arange(2, N - 1) + x = eigh_tridiagonal( + jnp.zeros(N - 2), + jnp.sqrt((n**2 - 1) / (4 * n**2 - 1)), + eigvals_only=True, + ) + c0 = put(jnp.zeros(N), -1, 1) + + # improve (single multiplicity) roots by one application of Newton + c = legder(c0) + dy = legval(x=x, c=c) + df = legval(x=x, c=legder(c)) + x -= dy / df + + w = 2 / (N * (N - 1) * legval(x=x, c=c0) ** 2) + + if not interior_only: + x = jnp.hstack([-1.0, x, 1.0]) + w_end = 2 / (deg * (deg - 1)) + w = jnp.hstack([w_end, w, w_end]) + + assert x.size == w.size == deg + return x, w + + +def get_quadrature(quad, automorphism): + """Apply automorphism to given quadrature. + + Parameters + ---------- + quad : (jnp.ndarray, jnp.ndarray) + Quadrature points xₖ and weights wₖ for the approximate evaluation of an + integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). + automorphism : (Callable, Callable) or None + The first callable should be an automorphism of the real interval [-1, 1]. + The second callable should be the derivative of the first. This map defines + a change of variable for the bounce integral. The choice made for the + automorphism will affect the performance of the quadrature method. + + Returns + ------- + x, w : (jnp.ndarray, jnp.ndarray) + Quadrature points and weights. + + """ + x, w = quad + assert x.ndim == w.ndim == 1 + if automorphism is not None: + auto, grad_auto = automorphism + w = w * grad_auto(x) + # Recall bijection_from_disc(auto(x), ζ₁, ζ₂) = ζ. + x = auto(x) + return x, w + + +def composite_linspace(x, num): + """Returns linearly spaced values between every pair of values in ``x``. + + Parameters + ---------- + x : jnp.ndarray + First axis has values to return linearly spaced values between. The remaining + axes are batch axes. Assumes input is sorted along first axis. + num : int + Number of values between every pair of values in ``x``. + + Returns + ------- + vals : jnp.ndarray + Shape ((x.shape[0] - 1) * num + x.shape[0], *x.shape[1:]). + Linearly spaced values between ``x``. + + """ + x = jnp.atleast_1d(x) + vals = jnp.linspace(x[:-1], x[1:], num + 1, endpoint=False) + vals = jnp.swapaxes(vals, 0, 1).reshape(-1, *x.shape[1:]) + vals = jnp.append(vals, x[jnp.newaxis, -1], axis=0) + assert vals.shape == ((x.shape[0] - 1) * num + x.shape[0], *x.shape[1:]) + return vals diff --git a/desc/integrals/surface_integral.py b/desc/integrals/surface_integral.py index acc1e6c1b9..944a711904 100644 --- a/desc/integrals/surface_integral.py +++ b/desc/integrals/surface_integral.py @@ -100,7 +100,7 @@ def line_integrals( The coordinate curve to compute the integration over. To clarify, a theta (poloidal) curve is the intersection of a rho surface (flux surface) and zeta (toroidal) surface. - fix_surface : str, float + fix_surface : (str, float) A tuple of the form: label, value. ``fix_surface`` label should differ from ``line_label``. By default, ``fix_surface`` is chosen to be the flux surface at rho=1. diff --git a/desc/io/optimizable_io.py b/desc/io/optimizable_io.py index 554cdac070..e15a21756e 100644 --- a/desc/io/optimizable_io.py +++ b/desc/io/optimizable_io.py @@ -169,16 +169,17 @@ class IOAble(ABC, metaclass=_CombinedMeta): """Abstract Base Class for savable and loadable objects. Objects inheriting from this class can be saved and loaded via hdf5 or pickle. - To save properly, each object should have an attribute `_io_attrs_` which + To save properly, each object should have an attribute ``_io_attrs_`` which is a list of strings of the object attributes or properties that should be saved and loaded. - For saved objects to be loaded correctly, the __init__ method of any custom - types being saved should only assign attributes that are listed in `_io_attrs_`. + For saved objects to be loaded correctly, the ``__init__`` method of any custom + types being saved should only assign attributes that are listed in ``_io_attrs_``. Other attributes or other initialization should be done in a separate - `set_up()` method that can be called during __init__. The loading process - will involve creating an empty object, bypassing init, then setting any `_io_attrs_` - of the object, then calling `_set_up()` without any arguments, if it exists. + ``set_up()`` method that can be called during ``__init__``. The loading process + will involve creating an empty object, bypassing init, then setting any + ``_io_attrs_`` of the object, then calling ``_set_up()`` without any arguments, + if it exists. """ diff --git a/desc/utils.py b/desc/utils.py index 44f744dcb6..41b32677ea 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -2,13 +2,14 @@ import operator import warnings +from functools import partial from itertools import combinations_with_replacement, permutations import numpy as np from scipy.special import factorial from termcolor import colored -from desc.backend import fori_loop, jit, jnp +from desc.backend import flatnonzero, fori_loop, jit, jnp, take class Timer: @@ -184,6 +185,13 @@ class _Indexable: def __getitem__(self, index): return index + @staticmethod + def get(stuff, axis, ndim): + slices = [slice(None)] * ndim + slices[axis] = stuff + slices = tuple(slices) + return slices + """ Helper object for building indexes for indexed update functions. @@ -684,4 +692,54 @@ def broadcast_tree(tree_in, tree_out, dtype=int): raise ValueError("trees must be nested lists of dicts") +@partial(jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"}) +def take_mask(a, mask, /, *, size=None, fill_value=None): + """JIT compilable method to return ``a[mask][:size]`` padded by ``fill_value``. + + Parameters + ---------- + a : jnp.ndarray + The source array. + mask : jnp.ndarray + Boolean mask to index into ``a``. Should have same shape as ``a``. + size : int + Elements of ``a`` at the first size True indices of ``mask`` will be returned. + If there are fewer elements than size indicates, the returned array will be + padded with ``fill_value``. The size default is ``mask.size``. + fill_value : Any + When there are fewer than the indicated number of elements, the remaining + elements will be filled with ``fill_value``. Defaults to NaN for inexact types, + the largest negative value for signed types, the largest positive value for + unsigned types, and True for booleans. + + Returns + ------- + result : jnp.ndarray + Shape (size, ). + + """ + assert a.shape == mask.shape + idx = flatnonzero(mask, size=setdefault(size, mask.size), fill_value=mask.size) + return take( + a, + idx, + mode="fill", + fill_value=fill_value, + unique_indices=True, + indices_are_sorted=True, + ) + + +def flatten_matrix(y): + """Flatten matrix to vector.""" + return y.reshape(*y.shape[:-2], -1) + + +# TODO: Eventually remove and use numpy's stuff. +# https://github.com/numpy/numpy/issues/25805 +def atleast_nd(ndmin, ary): + """Adds dimensions to front if necessary.""" + return jnp.array(ary, ndmin=ndmin) if jnp.ndim(ary) < ndmin else ary + + PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text diff --git a/devtools/dev-requirements_conda.yml b/devtools/dev-requirements_conda.yml index 5f5076a57e..5aa77689dd 100644 --- a/devtools/dev-requirements_conda.yml +++ b/devtools/dev-requirements_conda.yml @@ -15,9 +15,10 @@ dependencies: - pip: # Conda only parses a single list of pip requirements. # If two pip lists are given, all but the last list is skipped. - - interpax + - interpax >= 0.3.3 - jax[cpu] >= 0.3.2, < 0.5.0 - nvgpu + - orthax - plotly >= 5.16, < 6.0 - pylatexenc >= 2.0, < 3.0 # building the docs diff --git a/requirements.txt b/requirements.txt index a667a2a2db..fa5b86bba9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,13 @@ colorama h5py >= 3.0.0, < 4.0 -interpax +interpax >= 0.3.3 jax[cpu] >= 0.3.2, < 0.5.0 matplotlib >= 3.5.0, < 4.0.0 mpmath >= 1.0.0, < 2.0 netcdf4 >= 1.5.4, < 2.0 numpy >= 1.20.0, < 2.0.0 nvgpu +orthax plotly >= 5.16, < 6.0 psutil pylatexenc >= 2.0, < 3.0 diff --git a/requirements_conda.yml b/requirements_conda.yml index a151388648..da2996429a 100644 --- a/requirements_conda.yml +++ b/requirements_conda.yml @@ -14,8 +14,9 @@ dependencies: - pip: # Conda only parses a single list of pip requirements. # If two pip lists are given, all but the last list is skipped. - - interpax + - interpax >= 0.3.3 - jax[cpu] >= 0.3.2, < 0.5.0 - nvgpu + - orthax - plotly >= 5.16, < 6.0 - pylatexenc >= 2.0, < 3.0 diff --git a/tests/baseline/test_binormal_drift_bounce1d.png b/tests/baseline/test_binormal_drift_bounce1d.png new file mode 100644 index 0000000000..95339623df Binary files /dev/null and b/tests/baseline/test_binormal_drift_bounce1d.png differ diff --git a/tests/baseline/test_bounce1d_checks.png b/tests/baseline/test_bounce1d_checks.png new file mode 100644 index 0000000000..51e5a4d94f Binary files /dev/null and b/tests/baseline/test_bounce1d_checks.png differ diff --git a/tests/inputs/low-beta-shifted-circle.h5 b/tests/inputs/low-beta-shifted-circle.h5 index 31f4fab80b..dd75392a09 100644 Binary files a/tests/inputs/low-beta-shifted-circle.h5 and b/tests/inputs/low-beta-shifted-circle.h5 differ diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index 8c847ef3a0..fc3eebeb5d 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -63,7 +63,6 @@ "gbdrift", "cvdrift", "grad(alpha)", - "cvdrift0", "|e^helical|", "|grad(theta)|", " Redl", # may not exist for all configurations @@ -94,7 +93,6 @@ "K_vc", # only defined on surface "iota_num_rrr", "iota_den_rrr", - "cvdrift0", } @@ -135,6 +133,14 @@ def _skip_this(eq, name): or (eq.anisotropy is None and "beta_a" in name) or (eq.pressure is not None and " Redl" in name) or (eq.current is None and "iota_num" in name) + # These quantities require a coordinate mapping to compute and special grids, so + # it's not economical to test their axis limits here. Instead, a grid that + # includes the axis should be used in existing unit tests for these quantities. + or bool( + data_index["desc.equilibrium.equilibrium.Equilibrium"][name][ + "source_grid_requirement" + ] + ) ) @@ -388,3 +394,4 @@ def test_reverse_mode_ad_axis(name): obj.build(verbose=0) g = obj.grad(obj.x()) assert not np.any(np.isnan(g)) + print(np.count_nonzero(g), name) diff --git a/tests/test_grid.py b/tests/test_grid.py index 051ba1b89f..929a1bbe57 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -793,26 +793,23 @@ def test_meshgrid_reshape(self): zeta = np.linspace(0, 6 * np.pi, 5) grid = Grid.create_meshgrid([rho, alpha, zeta], coordinates="raz") r, a, z = grid.nodes.T - r = grid.meshgrid_reshape(r, "raz") - a = grid.meshgrid_reshape(a, "raz") - z = grid.meshgrid_reshape(z, "raz") # functions of zeta should separate along first two axes # since those are contiguous, this should work - f = z.reshape(-1, zeta.size) + f = grid.meshgrid_reshape(z, "raz").reshape(-1, zeta.size) for i in range(1, f.shape[0]): np.testing.assert_allclose(f[i - 1], f[i]) # likewise for rho - f = r.reshape(rho.size, -1) + f = grid.meshgrid_reshape(r, "raz").reshape(rho.size, -1) for i in range(1, f.shape[-1]): np.testing.assert_allclose(f[:, i - 1], f[:, i]) # test reshaping result won't mix data - f = (a**2 + z).reshape(rho.size, alpha.size, zeta.size) + f = grid.meshgrid_reshape(a**2 + z, "raz") for i in range(1, f.shape[0]): np.testing.assert_allclose(f[i - 1], f[i]) - f = (r**2 + z).reshape(rho.size, alpha.size, zeta.size) + f = grid.meshgrid_reshape(r**2 + z, "raz") for i in range(1, f.shape[1]): np.testing.assert_allclose(f[:, i - 1], f[:, i]) - f = (r**2 + a).reshape(rho.size, alpha.size, zeta.size) + f = grid.meshgrid_reshape(r**2 + a, "raz") for i in range(1, f.shape[-1]): np.testing.assert_allclose(f[..., i - 1], f[..., i]) diff --git a/tests/test_integrals.py b/tests/test_integrals.py index b15b019283..26798f3fbc 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -1,13 +1,27 @@ """Test integration algorithms.""" +from functools import partial + import numpy as np import pytest - +from jax import grad +from matplotlib import pyplot as plt +from numpy.polynomial.chebyshev import chebgauss, chebweight +from numpy.polynomial.legendre import leggauss +from scipy import integrate +from scipy.interpolate import CubicHermiteSpline +from scipy.special import ellipe, ellipkm1, roots_chebyu +from tests.test_plotting import tol_1d + +from desc.backend import jnp from desc.basis import FourierZernikeBasis +from desc.compute.utils import dot, safediv from desc.equilibrium import Equilibrium +from desc.equilibrium.coords import get_rtz_grid from desc.examples import get -from desc.grid import ConcentricGrid, LinearGrid, QuadratureGrid +from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid from desc.integrals import ( + Bounce1D, DFTInterpolator, FFTInterpolator, line_integrals, @@ -20,6 +34,22 @@ surface_variance, virtual_casing_biot_savart, ) +from desc.integrals.bounce_utils import ( + _get_extrema, + bounce_points, + get_pitch_inv, + interp_to_argmin, + interp_to_argmin_hard, +) +from desc.integrals.quad_utils import ( + automorphism_sin, + bijection_from_disc, + get_quadrature, + grad_automorphism_sin, + grad_bijection_from_disc, + leggauss_lob, + tanh_sinh, +) from desc.integrals.singularities import _get_quadrature_nodes from desc.integrals.surface_integral import _get_grid_surface from desc.transform import Transform @@ -688,3 +718,746 @@ def test_biest_interpolators(self): g2 = interp2(f(source_theta, source_zeta), i) np.testing.assert_allclose(g1, g2) np.testing.assert_allclose(g1, ff) + + +class TestBounce1DPoints: + """Test that bounce points are computed correctly.""" + + @staticmethod + def filter(z1, z2): + """Remove bounce points whose integrals have zero measure.""" + mask = (z1 - z2) != 0.0 + return z1[mask], z2[mask] + + @pytest.mark.unit + def test_z1_first(self): + """Case where straight line through first two intersects is in epigraph.""" + start = np.pi / 3 + end = 6 * np.pi + knots = np.linspace(start, end, 5) + B = CubicHermiteSpline(knots, np.cos(knots), -np.sin(knots)) + pitch_inv = 0.5 + intersect = B.solve(pitch_inv, extrapolate=False) + z1, z2 = bounce_points( + pitch_inv, knots, B.c.T, B.derivative().c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + np.testing.assert_allclose(z1, intersect[0::2]) + np.testing.assert_allclose(z2, intersect[1::2]) + + @pytest.mark.unit + def test_z2_first(self): + """Case where straight line through first two intersects is in hypograph.""" + start = -3 * np.pi + end = -start + k = np.linspace(start, end, 5) + B = CubicHermiteSpline(k, np.cos(k), -np.sin(k)) + pitch_inv = 0.5 + intersect = B.solve(pitch_inv, extrapolate=False) + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, B.derivative().c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + np.testing.assert_allclose(z1, intersect[1:-1:2]) + np.testing.assert_allclose(z2, intersect[0::2][1:]) + + @pytest.mark.unit + def test_z1_before_extrema(self): + """Case where local maximum is the shared intersect between two wells.""" + # To make sure both regions in epigraph left and right of extrema are + # integrated over. + start = -np.pi + end = -2 * start + k = np.linspace(start, end, 5) + B = CubicHermiteSpline( + k, np.cos(k) + 2 * np.sin(-2 * k), -np.sin(k) - 4 * np.cos(-2 * k) + ) + dB_dz = B.derivative() + pitch_inv = B(dB_dz.roots(extrapolate=False))[3] - 1e-13 + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + intersect = B.solve(pitch_inv, extrapolate=False) + np.testing.assert_allclose(z1[1], 1.982767, rtol=1e-6) + np.testing.assert_allclose(z1, intersect[[1, 2]], rtol=1e-6) + # intersect array could not resolve double root as single at index 2,3 + np.testing.assert_allclose(intersect[2], intersect[3], rtol=1e-6) + np.testing.assert_allclose(z2, intersect[[3, 4]], rtol=1e-6) + + @pytest.mark.unit + def test_z2_before_extrema(self): + """Case where local minimum is the shared intersect between two wells.""" + # To make sure both regions in hypograph left and right of extrema are not + # integrated over. + start = -1.2 * np.pi + end = -2 * start + k = np.linspace(start, end, 7) + B = CubicHermiteSpline( + k, + np.cos(k) + 2 * np.sin(-2 * k) + k / 4, + -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 4, + ) + dB_dz = B.derivative() + pitch_inv = B(dB_dz.roots(extrapolate=False))[2] + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + intersect = B.solve(pitch_inv, extrapolate=False) + np.testing.assert_allclose(z1, intersect[[0, -2]]) + np.testing.assert_allclose(z2, intersect[[1, -1]]) + + @pytest.mark.unit + def test_extrema_first_and_before_z1(self): + """Case where first intersect is extrema and second enters epigraph.""" + # To make sure we don't perform integral between first pair of intersects. + start = -1.2 * np.pi + end = -2 * start + k = np.linspace(start, end, 7) + B = CubicHermiteSpline( + k, + np.cos(k) + 2 * np.sin(-2 * k) + k / 20, + -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 20, + ) + dB_dz = B.derivative() + pitch_inv = B(dB_dz.roots(extrapolate=False))[2] + 1e-13 + z1, z2 = bounce_points( + pitch_inv, + k[2:], + B.c[:, 2:].T, + dB_dz.c[:, 2:].T, + check=True, + start=k[2], + include_knots=True, + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + intersect = B.solve(pitch_inv, extrapolate=False) + np.testing.assert_allclose(z1[0], 0.835319, rtol=1e-6) + intersect = intersect[intersect >= k[2]] + np.testing.assert_allclose(z1, intersect[[0, 2, 4]], rtol=1e-6) + np.testing.assert_allclose(z2, intersect[[0, 3, 5]], rtol=1e-6) + + @pytest.mark.unit + def test_extrema_first_and_before_z2(self): + """Case where first intersect is extrema and second exits epigraph.""" + # To make sure we do perform integral between first pair of intersects. + start = -1.2 * np.pi + end = -2 * start + 1 + k = np.linspace(start, end, 7) + B = CubicHermiteSpline( + k, + np.cos(k) + 2 * np.sin(-2 * k) + k / 10, + -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 10, + ) + dB_dz = B.derivative() + pitch_inv = B(dB_dz.roots(extrapolate=False))[1] - 1e-13 + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + # Our routine correctly detects intersection, while scipy, jnp.root fails. + intersect = B.solve(pitch_inv, extrapolate=False) + np.testing.assert_allclose(z1[0], -0.671904, rtol=1e-6) + np.testing.assert_allclose(z1, intersect[[0, 3, 5]], rtol=1e-5) + # intersect array could not resolve double root as single at index 0,1 + np.testing.assert_allclose(intersect[0], intersect[1], rtol=1e-5) + np.testing.assert_allclose(z2, intersect[[2, 4, 6]], rtol=1e-5) + + @pytest.mark.unit + def test_get_extrema(self): + """Test computation of extrema of |B|.""" + start = -np.pi + end = -2 * start + k = np.linspace(start, end, 5) + B = CubicHermiteSpline( + k, np.cos(k) + 2 * np.sin(-2 * k), -np.sin(k) - 4 * np.cos(-2 * k) + ) + dB_dz = B.derivative() + ext, B_ext = _get_extrema(k, B.c.T, dB_dz.c.T) + mask = ~np.isnan(ext) + ext, B_ext = ext[mask], B_ext[mask] + idx = np.argsort(ext) + + ext_scipy = np.sort(dB_dz.roots(extrapolate=False)) + B_ext_scipy = B(ext_scipy) + assert ext.size == ext_scipy.size + np.testing.assert_allclose(ext[idx], ext_scipy) + np.testing.assert_allclose(B_ext[idx], B_ext_scipy) + + +def _mod_cheb_gauss(deg): + x, w = chebgauss(deg) + w /= chebweight(x) + return x, w + + +def _mod_chebu_gauss(deg): + x, w = roots_chebyu(deg) + w *= chebweight(x) + return x, w + + +class TestBounce1DQuadrature: + """Test bounce quadrature.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "is_strong, quad, automorphism", + [ + (True, tanh_sinh(40), None), + (True, leggauss(25), "default"), + (False, tanh_sinh(20), None), + (False, leggauss_lob(10), "default"), + # sin automorphism still helps out chebyshev quadrature + (True, _mod_cheb_gauss(30), "default"), + (False, _mod_chebu_gauss(10), "default"), + ], + ) + def test_bounce_quadrature(self, is_strong, quad, automorphism): + """Test quadrature matches singular (strong and weak) elliptic integrals.""" + p = 1e-4 + m = 1 - p + # Some prime number that doesn't appear anywhere in calculation. + # Ensures no lucky cancellation occurs from ζ₂ − ζ₁ / π = π / (ζ₂ − ζ₁) + # which could mask errors since π appears often in transformations. + v = 7 + z1 = -np.pi / 2 * v + z2 = -z1 + knots = np.linspace(z1, z2, 50) + pitch_inv = 1 - 50 * jnp.finfo(jnp.array(1.0).dtype).eps + b = np.clip(np.sin(knots / v) ** 2, 1e-7, 1) + db = np.sin(2 * knots / v) / v + data = {"B^zeta": b, "B^zeta_z|r,a": db, "|B|": b, "|B|_z|r,a": db} + + if is_strong: + integrand = lambda B, pitch: 1 / jnp.sqrt(1 - m * pitch * B) + truth = v * 2 * ellipkm1(p) + else: + integrand = lambda B, pitch: jnp.sqrt(1 - m * pitch * B) + truth = v * 2 * ellipe(m) + kwargs = {} + if automorphism != "default": + kwargs["automorphism"] = automorphism + bounce = Bounce1D( + Grid.create_meshgrid([1, 0, knots], coordinates="raz"), + data, + quad, + check=True, + **kwargs, + ) + result = bounce.integrate(integrand, pitch_inv, check=True, plot=True) + assert np.count_nonzero(result) == 1 + np.testing.assert_allclose(result.sum(), truth, rtol=1e-4) + + @staticmethod + @partial(np.vectorize, excluded={0}) + def _adaptive_elliptic(integrand, k): + a = 0 + b = 2 * np.arcsin(k) + return integrate.quad(integrand, a, b, args=(k,), points=b)[0] + + @staticmethod + def _fixed_elliptic(integrand, k, deg): + k = np.atleast_1d(k) + a = np.zeros_like(k) + b = 2 * np.arcsin(k) + x, w = get_quadrature(leggauss(deg), (automorphism_sin, grad_automorphism_sin)) + Z = bijection_from_disc(x, a[..., np.newaxis], b[..., np.newaxis]) + k = k[..., np.newaxis] + quad = integrand(Z, k).dot(w) * grad_bijection_from_disc(a, b) + return quad + + # TODO: add the analytical test that converts incomplete elliptic integrals to + # complete ones using the Reciprocal Modulus transformation + # https://dlmf.nist.gov/19.7#E4. + @staticmethod + def elliptic_incomplete(k2): + """Calculate elliptic integrals for bounce averaged binormal drift. + + The test is nice because it is independent of all the bounce integrals + and splines. One can test performance of different quadrature methods + by using that method in the ``_fixed_elliptic`` method above. + + """ + K_integrand = lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * (k / 4) + E_integrand = lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) / (k * 4) + # Scipy's elliptic integrals are broken. + # https://github.com/scipy/scipy/issues/20525. + k = np.sqrt(k2) + K = TestBounce1DQuadrature._adaptive_elliptic(K_integrand, k) + E = TestBounce1DQuadrature._adaptive_elliptic(E_integrand, k) + # Make sure scipy's adaptive quadrature is not broken. + np.testing.assert_allclose( + K, TestBounce1DQuadrature._fixed_elliptic(K_integrand, k, 10) + ) + np.testing.assert_allclose( + E, TestBounce1DQuadrature._fixed_elliptic(E_integrand, k, 10) + ) + + I_0 = 4 / k * K + I_1 = 4 * k * E + I_2 = 16 * k * E + I_3 = 16 * k / 9 * (2 * (-1 + 2 * k2) * E - (-1 + k2) * K) + I_4 = 16 * k / 3 * ((-1 + 2 * k2) * E - 2 * (-1 + k2) * K) + I_5 = 32 * k / 30 * (2 * (1 - k2 + k2**2) * E - (1 - 3 * k2 + 2 * k2**2) * K) + I_6 = 4 / k * (2 * k2 * E + (1 - 2 * k2) * K) + I_7 = 2 * k / 3 * ((-2 + 4 * k2) * E - 4 * (-1 + k2) * K) + # Check for math mistakes. + np.testing.assert_allclose( + I_2, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * Z * np.sin(Z), k + ), + ) + np.testing.assert_allclose( + I_3, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) * Z * np.sin(Z), k + ), + ) + np.testing.assert_allclose( + I_4, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.sin(Z) ** 2, k + ), + ) + np.testing.assert_allclose( + I_5, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.sin(Z) ** 2, k + ), + ) + # scipy fails + np.testing.assert_allclose( + I_6, + TestBounce1DQuadrature._fixed_elliptic( + lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.cos(Z), + k, + deg=11, + ), + ) + np.testing.assert_allclose( + I_7, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.cos(Z), k + ), + ) + return I_0, I_1, I_2, I_3, I_4, I_5, I_6, I_7 + + +class TestBounce1D: + """Test bounce integration with one-dimensional local spline methods.""" + + @staticmethod + def _example_numerator(g_zz, B, pitch): + f = (1 - 0.5 * pitch * B) * g_zz + return safediv(f, jnp.sqrt(jnp.abs(1 - pitch * B))) + + @staticmethod + def _example_denominator(B, pitch): + return safediv(1, jnp.sqrt(jnp.abs(1 - pitch * B))) + + @pytest.mark.unit + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d * 4) + def test_bounce1d_checks(self): + """Test that all the internal correctness checks pass for real example.""" + # noqa: D202 + # Suppose we want to compute a bounce average of the function + # f(ℓ) = (1 − λ|B|/2) * g_zz, where g_zz is the squared norm of the + # toroidal basis vector on some set of field lines specified by (ρ, α) + # coordinates. This is defined as + # [∫ f(ℓ) / √(1 − λ|B|) dℓ] / [∫ 1 / √(1 − λ|B|) dℓ] + + # 1. Define python functions for the integrands. We do that above. + # 2. Pick flux surfaces, field lines, and how far to follow the field + # line in Clebsch coordinates ρ, α, ζ. + rho = np.linspace(0.1, 1, 6) + alpha = np.array([0, 0.5]) + zeta = np.linspace(-2 * np.pi, 2 * np.pi, 200) + + eq = get("HELIOTRON") + # 3. Convert above coordinates to DESC computational coordinates. + grid = get_rtz_grid( + eq, rho, alpha, zeta, coordinates="raz", period=(np.inf, 2 * np.pi, np.inf) + ) + # 4. Compute input data. + data = eq.compute( + Bounce1D.required_names + ["min_tz |B|", "max_tz |B|", "g_zz"], grid=grid + ) + # 5. Make the bounce integration operator. + bounce = Bounce1D( + grid.source_grid, + data, + quad=leggauss(3), # not checking quadrature accuracy in this test + check=True, + ) + pitch_inv = bounce.get_pitch_inv( + grid.compress(data["min_tz |B|"]), grid.compress(data["max_tz |B|"]), 10 + ) + num = bounce.integrate( + integrand=TestBounce1D._example_numerator, + pitch_inv=pitch_inv, + f=Bounce1D.reshape_data(grid.source_grid, data["g_zz"]), + check=True, + ) + den = bounce.integrate( + integrand=TestBounce1D._example_denominator, + pitch_inv=pitch_inv, + check=True, + batch=False, + ) + avg = safediv(num, den) + assert np.isfinite(avg).all() and np.count_nonzero(avg) + + # 6. Basic manipulation of the output. + # Sum all bounce averages across a particular field line, for every field line. + result = avg.sum(axis=-1) + # Group the result by pitch and flux surface. + result = result.reshape(alpha.size, rho.size, pitch_inv.shape[-1]) + # The result stored at + m, l, p = 0, 1, 3 + print("Result(α, ρ, λ):", result[m, l, p]) + # corresponds to the 1/λ value + print("1/λ(α, ρ):", pitch_inv[l, p]) + # for the Clebsch-type field line coordinates + nodes = grid.source_grid.meshgrid_reshape(grid.source_grid.nodes[:, :2], "arz") + print("(α, ρ):", nodes[m, l, 0]) + + # 7. Optionally check for correctness of bounce points + bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False) + + # 8. Plotting + fig, ax = bounce.plot(m, l, pitch_inv[l], include_legend=False, show=False) + return fig + + @pytest.mark.unit + @pytest.mark.parametrize("func", [interp_to_argmin, interp_to_argmin_hard]) + def test_interp_to_argmin(self, func): + """Test argmin interpolation.""" # noqa: D202 + + # Test functions chosen with purpose; don't change unless plotted and compared. + def h(z): + return np.cos(3 * z) * np.sin(2 * np.cos(z)) + np.cos(1.2 * z) + + def g(z): + return np.sin(3 * z) * np.cos(1 / (1 + z)) * np.cos(z**2) * z + + def dg_dz(z): + return ( + 3 * z * np.cos(3 * z) * np.cos(z**2) * np.cos(1 / (1 + z)) + - 2 * z**2 * np.sin(3 * z) * np.sin(z**2) * np.cos(1 / (1 + z)) + + z * np.sin(3 * z) * np.sin(1 / (1 + z)) * np.cos(z**2) / (1 + z) ** 2 + + np.sin(3 * z) * np.cos(z**2) * np.cos(1 / (1 + z)) + ) + + zeta = np.linspace(0, 3 * np.pi, 175) + bounce = Bounce1D( + Grid.create_meshgrid([1, 0, zeta], coordinates="raz"), + { + "B^zeta": np.ones_like(zeta), + "B^zeta_z|r,a": np.ones_like(zeta), + "|B|": g(zeta), + "|B|_z|r,a": dg_dz(zeta), + }, + ) + z1 = np.array(0, ndmin=4) + z2 = np.array(2 * np.pi, ndmin=4) + argmin = 5.61719 + h_min = h(argmin) + result = func( + h=h(zeta), + z1=z1, + z2=z2, + knots=zeta, + g=bounce.B, + dg_dz=bounce._dB_dz, + ) + assert result.shape == z1.shape + np.testing.assert_allclose(h_min, result, rtol=1e-3) + + # TODO: stellarator geometry test with ripples + @staticmethod + def drift_analytic(data): + """Compute analytic approximation for bounce-averaged binormal drift. + + Returns + ------- + drift_analytic : jnp.ndarray + Analytic approximation for the true result that the numerical computation + should attempt to match. + cvdrift, gbdrift : jnp.ndarray + Numerically computed ``data["cvdrift"]` and ``data["gbdrift"]`` normalized + by some scale factors for this unit test. These should be fed to the bounce + integration as input. + pitch_inv : jnp.ndarray + Shape (P, ). + 1/λ values used. + + """ + B = data["|B|"] / data["Bref"] + B0 = np.mean(B) + # epsilon should be changed to dimensionless, and computed in a way that + # is independent of normalization length scales, like "effective r/R0". + epsilon = data["a"] * data["rho"] # Aspect ratio of the flux surface. + np.testing.assert_allclose(epsilon, 0.05) + theta_PEST = data["alpha"] + data["iota"] * data["zeta"] + # same as 1 / (1 + epsilon cos(theta)) assuming epsilon << 1 + B_analytic = B0 * (1 - epsilon * np.cos(theta_PEST)) + np.testing.assert_allclose(B, B_analytic, atol=3e-3) + + gradpar = data["a"] * data["B^zeta"] / data["|B|"] + # This method of computing G0 suggests a fixed point iteration. + G0 = data["a"] + gradpar_analytic = G0 * (1 - epsilon * np.cos(theta_PEST)) + gradpar_theta_analytic = data["iota"] * gradpar_analytic + G0 = np.mean(gradpar_theta_analytic) + np.testing.assert_allclose(gradpar, gradpar_analytic, atol=5e-3) + + # Comparing coefficient calculation here with coefficients from compute/_metric + normalization = -np.sign(data["psi"]) * data["Bref"] * data["a"] ** 2 + cvdrift = data["cvdrift"] * normalization + gbdrift = data["gbdrift"] * normalization + dPdrho = np.mean(-0.5 * (cvdrift - gbdrift) * data["|B|"] ** 2) + alpha_MHD = -0.5 * dPdrho / data["iota"] ** 2 + gds21 = ( + -np.sign(data["iota"]) + * data["shear"] + * dot(data["grad(psi)"], data["grad(alpha)"]) + / data["Bref"] + ) + gds21_analytic = -data["shear"] * ( + data["shear"] * theta_PEST - alpha_MHD / B**4 * np.sin(theta_PEST) + ) + gds21_analytic_low_order = -data["shear"] * ( + data["shear"] * theta_PEST - alpha_MHD / B0**4 * np.sin(theta_PEST) + ) + np.testing.assert_allclose(gds21, gds21_analytic, atol=2e-2) + np.testing.assert_allclose(gds21, gds21_analytic_low_order, atol=2.7e-2) + + fudge_1 = 0.19 + gbdrift_analytic = fudge_1 * ( + -data["shear"] + + np.cos(theta_PEST) + - gds21_analytic / data["shear"] * np.sin(theta_PEST) + ) + gbdrift_analytic_low_order = fudge_1 * ( + -data["shear"] + + np.cos(theta_PEST) + - gds21_analytic_low_order / data["shear"] * np.sin(theta_PEST) + ) + fudge_2 = 0.07 + cvdrift_analytic = gbdrift_analytic + fudge_2 * alpha_MHD / B**2 + cvdrift_analytic_low_order = ( + gbdrift_analytic_low_order + fudge_2 * alpha_MHD / B0**2 + ) + np.testing.assert_allclose(gbdrift, gbdrift_analytic, atol=1e-2) + np.testing.assert_allclose(cvdrift, cvdrift_analytic, atol=2e-2) + np.testing.assert_allclose(gbdrift, gbdrift_analytic_low_order, atol=1e-2) + np.testing.assert_allclose(cvdrift, cvdrift_analytic_low_order, atol=2e-2) + + # Exclude singularity not captured by analytic approximation for pitch near + # the maximum |B|. (This is captured by the numerical integration). + pitch_inv = get_pitch_inv(np.min(B), np.max(B), 100)[:-1] + k2 = 0.5 * ((1 - B0 / pitch_inv) / (epsilon * B0 / pitch_inv) + 1) + I_0, I_1, I_2, I_3, I_4, I_5, I_6, I_7 = ( + TestBounce1DQuadrature.elliptic_incomplete(k2) + ) + y = np.sqrt(2 * epsilon * B0 / pitch_inv) + I_0, I_2, I_4, I_6 = map(lambda I: I / y, (I_0, I_2, I_4, I_6)) + I_1, I_3, I_5, I_7 = map(lambda I: I * y, (I_1, I_3, I_5, I_7)) + + drift_analytic_num = ( + fudge_2 * alpha_MHD / B0**2 * I_1 + - 0.5 + * fudge_1 + * ( + data["shear"] * (I_0 + I_1 - I_2 - I_3) + + alpha_MHD / B0**4 * (I_4 + I_5) + - (I_6 + I_7) + ) + ) / G0 + drift_analytic_den = I_0 / G0 + drift_analytic = drift_analytic_num / drift_analytic_den + return drift_analytic, cvdrift, gbdrift, pitch_inv + + @staticmethod + def drift_num_integrand(cvdrift, gbdrift, B, pitch): + """Integrand of numerator of bounce averaged binormal drift.""" + g = jnp.sqrt(1 - pitch * B) + return (cvdrift * g) - (0.5 * g * gbdrift) + (0.5 * gbdrift / g) + + @staticmethod + def drift_den_integrand(B, pitch): + """Integrand of denominator of bounce averaged binormal drift.""" + return 1 / jnp.sqrt(1 - pitch * B) + + @pytest.mark.unit + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d) + def test_binormal_drift_bounce1d(self): + """Test bounce-averaged drift with analytical expressions.""" + eq = Equilibrium.load(".//tests//inputs//low-beta-shifted-circle.h5") + psi_boundary = eq.Psi / (2 * np.pi) + psi = 0.25 * psi_boundary + rho = np.sqrt(psi / psi_boundary) + np.testing.assert_allclose(rho, 0.5) + + # Make a set of nodes along a single fieldline. + grid_fsa = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) + data = eq.compute(["iota"], grid=grid_fsa) + iota = grid_fsa.compress(data["iota"]).item() + alpha = 0 + zeta = np.linspace(-np.pi / iota, np.pi / iota, (2 * eq.M_grid) * 4 + 1) + grid = get_rtz_grid( + eq, + rho, + alpha, + zeta, + coordinates="raz", + period=(np.inf, 2 * np.pi, np.inf), + iota=iota, + ) + data = eq.compute( + Bounce1D.required_names + + [ + "cvdrift", + "gbdrift", + "grad(psi)", + "grad(alpha)", + "shear", + "iota", + "psi", + "a", + ], + grid=grid, + ) + np.testing.assert_allclose(data["psi"], psi) + np.testing.assert_allclose(data["iota"], iota) + assert np.all(data["B^zeta"] > 0) + data["Bref"] = 2 * np.abs(psi_boundary) / data["a"] ** 2 + data["rho"] = rho + data["alpha"] = alpha + data["zeta"] = zeta + data["psi"] = grid.compress(data["psi"]) + data["iota"] = grid.compress(data["iota"]) + data["shear"] = grid.compress(data["shear"]) + + # Compute analytic approximation. + drift_analytic, cvdrift, gbdrift, pitch_inv = TestBounce1D.drift_analytic(data) + # Compute numerical result. + bounce = Bounce1D( + grid.source_grid, + data, + quad=leggauss(28), # converges to absolute and relative tolerance of 1e-7 + Bref=data["Bref"], + Lref=data["a"], + check=True, + ) + bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False) + + f = Bounce1D.reshape_data(grid.source_grid, cvdrift, gbdrift) + drift_numerical_num = bounce.integrate( + integrand=TestBounce1D.drift_num_integrand, + pitch_inv=pitch_inv, + f=f, + num_well=1, + check=True, + ) + drift_numerical_den = bounce.integrate( + integrand=TestBounce1D.drift_den_integrand, + pitch_inv=pitch_inv, + num_well=1, + weight=np.ones(zeta.size), + check=True, + ) + drift_numerical = np.squeeze(drift_numerical_num / drift_numerical_den) + msg = "There should be one bounce integral per pitch in this example." + assert drift_numerical.size == drift_analytic.size, msg + np.testing.assert_allclose( + drift_numerical, drift_analytic, atol=5e-3, rtol=5e-2 + ) + + TestBounce1D._test_bounce_autodiff( + bounce, + TestBounce1D.drift_num_integrand, + f=f, + weight=np.ones(zeta.size), + ) + + fig, ax = plt.subplots() + ax.plot(pitch_inv, drift_analytic) + ax.plot(pitch_inv, drift_numerical) + return fig + + @staticmethod + def _test_bounce_autodiff(bounce, integrand, **kwargs): + """Make sure reverse mode AD works correctly on this algorithm. + + Non-differentiable operations (e.g. ``take_mask``) are used in computation. + See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html + and https://jax.readthedocs.io/en/latest/faq.html# + why-are-gradients-zero-for-functions-based-on-sort-order. + + If the AD tool works properly, then these operations should be assigned + zero gradients while the gradients wrt parameters of our physics computations + accumulate correctly. Less mature AD tools may have subtle bugs that cause + the gradients to not accumulate correctly. (There's a few + GitHub issues that JAX has fixed related to this in the past.) + + This test first confirms the gradients computed by reverse mode AD matches + the analytic approximation of the true gradient. Then we confirm that the + partial gradients wrt the integrand and bounce points are correct. + + Apply the Leibniz integral rule + https://en.wikipedia.org/wiki/Leibniz_integral_rule, with + the label w summing over the magnetic wells: + + ∂_λ ∑_w ∫_ζ₁^ζ₂ f dζ (λ) = ∑_w [ + ∫_ζ₁^ζ₂ (∂f/∂λ)(λ) dζ + + f(λ,ζ₂) (∂ζ₂/∂λ)(λ) + - f(λ,ζ₁) (∂ζ₁/∂λ)(λ) + ] + where (∂ζ₁/∂λ)(λ) = -λ² / (∂|B|/∂ζ|ρ,α)(ζ₁) + (∂ζ₂/∂λ)(λ) = -λ² / (∂|B|/∂ζ|ρ,α)(ζ₂) + + All terms in these expressions are known analytically. + If we wanted, it's simple to check explicitly that AD takes each derivative + correctly because |w| = 1 is constant and our tokamak has symmetry + (∂|B|/∂ζ|ρ,α)(ζ₁) = - (∂|B|/∂ζ|ρ,α)(ζ₂). + + After confirming the left hand side is correct, we just check that derivative + wrt bounce points of the right hand side doesn't vanish due to some zero + gradient issue mentioned above. + + """ + + def integrand_grad(*args, **kwargs2): + grad_fun = jnp.vectorize( + grad(integrand, -1), signature="()," * len(kwargs["f"]) + "(),()->()" + ) + return grad_fun(*args, *kwargs2.values()) + + def fun1(pitch): + return bounce.integrate(integrand, 1 / pitch, check=False, **kwargs).sum() + + def fun2(pitch): + return bounce.integrate( + integrand_grad, 1 / pitch, check=True, **kwargs + ).sum() + + pitch = 1.0 + # can easily obtain from math or just extrapolate from analytic expression plot + analytic_approximation_of_gradient = 650 + np.testing.assert_allclose( + grad(fun1)(pitch), analytic_approximation_of_gradient, rtol=1e-3 + ) + # It is expected that this is much larger because the integrand is singular + # wrt λ but the boundary derivative: f(λ,ζ₂) (∂ζ₂/∂λ)(λ) - f(λ,ζ₁) (∂ζ₁/∂λ)(λ). + # smooths out because the bounce points ζ₁ and ζ₂ are smooth functions of λ. + np.testing.assert_allclose(fun2(pitch), -131750, rtol=1e-1) diff --git a/tests/test_interp_utils.py b/tests/test_interp_utils.py new file mode 100644 index 0000000000..606b0fe090 --- /dev/null +++ b/tests/test_interp_utils.py @@ -0,0 +1,103 @@ +"""Test interpolation utilities.""" + +import numpy as np +import pytest +from numpy.polynomial.polynomial import polyvander + +from desc.integrals.interp_utils import polyder_vec, polyroot_vec, polyval_vec + + +class TestPolyUtils: + """Test polynomial utilities used for local spline interpolation in integrals.""" + + @pytest.mark.unit + def test_polyroot_vec(self): + """Test vectorized computation of cubic polynomial exact roots.""" + c = np.arange(-24, 24).reshape(4, 6, -1).transpose(-1, 1, 0) + # Ensure broadcasting won't hide error in implementation. + assert np.unique(c.shape).size == c.ndim + + k = np.broadcast_to(np.arange(c.shape[-2]), c.shape[:-1]) + # Now increase dimension so that shapes still broadcast, but stuff like + # ``c[...,-1]-=k`` is not allowed because it grows the dimension of ``c``. + # This is needed functionality in ``polyroot_vec`` that requires an awkward + # loop to obtain if using jnp.vectorize. + k = np.stack([k, k * 2 + 1]) + r = polyroot_vec(c, k, sort=True) + + for i in range(k.shape[0]): + d = c.copy() + d[..., -1] -= k[i] + # np.roots cannot be vectorized because it strips leading zeros and + # output shape is therefore dynamic. + for idx in np.ndindex(d.shape[:-1]): + np.testing.assert_allclose( + r[(i, *idx)], + np.sort(np.roots(d[idx])), + err_msg=f"Eigenvalue branch of polyroot_vec failed at {i, *idx}.", + ) + + # Now test analytic formula branch, Ensure it filters distinct roots, + # and ensure zero coefficients don't bust computation due to singularities + # in analytic formulae which are not present in iterative eigenvalue scheme. + c = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + [1, -1, -8, 12], + [1, -6, 11, -6], + [0, -6, 11, -2], + ] + ) + r = polyroot_vec(c, sort=True, distinct=True) + for j in range(c.shape[0]): + root = r[j][~np.isnan(r[j])] + unique_root = np.unique(np.roots(c[j])) + assert root.size == unique_root.size + np.testing.assert_allclose( + root, + unique_root, + err_msg=f"Analytic branch of polyroot_vec failed at {j}.", + ) + c = np.array([0, 1, -1, -8, 12]) + r = polyroot_vec(c, sort=True, distinct=True) + r = r[~np.isnan(r)] + unique_r = np.unique(np.roots(c)) + assert r.size == unique_r.size + np.testing.assert_allclose(r, unique_r) + + @pytest.mark.unit + def test_polyder_vec(self): + """Test vectorized computation of polynomial derivative.""" + c = np.arange(-18, 18).reshape(3, -1, 6) + # Ensure broadcasting won't hide error in implementation. + assert np.unique(c.shape).size == c.ndim + np.testing.assert_allclose( + polyder_vec(c), + np.vectorize(np.polyder, signature="(m)->(n)")(c), + ) + + @pytest.mark.unit + def test_polyval_vec(self): + """Test vectorized computation of polynomial evaluation.""" + + def test(x, c): + # Ensure broadcasting won't hide error in implementation. + assert np.unique(x.shape).size == x.ndim + assert np.unique(c.shape).size == c.ndim + np.testing.assert_allclose( + polyval_vec(x=x, c=c), + np.sum(polyvander(x, c.shape[-1] - 1) * c[..., ::-1], axis=-1), + ) + + c = np.arange(-60, 60).reshape(-1, 5, 3) + x = np.linspace(0, 20, np.prod(c.shape[:-1])).reshape(c.shape[:-1]) + test(x, c) + + x = np.stack([x, x * 2], axis=0) + x = np.stack([x, x * 2, x * 3, x * 4], axis=0) + assert c.shape[:-1] == x.shape[x.ndim - (c.ndim - 1) :] + assert np.unique((c.shape[-1],) + x.shape[c.ndim - 1 :]).size == x.ndim - 1 + test(x, c) diff --git a/tests/test_quad_utils.py b/tests/test_quad_utils.py new file mode 100644 index 0000000000..5a7c3d00e7 --- /dev/null +++ b/tests/test_quad_utils.py @@ -0,0 +1,103 @@ +"""Tests for quadrature utilities.""" + +import numpy as np +import pytest +from jax import grad + +from desc.backend import jnp +from desc.integrals.quad_utils import ( + automorphism_arcsin, + automorphism_sin, + bijection_from_disc, + bijection_to_disc, + composite_linspace, + grad_automorphism_arcsin, + grad_automorphism_sin, + grad_bijection_from_disc, + leggauss_lob, + tanh_sinh, +) +from desc.utils import only1 + + +@pytest.mark.unit +def test_composite_linspace(): + """Test this utility function which is used for integration over pitch.""" + B_min_tz = np.array([0.1, 0.2]) + B_max_tz = np.array([1, 3]) + breaks = np.linspace(B_min_tz, B_max_tz, num=5) + b = composite_linspace(breaks, num=3) + for i in range(breaks.shape[0]): + for j in range(breaks.shape[1]): + assert only1(np.isclose(breaks[i, j], b[:, j]).tolist()) + + +@pytest.mark.unit +def test_automorphism(): + """Test automorphisms.""" + a, b = -312, 786 + x = np.linspace(a, b, 10) + y = bijection_to_disc(x, a, b) + x_1 = bijection_from_disc(y, a, b) + np.testing.assert_allclose(x_1, x) + np.testing.assert_allclose(bijection_to_disc(x_1, a, b), y) + np.testing.assert_allclose(automorphism_arcsin(automorphism_sin(y)), y, atol=5e-7) + np.testing.assert_allclose(automorphism_sin(automorphism_arcsin(y)), y, atol=5e-7) + + np.testing.assert_allclose(grad_bijection_from_disc(a, b), 1 / (2 / (b - a))) + np.testing.assert_allclose( + grad_automorphism_sin(y), + 1 / grad_automorphism_arcsin(automorphism_sin(y)), + atol=2e-6, + ) + np.testing.assert_allclose( + 1 / grad_automorphism_arcsin(y), + grad_automorphism_sin(automorphism_arcsin(y)), + atol=2e-6, + ) + + # test that floating point error is acceptable + x = tanh_sinh(19)[0] + assert np.all(np.abs(x) < 1) + y = 1 / np.sqrt(1 - np.abs(x)) + assert np.isfinite(y).all() + y = 1 / np.sqrt(1 - np.abs(automorphism_sin(x))) + assert np.isfinite(y).all() + y = 1 / np.sqrt(1 - np.abs(automorphism_arcsin(x))) + assert np.isfinite(y).all() + + +@pytest.mark.unit +def test_leggauss_lobatto(): + """Test quadrature points and weights against known values.""" + with pytest.raises(ValueError): + x, w = leggauss_lob(1) + x, w = leggauss_lob(0, True) + assert x.size == w.size == 0 + + x, w = leggauss_lob(2) + np.testing.assert_allclose(x, [-1, 1]) + np.testing.assert_allclose(w, [1, 1]) + + x, w = leggauss_lob(3) + np.testing.assert_allclose(x, [-1, 0, 1]) + np.testing.assert_allclose(w, [1 / 3, 4 / 3, 1 / 3]) + np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + x, w = leggauss_lob(4) + np.testing.assert_allclose(x, [-1, -np.sqrt(1 / 5), np.sqrt(1 / 5), 1]) + np.testing.assert_allclose(w, [1 / 6, 5 / 6, 5 / 6, 1 / 6]) + np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + x, w = leggauss_lob(5) + np.testing.assert_allclose(x, [-1, -np.sqrt(3 / 7), 0, np.sqrt(3 / 7), 1]) + np.testing.assert_allclose(w, [1 / 10, 49 / 90, 32 / 45, 49 / 90, 1 / 10]) + np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + def fun(a): + x, w = leggauss_lob(a.size) + return jnp.dot(x * a, w) + + # make sure differentiable + # https://github.com/PlasmaControl/DESC/pull/854#discussion_r1733323161 + assert np.isfinite(grad(fun)(jnp.arange(10) * np.pi)).all() diff --git a/tests/test_utils.py b/tests/test_utils.py index 6bfadb4008..2812e8a01b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,13 @@ """Tests for utility functions.""" +from functools import partial + import numpy as np import pytest -from desc.backend import tree_leaves, tree_structure +from desc.backend import flatnonzero, jnp, tree_leaves, tree_structure from desc.grid import LinearGrid -from desc.utils import broadcast_tree, isalmostequal, islinspaced +from desc.utils import broadcast_tree, isalmostequal, islinspaced, take_mask @pytest.mark.unit @@ -197,3 +199,35 @@ def test_broadcast_tree(): ] for leaf, leaf_correct in zip(tree_leaves(tree), tree_leaves(tree_correct)): np.testing.assert_allclose(leaf, leaf_correct) + + +@partial(jnp.vectorize, signature="(m)->()") +def _last_value(a): + """Return the last non-nan value in ``a``.""" + a = a[::-1] + idx = jnp.squeeze(flatnonzero(~jnp.isnan(a), size=1, fill_value=0)) + return a[idx] + + +@pytest.mark.unit +def test_take_mask(): + """Test custom masked array operation.""" + rows = 5 + cols = 7 + a = np.random.rand(rows, cols) + nan_idx = np.random.choice(rows * cols, size=(rows * cols) // 2, replace=False) + a.ravel()[nan_idx] = np.nan + taken = take_mask(a, ~np.isnan(a)) + last = _last_value(taken) + for i in range(rows): + desired = a[i, ~np.isnan(a[i])] + assert np.array_equal( + taken[i], + np.pad(desired, (0, cols - desired.size), constant_values=np.nan), + equal_nan=True, + ) + assert np.array_equal( + last[i], + desired[-1] if desired.size else np.nan, + equal_nan=True, + )