Skip to content

Commit

Permalink
Merge pull request #878 from UXARRAY/philipc2/norm-coords
Browse files Browse the repository at this point in the history
Normalization of Parsed Cartesian Coordinates
  • Loading branch information
rajeeja authored Sep 18, 2024
2 parents f159bb1 + 7308b1d commit 0f7282c
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 16 deletions.
15 changes: 14 additions & 1 deletion benchmarks/mpas_ocean.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,20 @@ def time_nearest_neighbor_remapping(self):
def time_inverse_distance_weighted_remapping(self):
self.uxds_480["bottomDepth"].remap.inverse_distance_weighted(self.uxds_120.uxgrid)


class HoleEdgeIndices(DatasetBenchmark):
def time_construct_hole_edge_indices(self, resolution):
ux.grid.geometry._construct_hole_edge_indices(self.uxds.uxgrid.edge_face_connectivity)

class CheckNorm:
param_names = ['resolution']
params = ['480km', '120km']

def setup(self, resolution):
self.uxgrid = ux.open_grid(file_path_dict[resolution][0])

def teardown(self, resolution):
del self.uxgrid

def time_check_norm(self, resolution):
from uxarray.grid.validation import _check_normalization
_check_normalization(self.uxgrid)
38 changes: 32 additions & 6 deletions test/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def test_read_scrip(self):

# Test read from scrip and from ugrid for grid class
grid_CSne8 = ux.open_grid(gridfile_CSne8) # tests from scrip
pass


class TestOperators(TestCase):
Expand Down Expand Up @@ -926,7 +925,6 @@ def test_from_dataset(self):
xrds = xr.open_dataset(self.gridfile_scrip)
uxgrid = ux.Grid.from_dataset(xrds)

pass

def test_from_face_vertices(self):
single_face_latlon = [(0.0, 90.0), (-180, 0.0), (0.0, -90)]
Expand Down Expand Up @@ -961,7 +959,35 @@ def test_populate_bounds_GCA_mix(self):
face_bounds = bounds_xarray.values
nt.assert_allclose(grid.bounds.values, expected_bounds, atol=ERROR_TOLERANCE)

def test_opti_bounds(self):
import uxarray
uxgrid = ux.open_grid(gridfile_CSne8)
bounds = uxgrid.bounds
def test_populate_bounds_MPAS(self):
uxgrid = ux.open_grid(self.gridfile_mpas)
bounds_xarray = uxgrid.bounds


class TestNormalizeExistingCoordinates(TestCase):
gridfile_mpas = current_path / "meshfiles" / "mpas" / "QU" / "mesh.QU.1920km.151026.nc"
gridfile_CSne30 = current_path / "meshfiles" / "ugrid" / "outCSne30" / "outCSne30.ug"

def test_non_norm_initial(self):
"""Check the normalization of coordinates that were initially parsed as
non-normalized."""
from uxarray.grid.validation import _check_normalization
uxgrid = ux.open_grid(self.gridfile_mpas)

# Make the coordinates not normalized
uxgrid.node_x.data = 5 * uxgrid.node_x.data
uxgrid.node_y.data = 5 * uxgrid.node_y.data
uxgrid.node_z.data = 5 * uxgrid.node_z.data
assert not _check_normalization(uxgrid)

uxgrid.normalize_cartesian_coordinates()

assert _check_normalization(uxgrid)

def test_norm_initial(self):
"""Coordinates should be normalized for grids that we construct
them."""
from uxarray.grid.validation import _check_normalization
uxgrid = ux.open_grid(self.gridfile_CSne30)

assert _check_normalization(uxgrid)
37 changes: 37 additions & 0 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_set_desired_longitude_range,
_populate_node_latlon,
_populate_node_xyz,
_normalize_xyz,
)
from uxarray.grid.connectivity import (
_populate_edge_node_connectivity,
Expand Down Expand Up @@ -72,6 +73,7 @@
_check_connectivity,
_check_duplicate_nodes,
_check_area,
_check_normalization,
)

