Skip to content

Commit

Permalink
Merge pull request #974 from UXARRAY/zedwick/remap-overhaul
Browse files Browse the repository at this point in the history
Remapping Code Coverage, Modularize Reused Code, and Documentation Updates
  • Loading branch information
rajeeja authored Oct 3, 2024
2 parents 510e299 + 16c0af6 commit 86284d6
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 206 deletions.
1 change: 1 addition & 0 deletions docs/internal_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ Remapping
remap.inverse_distance_weighted._inverse_distance_weighted_remap
remap.inverse_distance_weighted._inverse_distance_weighted_remap_uxda
remap.inverse_distance_weighted._inverse_distance_weighted_remap_uxds
remap.utils._remap_grid_parse


Grid Parsing and Encoding
Expand Down
203 changes: 180 additions & 23 deletions test/test_remap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from unittest import TestCase
from pathlib import Path
import numpy.testing as nt

import uxarray as ux

Expand Down Expand Up @@ -104,16 +105,15 @@ def test_remap_return_types(self):
dsfile_v1_geoflow, dsfile_v2_geoflow, dsfile_v3_geoflow
]
source_uxds = ux.open_mfdataset(gridfile_geoflow, source_data_paths)
destination_uxds = ux.open_dataset(gridfile_CSne30,
dsfile_vortex_CSne30)
destination_grid = ux.open_grid(gridfile_CSne30)

remap_uxda_to_grid = source_uxds['v1'].remap.nearest_neighbor(
destination_uxds.uxgrid)
destination_grid)

assert isinstance(remap_uxda_to_grid, UxDataArray)

remap_uxds_to_grid = source_uxds.remap.nearest_neighbor(
destination_uxds.uxgrid)
destination_grid)

# Dataset with three vars: remapped "v1, v2, v3"
assert isinstance(remap_uxds_to_grid, UxDataset)
Expand All @@ -125,13 +125,17 @@ def test_edge_centers_remapping(self):

# Open source and destination datasets to remap to
source_grid = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow)
destination_grid = ux.open_dataset(mpasfile_QU, mpasfile_QU)
destination_grid = ux.open_grid(mpasfile_QU)

remap_to_edge_centers = source_grid['v1'].remap.nearest_neighbor(destination_grid=destination_grid.uxgrid,
remap_to="edge centers")
remap_to_edge_centers_spherical = source_grid['v1'].remap.nearest_neighbor(destination_grid=destination_grid,
remap_to="edge centers", coord_type='spherical')

remap_to_edge_centers_cartesian = source_grid['v1'].remap.nearest_neighbor(destination_grid=destination_grid,
remap_to="edge centers", coord_type='cartesian')

# Assert the data variable lies on the "edge centers"
self.assertTrue(remap_to_edge_centers._edge_centered())
self.assertTrue(remap_to_edge_centers_spherical._edge_centered())
self.assertTrue(remap_to_edge_centers_cartesian._edge_centered())

def test_overwrite(self):
"""Tests that the remapping no longer overwrites the dataset."""
Expand All @@ -142,11 +146,74 @@ def test_overwrite(self):

# Perform remapping
remap_to_edge_centers = source_grid['v1'].remap.nearest_neighbor(destination_grid=destination_dataset.uxgrid,
remap_to="nodes")
remap_to="face centers", coord_type='cartesian')

# Assert the remapped data is different from the original data
assert not np.array_equal(destination_dataset['v1'], remap_to_edge_centers)

def test_source_data_remap(self):
"""Test the remapping of all source data positions."""

# Open source and destination datasets to remap to
source_uxds = ux.open_dataset(mpasfile_QU, mpasfile_QU)
destination_grid = ux.open_grid(gridfile_geoflow)

# Remap from `face_centers`
face_centers = source_uxds['latCell'].remap.nearest_neighbor(
destination_grid=destination_grid,
remap_to="nodes"
)

# Remap from `nodes`
nodes = source_uxds['latVertex'].remap.nearest_neighbor(
destination_grid=destination_grid,
remap_to="nodes"
)

# Remap from `edges`
edges = source_uxds['angleEdge'].remap.nearest_neighbor(
destination_grid=destination_grid,
remap_to="nodes"
)

self.assertTrue(len(face_centers.values) != 0)
self.assertTrue(len(nodes.values) != 0)
self.assertTrue(len(edges.values) != 0)

def test_value_errors(self):
"""Tests the raising of value errors and warnings in the function."""

# Open source and destination datasets to remap to
source_uxds = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow)
source_uxds_2 = ux.open_dataset(mpasfile_QU, mpasfile_QU)
destination_grid = ux.open_grid(gridfile_geoflow)

# Raise ValueError when `remap_to` is invalid
with nt.assert_raises(ValueError):
source_uxds['v1'].remap.nearest_neighbor(
destination_grid=destination_grid,
remap_to="test", coord_type='spherical'
)
with nt.assert_raises(ValueError):
source_uxds['v1'].remap.nearest_neighbor(
destination_grid=destination_grid,
remap_to="test", coord_type="cartesian"
)

# Raise ValueError when `coord_type` is invalid
with nt.assert_raises(ValueError):
source_uxds['v1'].remap.nearest_neighbor(
destination_grid=destination_grid,
remap_to="nodes", coord_type="test"
)

# Raise ValueError when the source data is invalid
with nt.assert_raises(ValueError):
source_uxds_2['cellsOnCell'].remap.nearest_neighbor(
destination_grid=destination_grid,
remap_to="nodes"
)


class TestInverseDistanceWeightedRemapping(TestCase):
"""Testing for inverse distance weighted remapping."""
Expand All @@ -156,10 +223,10 @@ def test_remap_center_nodes(self):

