Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Sharding Prototype #1

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions chunking_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import json
import os

import zarr

store = zarr.DirectoryStore("data/chunking_test.zarr")
z = zarr.zeros((20, 3), chunks=(3, 2), shards=(2, 2), store=store, overwrite=True, compressor=None)
z[:10, :] = 42
z[15, 1] = 389
z[19, 2] = 1
z[0, 1] = -4.2

print(store[".zarray"].decode())
print("ONDISK", sorted(os.listdir("data/chunking_test.zarr")))
assert json.loads(store[".zarray"].decode()) ["shards"] == [2, 2]

print("STORE", sorted(store))
print("CHUNKSTORE (SHARDED)", sorted(z.chunk_store))

z_reopened = zarr.open("data/chunking_test.zarr")
assert z_reopened.shards == (2, 2)
assert z_reopened[15, 1] == 389
assert z_reopened[19, 2] == 1
assert z_reopened[0, 1] == -4.2
assert z_reopened[0, 0] == 42
152 changes: 152 additions & 0 deletions zarr/_storage/sharded_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from collections import defaultdict
from functools import reduce
import math
from typing import Any, Dict, Iterable, Iterator, List, Tuple, Union

import numpy as np

from zarr._storage.store import BaseStore, Store
from zarr.storage import StoreLike, array_meta_key, attrs_key, group_meta_key


def _cum_prod(x: Iterable[int]) -> Iterable[int]:
prod = 1
yield prod
for i in x[:-1]:
prod *= i
yield prod


class MortonOrderShardedStore(Store):
"""This class should not be used directly,
but is added to an Array as a wrapper when needed automatically."""

def __init__(
self,
store: StoreLike,
shards: Tuple[int, ...],
dimension_separator: str,
are_chunks_compressed: bool,
dtype: np.dtype,
fill_value: Any,
chunk_size: int,
) -> None:
self._store: BaseStore = BaseStore._ensure_store(store)
self._shards = shards
self._num_chunks_per_shard = reduce(lambda x, y: x*y, shards, 1)
self._dimension_separator = dimension_separator

chunk_has_constant_size = not are_chunks_compressed and not dtype == object
assert chunk_has_constant_size, "Currently only uncompressed, fixed-length data can be used."
self._chunk_has_constant_size = chunk_has_constant_size
if chunk_has_constant_size:
binary_fill_value = np.full(1, fill_value=fill_value or 0, dtype=dtype).tobytes()
self._fill_chunk = binary_fill_value * chunk_size
self._emtpy_meta = b"\x00" * math.ceil(self._num_chunks_per_shard / 8)

# unused when using Morton order
self._shard_strides = tuple(_cum_prod(shards))

# TODO: add warnings for ineffective reads/writes:
# * warn if partial reads are not available
# * optionally warn on unaligned writes if no partial writes are available

def __get_meta__(self, shard_content: Union[bytes, bytearray]) -> int:
return int.from_bytes(shard_content[-len(self._emtpy_meta):], byteorder="big")

def __set_meta__(self, shard_content: bytearray, meta: int) -> None:
shard_content[-len(self._emtpy_meta):] = meta.to_bytes(len(self._emtpy_meta), byteorder="big")

# The following two methods define the order of the chunks in a shard
# TODO use morton order
def __chunk_key_to_shard_key_and_index__(self, chunk_key: str) -> Tuple[str, int]:
# TODO: allow to be in a group (aka only use last parts for dimensions)
chunk_subkeys = map(int, chunk_key.split(self._dimension_separator))

