diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index deec777..54946a8 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -66,7 +66,7 @@ def conservative_regrid( # Attempt to infer the latitude coordinate if latitude_coord is None: for coord in data.coords: - if str(coord).lower().startswith("lat"): + if str(coord).lower() in ["lat", "latitude"]: latitude_coord = coord break diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index acf89cf..9c5a4b7 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Hashable from dataclasses import dataclass -from typing import Any, overload +from typing import Any, TypedDict, overload import numpy as np import pandas as pd @@ -10,6 +10,11 @@ class InvalidBoundsError(Exception): ... +class CoordHandler(TypedDict): + names: list[str] + func: Callable + + @dataclass class Grid: """Object storing grid information.""" @@ -241,112 +246,119 @@ def format_for_regrid( obj: xr.DataArray | xr.Dataset, target: xr.Dataset ) -> xr.DataArray | xr.Dataset: """Apply any pre-formatting to the input dataset to prepare for regridding. - Currently handles padding of spherical geometry if appropriate coordinate names - can be inferred containing 'lat' and 'lon'. + Currently handles padding of spherical geometry if lat/lon coordinates can + be inferred and the domain size requires boundary padding. """ - lat_coord = None - lon_coord = None - - for coord in obj.coords.keys(): - if str(coord).lower().startswith("lat"): - lat_coord = coord - elif str(coord).lower().startswith("lon"): - lon_coord = coord + orig_chunksizes = obj.chunksizes - if lon_coord is not None or lat_coord is not None: - obj = format_spherical(obj, target, lat_coord, lon_coord) + # Special-cased coordinates with accepted names and formatting function + coord_handlers: dict[str, CoordHandler] = { + "lat": {"names": ["lat", "latitude"], "func": format_lat}, + "lon": {"names": ["lon", "longitude"], "func": format_lon}, + } + # Identify coordinates that need to be formatted + formatted_coords = {} + for coord_type, handler in coord_handlers.items(): + for coord in obj.coords.keys(): + if str(coord).lower() in handler["names"]: + formatted_coords[coord_type] = str(coord) + + # Apply formatting + for coord_type, coord in formatted_coords.items(): + # Make sure formatted coords are sorted + obj = obj.sortby(coord) + target = target.sortby(coord) + obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords) + # Coerce back to a single chunk if that's what was passed + if len(orig_chunksizes.get(coord, [])) == 1: + obj = obj.chunk({coord: -1}) return obj -def format_spherical( +def format_lat( obj: xr.DataArray | xr.Dataset, - target: xr.Dataset, - lat_coord: Hashable, - lon_coord: Hashable, + target: xr.Dataset, # noqa ARG001 + formatted_coords: dict[str, str], ) -> xr.DataArray | xr.Dataset: - """Infer whether a lat/lon source grid represents a global domain and - automatically apply spherical padding to improve edge effects. - - For longitude, shift the coordinate to line up with the target values, then - add a single wraparound padding column if the domain is global and the east - or west edges of the target lie outside the source grid centers. - - For latitude, add a single value at each pole computed as the mean of the last + """For latitude, add a single value at each pole computed as the mean of the last row for global source grids where the first or last point lie equatorward of 90. """ + lat_coord = formatted_coords["lat"] + lon_coord = formatted_coords.get("lon") + + # Concat a padded value representing the mean of the first/last lat bands + # This should match the Pole="all" option of ESMF + # TODO: with cos(90) = 0 weighting, these weights might be 0? + + polar_lat = 90 + dy = obj.coords[lat_coord].diff(lat_coord).max().values.item() + + # Only pad if global but don't have edge values directly at poles + # South pole + if dy - polar_lat >= obj.coords[lat_coord].values[0] > -polar_lat: + south_pole = obj.isel({lat_coord: 0}) + if lon_coord is not None: + south_pole = south_pole.mean(lon_coord) + obj = xr.concat([south_pole, obj], dim=lat_coord) # type: ignore + obj.coords[lat_coord].values[0] = -polar_lat + + # North pole + if polar_lat - dy <= obj.coords[lat_coord].values[-1] < polar_lat: + north_pole = obj.isel({lat_coord: -1}) + if lon_coord is not None: + north_pole = north_pole.mean(lon_coord) + obj = xr.concat([obj, north_pole], dim=lat_coord) # type: ignore + obj.coords[lat_coord].values[-1] = polar_lat - orig_chunksizes = obj.chunksizes + return obj - # If the source coord fully covers the target, don't modify them - if lat_coord and not coord_is_covered(obj, target, lat_coord): - obj = obj.sortby(lat_coord) - target = target.sortby(lat_coord) - - # Only pad if global but don't have edge values directly at poles - polar_lat = 90 - dy = obj[lat_coord].diff(lat_coord).max().values - - # South pole - if dy - polar_lat >= obj[lat_coord][0] > -polar_lat: - south_pole = obj.isel({lat_coord: 0}) - # This should match the Pole="all" option of ESMF - if lon_coord is not None: - south_pole = south_pole.mean(lon_coord) - obj = xr.concat([south_pole, obj], dim=lat_coord) - obj[lat_coord].values[0] = -polar_lat - - # North pole - if polar_lat - dy <= obj[lat_coord][-1] < polar_lat: - north_pole = obj.isel({lat_coord: -1}) - if lon_coord is not None: - north_pole = north_pole.mean(lon_coord) - obj = xr.concat([obj, north_pole], dim=lat_coord) - obj[lat_coord].values[-1] = polar_lat - # Coerce back to a single chunk if that's what was passed - if len(orig_chunksizes.get(lat_coord, [])) == 1: - obj = obj.chunk({lat_coord: -1}) - - if lon_coord and not coord_is_covered(obj, target, lon_coord): - obj = obj.sortby(lon_coord) - target = target.sortby(lon_coord) - - target_lon = target[lon_coord].values - # Find a wrap point outside of the left and right bounds of the target - # This ensures we have coverage on the target and handles global > regional - wrap_point = (target_lon[-1] + target_lon[0] + 360) / 2 - lon = obj[lon_coord].values - lon = np.where(lon < wrap_point - 360, lon + 360, lon) - lon = np.where(lon > wrap_point, lon - 360, lon) - obj[lon_coord].values[:] = lon - - # Shift operations can produce duplicates - # Simplest solution is to drop them and add back when padding - obj = obj.sortby(lon_coord).drop_duplicates(lon_coord) - - # Only pad if domain is global in lon - dx_s = obj[lon_coord].diff(lon_coord).max().values - dx_t = target[lon_coord].diff(lon_coord).max().values - is_global_lon = obj[lon_coord].max() - obj[lon_coord].min() >= 360 - dx_s - - if is_global_lon: - left_pad = (obj[lon_coord][0] - target[lon_coord][0] + dx_t / 2) / dx_s - right_pad = (target[lon_coord][-1] - obj[lon_coord][-1] + dx_t / 2) / dx_s - left_pad = int(np.ceil(np.max([left_pad, 0]))) - right_pad = int(np.ceil(np.max([right_pad, 0]))) - lon = obj[lon_coord].values - obj = obj.pad( - {lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True +def format_lon( + obj: xr.DataArray | xr.Dataset, target: xr.Dataset, formatted_coords: dict[str, str] +) -> xr.DataArray | xr.Dataset: + """For longitude, shift the coordinate to line up with the target values, then + add a single wraparound padding column if the domain is global and the east + or west edges of the target lie outside the source grid centers. + """ + lon_coord = formatted_coords["lon"] + + # Find a wrap point outside of the left and right bounds of the target + # This ensures we have coverage on the target and handles global > regional + source_vals = obj.coords[lon_coord].values + target_vals = target.coords[lon_coord].values + wrap_point = (target_vals[-1] + target_vals[0] + 360) / 2 + source_vals = np.where( + source_vals < wrap_point - 360, source_vals + 360, source_vals + ) + source_vals = np.where(source_vals > wrap_point, source_vals - 360, source_vals) + obj.coords[lon_coord].values[:] = source_vals + + # Shift operations can produce duplicates + # Simplest solution is to drop them and add back when padding + obj = obj.sortby(lon_coord).drop_duplicates(lon_coord) + + # Only pad if domain is global in lon + source_lon = obj.coords[lon_coord] + target_lon = target.coords[lon_coord] + dx_s = source_lon.diff(lon_coord).max().values.item() + dx_t = target_lon.diff(lon_coord).max().values.item() + is_global_lon = source_lon.max().values - source_lon.min().values >= 360 - dx_s + + if is_global_lon: + left_pad = (source_lon.values[0] - target_lon.values[0] + dx_t / 2) / dx_s + right_pad = (target_lon.values[-1] - source_lon.values[-1] + dx_t / 2) / dx_s + left_pad = int(np.ceil(np.max([left_pad, 0]))) + right_pad = int(np.ceil(np.max([right_pad, 0]))) + obj = obj.pad({lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True) + if left_pad: + obj.coords[lon_coord].values[:left_pad] = ( + source_lon.values[-left_pad:] - 360 + ) + if right_pad: + obj.coords[lon_coord].values[-right_pad:] = ( + source_lon.values[:right_pad] + 360 ) - if left_pad: - obj[lon_coord].values[:left_pad] = lon[-left_pad:] - 360 - if right_pad: - obj[lon_coord].values[-right_pad:] = lon[:right_pad] + 360 - - # Coerce back to a single chunk if that's what was passed - if len(orig_chunksizes.get(lon_coord, [])) == 1: - obj = obj.chunk({lon_coord: -1}) return obj