# datasets to use for remap
dataset = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow)
destination_uxds = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow)
destination_grid = ux.open_grid(gridfile_geoflow)

data_on_face_centers = dataset['v1'].remap.inverse_distance_weighted(
destination_uxds.uxgrid, remap_to="face centers")
destination_grid, remap_to="face centers", power=6)

assert not np.array_equal(dataset['v1'], data_on_face_centers)

Expand All @@ -168,10 +235,10 @@ def test_remap_nodes(self):

# datasets to use for remap
dataset = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow)
destination_uxds = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow)
destination_grid = ux.open_grid(gridfile_geoflow)

data_on_nodes = dataset['v1'].remap.inverse_distance_weighted(
destination_uxds.uxgrid, remap_to="nodes")
destination_grid, remap_to="nodes")

assert not np.array_equal(dataset['v1'], data_on_nodes)

Expand Down Expand Up @@ -217,17 +284,16 @@ def test_remap_return_types(self):
dsfile_v1_geoflow, dsfile_v2_geoflow, dsfile_v3_geoflow
]
source_uxds = ux.open_mfdataset(gridfile_geoflow, source_data_paths)
destination_uxds = ux.open_dataset(gridfile_CSne30,
dsfile_vortex_CSne30)
destination_grid = ux.open_grid(gridfile_CSne30)

remap_uxda_to_grid = source_uxds['v1'].remap.inverse_distance_weighted(
destination_uxds.uxgrid, power=3, k=10)
destination_grid, power=3, k=10)

assert isinstance(remap_uxda_to_grid, UxDataArray)
assert len(remap_uxda_to_grid) == 1

remap_uxds_to_grid = source_uxds.remap.inverse_distance_weighted(
destination_uxds.uxgrid)
destination_grid)

# Dataset with three vars: remapped "v1, v2, v3"
assert isinstance(remap_uxds_to_grid, UxDataset)
Expand All @@ -239,15 +305,21 @@ def test_edge_remapping(self):

# Open source and destination datasets to remap to
source_grid = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow)
destination_grid = ux.open_dataset(mpasfile_QU, mpasfile_QU)
destination_grid = ux.open_grid(mpasfile_QU)

# Perform remapping to the edge centers of the dataset

remap_to_edge_centers = source_grid['v1'].remap.inverse_distance_weighted(destination_grid=destination_grid.uxgrid,
remap_to="edge centers")
remap_to_edge_centers_spherical = source_grid['v1'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="edge centers", coord_type='spherical')

remap_to_edge_centers_cartesian = source_grid['v1'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="edge centers", coord_type='cartesian')

# Assert the data variable lies on the "edge centers"
self.assertTrue(remap_to_edge_centers._edge_centered())
self.assertTrue(remap_to_edge_centers_spherical._edge_centered())
self.assertTrue(remap_to_edge_centers_cartesian._edge_centered())

def test_overwrite(self):
"""Tests that the remapping no longer overwrites the dataset."""
Expand All @@ -257,8 +329,93 @@ def test_overwrite(self):
destination_dataset = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow)

# Perform Remapping
remap_to_edge_centers = source_grid['v1'].remap.inverse_distance_weighted(destination_grid=destination_dataset.uxgrid,
remap_to="nodes")
remap_to_edge_centers = source_grid['v1'].remap.inverse_distance_weighted(
destination_grid=destination_dataset.uxgrid,
remap_to="face centers", coord_type='cartesian')

# Assert the remapped data is different from the original data
assert not np.array_equal(destination_dataset['v1'], remap_to_edge_centers)

def test_source_data_remap(self):
"""Test the remapping of all source data positions."""

# Open source and destination datasets to remap to
source_uxds = ux.open_dataset(mpasfile_QU, mpasfile_QU)
destination_grid = ux.open_grid(gridfile_geoflow)

# Remap from `face_centers`
face_centers = source_uxds['latCell'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="nodes"
)

# Remap from `nodes`
nodes = source_uxds['latVertex'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="nodes"
)

# Remap from `edges`
edges = source_uxds['angleEdge'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="nodes"
)

self.assertTrue(len(face_centers.values) != 0)
self.assertTrue(len(nodes.values) != 0)
self.assertTrue(len(edges.values) != 0)

def test_value_errors(self):
"""Tests the raising of value errors and warnings in the function."""

# Open source and destination datasets to remap to
source_uxds = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow)
source_uxds_2 = ux.open_dataset(mpasfile_QU, mpasfile_QU)
destination_grid = ux.open_grid(gridfile_geoflow)

# Raise ValueError when `k` =< 1
with nt.assert_raises(ValueError):
source_uxds['v1'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="nodes", k=1
)

# Raise ValueError when k is larger than `n_node`
with nt.assert_raises(ValueError):
source_uxds['v1'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="nodes", k=source_uxds.uxgrid.n_node + 1
)

# Raise ValueError when `remap_to` is invalid
with nt.assert_raises(ValueError):
source_uxds['v1'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="test", k=2, coord_type='spherical'
)
with nt.assert_raises(ValueError):
source_uxds['v1'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="test", k=2, coord_type="cartesian"
)

# Raise ValueError when `coord_type` is invalid
with nt.assert_raises(ValueError):
source_uxds['v1'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="nodes", k=2, coord_type="test"
)

# Raise ValueError when the source data is invalid
with nt.assert_raises(ValueError):
source_uxds_2['cellsOnCell'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="nodes"
)

# Raise UserWarning when `power` > 5
with nt.assert_warns(UserWarning):
source_uxds['v1'].remap.inverse_distance_weighted(
destination_grid=destination_grid,
remap_to="nodes", power=6
)
Loading

0 comments on commit 86284d6

Please sign in to comment.