Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR]: Improving regrid2 performance #533

Merged
merged 27 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5b1f3e5
Fixes regrid2 operating on numpy arrays
jasonb5 Aug 16, 2023
6ed8990
Merge branch 'main' into regrid2_performance
jasonb5 Oct 11, 2023
f50cf24
Removes regrid output mask
jasonb5 Oct 26, 2023
2e58cd7
Fixes extracting input data variable
jasonb5 Oct 26, 2023
df2d23a
Updates regrid2
jasonb5 Nov 7, 2023
3d5d0d9
Overrides dtype to match input
jasonb5 Nov 8, 2023
5f5e1be
Fixes wrapping when using np.take
jasonb5 Nov 8, 2023
dd671b4
Fixes copying variable attributes
jasonb5 Nov 8, 2023
c6e65b7
Fixes typing errors
jasonb5 Nov 8, 2023
a68bf9f
Fixes correcting dtype before mapping/regridding
jasonb5 Dec 7, 2023
53989b7
Fixes tests
jasonb5 Dec 7, 2023
e6fe7c9
Fixes latitude that was wrapping
jasonb5 Dec 7, 2023
934e028
Fixes copying coordinates when name is missmatched
jasonb5 Dec 7, 2023
62fe454
Fixes shifting longitude
jasonb5 Dec 7, 2023
d5d7f49
Merge branch 'main' into regrid2_performance
jasonb5 Dec 7, 2023
a2eb927
Fixes missmatched coordinate names
jasonb5 Dec 7, 2023
ea19fb0
Fixes reshape/ordering output data
jasonb5 Dec 9, 2023
b67fe3a
Adds more optimizations
jasonb5 Dec 9, 2023
bcdc652
Fixes masking
jasonb5 Jan 19, 2024
3312be3
Fixes failing tests
jasonb5 Jan 19, 2024
9377a88
Merge branch 'main' into regrid2_performance
jasonb5 Jan 19, 2024
9d49dbe
Fixes tests
jasonb5 Jan 20, 2024
ad61067
Merge branch 'main' into regrid2_performance
jasonb5 Feb 12, 2024
96eaae2
Fixes docstrings and adds comments
jasonb5 Feb 23, 2024
7695e2b
Adds comments
jasonb5 Feb 28, 2024
a5bf57c
Merge branch 'main' into regrid2_performance
jasonb5 Feb 28, 2024
e2125f9
Replaces cf.axes with get_dim_keys
jasonb5 Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 18 additions & 65 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,23 +384,6 @@ def test_vertical_placeholder(self):
with pytest.raises(NotImplementedError, match=""):
regridder.vertical("so", ds)

def test_missing_dimension(self):
ds = fixtures.generate_dataset(
decode_times=True, cf_compliant=False, has_bounds=True
)

del ds.lat.attrs["axis"]

output_grid = grid.create_gaussian_grid(32)

regridder = regrid2.Regrid2Regridder(ds, output_grid)

with pytest.raises(
RuntimeError,
match="Could not find axis 'lat', ensure 'lat' exists and the attributes are correct.",
):
regridder.horizontal("ts", ds)

@pytest.mark.filterwarnings("ignore:.*invalid value.*divide.*:RuntimeWarning")
def test_output_bounds(self):
ds = fixtures.generate_dataset(
Expand Down Expand Up @@ -499,45 +482,15 @@ def test_regrid_input_mask(self):

expected_output = np.array(
[
[0.0, 0.0, 0.0, 0.0],
[0.70710677, 0.70710677, 0.70710677, 0.70710677],
[0.70710677, 0.70710677, 0.70710677, 0.70710677],
[0.0, 0.0, 0.0, 0.0],
],
dtype=np.float32,
[0.0] * 4,
[0.7071067811865476] * 4,
[0.7071067811865476] * 4,
[0.0] * 4,
]
)

assert np.all(output_data.ts.values == expected_output)

def test_regrid_output_mask(self):
output_mask = [
[0, 0, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1],
[0, 0, 0, 0],
]

self.fine_2d_ds["mask"] = (("lat", "lon"), output_mask)

regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)

output_data = regridder.horizontal("ts", self.coarse_2d_ds)

expected_output = np.array(
[
[1.0, 1.0, 1.0, 1.0],
[1e20, 1e20, 1e20, 1e20],
[1e20, 1e20, 1e20, 1e20],
[1.0, 1.0, 1.0, 1.0],
],
dtype=np.float32,
)

# need to replace nans since nan != nan
output_data["ts"] = output_data.ts.fillna(1e20)

assert np.all(output_data.ts.values == expected_output)

def test_preserve_attrs(self):
regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)

Expand All @@ -547,7 +500,7 @@ def test_preserve_attrs(self):
assert output_data["ts"].attrs == self.da_attrs

for x in output_data.coords:
assert output_data[x].attrs == self.coarse_2d_ds[x].attrs
assert output_data[x].attrs == self.coarse_2d_ds[x].attrs, f"{x}"

def test_regrid_2d(self):
regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)
Expand Down Expand Up @@ -582,7 +535,7 @@ def test_regrid_4d(self):

def test_map_longitude_coarse_to_fine(self):
mapping, weights = regrid2._map_longitude(
self.coarse_lon_bnds, self.fine_lon_bnds
self.coarse_lon_bnds.values, self.fine_lon_bnds.values
)

expected_mapping = [
Expand All @@ -593,33 +546,33 @@ def test_map_longitude_coarse_to_fine(self):
]

expected_weigths = [
[[90]],
[[90]],
[[90]],
[[90]],
[90],
[90],
[90],
[90],
]

np.testing.assert_allclose(mapping, expected_mapping)
np.testing.assert_allclose(weights, expected_weigths)

def test_map_longitude_fine_to_coarse(self):
mapping, weights = regrid2._map_longitude(
self.fine_lon_bnds, self.coarse_lon_bnds
self.fine_lon_bnds.values, self.coarse_lon_bnds.values
)

expected_mapping = [
[0, 1],
[2, 3],
]

expected_weigths = [[[90, 90]], [[90, 90]]]
expected_weigths = [[90, 90], [90, 90]]

np.testing.assert_allclose(mapping, expected_mapping)
np.testing.assert_allclose(weights, expected_weigths)

def test_map_latitude_coarse_to_fine(self):
mapping, weights = regrid2._map_latitude(
self.coarse_lat_bnds, self.fine_lat_bnds
self.coarse_lat_bnds.values, self.fine_lat_bnds.values
)

expected_mapping = [
Expand Down Expand Up @@ -648,7 +601,7 @@ def test_map_latitude_coarse_to_fine(self):

def test_map_latitude_fine_to_coarse(self):
mapping, weights = regrid2._map_latitude(
self.fine_lat_bnds, self.coarse_lat_bnds
self.fine_lat_bnds.values, self.coarse_lat_bnds.values
)

expected_mapping = [
Expand All @@ -658,9 +611,9 @@ def test_map_latitude_fine_to_coarse(self):
]

expected_weigths = [
[[0.29289322], [0.20710678]],
[[0.5], [0.5]],
[[0.20710678], [0.29289322]],
[0.29289322, 0.20710678],
[0.5, 0.5],
[0.20710678, 0.29289322],
]

np.testing.assert_allclose(mapping, expected_mapping)
Expand Down
Loading
Loading