diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a36cec..49f23c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/). Changed: - the "most common" routine has been overhauled, thanks to [@dcherian](https://github.com/dcherian). It is now much more efficient, and can operate fully lazily on dask arrays. Users do need to provide the expected groups (i.e., unique labels in the data), and the regridder is only available for `xr.DataArray` currently ([#46](https://github.com/xarray-contrib/xarray-regrid/pull/46)). - you can now use `None` as input to the `time_dim` kwarg in the regridding methods to force regridding over the time dimension (as long as it's numeric) ([#46](https://github.com/xarray-contrib/xarray-regrid/pull/46)). + - Performance of the conservative method has been improved by simultaneously aggregating over all regridding dimensions. Conservative regridding now also produces outputs with the same grid chunks as the inputs, unless explicit chunksizes are passed via the `output_chunks` argument. ([#51](https://github.com/xarray-contrib/xarray-regrid/pull/51)). Added: - `.regrid.stat` for reducing datasets using statistical methods such as the variance or median ([#46](https://github.com/xarray-contrib/xarray-regrid/pull/46)). diff --git a/docs/notebooks/benchmarks/benchmarking_xesmf.ipynb b/docs/notebooks/benchmarks/benchmarking_xesmf.ipynb new file mode 100644 index 0000000..83ac72d --- /dev/null +++ b/docs/notebooks/benchmarks/benchmarking_xesmf.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performance of `xesmf` vs `xarray-regrid`\n", + "\n", + "Compare the two conservative methods using a moderately-sized synthetic dask dataset of about 4GB." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "import dask.array as da\n", + "import xarray as xr\n", + "import xesmf\n", + "\n", + "import xarray_regrid\n", + "\n", + "bounds = dict(south=-90, north=90, west=-180, east=180)\n", + "\n", + "source = xarray_regrid.Grid(\n", + " resolution_lat=0.25,\n", + " resolution_lon=0.25,\n", + " **bounds,\n", + ").create_regridding_dataset()\n", + "\n", + "target = xarray_regrid.Grid(\n", + " resolution_lat=1,\n", + " resolution_lon=1,\n", + " **bounds,\n", + ").create_regridding_dataset()\n", + "\n", + "\n", + "def source_data(source, chunks, n_times=1000):\n", + " data = da.random.random(\n", + " size=(n_times, source.latitude.size, source.longitude.size),\n", + " chunks=chunks,\n", + " ).astype(\"float32\")\n", + "\n", + " data = xr.DataArray(\n", + " data,\n", + " dims=[\"time\", \"latitude\", \"longitude\"],\n", + " coords={\n", + " \"time\": xr.date_range(\"2000-01-01\", periods=n_times, freq=\"D\"),\n", + " \"latitude\": source.latitude,\n", + " \"longitude\": source.longitude,\n", + " }\n", + " )\n", + "\n", + " return data\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chunking\n", + "\n", + "Test \"pancake\" (chunked in time) and \"churro\" (chunked in space) chunks of different sizes. The \"small\" versions are about 4 MB, and the \"large\" are about 100 MB." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "chunk_schemes = {\n", + " \"pancake_small\": (1, -1, -1),\n", + " \"pancake_large\": (25, -1, -1),\n", + " \"churro_small\": (-1, 32, 32),\n", + " \"churro_large\": (-1, 160, 160),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xesmf/backend.py:56: UserWarning: Latitude is outside of [-90, 90]\n", + " warnings.warn('Latitude is outside of [-90, 90]')\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xesmf/backend.py:56: UserWarning: Latitude is outside of [-90, 90]\n", + " warnings.warn('Latitude is outside of [-90, 90]')\n" + ] + } + ], + "source": [ + "# For larger grids, generating weights is quite expensive\n", + "xesmf_regridder = xesmf.Regridder(source, target, \"conservative\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timings\n", + "\n", + "Run timings for different chunkings schemes and with NaN skipping enabled and disabled, across both libraries. Compare the ratio of `xesmf / xarray-regrid` to see the speedup factor of using this library." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 72.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (32, 32).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 72.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (32, 32).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 72.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (32, 32).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 72.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (32, 32).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 6.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (160, 160).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 6.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (160, 160).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 6.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (160, 160).\n", + " result_var = func(*data_vars)\n", + "/home/slevang/miniconda3/envs/xarray-regrid/lib/python3.12/site-packages/xarray/core/computation.py:320: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 6.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (160, 160).\n", + " result_var = func(*data_vars)\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "import pandas as pd\n", + "\n", + "pd.options.display.precision = 1\n", + "\n", + "\n", + "def do_regrid(data, target, skipna):\n", + " data.regrid.conservative(target, skipna=skipna).compute()\n", + "\n", + "\n", + "def do_xesmf(data, target, skipna):\n", + " xesmf_regridder(data, skipna=skipna).compute()\n", + "\n", + "\n", + "def timing_grid(func, repeats=2):\n", + " times = pd.DataFrame(\n", + " index=chunk_schemes.keys(),\n", + " columns=[\"skipna=False\", \"skipna=True\"],\n", + " )\n", + " for name, chunks in chunk_schemes.items():\n", + " data = source_data(source, chunks)\n", + " for skipna in [False, True]:\n", + " execution_times = []\n", + " for _ in range(repeats):\n", + " start = time.perf_counter()\n", + " func(data, target, skipna)\n", + " end = time.perf_counter()\n", + " execution_times.append(end - start)\n", + " # Sometimes the first execution is a little slower\n", + " times.loc[name, f\"skipna={skipna}\"] = min(execution_times)\n", + "\n", + " return times\n", + "\n", + "\n", + "regrid_times = timing_grid(do_regrid)\n", + "xesmf_times = timing_grid(do_xesmf)\n", + "ratio = xesmf_times / regrid_times\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Results\n", + "\n", + "With current implementations, `xesmf` is slightly faster for large pancake-style chunks. `xarray-regrid` is much faster for small chunks, especially churro-style.\n", + "\n", + "These tests were run on an 8-core Intel i7 Ubuntu desktop:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
skipna=Falseskipna=True
pancake_small3.77.2
pancake_large0.61.1
churro_small14.216.9
churro_large1.82.4
\n", + "
" + ], + "text/plain": [ + " skipna=False skipna=True\n", + "pancake_small 3.7 7.2\n", + "pancake_large 0.6 1.1\n", + "churro_small 14.2 16.9\n", + "churro_large 1.8 2.4" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ratio" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xarray-regrid", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/index.rst b/docs/notebooks/index.rst index 81a632f..c923b05 100644 --- a/docs/notebooks/index.rst +++ b/docs/notebooks/index.rst @@ -13,6 +13,7 @@ Most notebooks compare the methods implemented in xarray-regrid against more sta benchmarks/benchmarking_bilinear.ipynb benchmarks/benchmarking_conservative.ipynb benchmarks/benchmarking_nearest.ipynb + benchmarks/benchmarking_xesmf.ipynb .. toctree:: :maxdepth: 1 diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 8c384b6..8d7c3e2 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -13,8 +13,6 @@ from xarray_regrid import utils -EMPTY_DA_NAME = "FRAC_EMPTY" - @overload def conservative_regrid( @@ -23,6 +21,7 @@ def conservative_regrid( latitude_coord: str | None, skipna: bool = True, nan_threshold: float = 1.0, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.DataArray: ... @@ -33,6 +32,7 @@ def conservative_regrid( latitude_coord: str | None, skipna: bool = True, nan_threshold: float = 1.0, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.Dataset: ... @@ -42,6 +42,7 @@ def conservative_regrid( latitude_coord: str | Hashable | None, skipna: bool = True, nan_threshold: float = 1.0, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.DataArray | xr.Dataset: """Refine a dataset using conservative regridding. @@ -64,6 +65,8 @@ def conservative_regrid( which will keep output points containing any non-null inputs. The threshold is applied sequentially to each dimension, and may produce different results than a threshold applied concurrently to all regridding dimensions. + output_chunks: Optional dictionary of explicit chunk sizes for the output data. + If not provided, the output will be chunked the same as the input data. Returns: Regridded input dataset @@ -90,6 +93,7 @@ def conservative_regrid( latitude_coord, skipna, nan_threshold, + output_chunks, ) regridded_data = regridded_data.reindex_like(target_ds, copy=False) @@ -103,17 +107,20 @@ def conservative_regrid_dataset( latitude_coord: Hashable, skipna: bool, nan_threshold: float, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.Dataset: """Dataset implementation of the conservative regridding method.""" data_vars = dict(data.data_vars) data_coords = dict(data.coords) - valid_fracs = {v: xr.DataArray(name=EMPTY_DA_NAME) for v in data_vars} data_attrs = {v: data_vars[v].attrs for v in data_vars} coord_attrs = {c: data_coords[c].attrs for c in data_coords} ds_attrs = data.attrs + # Create weights array and coverage mask for each regridding dim + weights = {} + covered = {} for coord in coords: - covered_grid = (coords[coord] <= data[coord].max()) & ( + covered[coord] = (coords[coord] <= data[coord].max()) & ( coords[coord] >= data[coord].min() ) @@ -121,36 +128,42 @@ def conservative_regrid_dataset( source_coords = data[coord].to_numpy() nd_weights = get_weights(source_coords, target_coords) - # Modify weights to correct for latitude distortion - weights = utils.create_dot_dataarray( + da_weights = utils.create_dot_dataarray( nd_weights, str(coord), target_coords, source_coords ) + # Modify weights to correct for latitude distortion if coord == latitude_coord: - weights = apply_spherical_correction(weights, latitude_coord) - - for array in data_vars.keys(): - if coord in data_vars[array].dims: - if sparse is not None: - var_weights = sparsify_weights(weights, data_vars[array]) - else: - var_weights = weights - - data_vars[array], valid_fracs[array] = apply_weights( - da=data_vars[array], - weights=var_weights, - coord=coord, - valid_frac=valid_fracs[array], - skipna=skipna, - ) - # Mask out any regridded points outside the original domain - data_vars[array] = data_vars[array].where(covered_grid) - - if skipna: - # Mask out any points that don't meet the nan threshold - valid_threshold = get_valid_threshold(nan_threshold) - for array, da in data_vars.items(): - data_vars[array] = da.where(valid_fracs[array] >= valid_threshold) + da_weights = apply_spherical_correction(da_weights, latitude_coord) + weights[coord] = da_weights + + # Apply the weights, using a unique set that matches chunking of each array + for array in data_vars.keys(): + var_weights = {} + for coord, weight_array in weights.items(): + var_input_chunks = data_vars[array].chunksizes.get(coord) + var_output_chunks = output_chunks.get(coord) if output_chunks else None + var_weights[coord] = format_weights( + weight_array, + coord, + data_vars[array].dtype, + var_input_chunks, + var_output_chunks, + ) + data_vars[array] = apply_weights( + da=data_vars[array], + weights=var_weights, + skipna=skipna, + nan_threshold=nan_threshold, + ) + # Mask out any regridded points outside the original domain + # Limit to dims present on this array otherwise .where broadcasts + var_covered = xr.DataArray(True) + for coord in var_weights.keys(): + var_covered = var_covered & covered[coord] + data_vars[array] = data_vars[array].where(var_covered) + + # Rebuild the results ensuring we preserve attributes and other coordinates for array, attrs in data_attrs.items(): data_vars[array].attrs = attrs @@ -167,58 +180,32 @@ def conservative_regrid_dataset( def apply_weights( da: xr.DataArray, - weights: xr.DataArray, - coord: Hashable, - valid_frac: xr.DataArray, + weights: dict[Hashable, xr.DataArray], skipna: bool, -) -> tuple[xr.DataArray, xr.DataArray]: - """Apply the weights to convert data to the new coordinates.""" - coord_map = {f"target_{coord}": coord} - weights_norm = weights.copy() + nan_threshold: float, +) -> xr.DataArray: + """Apply the weights over all regridding dimensions simultaneously with `xr.dot`.""" + coords = list(weights.keys()) + weight_arrays = list(weights.values()) if skipna: - notnull = da.notnull() - # Renormalize the weights along this dim by the accumulated valid_frac - # along previous dimensions - if valid_frac.name != EMPTY_DA_NAME: - weights_norm = weights * (valid_frac / valid_frac.mean(dim=[coord])).fillna( - 0 - ) + valid_frac = xr.dot( + da.notnull(), *weight_arrays, dim=list(weights.keys()), optimize=True + ) - da_reduced: xr.DataArray = xr.dot( - da.fillna(0), weights_norm, dim=[coord], optimize=True + da_regrid: xr.DataArray = xr.dot( + da.fillna(0), *weight_arrays, dim=list(weights.keys()), optimize=True ) - da_reduced = da_reduced.rename(coord_map).transpose(*da.dims) if skipna: - weights_valid_sum: xr.DataArray = xr.dot( - notnull, weights_norm, dim=[coord], optimize=True - ) - weights_valid_sum = weights_valid_sum.rename(coord_map) - da_reduced /= weights_valid_sum.clip(1e-6, None) - - if valid_frac.name == EMPTY_DA_NAME: - # Begin tracking the valid fraction - valid_frac = weights_valid_sum - - else: - # Update the valid points on this dimension - valid_frac = xr.dot(valid_frac, weights, dim=[coord], optimize=True) - valid_frac = valid_frac.rename(coord_map) - valid_frac = valid_frac.clip(0, 1) - - # In some cases, dot product of dask data and sparse weights fails - # to automatically densify, which prevents future conversion to numpy - if ( - sparse is not None - and da_reduced.chunks - and isinstance(da_reduced.data._meta, sparse.COO) - ): - da_reduced.data = da_reduced.data.map_blocks( - lambda x: x.todense(), dtype=da_reduced.dtype - ) + da_regrid /= valid_frac + da_regrid = da_regrid.where(valid_frac >= get_valid_threshold(nan_threshold)) - return da_reduced, valid_frac + # Rename temporary coordinates and ensure original dimension order + coord_map = {f"target_{coord}": coord for coord in coords} + da_regrid = da_regrid.rename(coord_map).transpose(*da.dims) + + return da_regrid def get_valid_threshold(nan_threshold: float) -> float: @@ -273,15 +260,43 @@ def lat_weight(latitude: np.ndarray, latitude_res: float) -> np.ndarray: return h * dlat / (np.pi * 4) # type: ignore -def sparsify_weights(weights: xr.DataArray, da: xr.DataArray) -> xr.DataArray: - """Create a sparse version of the weights that matches the dtype and chunks - of the array to be regridded. Even though the weights can be constructed as - dense arrays, contraction is more efficient with sparse operations.""" - new_weights = weights.copy().astype(da.dtype) - if da.chunks: - chunks = {k: v for k, v in da.chunksizes.items() if k in weights.dims} - new_weights.data = new_weights.chunk(chunks).data.map_blocks(sparse.COO) - else: +def format_weights( + weights: xr.DataArray, + coord: Hashable, + input_dtype: np.dtype, + input_chunks: tuple[int, ...] | None, + output_chunks: tuple[int, ...] | int | None, +) -> xr.DataArray: + """Format the raw weights array such that: + + 1. Weights match the dtype of the input data + 1. Weights are chunked 1:1 with the source data + 2. Weights are chunked as requested in the target grid. If no chunks are + provided, the same chunksize as the source grid will be used. + See: https://github.com/dask/dask/issues/2225 + 3. Weights are converted to a sparse representation (on a per chunk basis) + if the `sparse` package is available. + """ + # Use single precision weights at minimum, double if input is double + weights_dtype = np.result_type(np.float32, input_dtype) + new_weights = weights.copy().astype(weights_dtype) + + chunks: dict[Hashable, tuple[int, ...] | int] = {} + if input_chunks is not None: + chunks[coord] = input_chunks + if output_chunks is None: + # Set output chunking to match input, but precise chunks won't match shape, + # so take the max in case of uneven chunks + output_chunks = max(input_chunks) + + if output_chunks is not None: + chunks[f"target_{coord}"] = output_chunks + + if chunks: + new_weights = new_weights.chunk(chunks) + if sparse is not None: + new_weights.data = new_weights.data.map_blocks(sparse.COO) + elif sparse is not None: new_weights.data = sparse.COO(weights.data) return new_weights diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index 7a581e4..b424f07 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -1,3 +1,4 @@ +from collections.abc import Hashable from typing import overload import numpy as np @@ -86,7 +87,8 @@ def conservative( latitude_coord: str | None = None, time_dim: str | None = "time", skipna: bool = True, - nan_threshold: float = 0.0, + nan_threshold: float = 1.0, + output_chunks: dict[Hashable, int] | None = None, ) -> xr.DataArray | xr.Dataset: """Regrid to the coords of the target dataset with a conservative scheme. @@ -94,21 +96,20 @@ def conservative( ds_target_grid: Dataset containing the target coordinates. latitude_coord: Name of the latitude coord, to be used for applying the spherical correction. By default, attempt to infer a latitude coordinate - as anything starting with "lat". + as either "latitude" or "lat". time_dim: Name of the time dimension. Defaults to "time". Use `None` to force regridding over the time dimension. - skipna: If True, enable handling for NaN values. This adds some overhead, - so can be disabled for optimal performance on data without any NaNs. - With `skipna=True, chunking is recommended in the non-grid dimensions, - otherwise the intermediate arrays that track the fraction of valid data - can become very large and consume excessive memory. - Warning: with `skipna=False`, isolated NaNs will propagate throughout - the dataset due to the sequential regridding scheme over each dimension. + skipna: If True, enable handling for NaN values. This adds only a small + amount of overhead, but can be disabled for optimal performance on data + without any NaNs. nan_threshold: Threshold value that will retain any output points containing at least this many non-null input points. The default value is 1.0, which will keep output points containing any non-null inputs, while a value of 0.0 will only keep output points where all inputs are non-null. + output_chunks: Optional dictionary of explicit chunk sizes for the output + data. If not provided, the output will be chunked the same as the input + data. Returns: Data regridded to the target dataset coordinates. @@ -120,7 +121,12 @@ def conservative( ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) ds_formatted = format_for_regrid(self._obj, ds_target_grid) return conservative.conservative_regrid( - ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold + ds_formatted, + ds_target_grid, + latitude_coord, + skipna, + nan_threshold, + output_chunks, ) def most_common( diff --git a/tests/test_regrid.py b/tests/test_regrid.py index d58fd14..a966057 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -208,6 +208,29 @@ def test_conservative_nan_thresholds_against_coarsen(nan_threshold): xr.testing.assert_allclose(da_coarsen, da_regrid) +def test_conservative_output_chunks(): + data = xr.DataArray( + np.ones((8, 8)), + coords={"x": np.linspace(0, 1, 8), "y": np.linspace(0, 1, 8)}, + dims=("x", "y"), + ) + target = xr.Dataset(coords={"x": np.linspace(0, 1, 4), "y": np.linspace(0, 1, 4)}) + + # Non-dask input should return non-dask output + result = data.regrid.conservative(target) + assert result.chunks is None + + # Dask input with unspecified output chunks should match the input chunks + result = data.chunk({"x": 2, "y": 2}).regrid.conservative(target) + assert result.chunks == ((2, 2), (2, 2)) + + # Specified output chunks should be respected + result = data.chunk({"x": 4, "y": 4}).regrid.conservative( + target, output_chunks={"x": 2, "y": 2} + ) + assert result.chunks == ((2, 2), (2, 2)) + + @pytest.mark.skipif(xesmf is None, reason="xesmf required") def test_conservative_nan_thresholds_against_xesmf(): ds = xr.tutorial.open_dataset("ersstv5").sst.isel(time=[0]).compute()