Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow opening selected groups only #338

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions datatree/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _get_nc_dataset_class(engine):
return Dataset


def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree:
def open_datatree(filename_or_obj, engine=None, groups=None, **kwargs) -> DataTree:
"""
Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file.

Expand All @@ -44,6 +44,8 @@ def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree:
Strings and Path objects are interpreted as a path to a netCDF file or Zarr store.
engine : str, optional
Xarray backend engine to us. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`.
groups : sequence of str, optional
The sequence of groups to read from the file. By default groups is None, which means reading all the groups.
kwargs :
Additional keyword arguments passed to ``xarray.open_dataset`` for each group.

Expand All @@ -53,20 +55,24 @@ def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree:
"""

if engine == "zarr":
return _open_datatree_zarr(filename_or_obj, **kwargs)
return _open_datatree_zarr(filename_or_obj, groups=groups, **kwargs)
elif engine in [None, "netcdf4", "h5netcdf"]:
return _open_datatree_netcdf(filename_or_obj, engine=engine, **kwargs)
return _open_datatree_netcdf(
filename_or_obj, engine=engine, groups=groups, **kwargs
)
else:
raise ValueError("Unsupported engine")


def _open_datatree_netcdf(filename: str, **kwargs) -> DataTree:
def _open_datatree_netcdf(filename: str, groups=None, **kwargs) -> DataTree:
ncDataset = _get_nc_dataset_class(kwargs.get("engine", None))

ds = open_dataset(filename, **kwargs)
tree_root = DataTree.from_dict({"/": ds})
with ncDataset(filename, mode="r") as ncds:
for path in _iter_nc_groups(ncds):
if groups is None:
groups = _iter_nc_groups(ncds)
for path in groups:
subgroup_ds = open_dataset(filename, group=path, **kwargs)

# TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again
Expand All @@ -81,13 +87,15 @@ def _open_datatree_netcdf(filename: str, **kwargs) -> DataTree:
return tree_root


def _open_datatree_zarr(store, **kwargs) -> DataTree:
def _open_datatree_zarr(store, groups=None, **kwargs) -> DataTree:
import zarr # type: ignore

zds = zarr.open_group(store, mode="r")
ds = open_dataset(store, engine="zarr", **kwargs)
tree_root = DataTree.from_dict({"/": ds})
for path in _iter_zarr_groups(zds):
if groups is None:
groups = _iter_zarr_groups(zds)
for path in groups:
try:
subgroup_ds = open_dataset(store, engine="zarr", group=path, **kwargs)
except zarr.errors.PathNotFoundError:
Expand Down
40 changes: 40 additions & 0 deletions datatree/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import zarr.errors

from datatree.datatree import DataTree
from datatree.io import open_datatree
from datatree.testing import assert_equal
from datatree.tests import requires_h5netcdf, requires_netCDF4, requires_zarr
Expand Down Expand Up @@ -39,6 +40,19 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_netcdf(filepath, encoding=enc, engine="netcdf4")

@requires_netCDF4
def test_to_netcdf_selected_groups(self, tmpdir, simple_datatree):
filepath = str(
tmpdir / "test.nc"
) # casting to str avoids a pathlib bug in xarray
original_dt = simple_datatree
original_dt.to_netcdf(filepath, engine="netcdf4")

roundtrip_dt = open_datatree(filepath, groups=["set1"])
assert {
k for k in roundtrip_dt.keys() if isinstance(roundtrip_dt[k], DataTree)
} == {"set1"}

@requires_h5netcdf
def test_to_h5netcdf(self, tmpdir, simple_datatree):
filepath = str(
Expand All @@ -50,6 +64,19 @@ def test_to_h5netcdf(self, tmpdir, simple_datatree):
roundtrip_dt = open_datatree(filepath)
assert_equal(original_dt, roundtrip_dt)

@requires_h5netcdf
def test_to_h5netcdf_selected_groups(self, tmpdir, simple_datatree):
filepath = str(
tmpdir / "test.nc"
) # casting to str avoids a pathlib bug in xarray
original_dt = simple_datatree
original_dt.to_netcdf(filepath, engine="h5netcdf")

roundtrip_dt = open_datatree(filepath, groups=["set1"])
assert {
k for k in roundtrip_dt.keys() if isinstance(roundtrip_dt[k], DataTree)
} == {"set1"}

@requires_zarr
def test_to_zarr(self, tmpdir, simple_datatree):
filepath = str(
Expand All @@ -61,6 +88,19 @@ def test_to_zarr(self, tmpdir, simple_datatree):
roundtrip_dt = open_datatree(filepath, engine="zarr")
assert_equal(original_dt, roundtrip_dt)

@requires_zarr
def test_to_zarr_selected_groups(self, tmpdir, simple_datatree):
filepath = str(
tmpdir / "test.zarr"
) # casting to str avoids a pathlib bug in xarray
original_dt = simple_datatree
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr", groups=["set1"])
assert {
k for k in roundtrip_dt.keys() if isinstance(roundtrip_dt[k], DataTree)
} == {"set1"}

@requires_zarr
def test_zarr_encoding(self, tmpdir, simple_datatree):
import zarr
Expand Down
3 changes: 3 additions & 0 deletions docs/source/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ v0.0.14 (unreleased)
New Features
~~~~~~~~~~~~

- Allow opening datatrees from file for just some selected groups (:pull:`338`)
By `Martin Raspaud <https://github.com/mraspaud>` and `Pouria Khalaj <https://github.com/pkhalaj>`

Breaking changes
~~~~~~~~~~~~~~~~

Expand Down
Loading