Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ku/fourier bounce part1 #1259

Merged
merged 11 commits into from
Sep 24, 2024
58 changes: 27 additions & 31 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,7 @@
from jax.numpy import bincount, flatnonzero, repeat, take
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import (
block_diag,
cho_factor,
cho_solve,
eigh_tridiagonal,
qr,
solve_triangular,
)
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.scipy.special import gammaln, logsumexp
from jax.tree_util import (
register_pytree_node,
Expand All @@ -98,6 +91,31 @@
jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid
)

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.

Parameters
----------
func : callable
Function to decorate

Returns
-------
wrapper : callable
Decorated function that will always run on CPU even if
there are available GPUs.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)

return wrapper

# JAX implementation is not differentiable on gpu.
eigh_tridiagonal = execute_on_cpu(jax.scipy.linalg.eigh_tridiagonal)

def put(arr, inds, vals):
"""Functional interface for array "fancy indexing".

Expand All @@ -123,28 +141,6 @@ def put(arr, inds, vals):
return arr
return jnp.asarray(arr).at[inds].set(vals)

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.

Parameters
----------
func : callable
Function to decorate

Returns
-------
wrapper : callable
Decorated function that will run always on CPU even if
there are available GPUs.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)

return wrapper

def sign(x):
"""Sign function, but returns 1 for x==0.

Expand Down Expand Up @@ -427,7 +423,7 @@ def tangent_solve(g, y):

trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz

def imap(f, xs, batch_size=None, in_axes=0, out_axes=0):
def imap(f, xs, *, batch_size=None, in_axes=0, out_axes=0):
"""Generalizes jax.lax.map; uses numpy."""
if not isinstance(xs, np.ndarray):
raise NotImplementedError(
Expand Down
7 changes: 3 additions & 4 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from desc.compute import data_index, get_data_deps, get_profiles, get_transforms
from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid
from desc.transform import Transform
from desc.utils import check_posint, errorif, setdefault, warnif
from desc.utils import check_posint, errorif, safenorm, setdefault, warnif


def _periodic(x, period):
Expand Down Expand Up @@ -272,9 +272,8 @@ def _initial_guess_nn_search(coords, inbasis, eq, period, compute):
coords = jnp.asarray(coords)

def _distance_body(i, idx):
d = _periodic(coords[i], period) - _periodic(xg, period)
d = jnp.where((d > period / 2) & jnp.isfinite(period), period - d, d)
distance = jnp.linalg.norm(d, axis=-1)
d = _fixup_residual(coords[i] - xg, period)
distance = safenorm(d, axis=-1)
k = jnp.argmin(distance)
idx = put(idx, i, k)
return idx
Expand Down
Loading