shard_tuple, index_tuple = zip(*((subkey // shard_i, subkey % shard_i) for subkey, shard_i in zip(chunk_subkeys, self._shards)))
shard_key = self._dimension_separator.join(map(str, shard_tuple))
index = sum(i * j for i, j in zip(index_tuple, self._shard_strides))
return shard_key, index

def __shard_key_and_index_to_chunk_key__(self, shard_key_tuple: Tuple[int, ...], shard_index: int) -> str:
offset = tuple(shard_index % s2 // s1 for s1, s2 in zip(self._shard_strides, self._shard_strides[1:] + (self._num_chunks_per_shard,)))
original_key = (shard_key_i * shards_i + offset_i for shard_key_i, offset_i, shards_i in zip(shard_key_tuple, offset, self._shards))
return self._dimension_separator.join(map(str, original_key))

def __keys_to_shard_groups__(self, keys: Iterable[str]) -> Dict[str, List[Tuple[str, str]]]:
shard_indices_per_shard_key = defaultdict(list)
for chunk_key in keys:
shard_key, shard_index = self.__chunk_key_to_shard_key_and_index__(chunk_key)
shard_indices_per_shard_key[shard_key].append((shard_index, chunk_key))
return shard_indices_per_shard_key

def __get_chunk_slice__(self, shard_index: int) -> Tuple[int, int]:
start = shard_index * len(self._fill_chunk)
return slice(start, start + len(self._fill_chunk))

def __getitem__(self, key: str) -> bytes:
return self.getitems([key])[key]

def getitems(self, keys: Iterable[str], **kwargs) -> Dict[str, bytes]:
result = {}
for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(keys).items():
# TODO use partial reads if available
full_shard_value = self._store[shard_key]
# TODO omit items if they don't exist
for shard_index, chunk_key in chunks_in_shard:
result[chunk_key] = full_shard_value[self.__get_chunk_slice__(shard_index)]
return result

def __setitem__(self, key: str, value: bytes) -> None:
self.setitems({key: value})

def setitems(self, values: Dict[str, bytes]) -> None:
for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(values.keys()).items():
if len(chunks_in_shard) == self._num_chunks_per_shard:
# TODO shards at a non-dataset-size aligned surface are not captured here yet
full_shard_value = b"".join(
values[chunk_key] for _, chunk_key in sorted(chunks_in_shard)
) + b"\xff" * len(self._emtpy_meta)
self._store[shard_key] = full_shard_value
else:
# TODO use partial writes if available
try:
full_shard_value = bytearray(self._store[shard_key])
except KeyError:
full_shard_value = bytearray(self._fill_chunk * self._num_chunks_per_shard + self._emtpy_meta)
chunk_mask = self.__get_meta__(full_shard_value)
for shard_index, chunk_key in chunks_in_shard:
chunk_mask |= 1 << shard_index
full_shard_value[self.__get_chunk_slice__(shard_index)] = values[chunk_key]
self.__set_meta__(full_shard_value, chunk_mask)
self._store[shard_key] = full_shard_value

def __delitem__(self, key) -> None:
# TODO not implemented yet, also delitems
# Deleting the "last" chunk in a shard needs to remove the whole shard
raise NotImplementedError("Deletion is not yet implemented")

def __iter__(self) -> Iterator[str]:
for shard_key in self._store.__iter__():
if any(shard_key.endswith(i) for i in (array_meta_key, group_meta_key, attrs_key)):
# Special keys such as ".zarray" are passed on as-is
yield shard_key
else:
# For each shard key in the wrapped store, all corresponding chunks are yielded.
# TODO: allow to be in a group (aka only use last parts for dimensions)
shard_key_tuple = tuple(map(int, shard_key.split(self._dimension_separator)))
mask = self.__get_meta__(self._store[shard_key])
for i in range(self._num_chunks_per_shard):
if mask == 0:
break
if mask & 1:
yield self.__shard_key_and_index_to_chunk_key__(shard_key_tuple, i)
mask >>= 1

def __len__(self) -> int:
return sum(1 for _ in self.keys())


SHARDED_STORES = {
"morton_order": MortonOrderShardedStore,
}
1 change: 1 addition & 0 deletions zarr/_storage/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _ensure_store(store: Any):


class Store(BaseStore):
# TODO: document methods which allow optimizations, e.g. delitems, setitems, getitems, listdir, …
"""Abstract store class used by implementations following the Zarr v2 spec.

Adds public `listdir`, `rename`, and `rmdir` methods on top of BaseStore.
Expand Down
50 changes: 41 additions & 9 deletions zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import operator
import re
from functools import reduce
from typing import Optional, Tuple

import numpy as np
from numcodecs.compat import ensure_bytes, ensure_ndarray

from collections.abc import MutableMapping
from zarr._storage.sharded_store import SHARDED_STORES

from zarr.attrs import Attributes
from zarr.codecs import AsType, get_codec
Expand Down Expand Up @@ -191,6 +193,9 @@ def __init__(
self._oindex = OIndex(self)
self._vindex = VIndex(self)

# the sharded store is only initialized when needed
self._cached_sharded_store = None

def _load_metadata(self):
"""(Re)load metadata from store."""
if self._synchronizer is None:
Expand All @@ -213,6 +218,8 @@ def _load_metadata_nosync(self):
self._meta = meta
self._shape = meta['shape']
self._chunks = meta['chunks']
self._shards = meta.get('shards')
self._shard_format = meta.get('shard_format')
self._dtype = meta['dtype']
self._fill_value = meta['fill_value']
self._order = meta['order']
Expand Down Expand Up @@ -262,9 +269,12 @@ def _flush_metadata_nosync(self):
filters_config = [f.get_config() for f in self._filters]
else:
filters_config = None
# Possible (unrelated) bug:
# should the dimension_separator also be included in this dict?
meta = dict(shape=self._shape, chunks=self._chunks, dtype=self._dtype,
compressor=compressor_config, fill_value=self._fill_value,
order=self._order, filters=filters_config)
order=self._order, filters=filters_config,
shards=self._shards, shard_format=self._shard_format)
mkey = self._key_prefix + array_meta_key
self._store[mkey] = self._store._metadata_class.encode_array_metadata(meta)

Expand Down Expand Up @@ -309,9 +319,23 @@ def read_only(self, value):
def chunk_store(self):
"""A MutableMapping providing the underlying storage for array chunks."""
if self._chunk_store is None:
return self._store
chunk_store = self._store
else:
chunk_store = self._chunk_store
if self._shards is None:
return chunk_store
else:
return self._chunk_store
if self._cached_sharded_store is None:
self._cached_sharded_store = SHARDED_STORES[self._shard_format](
chunk_store,
shards=self._shards,
dimension_separator=self._dimension_separator,
are_chunks_compressed=self._compressor is not None,
dtype=self._dtype,
fill_value=self._fill_value or 0,
chunk_size=reduce(operator.mul, self._chunks, 1),
)
return self._cached_sharded_store

@property
def shape(self):
Expand All @@ -327,11 +351,17 @@ def shape(self, value):
self.resize(value)

@property
def chunks(self):
def chunks(self) -> Optional[Tuple[int, ...]]:
"""A tuple of integers describing the length of each dimension of a
chunk of the array."""
chunk of the array, or None."""
return self._chunks

@property
def shards(self):
"""A tuple of integers describing the number of chunks in each shard
jstriebel marked this conversation as resolved.
Show resolved Hide resolved
of the array."""
return self._shards

@property
def dtype(self):
"""The NumPy data type."""
Expand Down Expand Up @@ -1703,7 +1733,7 @@ def _set_selection(self, indexer, value, fields=None):
check_array_shape('value', value, sel_shape)

# iterate over chunks in range
if not hasattr(self.store, "setitems") or self._synchronizer is not None \
if not hasattr(self.chunk_store, "setitems") or self._synchronizer is not None \
or any(map(lambda x: x == 0, self.shape)):
# iterative approach
for chunk_coords, chunk_selection, out_selection in indexer:
Expand Down Expand Up @@ -1899,7 +1929,7 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
and hasattr(self._compressor, "decode_partial")
and not fields
and self.dtype != object
and hasattr(self.chunk_store, "getitems")
and hasattr(self.chunk_store, "getitems") # TODO: this should rather check for read_block or similar
):
partial_read_decode = True
cdatas = {
Expand Down Expand Up @@ -1946,8 +1976,8 @@ def _chunk_setitems(self, lchunk_coords, lchunk_selection, values, fields=None):
self.chunk_store.setitems(to_store)

def _chunk_delitems(self, ckeys):
if hasattr(self.store, "delitems"):
self.store.delitems(ckeys)
if hasattr(self.chunk_store, "delitems"):
self.chunk_store.delitems(ckeys)
else: # pragma: no cover
# exempting this branch from coverage as there are no extant stores
# that will trigger this condition, but it's possible that they
Expand Down Expand Up @@ -2236,6 +2266,7 @@ def digest(self, hashname="sha1"):

h = hashlib.new(hashname)

# TODO: operate on shards here if available:
for i in itertools.product(*[range(s) for s in self.cdata_shape]):
h.update(self.chunk_store.get(self._chunk_key(i), b""))

Expand Down Expand Up @@ -2362,6 +2393,7 @@ def _resize_nosync(self, *args):
except KeyError:
# chunk not initialized
pass
# TODO: collect all chunks do delete and use _chunk_delitems

def append(self, data, axis=0):
"""Append `data` to `axis`.
Expand Down
6 changes: 4 additions & 2 deletions zarr/creation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional, Tuple, Union
from warnings import warn

import numpy as np
Expand All @@ -19,7 +20,8 @@ def create(shape, chunks=True, dtype=None, compressor='default',
fill_value=0, order='C', store=None, synchronizer=None,
overwrite=False, path=None, chunk_store=None, filters=None,
cache_metadata=True, cache_attrs=True, read_only=False,
object_codec=None, dimension_separator=None, write_empty_chunks=True, **kwargs):
object_codec=None, dimension_separator=None, write_empty_chunks=True,
shards: Union[int, Tuple[int, ...], None]=None, shard_format: str="morton_order", **kwargs):
"""Create an array.

Parameters
Expand Down Expand Up @@ -145,7 +147,7 @@ def create(shape, chunks=True, dtype=None, compressor='default',
init_array(store, shape=shape, chunks=chunks, dtype=dtype, compressor=compressor,
fill_value=fill_value, order=order, overwrite=overwrite, path=path,
chunk_store=chunk_store, filters=filters, object_codec=object_codec,
dimension_separator=dimension_separator)
dimension_separator=dimension_separator, shards=shards, shard_format=shard_format)

# instantiate array
z = Array(store, path=path, chunk_store=chunk_store, synchronizer=synchronizer,
Expand Down
15 changes: 12 additions & 3 deletions zarr/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A
object_codec = None

dimension_separator = meta.get("dimension_separator", None)
shards = meta.get("shards", None)
shard_format = meta.get("shard_format", None)
fill_value = cls.decode_fill_value(meta['fill_value'], dtype, object_codec)
meta = dict(
zarr_format=meta["zarr_format"],
Expand All @@ -64,6 +66,10 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A
)
if dimension_separator:
meta['dimension_separator'] = dimension_separator
if shards:
meta['shards'] = tuple(shards)
assert shard_format is not None
meta['shard_format'] = shard_format
except Exception as e:
raise MetadataError("error decoding metadata") from e
else:
Expand All @@ -77,6 +83,8 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes:
dtype, sdshape = dtype.subdtype

dimension_separator = meta.get("dimension_separator")
shards = meta.get("shards")
shard_format = meta.get("shard_format")
if dtype.hasobject:
import numcodecs
object_codec = numcodecs.get_codec(meta['filters'][0])
Expand All @@ -95,9 +103,10 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes:
)
if dimension_separator:
meta['dimension_separator'] = dimension_separator

if dimension_separator:
meta["dimension_separator"] = dimension_separator
if shards:
meta['shards'] = shards
assert shard_format is not None
meta['shard_format'] = shard_format

return json_dumps(meta)

Expand Down
Loading