diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 41b39bf2b1a..f809590d351 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,9 @@ Bug fixes Internal Changes ~~~~~~~~~~~~~~~~ +- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`) + By `Matt Savoie `_ `Owen Littlejohns + ` and `Tom Nicholas `_. .. _whats-new.2024.03.0: diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1b06d87c9fb..e24dfb504b1 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -18,6 +18,11 @@ from xarray.core.coordinates import DatasetCoordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset, DataVariables +from xarray.core.datatree_mapping import ( + TreeIsomorphismError, + check_isomorphic, + map_over_subtree, +) from xarray.core.indexes import Index, Indexes from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS @@ -36,11 +41,6 @@ from xarray.datatree_.datatree.formatting_html import ( datatree_repr as datatree_repr_html, ) -from xarray.datatree_.datatree.mapping import ( - TreeIsomorphismError, - check_isomorphic, - map_over_subtree, -) from xarray.datatree_.datatree.ops import ( DataTreeArithmeticMixin, MappedDatasetMethodsMixin, diff --git a/xarray/datatree_/datatree/mapping.py b/xarray/core/datatree_mapping.py similarity index 94% rename from xarray/datatree_/datatree/mapping.py rename to xarray/core/datatree_mapping.py index 9546905e1ac..714921d2a90 100644 --- a/xarray/datatree_/datatree/mapping.py +++ b/xarray/core/datatree_mapping.py @@ -4,10 +4,9 @@ import sys from itertools import repeat from textwrap import dedent -from typing import TYPE_CHECKING, Callable, Tuple +from typing import TYPE_CHECKING, Callable from xarray import DataArray, Dataset - from xarray.core.iterators import LevelOrderIter from xarray.core.treenode import NodePath, TreeNode @@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): path_a, path_b = node_a.path, node_b.path - if require_names_equal: - if node_a.name != node_b.name: - diff = dedent( - f"""\ + if require_names_equal and node_a.name != node_b.name: + diff = dedent( + f"""\ Node '{path_a}' in the left object has name '{node_a.name}' Node '{path_b}' in the right object has name '{node_b.name}'""" - ) - return diff + ) + return diff if len(node_a.children) != len(node_b.children): diff = dedent( @@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable: func : callable Function to apply to datasets with signature: - `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. + `func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`. (i.e. func must accept at least one Dataset and return at least one Dataset.) Function will not be applied to any nodes without datasets. @@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable: # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? @functools.wraps(func) - def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: + def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: """Internal function which maps func over every node in tree, returning a tree of the results.""" from xarray.core.datatree import DataTree @@ -259,7 +257,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: return _map_over_subtree -def _handle_errors_with_path_context(path): +def _handle_errors_with_path_context(path: str): """Wraps given function so that if it fails it also raises path to node on which it failed.""" def decorator(func): @@ -267,11 +265,10 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - if sys.version_info >= (3, 11): - # Add the context information to the error message - e.add_note( - f"Raised whilst mapping function over node with path {path}" - ) + # Add the context information to the error message + add_note( + e, f"Raised whilst mapping function over node with path {path}" + ) raise return wrapper @@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None: err.add_note(msg) -def _check_single_set_return_values(path_to_node, obj): +def _check_single_set_return_values( + path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray] +): """Check types returned from single evaluation of func, and return number of return values received from func.""" if isinstance(obj, (Dataset, DataArray)): return 1 diff --git a/xarray/core/iterators.py b/xarray/core/iterators.py index 5c0c0f652e8..dd5fa7ee97a 100644 --- a/xarray/core/iterators.py +++ b/xarray/core/iterators.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections import abc from collections.abc import Iterator from typing import Callable @@ -9,7 +8,7 @@ """These iterators are copied from anytree.iterators, with minor modifications.""" -class LevelOrderIter(abc.Iterator): +class LevelOrderIter(Iterator): """Iterate over tree applying level-order strategy starting at `node`. This is the iterator used by `DataTree` to traverse nodes. diff --git a/xarray/datatree_/datatree/__init__.py b/xarray/datatree_/datatree/__init__.py index f2603b64641..3159d612913 100644 --- a/xarray/datatree_/datatree/__init__.py +++ b/xarray/datatree_/datatree/__init__.py @@ -1,11 +1,8 @@ # import public API -from .mapping import TreeIsomorphismError, map_over_subtree from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError __all__ = ( - "TreeIsomorphismError", "InvalidTreeError", "NotFoundInTreeError", - "map_over_subtree", ) diff --git a/xarray/datatree_/datatree/formatting.py b/xarray/datatree_/datatree/formatting.py index 9ebee72d4ef..fdd23933ae6 100644 --- a/xarray/datatree_/datatree/formatting.py +++ b/xarray/datatree_/datatree/formatting.py @@ -2,7 +2,7 @@ from xarray.core.formatting import _compat_to_str, diff_dataset_repr -from xarray.datatree_.datatree.mapping import diff_treestructure +from xarray.core.datatree_mapping import diff_treestructure from xarray.datatree_.datatree.render import RenderTree if TYPE_CHECKING: diff --git a/xarray/datatree_/datatree/ops.py b/xarray/datatree_/datatree/ops.py index d6ac4f83e7c..83b9d1b275a 100644 --- a/xarray/datatree_/datatree/ops.py +++ b/xarray/datatree_/datatree/ops.py @@ -2,7 +2,7 @@ from xarray import Dataset -from .mapping import map_over_subtree +from xarray.core.datatree_mapping import map_over_subtree """ Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree. diff --git a/xarray/datatree_/datatree/tests/test_mapping.py b/xarray/tests/test_datatree_mapping.py similarity index 98% rename from xarray/datatree_/datatree/tests/test_mapping.py rename to xarray/tests/test_datatree_mapping.py index c6cd04887c0..16ca726759d 100644 --- a/xarray/datatree_/datatree/tests/test_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -1,9 +1,13 @@ import numpy as np import pytest -import xarray as xr +import xarray as xr from xarray.core.datatree import DataTree -from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree +from xarray.core.datatree_mapping import ( + TreeIsomorphismError, + check_isomorphic, + map_over_subtree, +) from xarray.datatree_.datatree.testing import assert_equal empty = xr.Dataset() @@ -12,7 +16,7 @@ class TestCheckTreesIsomorphic: def test_not_a_tree(self): with pytest.raises(TypeError, match="not a tree"): - check_isomorphic("s", 1) + check_isomorphic("s", 1) # type: ignore[arg-type] def test_different_widths(self): dt1 = DataTree.from_dict(d={"a": empty}) @@ -69,7 +73,7 @@ def test_not_isomorphic_complex_tree(self, create_test_datatree): def test_checking_from_root(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() - real_root = DataTree(name="real root") + real_root: DataTree = DataTree(name="real root") dt2.name = "not_real_root" dt2.parent = real_root with pytest.raises(TreeIsomorphismError):