from xarray.core.utils import UncachedAccessor
Expand Down Expand Up @@ -175,6 +177,9 @@ def __init__(
self._ball_tree = None
self._kd_tree = None

# flag to track if coordinates are normalized
self._normalized = None

# set desired longitude range to [-180, 180]
_set_desired_longitude_range(self._ds)

Expand Down Expand Up @@ -1420,6 +1425,38 @@ def compute_face_areas(

return self._face_areas, self._face_jacobian

def normalize_cartesian_coordinates(self):
"""Normalizes Cartesian coordinates."""

if _check_normalization(self):
# check if coordinates are already normalized
return

if "node_x" in self._ds:
# normalize node coordinates
node_x, node_y, node_z = _normalize_xyz(
self.node_x.values, self.node_y.values, self.node_z.values
)
self.node_x.data = node_x
self.node_y.data = node_y
self.node_z.data = node_z
if "edge_x" in self._ds:
# normalize edge coordinates
edge_x, edge_y, edge_z = _normalize_xyz(
self.edge_x.values, self.edge_y.values, self.edge_z.values
)
self.edge_x.data = edge_x
self.edge_y.data = edge_y
self.edge_z.data = edge_z
if "face_x" in self._ds:
# normalize face coordinates
face_x, face_y, face_z = _normalize_xyz(
self.face_x.values, self.face_y.values, self.face_z.values
)
self.face_x.data = face_x
self.face_y.data = face_y
self.face_z.data = face_z

def to_xarray(self, grid_format: Optional[str] = "ugrid"):
"""Returns a xarray Dataset representation in a specific grid format
from the Grid object.
Expand Down
61 changes: 52 additions & 9 deletions uxarray/grid/validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import numpy as np
from warnings import warn


from uxarray.constants import ERROR_TOLERANCE


# validation helper functions
def _check_connectivity(self):
def _check_connectivity(grid):
"""Check if all nodes are referenced by at least one element.
If not, the mesh may have hanging nodes and may not a valid UGRID
Expand All @@ -15,28 +14,28 @@ def _check_connectivity(self):

# Check if all nodes are referenced by at least one element
# get unique nodes in connectivity
nodes_in_conn = np.unique(self.face_node_connectivity.values.flatten())
nodes_in_conn = np.unique(grid.face_node_connectivity.values.flatten())
# remove negative indices/fill values from the list
nodes_in_conn = nodes_in_conn[nodes_in_conn >= 0]

# check if the size of unique nodes in connectivity is equal to the number of nodes
if nodes_in_conn.size == self.n_node:
if nodes_in_conn.size == grid.n_node:
print("-All nodes are referenced by at least one element.")
return True
else:
warn(
"Some nodes may not be referenced by any element. {0} and {1}".format(
nodes_in_conn.size, self.n_node
nodes_in_conn.size, grid.n_node
),
RuntimeWarning,
)
return False


def _check_duplicate_nodes(self):
def _check_duplicate_nodes(grid):
"""Check if there are duplicate nodes in the mesh."""

coords1 = np.column_stack((np.vstack(self.node_lon), np.vstack(self.node_lat)))
coords1 = np.column_stack((np.vstack(grid.node_lon), np.vstack(grid.node_lat)))
unique_nodes, indices = np.unique(coords1, axis=0, return_index=True)
duplicate_indices = np.setdiff1d(np.arange(len(coords1)), indices)

Expand All @@ -53,9 +52,9 @@ def _check_duplicate_nodes(self):
return True


def _check_area(self):
def _check_area(grid):
"""Check if each face area is greater than our constant ERROR_TOLERANCE."""
areas = self.face_areas
areas = grid.face_areas
# Check if area of any face is close to zero
if np.any(np.isclose(areas, 0, atol=ERROR_TOLERANCE)):
warn(
Expand All @@ -66,3 +65,47 @@ def _check_area(self):
else:
print("-No face area is close to zero.")
return True


def _check_normalization(grid):
"""Checks whether all the cartesiain coordinates are normalized."""

if grid._normalized is True:
# grid is already normalized, no need to run extra checks
return grid._normalized

if "node_x" in grid._ds:
if not (
np.isclose(
(grid.node_x**2 + grid.node_y**2 + grid.node_z**2),
1.0,
atol=ERROR_TOLERANCE,
)
).all():
grid._normalized = False
return False
if "edge_x" in grid._ds:
if not (
np.isclose(
(grid.node_x**2 + grid.node_y**2 + grid.node_z**2),
1.0,
atol=ERROR_TOLERANCE,
)
).all():
grid._normalized = False
return False
if "face_x" in grid._ds:
if not (
np.isclose(
(grid.node_x**2 + grid.node_y**2 + grid.node_z**2),
1.0,
atol=ERROR_TOLERANCE,
)
).all():
grid._normalized = False
return False

# set the grid as normalized
grid._normalized = True

return True

0 comments on commit 0f7282c

Please sign in to comment.