Skip to content

Commit

Permalink
refactor to separate coord handling functions, appease mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang committed Sep 13, 2024
1 parent 6446048 commit 283c310
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 94 deletions.
2 changes: 1 addition & 1 deletion src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def conservative_regrid(
# Attempt to infer the latitude coordinate
if latitude_coord is None:
for coord in data.coords:
if str(coord).lower().startswith("lat"):
if str(coord).lower() in ["lat", "latitude"]:
latitude_coord = coord
break

Expand Down
198 changes: 105 additions & 93 deletions src/xarray_regrid/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable, Hashable
from dataclasses import dataclass
from typing import Any, overload
from typing import Any, TypedDict, overload

import numpy as np
import pandas as pd
Expand All @@ -10,6 +10,11 @@
class InvalidBoundsError(Exception): ...


class CoordHandler(TypedDict):
names: list[str]
func: Callable


@dataclass
class Grid:
"""Object storing grid information."""
Expand Down Expand Up @@ -241,112 +246,119 @@ def format_for_regrid(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset
) -> xr.DataArray | xr.Dataset:
"""Apply any pre-formatting to the input dataset to prepare for regridding.
Currently handles padding of spherical geometry if appropriate coordinate names
can be inferred containing 'lat' and 'lon'.
Currently handles padding of spherical geometry if lat/lon coordinates can
be inferred and the domain size requires boundary padding.
"""
lat_coord = None
lon_coord = None

for coord in obj.coords.keys():
if str(coord).lower().startswith("lat"):
lat_coord = coord
elif str(coord).lower().startswith("lon"):
lon_coord = coord
orig_chunksizes = obj.chunksizes

if lon_coord is not None or lat_coord is not None:
obj = format_spherical(obj, target, lat_coord, lon_coord)
# Special-cased coordinates with accepted names and formatting function
coord_handlers: dict[str, CoordHandler] = {
"lat": {"names": ["lat", "latitude"], "func": format_lat},
"lon": {"names": ["lon", "longitude"], "func": format_lon},
}
# Identify coordinates that need to be formatted
formatted_coords = {}
for coord_type, handler in coord_handlers.items():
for coord in obj.coords.keys():
if str(coord).lower() in handler["names"]:
formatted_coords[coord_type] = str(coord)

# Apply formatting
for coord_type, coord in formatted_coords.items():
# Make sure formatted coords are sorted
obj = obj.sortby(coord)
target = target.sortby(coord)
obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords)
# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(coord, [])) == 1:
obj = obj.chunk({coord: -1})

return obj


