Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply function to points within circular neighborhood #941

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
118 changes: 115 additions & 3 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import numpy as np


from typing import TYPE_CHECKING, Optional, Union, Hashable, Literal
from typing import TYPE_CHECKING, Callable, Optional, Union, Hashable, Literal

from uxarray.constants import GRID_DIMS
from uxarray.formatting_html import array_repr

from html import escape
Expand Down Expand Up @@ -1046,8 +1047,6 @@ def isel(self, ignore_grid=False, *args, **kwargs):
> uxda.subset(n_node=[1, 2, 3])
"""

from uxarray.constants import GRID_DIMS

if any(grid_dim in kwargs for grid_dim in GRID_DIMS) and not ignore_grid:
# slicing a grid-dimension through Grid object

Expand Down Expand Up @@ -1104,3 +1103,116 @@ def _slice_from_grid(self, sliced_grid):
dims=self.dims,
attrs=self.attrs,
)

def neighborhood_filter(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation looks great! May we move the bulk of the logic into the uxarray.grid.neighbors module and call that helper from here?

We can keep the data-mapping checks here, and anything related to constructing and returining the final data array but the bulk of the computations would go inside a helper in the module mentioned above.

Copy link
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to think about how to do that, but I am happy to defer to you.

self,
func: Callable = np.mean,
r: float = 1.0,
) -> UxDataArray:
"""Apply neighborhood filter
Parameters:
-----------
func: Callable, default=np.mean
Apply this function to neighborhood
r : float, default=1.
Radius of neighborhood. For spherical coordinates, the radius is in units of degrees,
and for cartesian coordinates, the radius is in meters.
Returns:
--------
destination_data : np.ndarray
Filtered data.
"""

if self._face_centered():
data_mapping = "face centers"
elif self._node_centered():
data_mapping = "nodes"
elif self._edge_centered():
data_mapping = "edge centers"
else:
raise ValueError(
"Data_mapping is not face, node, or edge. Could not define data_mapping."
)

# reconstruct because the cached tree could be built from
# face centers, edge centers or nodes.
tree = self.uxgrid.get_ball_tree(coordinates=data_mapping, reconstruct=True)
Comment on lines +1063 to +1064
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aaronzedwick

We should probably fix this logic in get_ball_tree(), since we shouldn't need to manually set reconstruct=False

        if self._ball_tree is None or reconstruct:
            self._ball_tree = BallTree(
                self,
                coordinates=coordinates,
                distance_metric=distance_metric,
                coordinate_system=coordinate_system,
                reconstruct=reconstruct,
            )
        else:
            if coordinates != self._ball_tree._coordinates:
                self._ball_tree.coordinates = coordinates

The coordinates != self._ball_tree._coordinates check should be included in the first if

Copy link
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. So, move the coordinates check to the if-clause like this?

                if (
                    self._ball_tree is None
                    or coordinates != self._ball_tree._coordinates
                    or reconstruct
                ):

                    self._ball_tree = BallTree(
                        self,
                        coordinates=coordinates,
                        distance_metric=distance_metric,
                        coordinate_system=coordinate_system,
                        reconstruct=reconstruct,
                    )

What if the coordinate_system is different? Would that also require a newly constructed tree?

Copy link
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever logic is fixed in Grid.get_ball_tree should also be applied to Grid.get_kdtree.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking coordinate system also (coordinate_system is not a hidden variable of _ball_tree; it has no underscore):

                if (
                    self._ball_tree is None
                    or coordinates != self._ball_tree._coordinates
                    or coordinate_system != self._ball_tree.coordinate_system
                    or reconstruct
                ):

                    self._ball_tree = BallTree(
                        self,
                        coordinates=coordinates,
                        distance_metric=distance_metric,
                        coordinate_system=coordinate_system,
                        reconstruct=reconstruct,
                    )


coordinate_system = tree.coordinate_system

if coordinate_system == "spherical":
if data_mapping == "nodes":
lon, lat = (
self.uxgrid.node_lon.values,
self.uxgrid.node_lat.values,
)
elif data_mapping == "face centers":
lon, lat = (
self.uxgrid.face_lon.values,
self.uxgrid.face_lat.values,
)
elif data_mapping == "edge centers":
lon, lat = (
self.uxgrid.edge_lon.values,
self.uxgrid.edge_lat.values,
)
else:
raise ValueError(
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', "
f"but received: {data_mapping}"
)

dest_coords = np.c_[lon, lat]

elif coordinate_system == "cartesian":
if data_mapping == "nodes":
x, y, z = (
self.uxgrid.node_x.values,
self.uxgrid.node_y.values,
self.uxgrid.node_z.values,
)
elif data_mapping == "face centers":
x, y, z = (
self.uxgrid.face_x.values,
self.uxgrid.face_y.values,
self.uxgrid.face_z.values,
)
elif data_mapping == "edge centers":
x, y, z = (
self.uxgrid.edge_x.values,
self.uxgrid.edge_y.values,
self.uxgrid.edge_z.values,
)
else:
raise ValueError(
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', "
f"but received: {data_mapping}"
)

dest_coords = np.c_[x, y, z]
ahijevyc marked this conversation as resolved.
Show resolved Hide resolved

else:
raise ValueError(
f"Invalid coordinate_system. Expected either 'spherical' or 'cartesian', but received {coordinate_system}"
)

neighbor_indices = tree.query_radius(dest_coords, r=r)

destination_data = np.empty(self.data.shape)

# assert last dimension is a GRID dimension.
assert self.dims[-1] in GRID_DIMS, (
f"expected last dimension of uxDataArray {self.data.dims[-1]} "
f"to be one of {GRID_DIMS}"
)
# Apply function to indices on last axis.
for i, idx in enumerate(neighbor_indices):
if len(idx):
destination_data[..., i] = func(self.data[..., idx])

# construct data array for filtered variable
uxda_filter = self._copy()

uxda_filter.data = destination_data

return uxda_filter
37 changes: 36 additions & 1 deletion uxarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import sys

from typing import Optional, IO, Union
from typing import Callable, Optional, IO, Union

from uxarray.constants import GRID_DIMS
from uxarray.grid import Grid
from uxarray.core.dataarray import UxDataArray

Expand Down Expand Up @@ -338,6 +339,40 @@ def to_array(self) -> UxDataArray:
xarr = super().to_array()
return UxDataArray(xarr, uxgrid=self.uxgrid)

def neighborhood_filter(
self,
func: Callable = np.mean,
r: float = 1.0,
):
"""Neighborhood function implementation for ``UxDataset``.
Parameters
---------
func : Callable = np.mean
Apply this function to neighborhood
r : float, default=1.
Radius of neighborhood
"""
ahijevyc marked this conversation as resolved.
Show resolved Hide resolved

destination_uxds = self._copy()
# Loop through uxDataArrays in uxDataset
for var_name in self.data_vars:
uxda = self[var_name]

# Skip if uxDataArray has no GRID dimension.
grid_dims = [dim for dim in uxda.dims if dim in GRID_DIMS]
if len(grid_dims) == 0:
continue

# Put GRID dimension last for UxDataArray.neighborhood_filter.
remember_dim_order = uxda.dims
uxda = uxda.transpose(..., grid_dims[0])
# Filter uxDataArray.
uxda = uxda.neighborhood_filter(func, r)
# Restore old dimension order.
destination_uxds[var_name] = uxda.transpose(*remember_dim_order)

return destination_uxds

def nearest_neighbor_remap(
self,
destination_obj: Union[Grid, UxDataArray, UxDataset],
Expand Down
Loading