diff --git a/chunking_test.py b/chunking_test.py new file mode 100644 index 000000000..43491f677 --- /dev/null +++ b/chunking_test.py @@ -0,0 +1,25 @@ +import json +import os + +import zarr + +store = zarr.DirectoryStore("data/chunking_test.zarr") +z = zarr.zeros((20, 3), chunks=(3, 2), shards=(2, 2), store=store, overwrite=True, compressor=None) +z[:10, :] = 42 +z[15, 1] = 389 +z[19, 2] = 1 +z[0, 1] = -4.2 + +print(store[".zarray"].decode()) +print("ONDISK", sorted(os.listdir("data/chunking_test.zarr"))) +assert json.loads(store[".zarray"].decode()) ["shards"] == [2, 2] + +print("STORE", sorted(store)) +print("CHUNKSTORE (SHARDED)", sorted(z.chunk_store)) + +z_reopened = zarr.open("data/chunking_test.zarr") +assert z_reopened.shards == (2, 2) +assert z_reopened[15, 1] == 389 +assert z_reopened[19, 2] == 1 +assert z_reopened[0, 1] == -4.2 +assert z_reopened[0, 0] == 42 diff --git a/zarr/_storage/sharded_store.py b/zarr/_storage/sharded_store.py new file mode 100644 index 000000000..2857e738c --- /dev/null +++ b/zarr/_storage/sharded_store.py @@ -0,0 +1,152 @@ +from collections import defaultdict +from functools import reduce +import math +from typing import Any, Dict, Iterable, Iterator, List, Tuple, Union + +import numpy as np + +from zarr._storage.store import BaseStore, Store +from zarr.storage import StoreLike, array_meta_key, attrs_key, group_meta_key + + +def _cum_prod(x: Iterable[int]) -> Iterable[int]: + prod = 1 + yield prod + for i in x[:-1]: + prod *= i + yield prod + + +class MortonOrderShardedStore(Store): + """This class should not be used directly, + but is added to an Array as a wrapper when needed automatically.""" + + def __init__( + self, + store: StoreLike, + shards: Tuple[int, ...], + dimension_separator: str, + are_chunks_compressed: bool, + dtype: np.dtype, + fill_value: Any, + chunk_size: int, + ) -> None: + self._store: BaseStore = BaseStore._ensure_store(store) + self._shards = shards + self._num_chunks_per_shard = reduce(lambda x, y: x*y, shards, 1) + self._dimension_separator = dimension_separator + + chunk_has_constant_size = not are_chunks_compressed and not dtype == object + assert chunk_has_constant_size, "Currently only uncompressed, fixed-length data can be used." + self._chunk_has_constant_size = chunk_has_constant_size + if chunk_has_constant_size: + binary_fill_value = np.full(1, fill_value=fill_value or 0, dtype=dtype).tobytes() + self._fill_chunk = binary_fill_value * chunk_size + self._emtpy_meta = b"\x00" * math.ceil(self._num_chunks_per_shard / 8) + + # unused when using Morton order + self._shard_strides = tuple(_cum_prod(shards)) + + # TODO: add warnings for ineffective reads/writes: + # * warn if partial reads are not available + # * optionally warn on unaligned writes if no partial writes are available + + def __get_meta__(self, shard_content: Union[bytes, bytearray]) -> int: + return int.from_bytes(shard_content[-len(self._emtpy_meta):], byteorder="big") + + def __set_meta__(self, shard_content: bytearray, meta: int) -> None: + shard_content[-len(self._emtpy_meta):] = meta.to_bytes(len(self._emtpy_meta), byteorder="big") + + # The following two methods define the order of the chunks in a shard + # TODO use morton order + def __chunk_key_to_shard_key_and_index__(self, chunk_key: str) -> Tuple[str, int]: + # TODO: allow to be in a group (aka only use last parts for dimensions) + chunk_subkeys = map(int, chunk_key.split(self._dimension_separator)) + + shard_tuple, index_tuple = zip(*((subkey // shard_i, subkey % shard_i) for subkey, shard_i in zip(chunk_subkeys, self._shards))) + shard_key = self._dimension_separator.join(map(str, shard_tuple)) + index = sum(i * j for i, j in zip(index_tuple, self._shard_strides)) + return shard_key, index + + def __shard_key_and_index_to_chunk_key__(self, shard_key_tuple: Tuple[int, ...], shard_index: int) -> str: + offset = tuple(shard_index % s2 // s1 for s1, s2 in zip(self._shard_strides, self._shard_strides[1:] + (self._num_chunks_per_shard,))) + original_key = (shard_key_i * shards_i + offset_i for shard_key_i, offset_i, shards_i in zip(shard_key_tuple, offset, self._shards)) + return self._dimension_separator.join(map(str, original_key)) + + def __keys_to_shard_groups__(self, keys: Iterable[str]) -> Dict[str, List[Tuple[str, str]]]: + shard_indices_per_shard_key = defaultdict(list) + for chunk_key in keys: + shard_key, shard_index = self.__chunk_key_to_shard_key_and_index__(chunk_key) + shard_indices_per_shard_key[shard_key].append((shard_index, chunk_key)) + return shard_indices_per_shard_key + + def __get_chunk_slice__(self, shard_index: int) -> Tuple[int, int]: + start = shard_index * len(self._fill_chunk) + return slice(start, start + len(self._fill_chunk)) + + def __getitem__(self, key: str) -> bytes: + return self.getitems([key])[key] + + def getitems(self, keys: Iterable[str], **kwargs) -> Dict[str, bytes]: + result = {} + for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(keys).items(): + # TODO use partial reads if available + full_shard_value = self._store[shard_key] + # TODO omit items if they don't exist + for shard_index, chunk_key in chunks_in_shard: + result[chunk_key] = full_shard_value[self.__get_chunk_slice__(shard_index)] + return result + + def __setitem__(self, key: str, value: bytes) -> None: + self.setitems({key: value}) + + def setitems(self, values: Dict[str, bytes]) -> None: + for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(values.keys()).items(): + if len(chunks_in_shard) == self._num_chunks_per_shard: + # TODO shards at a non-dataset-size aligned surface are not captured here yet + full_shard_value = b"".join( + values[chunk_key] for _, chunk_key in sorted(chunks_in_shard) + ) + b"\xff" * len(self._emtpy_meta) + self._store[shard_key] = full_shard_value + else: + # TODO use partial writes if available + try: + full_shard_value = bytearray(self._store[shard_key]) + except KeyError: + full_shard_value = bytearray(self._fill_chunk * self._num_chunks_per_shard + self._emtpy_meta) + chunk_mask = self.__get_meta__(full_shard_value) + for shard_index, chunk_key in chunks_in_shard: + chunk_mask |= 1 << shard_index + full_shard_value[self.__get_chunk_slice__(shard_index)] = values[chunk_key] + self.__set_meta__(full_shard_value, chunk_mask) + self._store[shard_key] = full_shard_value + + def __delitem__(self, key) -> None: + # TODO not implemented yet, also delitems + # Deleting the "last" chunk in a shard needs to remove the whole shard + raise NotImplementedError("Deletion is not yet implemented") + + def __iter__(self) -> Iterator[str]: + for shard_key in self._store.__iter__(): + if any(shard_key.endswith(i) for i in (array_meta_key, group_meta_key, attrs_key)): + # Special keys such as ".zarray" are passed on as-is + yield shard_key + else: + # For each shard key in the wrapped store, all corresponding chunks are yielded. + # TODO: allow to be in a group (aka only use last parts for dimensions) + shard_key_tuple = tuple(map(int, shard_key.split(self._dimension_separator))) + mask = self.__get_meta__(self._store[shard_key]) + for i in range(self._num_chunks_per_shard): + if mask == 0: + break + if mask & 1: + yield self.__shard_key_and_index_to_chunk_key__(shard_key_tuple, i) + mask >>= 1 + + def __len__(self) -> int: + return sum(1 for _ in self.keys()) + + +SHARDED_STORES = { + "morton_order": MortonOrderShardedStore, +} diff --git a/zarr/_storage/store.py b/zarr/_storage/store.py index 6f5bf78e2..6714e729f 100644 --- a/zarr/_storage/store.py +++ b/zarr/_storage/store.py @@ -110,6 +110,7 @@ def _ensure_store(store: Any): class Store(BaseStore): + # TODO: document methods which allow optimizations, e.g. delitems, setitems, getitems, listdir, … """Abstract store class used by implementations following the Zarr v2 spec. Adds public `listdir`, `rename`, and `rmdir` methods on top of BaseStore. diff --git a/zarr/core.py b/zarr/core.py index d36613942..2c5505079 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -5,11 +5,13 @@ import operator import re from functools import reduce +from typing import Optional, Tuple import numpy as np from numcodecs.compat import ensure_bytes, ensure_ndarray from collections.abc import MutableMapping +from zarr._storage.sharded_store import SHARDED_STORES from zarr.attrs import Attributes from zarr.codecs import AsType, get_codec @@ -191,6 +193,9 @@ def __init__( self._oindex = OIndex(self) self._vindex = VIndex(self) + # the sharded store is only initialized when needed + self._cached_sharded_store = None + def _load_metadata(self): """(Re)load metadata from store.""" if self._synchronizer is None: @@ -213,6 +218,8 @@ def _load_metadata_nosync(self): self._meta = meta self._shape = meta['shape'] self._chunks = meta['chunks'] + self._shards = meta.get('shards') + self._shard_format = meta.get('shard_format') self._dtype = meta['dtype'] self._fill_value = meta['fill_value'] self._order = meta['order'] @@ -262,9 +269,12 @@ def _flush_metadata_nosync(self): filters_config = [f.get_config() for f in self._filters] else: filters_config = None + # Possible (unrelated) bug: + # should the dimension_separator also be included in this dict? meta = dict(shape=self._shape, chunks=self._chunks, dtype=self._dtype, compressor=compressor_config, fill_value=self._fill_value, - order=self._order, filters=filters_config) + order=self._order, filters=filters_config, + shards=self._shards, shard_format=self._shard_format) mkey = self._key_prefix + array_meta_key self._store[mkey] = self._store._metadata_class.encode_array_metadata(meta) @@ -309,9 +319,23 @@ def read_only(self, value): def chunk_store(self): """A MutableMapping providing the underlying storage for array chunks.""" if self._chunk_store is None: - return self._store + chunk_store = self._store + else: + chunk_store = self._chunk_store + if self._shards is None: + return chunk_store else: - return self._chunk_store + if self._cached_sharded_store is None: + self._cached_sharded_store = SHARDED_STORES[self._shard_format]( + chunk_store, + shards=self._shards, + dimension_separator=self._dimension_separator, + are_chunks_compressed=self._compressor is not None, + dtype=self._dtype, + fill_value=self._fill_value or 0, + chunk_size=reduce(operator.mul, self._chunks, 1), + ) + return self._cached_sharded_store @property def shape(self): @@ -327,11 +351,17 @@ def shape(self, value): self.resize(value) @property - def chunks(self): + def chunks(self) -> Optional[Tuple[int, ...]]: """A tuple of integers describing the length of each dimension of a - chunk of the array.""" + chunk of the array, or None.""" return self._chunks + @property + def shards(self): + """A tuple of integers describing the number of chunks in each shard + of the array.""" + return self._shards + @property def dtype(self): """The NumPy data type.""" @@ -1703,7 +1733,7 @@ def _set_selection(self, indexer, value, fields=None): check_array_shape('value', value, sel_shape) # iterate over chunks in range - if not hasattr(self.store, "setitems") or self._synchronizer is not None \ + if not hasattr(self.chunk_store, "setitems") or self._synchronizer is not None \ or any(map(lambda x: x == 0, self.shape)): # iterative approach for chunk_coords, chunk_selection, out_selection in indexer: @@ -1899,7 +1929,7 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection, and hasattr(self._compressor, "decode_partial") and not fields and self.dtype != object - and hasattr(self.chunk_store, "getitems") + and hasattr(self.chunk_store, "getitems") # TODO: this should rather check for read_block or similar ): partial_read_decode = True cdatas = { @@ -1946,8 +1976,8 @@ def _chunk_setitems(self, lchunk_coords, lchunk_selection, values, fields=None): self.chunk_store.setitems(to_store) def _chunk_delitems(self, ckeys): - if hasattr(self.store, "delitems"): - self.store.delitems(ckeys) + if hasattr(self.chunk_store, "delitems"): + self.chunk_store.delitems(ckeys) else: # pragma: no cover # exempting this branch from coverage as there are no extant stores # that will trigger this condition, but it's possible that they @@ -2236,6 +2266,7 @@ def digest(self, hashname="sha1"): h = hashlib.new(hashname) + # TODO: operate on shards here if available: for i in itertools.product(*[range(s) for s in self.cdata_shape]): h.update(self.chunk_store.get(self._chunk_key(i), b"")) @@ -2362,6 +2393,7 @@ def _resize_nosync(self, *args): except KeyError: # chunk not initialized pass + # TODO: collect all chunks do delete and use _chunk_delitems def append(self, data, axis=0): """Append `data` to `axis`. diff --git a/zarr/creation.py b/zarr/creation.py index 244a9b080..d31860164 100644 --- a/zarr/creation.py +++ b/zarr/creation.py @@ -1,3 +1,4 @@ +from typing import Optional, Tuple, Union from warnings import warn import numpy as np @@ -19,7 +20,8 @@ def create(shape, chunks=True, dtype=None, compressor='default', fill_value=0, order='C', store=None, synchronizer=None, overwrite=False, path=None, chunk_store=None, filters=None, cache_metadata=True, cache_attrs=True, read_only=False, - object_codec=None, dimension_separator=None, write_empty_chunks=True, **kwargs): + object_codec=None, dimension_separator=None, write_empty_chunks=True, + shards: Union[int, Tuple[int, ...], None]=None, shard_format: str="morton_order", **kwargs): """Create an array. Parameters @@ -145,7 +147,7 @@ def create(shape, chunks=True, dtype=None, compressor='default', init_array(store, shape=shape, chunks=chunks, dtype=dtype, compressor=compressor, fill_value=fill_value, order=order, overwrite=overwrite, path=path, chunk_store=chunk_store, filters=filters, object_codec=object_codec, - dimension_separator=dimension_separator) + dimension_separator=dimension_separator, shards=shards, shard_format=shard_format) # instantiate array z = Array(store, path=path, chunk_store=chunk_store, synchronizer=synchronizer, diff --git a/zarr/meta.py b/zarr/meta.py index c292b09a1..d63be624d 100644 --- a/zarr/meta.py +++ b/zarr/meta.py @@ -51,6 +51,8 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A object_codec = None dimension_separator = meta.get("dimension_separator", None) + shards = meta.get("shards", None) + shard_format = meta.get("shard_format", None) fill_value = cls.decode_fill_value(meta['fill_value'], dtype, object_codec) meta = dict( zarr_format=meta["zarr_format"], @@ -64,6 +66,10 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A ) if dimension_separator: meta['dimension_separator'] = dimension_separator + if shards: + meta['shards'] = tuple(shards) + assert shard_format is not None + meta['shard_format'] = shard_format except Exception as e: raise MetadataError("error decoding metadata") from e else: @@ -77,6 +83,8 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes: dtype, sdshape = dtype.subdtype dimension_separator = meta.get("dimension_separator") + shards = meta.get("shards") + shard_format = meta.get("shard_format") if dtype.hasobject: import numcodecs object_codec = numcodecs.get_codec(meta['filters'][0]) @@ -95,9 +103,10 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes: ) if dimension_separator: meta['dimension_separator'] = dimension_separator - - if dimension_separator: - meta["dimension_separator"] = dimension_separator + if shards: + meta['shards'] = shards + assert shard_format is not None + meta['shard_format'] = shard_format return json_dumps(meta) diff --git a/zarr/storage.py b/zarr/storage.py index 7170eeaf2..19709cc11 100644 --- a/zarr/storage.py +++ b/zarr/storage.py @@ -54,7 +54,7 @@ from zarr.util import (buffer_size, json_loads, nolock, normalize_chunks, normalize_dimension_separator, normalize_dtype, normalize_fill_value, normalize_order, - normalize_shape, normalize_storage_path, retry_call) + normalize_shape, normalize_shards, normalize_storage_path, retry_call) from zarr._storage.absstore import ABSStore # noqa: F401 from zarr._storage.store import (_listdir_from_keys, @@ -236,6 +236,8 @@ def init_array( filters=None, object_codec=None, dimension_separator=None, + shards: Union[int, Tuple[int, ...], None]=None, + shard_format: Optional[str]=None, ): """Initialize an array store with the given configuration. Note that this is a low-level function and there should be no need to call this directly from user code. @@ -353,7 +355,8 @@ def init_array( order=order, overwrite=overwrite, path=path, chunk_store=chunk_store, filters=filters, object_codec=object_codec, - dimension_separator=dimension_separator) + dimension_separator=dimension_separator, + shards=shards, shard_format=shard_format) def _init_array_metadata( @@ -370,6 +373,8 @@ def _init_array_metadata( filters=None, object_codec=None, dimension_separator=None, + shards:Union[int, Tuple[int, ...], None] = None, + shard_format: Optional[str]=None, ): # guard conditions @@ -388,6 +393,8 @@ def _init_array_metadata( shape = normalize_shape(shape) + dtype.shape dtype = dtype.base chunks = normalize_chunks(chunks, shape, dtype.itemsize) + shards = normalize_shards(shards, shape) + shard_format = shard_format or "morton_order" order = normalize_order(order) fill_value = normalize_fill_value(fill_value, dtype) @@ -445,6 +452,9 @@ def _init_array_metadata( compressor=compressor_config, fill_value=fill_value, order=order, filters=filters_config, dimension_separator=dimension_separator) + if shards is not None: + meta["shards"] = shards + meta["shard_format"] = shard_format key = _path_to_prefix(path) + array_meta_key if hasattr(store, '_metadata_class'): store[key] = store._metadata_class.encode_array_metadata(meta) # type: ignore diff --git a/zarr/util.py b/zarr/util.py index d092ffe0d..220e49cbd 100644 --- a/zarr/util.py +++ b/zarr/util.py @@ -149,6 +149,38 @@ def normalize_chunks( return tuple(chunks) +def normalize_shards( + shards: Optional[Tuple[Optional[int], ...]], shape: Tuple[int, ...], +) -> Optional[Tuple[int, ...]]: + """Convenience function to normalize the `shards` argument for an array + with the given `shape`.""" + + # N.B., expect shape already normalized + + if shards is None: + return None + + # handle 1D convenience form + if isinstance(shards, numbers.Integral): + shards = tuple(int(shards) for _ in shape) + + # handle bad dimensionality + if len(shards) > len(shape): + raise ValueError('too many dimensions in shards') + + # handle underspecified shards + if len(shards) < len(shape): + # assume single shards across remaining dimensions + shards += (1, ) * len(shape) - len(shards) + + # handle None or -1 in shards + if -1 in shards or None in shards: + shards = tuple(s if c == -1 or c is None else int(c) + for s, c in zip(shape, shards)) + + return tuple(shards) + + def normalize_dtype(dtype: Union[str, np.dtype], object_codec) -> Tuple[np.dtype, Any]: # convenience API for object arrays @@ -560,6 +592,7 @@ def __init__(self, store_key, chunk_store): # is it fsstore or an actual fsspec map object assert hasattr(self.chunk_store, "map") self.map = self.chunk_store.map + # TODO maybe use partial_read here also self.fs = self.chunk_store.fs self.store_key = store_key self.buff = None