def format_spherical(
def format_lat(
obj: xr.DataArray | xr.Dataset,
target: xr.Dataset,
lat_coord: Hashable,
lon_coord: Hashable,
target: xr.Dataset, # noqa ARG001
formatted_coords: dict[str, str],
) -> xr.DataArray | xr.Dataset:
"""Infer whether a lat/lon source grid represents a global domain and
automatically apply spherical padding to improve edge effects.
For longitude, shift the coordinate to line up with the target values, then
add a single wraparound padding column if the domain is global and the east
or west edges of the target lie outside the source grid centers.
For latitude, add a single value at each pole computed as the mean of the last
"""For latitude, add a single value at each pole computed as the mean of the last
row for global source grids where the first or last point lie equatorward of 90.
"""
lat_coord = formatted_coords["lat"]
lon_coord = formatted_coords.get("lon")

# Concat a padded value representing the mean of the first/last lat bands
# This should match the Pole="all" option of ESMF
# TODO: with cos(90) = 0 weighting, these weights might be 0?

polar_lat = 90
dy = obj.coords[lat_coord].diff(lat_coord).max().values.item()

# Only pad if global but don't have edge values directly at poles
# South pole
if dy - polar_lat >= obj.coords[lat_coord].values[0] > -polar_lat:
south_pole = obj.isel({lat_coord: 0})
if lon_coord is not None:
south_pole = south_pole.mean(lon_coord)
obj = xr.concat([south_pole, obj], dim=lat_coord) # type: ignore
obj.coords[lat_coord].values[0] = -polar_lat

# North pole
if polar_lat - dy <= obj.coords[lat_coord].values[-1] < polar_lat:
north_pole = obj.isel({lat_coord: -1})
if lon_coord is not None:
north_pole = north_pole.mean(lon_coord)
obj = xr.concat([obj, north_pole], dim=lat_coord) # type: ignore
obj.coords[lat_coord].values[-1] = polar_lat

orig_chunksizes = obj.chunksizes
return obj

# If the source coord fully covers the target, don't modify them
if lat_coord and not coord_is_covered(obj, target, lat_coord):
obj = obj.sortby(lat_coord)
target = target.sortby(lat_coord)

# Only pad if global but don't have edge values directly at poles
polar_lat = 90
dy = obj[lat_coord].diff(lat_coord).max().values

# South pole
if dy - polar_lat >= obj[lat_coord][0] > -polar_lat:
south_pole = obj.isel({lat_coord: 0})
# This should match the Pole="all" option of ESMF
if lon_coord is not None:
south_pole = south_pole.mean(lon_coord)
obj = xr.concat([south_pole, obj], dim=lat_coord)
obj[lat_coord].values[0] = -polar_lat

# North pole
if polar_lat - dy <= obj[lat_coord][-1] < polar_lat:
north_pole = obj.isel({lat_coord: -1})
if lon_coord is not None:
north_pole = north_pole.mean(lon_coord)
obj = xr.concat([obj, north_pole], dim=lat_coord)
obj[lat_coord].values[-1] = polar_lat

# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(lat_coord, [])) == 1:
obj = obj.chunk({lat_coord: -1})

if lon_coord and not coord_is_covered(obj, target, lon_coord):
obj = obj.sortby(lon_coord)
target = target.sortby(lon_coord)

target_lon = target[lon_coord].values
# Find a wrap point outside of the left and right bounds of the target
# This ensures we have coverage on the target and handles global > regional
wrap_point = (target_lon[-1] + target_lon[0] + 360) / 2
lon = obj[lon_coord].values
lon = np.where(lon < wrap_point - 360, lon + 360, lon)
lon = np.where(lon > wrap_point, lon - 360, lon)
obj[lon_coord].values[:] = lon

# Shift operations can produce duplicates
# Simplest solution is to drop them and add back when padding
obj = obj.sortby(lon_coord).drop_duplicates(lon_coord)

# Only pad if domain is global in lon
dx_s = obj[lon_coord].diff(lon_coord).max().values
dx_t = target[lon_coord].diff(lon_coord).max().values
is_global_lon = obj[lon_coord].max() - obj[lon_coord].min() >= 360 - dx_s

if is_global_lon:
left_pad = (obj[lon_coord][0] - target[lon_coord][0] + dx_t / 2) / dx_s
right_pad = (target[lon_coord][-1] - obj[lon_coord][-1] + dx_t / 2) / dx_s
left_pad = int(np.ceil(np.max([left_pad, 0])))
right_pad = int(np.ceil(np.max([right_pad, 0])))
lon = obj[lon_coord].values
obj = obj.pad(
{lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True
def format_lon(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, formatted_coords: dict[str, str]
) -> xr.DataArray | xr.Dataset:
"""For longitude, shift the coordinate to line up with the target values, then
add a single wraparound padding column if the domain is global and the east
or west edges of the target lie outside the source grid centers.
"""
lon_coord = formatted_coords["lon"]

# Find a wrap point outside of the left and right bounds of the target
# This ensures we have coverage on the target and handles global > regional
source_vals = obj.coords[lon_coord].values
target_vals = target.coords[lon_coord].values
wrap_point = (target_vals[-1] + target_vals[0] + 360) / 2
source_vals = np.where(
source_vals < wrap_point - 360, source_vals + 360, source_vals
)
source_vals = np.where(source_vals > wrap_point, source_vals - 360, source_vals)
obj.coords[lon_coord].values[:] = source_vals

# Shift operations can produce duplicates
# Simplest solution is to drop them and add back when padding
obj = obj.sortby(lon_coord).drop_duplicates(lon_coord)

# Only pad if domain is global in lon
source_lon = obj.coords[lon_coord]
target_lon = target.coords[lon_coord]
dx_s = source_lon.diff(lon_coord).max().values.item()
dx_t = target_lon.diff(lon_coord).max().values.item()
is_global_lon = source_lon.max().values - source_lon.min().values >= 360 - dx_s

if is_global_lon:
left_pad = (source_lon.values[0] - target_lon.values[0] + dx_t / 2) / dx_s
right_pad = (target_lon.values[-1] - source_lon.values[-1] + dx_t / 2) / dx_s
left_pad = int(np.ceil(np.max([left_pad, 0])))
right_pad = int(np.ceil(np.max([right_pad, 0])))
obj = obj.pad({lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True)
if left_pad:
obj.coords[lon_coord].values[:left_pad] = (
source_lon.values[-left_pad:] - 360
)
if right_pad:
obj.coords[lon_coord].values[-right_pad:] = (
source_lon.values[:right_pad] + 360
)
if left_pad:
obj[lon_coord].values[:left_pad] = lon[-left_pad:] - 360
if right_pad:
obj[lon_coord].values[-right_pad:] = lon[:right_pad] + 360

# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(lon_coord, [])) == 1:
obj = obj.chunk({lon_coord: -1})

return obj

Expand Down

0 comments on commit 283c310

Please sign in to comment.