Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: API for accessing partitioned connectivity, reworked topological aggregations #978

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 96 additions & 77 deletions uxarray/core/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import numpy as np
import dask
import dask.array as da

from uxarray.grid.connectivity import get_face_node_partitions

import uxarray.core.dataarray


NUMPY_AGGREGATIONS = {
"mean": np.mean,
"max": np.max,
"min": np.min,
"prod": np.prod,
"sum": np.sum,
"std": np.std,
"var": np.var,
"median": np.median,
"all": np.all,
"any": np.any,
}
from uxarray.core.numba_aggregation import NUMBA_NODE_FACE_AGGS, NUMBA_NODE_EDGE_AGGS


def result_array(arr):
if isinstance(arr, np.ndarray):
return np.empty
if isinstance(arr, dask.array.core.Array):
return da.empty


def _uxda_grid_aggregate(uxda, destination, aggregation, **kwargs):
Expand Down Expand Up @@ -79,18 +75,10 @@ def _node_to_face_aggregation(uxda, aggregation, aggregation_func_kwargs):
f"{uxda.uxgrid.n_face}."
)

if isinstance(uxda.data, np.ndarray):
# apply aggregation using numpy
aggregated_var = _apply_node_to_face_aggregation_numpy(
uxda, NUMPY_AGGREGATIONS[aggregation], aggregation_func_kwargs
)
elif isinstance(uxda.data, da.Array):
# apply aggregation on dask array, TODO:
aggregated_var = _apply_node_to_face_aggregation_numpy(
uxda, NUMPY_AGGREGATIONS[aggregation], aggregation_func_kwargs
)
else:
raise ValueError
# TODO:
aggregated_var = _apply_node_to_face_aggregation(
uxda, aggregation, aggregation_func_kwargs
)

return uxarray.core.dataarray.UxDataArray(
uxgrid=uxda.uxgrid,
Expand All @@ -100,41 +88,58 @@ def _node_to_face_aggregation(uxda, aggregation, aggregation_func_kwargs):
).rename({"n_node": "n_face"})


def _apply_node_to_face_aggregation_numpy(
uxda, aggregation_func, aggregation_func_kwargs
def _apply_node_to_face_aggregation(
uxda, aggregation_func, aggregation_func_kwargs, result_array_kwargs=None
):
"""Applies a Node to Face Topological aggregation on a Numpy array."""
data = uxda.values
face_node_conn = uxda.uxgrid.face_node_connectivity.values
n_nodes_per_face = uxda.uxgrid.n_nodes_per_face.values

(
change_ind,
n_nodes_per_face_sorted_ind,
element_sizes,
size_counts,
) = get_face_node_partitions(n_nodes_per_face)

result = np.empty(shape=(data.shape[:-1]) + (uxda.uxgrid.n_face,))

for e, start, end in zip(element_sizes, change_ind[:-1], change_ind[1:]):
face_inds = n_nodes_per_face_sorted_ind[start:end]
face_nodes_par = face_node_conn[face_inds, 0:e]

# apply aggregation function to current face node partition
aggregation_par = aggregation_func(
data[..., face_nodes_par], axis=-1, **aggregation_func_kwargs
)
"""TODO:"""

# store current aggregation
result[..., face_inds] = aggregation_par
if isinstance(uxda.data, np.ndarray):
_numba_agg_func = NUMBA_NODE_FACE_AGGS[aggregation_func]
result = _numba_agg_func(
uxda.data,
uxda.uxgrid.face_node_connectivity.values,
uxda.uxgrid.n_nodes_per_face.values,
uxda.uxgrid.n_face,
)
elif isinstance(uxda.data, dask.array.core.Array):
result = _node_to_face_aggregation_dask(
uxda,
aggregation_func,
aggregation_func_kwargs,
)
else:
raise ValueError("TODO")

return result


def _apply_node_to_face_aggregation_dask(*args, **kwargs):
"""Applies a Node to Face Topological aggregation on a Dask array."""
pass
def _node_to_face_aggregation_dask(uxda, aggregation_func, aggregation_func_kwargs):
# shape [..., n_face] since data is being aggregated onto the faces
result = result_array(uxda.data)(
shape=(uxda.data.shape[:-1]) + (uxda.uxgrid.n_face,)
)

for (
cur_face_node_partition,
cur_original_face_indices,
) in uxda.uxgrid.partitioned_face_node_connectivity:
# index array using flattened connectivity (to avoid Dask errors)
data_flat = uxda.data[..., cur_face_node_partition.flatten()]

# reshape index data back to desired shape [..., n_face_geom, geom_size]
data_reshaped = data_flat.reshape(
(uxda.data.shape[:-1]) + cur_face_node_partition.shape
)

# apply aggregation on current partition of elements
aggregation_par = getattr(data_reshaped, aggregation_func)(
axis=-1, **aggregation_func_kwargs
)

# store computed aggregation using original face indices
result[..., cur_original_face_indices] = aggregation_par

return result


def _node_to_edge_aggregation(uxda, aggregation, aggregation_func_kwargs):
Expand All @@ -145,39 +150,53 @@ def _node_to_edge_aggregation(uxda, aggregation, aggregation_func_kwargs):
f"{uxda.uxgrid.n_face}."
)

