From a756cc296bfa357f3402e7f80218c3e136fd314c Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 25 Sep 2024 13:01:58 -0400 Subject: [PATCH 1/6] dot product over all dims simultaneously --- src/xarray_regrid/methods/conservative.py | 127 +++++++++------------- src/xarray_regrid/regrid.py | 14 +-- 2 files changed, 56 insertions(+), 85 deletions(-) diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 8c384b6..7a361b2 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -13,8 +13,6 @@ from xarray_regrid import utils -EMPTY_DA_NAME = "FRAC_EMPTY" - @overload def conservative_regrid( @@ -107,13 +105,15 @@ def conservative_regrid_dataset( """Dataset implementation of the conservative regridding method.""" data_vars = dict(data.data_vars) data_coords = dict(data.coords) - valid_fracs = {v: xr.DataArray(name=EMPTY_DA_NAME) for v in data_vars} data_attrs = {v: data_vars[v].attrs for v in data_vars} coord_attrs = {c: data_coords[c].attrs for c in data_coords} ds_attrs = data.attrs + # Create weights array and coverage mask for each regridding dim + weights = {} + covered = {} for coord in coords: - covered_grid = (coords[coord] <= data[coord].max()) & ( + covered[coord] = (coords[coord] <= data[coord].max()) & ( coords[coord] >= data[coord].min() ) @@ -121,36 +121,37 @@ def conservative_regrid_dataset( source_coords = data[coord].to_numpy() nd_weights = get_weights(source_coords, target_coords) - # Modify weights to correct for latitude distortion - weights = utils.create_dot_dataarray( + da_weights = utils.create_dot_dataarray( nd_weights, str(coord), target_coords, source_coords ) + # Modify weights to correct for latitude distortion if coord == latitude_coord: - weights = apply_spherical_correction(weights, latitude_coord) - - for array in data_vars.keys(): - if coord in data_vars[array].dims: - if sparse is not None: - var_weights = sparsify_weights(weights, data_vars[array]) - else: - var_weights = weights - - data_vars[array], valid_fracs[array] = apply_weights( - da=data_vars[array], - weights=var_weights, - coord=coord, - valid_frac=valid_fracs[array], - skipna=skipna, - ) - # Mask out any regridded points outside the original domain - data_vars[array] = data_vars[array].where(covered_grid) - - if skipna: - # Mask out any points that don't meet the nan threshold - valid_threshold = get_valid_threshold(nan_threshold) - for array, da in data_vars.items(): - data_vars[array] = da.where(valid_fracs[array] >= valid_threshold) - + da_weights = apply_spherical_correction(da_weights, latitude_coord) + weights[coord] = da_weights + + # Apply the weights, using a unique set that matches chunking of each array + for array in data_vars.keys(): + var_weights = {} + for coord, weight_array in weights.items(): + if sparse is not None: + var_weights[coord] = sparsify_weights(weight_array, data_vars[array]) + else: + var_weights[coord] = weight_array + + data_vars[array] = apply_weights( + da=data_vars[array], + weights=var_weights, + skipna=skipna, + nan_threshold=nan_threshold, + ) + # Mask out any regridded points outside the original domain + # Limit to dims present on this array otherwise .where broadcasts + var_covered = xr.DataArray(True) + for coord in var_weights.keys(): + var_covered = var_covered & covered[coord] + data_vars[array] = data_vars[array].where(var_covered) + + # Rebuild the results ensuring we preserve attributes and other coordinates for array, attrs in data_attrs.items(): data_vars[array].attrs = attrs @@ -167,58 +168,32 @@ def conservative_regrid_dataset( def apply_weights( da: xr.DataArray, - weights: xr.DataArray, - coord: Hashable, - valid_frac: xr.DataArray, + weights: dict[Hashable, xr.DataArray], skipna: bool, -) -> tuple[xr.DataArray, xr.DataArray]: - """Apply the weights to convert data to the new coordinates.""" - coord_map = {f"target_{coord}": coord} - weights_norm = weights.copy() + nan_threshold: float, +) -> xr.DataArray: + """Apply the weights over all regridding dimensions simultaneously with `xr.dot`.""" + coords = list(weights.keys()) + weight_arrays = list(weights.values()) if skipna: - notnull = da.notnull() - # Renormalize the weights along this dim by the accumulated valid_frac - # along previous dimensions - if valid_frac.name != EMPTY_DA_NAME: - weights_norm = weights * (valid_frac / valid_frac.mean(dim=[coord])).fillna( - 0 - ) - - da_reduced: xr.DataArray = xr.dot( - da.fillna(0), weights_norm, dim=[coord], optimize=True + valid_frac = xr.dot( + da.notnull(), *weight_arrays, dim=list(weights.keys()), optimize=True + ) + + da_regrid: xr.DataArray = xr.dot( + da.fillna(0), *weight_arrays, dim=list(weights.keys()), optimize=True ) - da_reduced = da_reduced.rename(coord_map).transpose(*da.dims) if skipna: - weights_valid_sum: xr.DataArray = xr.dot( - notnull, weights_norm, dim=[coord], optimize=True - ) - weights_valid_sum = weights_valid_sum.rename(coord_map) - da_reduced /= weights_valid_sum.clip(1e-6, None) - - if valid_frac.name == EMPTY_DA_NAME: - # Begin tracking the valid fraction - valid_frac = weights_valid_sum - - else: - # Update the valid points on this dimension - valid_frac = xr.dot(valid_frac, weights, dim=[coord], optimize=True) - valid_frac = valid_frac.rename(coord_map) - valid_frac = valid_frac.clip(0, 1) - - # In some cases, dot product of dask data and sparse weights fails - # to automatically densify, which prevents future conversion to numpy - if ( - sparse is not None - and da_reduced.chunks - and isinstance(da_reduced.data._meta, sparse.COO) - ): - da_reduced.data = da_reduced.data.map_blocks( - lambda x: x.todense(), dtype=da_reduced.dtype - ) + da_regrid /= valid_frac + da_regrid = da_regrid.where(valid_frac >= get_valid_threshold(nan_threshold)) + + # Rename temporary coordinates and ensure original dimension order + coord_map = {f"target_{coord}": coord for coord in coords} + da_regrid = da_regrid.rename(coord_map).transpose(*da.dims) - return da_reduced, valid_frac + return da_regrid def get_valid_threshold(nan_threshold: float) -> float: diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index 7a581e4..f0f2d44 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -86,7 +86,7 @@ def conservative( latitude_coord: str | None = None, time_dim: str | None = "time", skipna: bool = True, - nan_threshold: float = 0.0, + nan_threshold: float = 1.0, ) -> xr.DataArray | xr.Dataset: """Regrid to the coords of the target dataset with a conservative scheme. @@ -94,16 +94,12 @@ def conservative( ds_target_grid: Dataset containing the target coordinates. latitude_coord: Name of the latitude coord, to be used for applying the spherical correction. By default, attempt to infer a latitude coordinate - as anything starting with "lat". + as either "latitude" or "lat". time_dim: Name of the time dimension. Defaults to "time". Use `None` to force regridding over the time dimension. - skipna: If True, enable handling for NaN values. This adds some overhead, - so can be disabled for optimal performance on data without any NaNs. - With `skipna=True, chunking is recommended in the non-grid dimensions, - otherwise the intermediate arrays that track the fraction of valid data - can become very large and consume excessive memory. - Warning: with `skipna=False`, isolated NaNs will propagate throughout - the dataset due to the sequential regridding scheme over each dimension. + skipna: If True, enable handling for NaN values. This adds only a small + amount of overhead, but can be disabled for optimal performance on data + without any NaNs. nan_threshold: Threshold value that will retain any output points containing at least this many non-null input points. The default value is 1.0, which will keep output points containing any non-null inputs, From c23d2a46327df68c7959b9dacaa37a11507300d5 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 25 Sep 2024 16:39:13 -0400 Subject: [PATCH 2/6] add benchmarking against xesmf --- .../benchmarks/benchmarking_xesmf.ipynb | 285 ++++++++++++++++++ 1 file changed, 285 insertions(+) create mode 100644 docs/notebooks/benchmarks/benchmarking_xesmf.ipynb diff --git a/docs/notebooks/benchmarks/benchmarking_xesmf.ipynb b/docs/notebooks/benchmarks/benchmarking_xesmf.ipynb new file mode 100644 index 0000000..83ac72d --- /dev/null +++ b/docs/notebooks/benchmarks/benchmarking_xesmf.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performance of `xesmf` vs `xarray-regrid`\n", + "\n", + "Compare the two conservative methods using a moderately-sized synthetic dask dataset of about 4GB." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "import dask.array as da\n", + "import xarray as xr\n", + "import xesmf\n", + "\n", + "import xarray_regrid\n", + "\n", + "bounds = dict(south=-90, north=90, west=-180, east=180)\n", + "\n", + "source = xarray_regrid.Grid(\n", + " resolution_lat=0.25,\n", + " resolution_lon=0.25,\n", + " **bounds,\n", + ").create_regridding_dataset()\n", + "\n", + "target = xarray_regrid.Grid(\n", + " resolution_lat=1,\n", + " resolution_lon=1,\n", + " **bounds,\n", + ").create_regridding_dataset()\n", + "\n", + "\n", + "def source_data(source, chunks, n_times=1000):\n", + " data = da.random.random(\n", + " size=(n_times, source.latitude.size, source.longitude.size),\n", + " chunks=chunks,\n", + " ).astype(\"float32\")\n", + "\n", + " data = xr.DataArray(\n", + " data,\n", + " dims=[\"time\", \"latitude\", \"longitude\"],\n", + " coords={\n", + " \"time\": xr.date_range(\"2000-01-01\", periods=n_times, freq=\"D\"),\n", + " \"latitude\": source.latitude,\n", + " \"longitude\": source.longitude,\n", + " }\n", + " )\n", + "\n", + " return data\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chunking\n", + "\n", + "Test \"pancake\" (chunked in time) and \"churro\" (chunked in space) chunks of different sizes. The \"small\" versions are about 4 MB, and the \"large\" are about 100 MB." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "chunk_schemes = {\n", + " \"pancake_small\": (1, -1, -1),\n", + " \"pancake_large\": (25, -1, -1),\n", + " \"churro_small\": (-1, 32, 32),\n", + " \"churro_large\": (-1, 160, 160),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xesmf/backend.py:56: UserWarning: Latitude is outside of [-90, 90]\n", + " warnings.warn('Latitude is outside of [-90, 90]')\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xesmf/backend.py:56: UserWarning: Latitude is outside of [-90, 90]\n", + " warnings.warn('Latitude is outside of [-90, 90]')\n" + ] + } + ], + "source": [ + "# For larger grids, generating weights is quite expensive\n", + "xesmf_regridder = xesmf.Regridder(source, target, \"conservative\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timings\n", + "\n", + "Run timings for different chunkings schemes and with NaN skipping enabled and disabled, across both libraries. Compare the ratio of `xesmf / xarray-regrid` to see the speedup factor of using this library." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 72.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (32, 32).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 72.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (32, 32).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 72.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (32, 32).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 72.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (32, 32).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 6.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (160, 160).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 6.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (160, 160).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 6.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (160, 160).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 6.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (160, 160).\n", + " result_var = func(*data_vars)\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "import pandas as pd\n", + "\n", + "pd.options.display.precision = 1\n", + "\n", + "\n", + "def do_regrid(data, target, skipna):\n", + " data.regrid.conservative(target, skipna=skipna).compute()\n", + "\n", + "\n", + "def do_xesmf(data, target, skipna):\n", + " xesmf_regridder(data, skipna=skipna).compute()\n", + "\n", + "\n", + "def timing_grid(func, repeats=2):\n", + " times = pd.DataFrame(\n", + " index=chunk_schemes.keys(),\n", + " columns=[\"skipna=False\", \"skipna=True\"],\n", + " )\n", + " for name, chunks in chunk_schemes.items():\n", + " data = source_data(source, chunks)\n", + " for skipna in [False, True]:\n", + " execution_times = []\n", + " for _ in range(repeats):\n", + " start = time.perf_counter()\n", + " func(data, target, skipna)\n", + " end = time.perf_counter()\n", + " execution_times.append(end - start)\n", + " # Sometimes the first execution is a little slower\n", + " times.loc[name, f\"skipna={skipna}\"] = min(execution_times)\n", + "\n", + " return times\n", + "\n", + "\n", + "regrid_times = timing_grid(do_regrid)\n", + "xesmf_times = timing_grid(do_xesmf)\n", + "ratio = xesmf_times / regrid_times\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Results\n", + "\n", + "With current implementations, `xesmf` is slightly faster for large pancake-style chunks. `xarray-regrid` is much faster for small chunks, especially churro-style.\n", + "\n", + "These tests were run on an 8-core Intel i7 Ubuntu desktop:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
skipna=Falseskipna=True
pancake_small3.77.2
pancake_large0.61.1
churro_small14.216.9
churro_large1.82.4
\n", + "
" + ], + "text/plain": [ + " skipna=False skipna=True\n", + "pancake_small 3.7 7.2\n", + "pancake_large 0.6 1.1\n", + "churro_small 14.2 16.9\n", + "churro_large 1.8 2.4" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ratio" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xarray-regrid", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 4cdac14abc729d79262cf3df44bd321fdf10f26e Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 25 Sep 2024 16:42:47 -0400 Subject: [PATCH 3/6] add changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a36cec..601d8d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/). Changed: - the "most common" routine has been overhauled, thanks to [@dcherian](https://github.com/dcherian). It is now much more efficient, and can operate fully lazily on dask arrays. Users do need to provide the expected groups (i.e., unique labels in the data), and the regridder is only available for `xr.DataArray` currently ([#46](https://github.com/xarray-contrib/xarray-regrid/pull/46)). - you can now use `None` as input to the `time_dim` kwarg in the regridding methods to force regridding over the time dimension (as long as it's numeric) ([#46](https://github.com/xarray-contrib/xarray-regrid/pull/46)). + - performance of the conservative method has been improved by simultaneously aggregating over all regridding dimensions ([#51](https://github.com/xarray-contrib/xarray-regrid/pull/51)). Added: - `.regrid.stat` for reducing datasets using statistical methods such as the variance or median ([#46](https://github.com/xarray-contrib/xarray-regrid/pull/46)). From 9de8ba0fd333342618e84bd49c5dd32390093394 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 25 Sep 2024 22:16:09 -0400 Subject: [PATCH 4/6] add output_chunks arg, match input chunks otherwise --- src/xarray_regrid/methods/conservative.py | 66 ++++++++++++++++++----- src/xarray_regrid/regrid.py | 12 ++++- tests/test_regrid.py | 23 ++++++++ 3 files changed, 87 insertions(+), 14 deletions(-) diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 7a361b2..8d7c3e2 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -21,6 +21,7 @@ def conservative_regrid( latitude_coord: str | None, skipna: bool = True, nan_threshold: float = 1.0, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.DataArray: ... @@ -31,6 +32,7 @@ def conservative_regrid( latitude_coord: str | None, skipna: bool = True, nan_threshold: float = 1.0, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.Dataset: ... @@ -40,6 +42,7 @@ def conservative_regrid( latitude_coord: str | Hashable | None, skipna: bool = True, nan_threshold: float = 1.0, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.DataArray | xr.Dataset: """Refine a dataset using conservative regridding. @@ -62,6 +65,8 @@ def conservative_regrid( which will keep output points containing any non-null inputs. The threshold is applied sequentially to each dimension, and may produce different results than a threshold applied concurrently to all regridding dimensions. + output_chunks: Optional dictionary of explicit chunk sizes for the output data. + If not provided, the output will be chunked the same as the input data. Returns: Regridded input dataset @@ -88,6 +93,7 @@ def conservative_regrid( latitude_coord, skipna, nan_threshold, + output_chunks, ) regridded_data = regridded_data.reindex_like(target_ds, copy=False) @@ -101,6 +107,7 @@ def conservative_regrid_dataset( latitude_coord: Hashable, skipna: bool, nan_threshold: float, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.Dataset: """Dataset implementation of the conservative regridding method.""" data_vars = dict(data.data_vars) @@ -133,10 +140,15 @@ def conservative_regrid_dataset( for array in data_vars.keys(): var_weights = {} for coord, weight_array in weights.items(): - if sparse is not None: - var_weights[coord] = sparsify_weights(weight_array, data_vars[array]) - else: - var_weights[coord] = weight_array + var_input_chunks = data_vars[array].chunksizes.get(coord) + var_output_chunks = output_chunks.get(coord) if output_chunks else None + var_weights[coord] = format_weights( + weight_array, + coord, + data_vars[array].dtype, + var_input_chunks, + var_output_chunks, + ) data_vars[array] = apply_weights( da=data_vars[array], @@ -248,15 +260,43 @@ def lat_weight(latitude: np.ndarray, latitude_res: float) -> np.ndarray: return h * dlat / (np.pi * 4) # type: ignore -def sparsify_weights(weights: xr.DataArray, da: xr.DataArray) -> xr.DataArray: - """Create a sparse version of the weights that matches the dtype and chunks - of the array to be regridded. Even though the weights can be constructed as - dense arrays, contraction is more efficient with sparse operations.""" - new_weights = weights.copy().astype(da.dtype) - if da.chunks: - chunks = {k: v for k, v in da.chunksizes.items() if k in weights.dims} - new_weights.data = new_weights.chunk(chunks).data.map_blocks(sparse.COO) - else: +def format_weights( + weights: xr.DataArray, + coord: Hashable, + input_dtype: np.dtype, + input_chunks: tuple[int, ...] | None, + output_chunks: tuple[int, ...] | int | None, +) -> xr.DataArray: + """Format the raw weights array such that: + + 1. Weights match the dtype of the input data + 1. Weights are chunked 1:1 with the source data + 2. Weights are chunked as requested in the target grid. If no chunks are + provided, the same chunksize as the source grid will be used. + See: https://github.com/dask/dask/issues/2225 + 3. Weights are converted to a sparse representation (on a per chunk basis) + if the `sparse` package is available. + """ + # Use single precision weights at minimum, double if input is double + weights_dtype = np.result_type(np.float32, input_dtype) + new_weights = weights.copy().astype(weights_dtype) + + chunks: dict[Hashable, tuple[int, ...] | int] = {} + if input_chunks is not None: + chunks[coord] = input_chunks + if output_chunks is None: + # Set output chunking to match input, but precise chunks won't match shape, + # so take the max in case of uneven chunks + output_chunks = max(input_chunks) + + if output_chunks is not None: + chunks[f"target_{coord}"] = output_chunks + + if chunks: + new_weights = new_weights.chunk(chunks) + if sparse is not None: + new_weights.data = new_weights.data.map_blocks(sparse.COO) + elif sparse is not None: new_weights.data = sparse.COO(weights.data) return new_weights diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index f0f2d44..b424f07 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -1,3 +1,4 @@ +from collections.abc import Hashable from typing import overload import numpy as np @@ -87,6 +88,7 @@ def conservative( time_dim: str | None = "time", skipna: bool = True, nan_threshold: float = 1.0, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.DataArray | xr.Dataset: """Regrid to the coords of the target dataset with a conservative scheme. @@ -105,6 +107,9 @@ def conservative( is 1.0, which will keep output points containing any non-null inputs, while a value of 0.0 will only keep output points where all inputs are non-null. + output_chunks: Optional dictionary of explicit chunk sizes for the output + data. If not provided, the output will be chunked the same as the input + data. Returns: Data regridded to the target dataset coordinates. @@ -116,7 +121,12 @@ def conservative( ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) ds_formatted = format_for_regrid(self._obj, ds_target_grid) return conservative.conservative_regrid( - ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold + ds_formatted, + ds_target_grid, + latitude_coord, + skipna, + nan_threshold, + output_chunks, ) def most_common( diff --git a/tests/test_regrid.py b/tests/test_regrid.py index d58fd14..a966057 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -208,6 +208,29 @@ def test_conservative_nan_thresholds_against_coarsen(nan_threshold): xr.testing.assert_allclose(da_coarsen, da_regrid) +def test_conservative_output_chunks(): + data = xr.DataArray( + np.ones((8, 8)), + coords={"x": np.linspace(0, 1, 8), "y": np.linspace(0, 1, 8)}, + dims=("x", "y"), + ) + target = xr.Dataset(coords={"x": np.linspace(0, 1, 4), "y": np.linspace(0, 1, 4)}) + + # Non-dask input should return non-dask output + result = data.regrid.conservative(target) + assert result.chunks is None + + # Dask input with unspecified output chunks should match the input chunks + result = data.chunk({"x": 2, "y": 2}).regrid.conservative(target) + assert result.chunks == ((2, 2), (2, 2)) + + # Specified output chunks should be respected + result = data.chunk({"x": 4, "y": 4}).regrid.conservative( + target, output_chunks={"x": 2, "y": 2} + ) + assert result.chunks == ((2, 2), (2, 2)) + + @pytest.mark.skipif(xesmf is None, reason="xesmf required") def test_conservative_nan_thresholds_against_xesmf(): ds = xr.tutorial.open_dataset("ersstv5").sst.isel(time=[0]).compute() From a9d875d6f12b12d1424c4ebac64ea6354ec22468 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 25 Sep 2024 22:27:12 -0400 Subject: [PATCH 5/6] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 601d8d3..49f23c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/). Changed: - the "most common" routine has been overhauled, thanks to [@dcherian](https://github.com/dcherian). It is now much more efficient, and can operate fully lazily on dask arrays. Users do need to provide the expected groups (i.e., unique labels in the data), and the regridder is only available for `xr.DataArray` currently ([#46](https://github.com/xarray-contrib/xarray-regrid/pull/46)). - you can now use `None` as input to the `time_dim` kwarg in the regridding methods to force regridding over the time dimension (as long as it's numeric) ([#46](https://github.com/xarray-contrib/xarray-regrid/pull/46)). - - performance of the conservative method has been improved by simultaneously aggregating over all regridding dimensions ([#51](https://github.com/xarray-contrib/xarray-regrid/pull/51)). + - Performance of the conservative method has been improved by simultaneously aggregating over all regridding dimensions. Conservative regridding now also produces outputs with the same grid chunks as the inputs, unless explicit chunksizes are passed via the `output_chunks` argument. ([#51](https://github.com/xarray-contrib/xarray-regrid/pull/51)). Added: - `.regrid.stat` for reducing datasets using statistical methods such as the variance or median ([#46](https://github.com/xarray-contrib/xarray-regrid/pull/46)). From 7fbdb079d0d8a97145edd1a657387abf4d408719 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 26 Sep 2024 07:38:24 -0400 Subject: [PATCH 6/6] add notebook to docs --- docs/notebooks/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/notebooks/index.rst b/docs/notebooks/index.rst index 81a632f..c923b05 100644 --- a/docs/notebooks/index.rst +++ b/docs/notebooks/index.rst @@ -13,6 +13,7 @@ Most notebooks compare the methods implemented in xarray-regrid against more sta benchmarks/benchmarking_bilinear.ipynb benchmarks/benchmarking_conservative.ipynb benchmarks/benchmarking_nearest.ipynb + benchmarks/benchmarking_xesmf.ipynb .. toctree:: :maxdepth: 1