Skip to content

Commit

Permalink
(wip): support for new zarr version
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Oct 21, 2024
1 parent 3260222 commit 0840150
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ test = [
"loompy>=3.0.5",
"pytest>=8.2",
"pytest-cov>=2.10",
"zarr<3.0.0a0",
"zarr>=3.0.0b0",
"matplotlib",
"scikit-learn",
"openpyxl",
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def append(self, sparse_matrix: ss.csr_matrix | ss.csc_matrix | SpArray) -> None
f"Matrices must have same format. Currently are "
f"{self.format!r} and {sparse_matrix.format!r}"
)
indptr_offset = len(self.group["indices"])
[indptr_offset] = self.group["indices"].shape
if self.group["indptr"].dtype == np.int32:
new_nnz = indptr_offset + len(sparse_matrix.indices)
if new_nnz >= np.iinfo(np.int32).max:
Expand Down
92 changes: 65 additions & 27 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def read_basic(
# Backwards compat sparse arrays
if "h5sparse_format" in elem.attrs:
return sparse_dataset(elem).to_memory()
return {k: _reader.read_elem(v) for k, v in elem.items()}
return {k: _reader.read_elem(v) for k, v in dict(elem).items()}
elif isinstance(elem, h5py.Dataset):
return h5ad.read_dataset(elem) # TODO: Handle legacy

Expand All @@ -161,7 +161,7 @@ def read_basic_zarr(
# Backwards compat sparse arrays
if "h5sparse_format" in elem.attrs:
return sparse_dataset(elem).to_memory()
return {k: _reader.read_elem(v) for k, v in elem.items()}
return {k: _reader.read_elem(v) for k, v in dict(elem).items()}

Check warning on line 164 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L164

Added line #L164 was not covered by tests
elif isinstance(elem, ZarrArray):
return zarr.read_dataset(elem) # TODO: Handle legacy

Expand Down Expand Up @@ -334,7 +334,7 @@ def write_raw(
@_REGISTRY.register_read(H5Group, IOSpec("dict", "0.1.0"))
@_REGISTRY.register_read(ZarrGroup, IOSpec("dict", "0.1.0"))
def read_mapping(elem: GroupStorageType, *, _reader: Reader) -> dict[str, AxisStorable]:
return {k: _reader.read_elem(v) for k, v in elem.items()}
return {k: _reader.read_elem(v) for k, v in dict(elem).items()}


@_REGISTRY.register_write(H5Group, dict, IOSpec("dict", "0.1.0"))
Expand Down Expand Up @@ -390,7 +390,7 @@ def write_basic(
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
"""Write methods which underlying library handles natively."""
f.create_dataset(k, data=elem, **dataset_kwargs)
f.create_dataset(k, data=elem, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)


_REGISTRY.register_write(H5Group, CupyArray, IOSpec("array", "0.2.0"))(
Expand All @@ -411,8 +411,12 @@ def write_basic_dask_zarr(
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
import dask.array as da
import zarr

g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
if Version(zarr.__version__) >= Version("3.0.0b0"):
g = f.require_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
else:
g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)

Check warning on line 419 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L419

Added line #L419 was not covered by tests
da.store(elem, g, lock=GLOBAL_LOCK)


Expand Down Expand Up @@ -505,23 +509,37 @@ def write_vlen_string_array_zarr(
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
import numcodecs

if Version(numcodecs.__version__) < Version("0.13"):
msg = "Old numcodecs version detected. Please update for improved performance and stability."
warnings.warn(msg)
# Workaround for https://github.com/zarr-developers/numcodecs/issues/514
if hasattr(elem, "flags") and not elem.flags.writeable:
elem = elem.copy()

f.create_dataset(
k,
shape=elem.shape,
dtype=object,
object_codec=numcodecs.VLenUTF8(),
**dataset_kwargs,
)
f[k][:] = elem
import zarr

if Version(zarr.__version__) < Version("3.0.0b0"):
import numcodecs

Check warning on line 515 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L515

Added line #L515 was not covered by tests

if Version(numcodecs.__version__) < Version("0.13"):
msg = "Old numcodecs version detected. Please update for improved performance and stability."
warnings.warn(msg)

Check warning on line 519 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L517-L519

Added lines #L517 - L519 were not covered by tests
# Workaround for https://github.com/zarr-developers/numcodecs/issues/514
if hasattr(elem, "flags") and not elem.flags.writeable:
elem = elem.copy()

Check warning on line 522 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L521-L522

Added lines #L521 - L522 were not covered by tests

f.create_dataset(

Check warning on line 524 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L524

Added line #L524 was not covered by tests
k,
shape=elem.shape,
dtype=object,
object_codec=numcodecs.VLenUTF8(),
**dataset_kwargs,
)
f[k][:] = elem

Check warning on line 531 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L531

Added line #L531 was not covered by tests
else:
from zarr.codecs import VLenUTF8Codec

f.create_array(
k,
shape=elem.shape,
dtype=str,
codecs=[VLenUTF8Codec()],
**dataset_kwargs,
)
f[k][:] = elem


###############
Expand Down Expand Up @@ -576,7 +594,9 @@ def write_recarray_zarr(
):
from anndata.compat import _to_fixed_length_strings

f.create_dataset(k, data=_to_fixed_length_strings(elem), **dataset_kwargs)
f.create_dataset(
k, data=_to_fixed_length_strings(elem), shape=elem.shape, **dataset_kwargs
)


#################
Expand All @@ -602,9 +622,27 @@ def write_sparse_compressed(
if isinstance(f, H5Group) and "maxshape" not in dataset_kwargs:
dataset_kwargs = dict(maxshape=(None,), **dataset_kwargs)

g.create_dataset("data", data=value.data, **dataset_kwargs)
g.create_dataset("indices", data=value.indices, **dataset_kwargs)
g.create_dataset("indptr", data=value.indptr, dtype=indptr_dtype, **dataset_kwargs)
g.create_dataset(
"data",
data=value.data,
shape=value.data.shape,
dtype=value.data.dtype,
**dataset_kwargs,
)
g.create_dataset(
"indices",
data=value.indices,
shape=value.indices.shape,
dtype=value.indices.dtype,
**dataset_kwargs,
)
g.create_dataset(
"indptr",
data=value.indptr,
shape=value.indptr.shape,
dtype=indptr_dtype,
**dataset_kwargs,
)


write_csr = partial(write_sparse_compressed, fmt="csr")
Expand Down Expand Up @@ -1117,7 +1155,7 @@ def write_scalar(
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
return f.create_dataset(key, data=np.array(value), **dataset_kwargs)
return f.create_dataset(key, data=np.array(value), shape=(), **dataset_kwargs)


def write_hdf5_scalar(
Expand Down
17 changes: 13 additions & 4 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Generic, TypeVar

from packaging.version import Version

from anndata._io.utils import report_read_key_on_error, report_write_key_on_error
from anndata._types import Read, ReadDask, _ReadDaskInternal, _ReadInternal
from anndata.compat import DaskArray, _read_attr
from anndata.compat import DaskArray, ZarrGroup, _read_attr

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Iterable
Expand Down Expand Up @@ -341,11 +343,18 @@ def write_elem(
return lambda *_, **__: None

# Normalize k to absolute path
if not PurePosixPath(k).is_absolute():
k = str(PurePosixPath(store.name) / k)
if isinstance(store, ZarrGroup):
import zarr

if Version(zarr.__version__) < Version("3.0.0b0"):
if not PurePosixPath(k).is_absolute():
k = str(PurePosixPath(store.name) / k)

Check warning on line 351 in src/anndata/_io/specs/registry.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/registry.py#L350-L351

Added lines #L350 - L351 were not covered by tests

if k == "/":
store.clear()
if isinstance(store, ZarrGroup):
store.store.clear()
else:
store.clear()
elif k in store:
del store[k]

Expand Down
4 changes: 2 additions & 2 deletions src/anndata/_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def report_read_key_on_error(func):
>>> @report_read_key_on_error
... def read_arr(group):
... raise NotImplementedError()
>>> z = zarr.open("tmp.zarr")
>>> z = zarr.open("tmp.zarr", mode="w")
>>> z["X"] = [1, 2, 3]
>>> read_arr(z["X"]) # doctest: +SKIP
"""
Expand Down Expand Up @@ -228,7 +228,7 @@ def report_write_key_on_error(func):
>>> @report_write_key_on_error
... def write_arr(group, key, val):
... raise NotImplementedError()
>>> z = zarr.open("tmp.zarr")
>>> z = zarr.open("tmp.zarr", mode="w")
>>> X = [1, 2, 3]
>>> write_arr(z, "X", X) # doctest: +SKIP
"""
Expand Down
4 changes: 2 additions & 2 deletions src/anndata/_io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def write_zarr(
f.attrs.setdefault("encoding-version", "0.1.0")

def callback(func, s, k, elem, dataset_kwargs, iospec):
if chunks is not None and not isinstance(elem, sparse.spmatrix) and k == "/X":
if chunks is not None and not isinstance(elem, sparse.spmatrix) and k == "X":
dataset_kwargs = dict(dataset_kwargs, chunks=chunks)
func(s, k, elem, dataset_kwargs=dataset_kwargs)

Expand Down Expand Up @@ -73,7 +73,7 @@ def callback(func, elem_name: str, elem, iospec):
return AnnData(
**{
k: read_dispatched(v, callback)
for k, v in elem.items()
for k, v in dict(elem).items()
if not k.startswith("raw.")
}
)
Expand Down
4 changes: 2 additions & 2 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __exit__(self, *_exc_info) -> None:
#############################

if find_spec("zarr") or TYPE_CHECKING:
from zarr.core import Array as ZarrArray
from zarr.hierarchy import Group as ZarrGroup
from zarr import Array as ZarrArray
from zarr import Group as ZarrGroup
else:

class ZarrArray:
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/experimental/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def callback(func, elem_name: str, elem, iospec):
elif iospec.encoding_type == "array":
return elem
elif iospec.encoding_type == "dict":
return {k: read_as_backed(v) for k, v in elem.items()}
return {k: read_as_backed(v) for k, v in dict(elem).items()}
else:
return func(elem)

Expand Down
6 changes: 3 additions & 3 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,17 +1040,17 @@ def shares_memory_sparse(x, y):
]

if find_spec("zarr") or TYPE_CHECKING:
from zarr import DirectoryStore
from zarr.storage import LocalStore
else:

class DirectoryStore:
class LocalStore:

Check warning on line 1046 in src/anndata/tests/helpers.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/tests/helpers.py#L1046

Added line #L1046 was not covered by tests
def __init__(self, *_args, **_kwargs) -> None:
cls_name = type(self).__name__
msg = f"zarr must be imported to create a {cls_name} instance."
raise ImportError(msg)


class AccessTrackingStore(DirectoryStore):
class AccessTrackingStore(LocalStore):
_access_count: Counter[str]
_accessed_keys: dict[str, list[str]]

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def tokenize_anndata(adata: ad.AnnData):
res.extend([tokenize(adata.obs), tokenize(adata.var)])
for attr in ["obsm", "varm", "obsp", "varp", "layers"]:
elem = getattr(adata, attr)
res.append(tokenize(list(elem.items())))
res.append(tokenize(list(dict(elem).items())))
res.append(joblib.hash(adata.uns))
if adata.raw is not None:
res.append(tokenize(adata.raw.to_adata()))
Expand Down
5 changes: 4 additions & 1 deletion tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def read_zarr_backed(path):
def callback(func, elem_name, elem, iospec):
if iospec.encoding_type == "anndata" or elem_name.endswith("/"):
return AnnData(
**{k: read_dispatched(v, callback) for k, v in elem.items()}
**{
k: read_dispatched(v, callback)
for k, v in dict(elem).items()
}
)
if iospec.encoding_type in {"csc_matrix", "csr_matrix"}:
return sparse_dataset(elem)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_io_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def test_read_zarr_from_group(tmp_path, consolidated):
write_elem(z, "table/table", adata)

if consolidated:
zarr.convenience.consolidate_metadata(z.store)
zarr.consolidate_metadata(z.store)

if consolidated:
read_func = zarr.open_consolidated
Expand Down

0 comments on commit 0840150

Please sign in to comment.