Skip to content

Commit

Permalink
make sparse optional
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang committed Sep 21, 2024
1 parent ddc7a42 commit 28fead0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ dependencies = [
"xarray",
"flox",
"scipy",
"sparse",
"opt-einsum",
]

[tool.hatch.build]
Expand All @@ -41,10 +39,15 @@ Issues = "https://github.com/EXCITED-CO2/xarray-regrid/issues"
Source = "https://github.com/EXCITED-CO2/xarray-regrid"

[project.optional-dependencies]
accel = [
"sparse",
"opt-einsum",
]
benchmarking = [
"dask[distributed]",
"matplotlib",
"zarr",
"h5netcdf",
"requests",
"aiohttp",
]
Expand All @@ -71,7 +74,7 @@ docs = [ # Required for ReadTheDocs
path = "src/xarray_regrid/__init__.py"

[tool.hatch.envs.default]
features = ["dev", "benchmarking"]
features = ["accel", "dev", "benchmarking"]

[tool.hatch.envs.default.scripts]
lint = [
Expand Down
24 changes: 18 additions & 6 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import numpy as np
import xarray as xr
from sparse import COO # type: ignore

try:
import sparse # type: ignore
except ImportError:
sparse = None

from xarray_regrid import utils

Expand Down Expand Up @@ -126,7 +130,11 @@ def conservative_regrid_dataset(

for array in data_vars.keys():
if coord in data_vars[array].dims:
var_weights = sparsify_weights(weights, data_vars[array])
if sparse is not None:
var_weights = sparsify_weights(weights, data_vars[array])
else:
var_weights = weights

data_vars[array], valid_fracs[array] = apply_weights(
da=data_vars[array],
weights=var_weights,
Expand Down Expand Up @@ -200,8 +208,12 @@ def apply_weights(
valid_frac = valid_frac.clip(0, 1)

# In some cases, dot product of dask data and sparse weights fails
# to densify, which prevents future conversion to numpy
if da_reduced.chunks and isinstance(da_reduced.data._meta, COO):
# to automatically densify, which prevents future conversion to numpy
if (
sparse is not None
and da_reduced.chunks
and isinstance(da_reduced.data._meta, sparse.COO)
):
da_reduced.data = da_reduced.data.map_blocks(
lambda x: x.todense(), dtype=da_reduced.dtype
)
Expand Down Expand Up @@ -268,8 +280,8 @@ def sparsify_weights(weights: xr.DataArray, da: xr.DataArray) -> xr.DataArray:
new_weights = weights.copy().astype(da.dtype)
if da.chunks:
chunks = {k: v for k, v in da.chunksizes.items() if k in weights.dims}
new_weights.data = new_weights.chunk(chunks).data.map_blocks(COO)
new_weights.data = new_weights.chunk(chunks).data.map_blocks(sparse.COO)
else:
new_weights.data = COO(weights.data)
new_weights.data = sparse.COO(weights.data)

return new_weights
2 changes: 1 addition & 1 deletion tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_conservative_nan_thresholds_against_coarsen(nan_threshold):

@pytest.mark.skipif(xesmf is None, reason="xesmf required")
def test_conservative_nan_thresholds_against_xesmf():
ds = xr.tutorial.open_dataset("ersstv5").sst.compute().isel(time=[0])
ds = xr.tutorial.open_dataset("ersstv5").sst.isel(time=[0]).compute()
ds = ds.rename(lon="longitude", lat="latitude")
new_grid = xarray_regrid.Grid(
north=90,
Expand Down

0 comments on commit 28fead0

Please sign in to comment.