if isinstance(uxda.data, np.ndarray):
# apply aggregation using numpy
aggregation_var = _apply_node_to_edge_aggregation_numpy(
uxda, NUMPY_AGGREGATIONS[aggregation], aggregation_func_kwargs
)
elif isinstance(uxda.data, da.Array):
# apply aggregation on dask array, TODO:
aggregation_var = _apply_node_to_edge_aggregation_numpy(
uxda, NUMPY_AGGREGATIONS[aggregation], aggregation_func_kwargs
)
else:
raise ValueError
# TODO:
aggregated_var = _apply_node_to_edge_aggregation_(
uxda, aggregation, aggregation_func_kwargs
)

return uxarray.core.dataarray.UxDataArray(
uxgrid=uxda.uxgrid,
data=aggregation_var,
data=aggregated_var,
dims=uxda.dims,
name=uxda.name,
).rename({"n_node": "n_edge"})


def _apply_node_to_edge_aggregation_numpy(
uxda, aggregation_func, aggregation_func_kwargs
def _apply_node_to_edge_aggregation_(
uxda, aggregation_func, aggregation_func_kwargs, result_array_kwargs=None
):
"""Applies a Node to Edge topological aggregation on a numpy array."""
data = uxda.values
edge_node_conn = uxda.uxgrid.edge_node_connectivity.values
result = aggregation_func(
data[..., edge_node_conn], axis=-1, **aggregation_func_kwargs
"""TODO:"""
if isinstance(uxda.data, np.ndarray):
_numba_agg_func = NUMBA_NODE_EDGE_AGGS[aggregation_func]
result = _numba_agg_func(
uxda.data, uxda.uxgrid.edge_node_connectivity.values, uxda.uxgrid.n_face
)
elif isinstance(uxda.data, dask.array.core.Array):
result = _node_to_edge_aggregation_dask(
uxda,
aggregation_func,
aggregation_func_kwargs,
)
else:
raise ValueError("TODO")

return result

data_flat = uxda.data[..., uxda.uxgrid.edge_node_connectivity.data.flatten()]
data_reshaped = data_flat.reshape((uxda.data.shape[:-1]) + (uxda.uxgrid.n_edge, 2))
result = getattr(data_reshaped, aggregation_func)(
axis=-1, **aggregation_func_kwargs
)
return result


def _apply_node_to_edge_aggregation_dask(*args, **kwargs):
"""Applies a Node to Edge topological aggregation on a dask array."""
pass
def _node_to_edge_aggregation_dask(uxda, aggregation_func, aggregation_func_kwargs):
"""TODO:"""

data_flat = uxda.data[..., uxda.uxgrid.edge_node_connectivity.data.flatten()]
data_reshaped = data_flat.reshape((uxda.data.shape[:-1]) + (uxda.uxgrid.n_edge, 2))
result = getattr(data_reshaped, aggregation_func)(
axis=-1, **aggregation_func_kwargs
)
return result
Loading
Loading