Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate datatree mapping.py #8948

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Bug fixes

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
from xarray.core.datatree_mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.core.indexes import Index, Indexes
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
Expand All @@ -36,11 +41,6 @@
from xarray.datatree_.datatree.formatting_html import (
datatree_repr as datatree_repr_html,
)
from xarray.datatree_.datatree.mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.datatree_.datatree.ops import (
DataTreeArithmeticMixin,
MappedDatasetMethodsMixin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import sys
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomNicholas - there are no significant changes to the way map_over_subtree works in this PR. As discussed today, I've opened the PR, and am happy to make changes based on conversation here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay looking at this again I can't see an obvious way to refactor it. I'm mostly just concerned that it actually makes sense to you @owenlittlejohns ?

If so then I suggest we spend the time thinking about #8949 instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomNicholas - yeah. @flamingbear and I have both spent a fair bit of time trying to grok this module. I don't claim that I've got a perfect understanding, but I feel fairly good on the whole. (The conversation yesterday that led to #8949 helped a bit, too)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay great. So I guess we merge this??

FYI if you want to chat about any of this over zoom outside of the weekly meeting I would be happy to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're good to merge!

from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Tuple
from typing import TYPE_CHECKING, Callable

from xarray import DataArray, Dataset

from xarray.core.iterators import LevelOrderIter
from xarray.core.treenode import NodePath, TreeNode

Expand Down Expand Up @@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal:
if node_a.name != node_b.name:
diff = dedent(
f"""\
if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
Expand Down Expand Up @@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
func : callable
Function to apply to datasets with signature:

`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
`func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`.

(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
Expand Down Expand Up @@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable:
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

@functools.wraps(func)
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from xarray.core.datatree import DataTree

Expand Down Expand Up @@ -259,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
return _map_over_subtree


def _handle_errors_with_path_context(path):
def _handle_errors_with_path_context(path: str):
"""Wraps given function so that if it fails it also raises path to node on which it failed."""

def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if sys.version_info >= (3, 11):
# Add the context information to the error message
e.add_note(
f"Raised whilst mapping function over node with path {path}"
)
# Add the context information to the error message
add_note(
e, f"Raised whilst mapping function over node with path {path}"
)
raise

return wrapper
Expand All @@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
err.add_note(msg)


def _check_single_set_return_values(path_to_node, obj):
def _check_single_set_return_values(
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
Expand Down
3 changes: 1 addition & 2 deletions xarray/core/iterators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from collections import abc
from collections.abc import Iterator
from typing import Callable

Expand All @@ -9,7 +8,7 @@
"""These iterators are copied from anytree.iterators, with minor modifications."""


class LevelOrderIter(abc.Iterator):
class LevelOrderIter(Iterator):
"""Iterate over tree applying level-order strategy starting at `node`.
This is the iterator used by `DataTree` to traverse nodes.

Expand Down
3 changes: 0 additions & 3 deletions xarray/datatree_/datatree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# import public API
from .mapping import TreeIsomorphismError, map_over_subtree
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError


__all__ = (
"TreeIsomorphismError",
"InvalidTreeError",
"NotFoundInTreeError",
"map_over_subtree",
)
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from xarray.core.formatting import _compat_to_str, diff_dataset_repr

from xarray.datatree_.datatree.mapping import diff_treestructure
from xarray.core.datatree_mapping import diff_treestructure
from xarray.datatree_.datatree.render import RenderTree

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from xarray import Dataset

from .mapping import map_over_subtree
from xarray.core.datatree_mapping import map_over_subtree

"""
Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import numpy as np
import pytest
import xarray as xr

import xarray as xr
from xarray.core.datatree import DataTree
from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree
from xarray.core.datatree_mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.datatree_.datatree.testing import assert_equal

empty = xr.Dataset()
Expand All @@ -12,7 +16,7 @@
class TestCheckTreesIsomorphic:
def test_not_a_tree(self):
with pytest.raises(TypeError, match="not a tree"):
check_isomorphic("s", 1)
check_isomorphic("s", 1) # type: ignore[arg-type]

def test_different_widths(self):
dt1 = DataTree.from_dict(d={"a": empty})
Expand Down Expand Up @@ -69,7 +73,7 @@ def test_not_isomorphic_complex_tree(self, create_test_datatree):
def test_checking_from_root(self, create_test_datatree):
dt1 = create_test_datatree()
dt2 = create_test_datatree()
real_root = DataTree(name="real root")
real_root: DataTree = DataTree(name="real root")
dt2.name = "not_real_root"
dt2.parent = real_root
with pytest.raises(TreeIsomorphismError):
Expand Down
Loading