Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding Prototype I: implementation as translating Store #876

Closed
wants to merge 11 commits into from
55 changes: 55 additions & 0 deletions chunking_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
import os

import zarr

store = zarr.DirectoryStore("data/chunking_test.zarr")
z = zarr.zeros((20, 3), chunks=(3, 2), shards=(2, 2), store=store, overwrite=True, compressor=None)
z[:10, :] = 42
z[15, 1] = 389
z[19, 2] = 1
z[0, 1] = -4.2

print(store[".zarray"].decode())
# {
# "chunks": [
# 3,
# 2
# ],
# "compressor": null,
# "dtype": "<f8",
# "fill_value": 0.0,
# "filters": null,
# "order": "C",
# "shape": [
# 20,
# 3
# ],
# "shard_format": "indexed",
# "shards": [
# 2,
# 2
# ],
# "zarr_format": 2
# }

assert json.loads(store[".zarray"].decode()) ["shards"] == [2, 2]

print("ONDISK", sorted(os.listdir("data/chunking_test.zarr")))
print("STORE", sorted(store))
print("CHUNKSTORE (SHARDED)", sorted(z.chunk_store))

# ONDISK ['.zarray', '0.0', '1.0', '2.0', '3.0']
# STORE ['.zarray', '0.0', '1.0', '2.0', '3.0']
# CHUNKSTORE (SHARDED) ['.zarray', '0.0', '0.1', '1.0', '1.1', '2.0', '2.1', '3.0', '3.1', '5.0', '6.1']

index_bytes = z.store["0.0"][-2*2*16:]
print("INDEX 0.0", [int.from_bytes(index_bytes[i:i+8], byteorder="little") for i in range(0, len(index_bytes), 8)])
# INDEX 0.0 [0, 48, 48, 48, 96, 48, 144, 48]

z_reopened = zarr.open("data/chunking_test.zarr")
assert z_reopened.shards == (2, 2)
assert z_reopened[15, 1] == 389
assert z_reopened[19, 2] == 1
assert z_reopened[0, 1] == -4.2
assert z_reopened[0, 0] == 42
184 changes: 184 additions & 0 deletions zarr/_storage/sharded_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from collections import defaultdict
from functools import reduce
from itertools import product
from typing import Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union

import numpy as np

from zarr._storage.store import BaseStore, Store
from zarr.storage import StoreLike, array_meta_key, attrs_key, group_meta_key


MAX_UINT_64 = 2 ** 64 - 1


class _ShardIndex(NamedTuple):
store: "IndexedShardedStore"
offsets_and_lengths: np.ndarray # dtype uint64, shape (shards_0, _shards_1, ..., 2)

def __localize_chunk__(self, chunk: Tuple[int, ...]) -> Tuple[int, ...]:
return tuple(chunk_i % shard_i for chunk_i, shard_i in zip(chunk, self.store._shards))

def get_chunk_slice(self, chunk: Tuple[int, ...]) -> Optional[slice]:
localized_chunk = self.__localize_chunk__(chunk)
chunk_start, chunk_len = self.offsets_and_lengths[localized_chunk]
if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64):
return None
else:
return slice(chunk_start, chunk_start + chunk_len)

def set_chunk_slice(self, chunk: Tuple[int, ...], chunk_slice: Optional[slice]) -> None:
localized_chunk = self.__localize_chunk__(chunk)
if chunk_slice is None:
self.offsets_and_lengths[localized_chunk] = (MAX_UINT_64, MAX_UINT_64)
else:
self.offsets_and_lengths[localized_chunk] = (
chunk_slice.start,
chunk_slice.stop - chunk_slice.start
)

def to_bytes(self) -> bytes:
return self.offsets_and_lengths.tobytes(order='C')

@classmethod
def from_bytes(
cls, buffer: Union[bytes, bytearray], store: "IndexedShardedStore"
) -> "_ShardIndex":
return cls(
store=store,
offsets_and_lengths=np.frombuffer(
bytearray(buffer), dtype="<u8"
).reshape(*store._shards, 2, order="C")
)

@classmethod
def create_empty(cls, store: "IndexedShardedStore"):
# reserving 2*64bit per chunk for offset and length:
return cls.from_bytes(
MAX_UINT_64.to_bytes(8, byteorder="little") * (2 * store._num_chunks_per_shard),
store=store
)


class IndexedShardedStore(Store):
"""This class should not be used directly,
but is added to an Array as a wrapper when needed automatically."""

def __init__(
self,
store: StoreLike,
shards: Tuple[int, ...],
dimension_separator: str,
) -> None:
self._store: BaseStore = BaseStore._ensure_store(store)
self._shards = shards
self._num_chunks_per_shard = reduce(lambda x, y: x*y, shards, 1)
self._dimension_separator = dimension_separator

# TODO: add warnings for ineffective reads/writes:
# * warn if partial reads are not available
# * optionally warn on unaligned writes if no partial writes are available

def __keys_to_shard_groups__(
self, keys: Iterable[str]
) -> Dict[str, List[Tuple[str, Tuple[int, ...]]]]:
shard_indices_per_shard_key = defaultdict(list)
for chunk_key in keys:
# TODO: allow to be in a group (aka only use last parts for dimensions)
chunk_subkeys = tuple(map(int, chunk_key.split(self._dimension_separator)))
shard_key_tuple = (
subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._shards)
)
shard_key = self._dimension_separator.join(map(str, shard_key_tuple))
shard_indices_per_shard_key[shard_key].append((chunk_key, chunk_subkeys))
return shard_indices_per_shard_key

def __get_index__(self, buffer: Union[bytes, bytearray]) -> _ShardIndex:
# At the end of each shard 2*64bit per chunk for offset and length define the index:
return _ShardIndex.from_bytes(buffer[-16 * self._num_chunks_per_shard:], self)

def __get_chunks_in_shard(self, shard_key: str) -> Iterator[Tuple[int, ...]]:
# TODO: allow to be in a group (aka only use last parts for dimensions)
shard_key_tuple = tuple(map(int, shard_key.split(self._dimension_separator)))
for chunk_offset in product(*(range(i) for i in self._shards)):
yield tuple(
shard_key_i * shards_i + offset_i
for shard_key_i, offset_i, shards_i
in zip(shard_key_tuple, chunk_offset, self._shards)
)

def __getitem__(self, key: str) -> bytes:
return self.getitems([key])[key]

def getitems(self, keys: Iterable[str], **kwargs) -> Dict[str, bytes]:
result = {}
for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(keys).items():
# TODO use partial read if available
full_shard_value = self._store[shard_key]
index = self.__get_index__(full_shard_value)
for chunk_key, chunk_subkeys in chunks_in_shard:
chunk_slice = index.get_chunk_slice(chunk_subkeys)
if chunk_slice is not None:
result[chunk_key] = full_shard_value[chunk_slice]
return result

def __setitem__(self, key: str, value: bytes) -> None:
self.setitems({key: value})

def setitems(self, values: Dict[str, bytes]) -> None:
for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(values.keys()).items():
all_chunks = set(self.__get_chunks_in_shard(shard_key))
chunks_to_set = set(chunk_subkeys for _chunk_key, chunk_subkeys in chunks_in_shard)
chunks_to_read = all_chunks - chunks_to_set
new_content = {
chunk_subkeys: values[chunk_key] for chunk_key, chunk_subkeys in chunks_in_shard
}
try:
# TODO use partial read if available
full_shard_value = self._store[shard_key]
except KeyError:
index = _ShardIndex.create_empty(self)
else:
index = self.__get_index__(full_shard_value)
for chunk_to_read in chunks_to_read:
chunk_slice = index.get_chunk_slice(chunk_to_read)
if chunk_slice is not None:
new_content[chunk_to_read] = full_shard_value[chunk_slice]

# TODO use partial write if available and possible (e.g. at the end)
shard_content = b""
# TODO: order the chunks in the shard:
for chunk_subkeys, chunk_content in new_content.items():
chunk_slice = slice(len(shard_content), len(shard_content) + len(chunk_content))
index.set_chunk_slice(chunk_subkeys, chunk_slice)
shard_content += chunk_content
# Appending the index at the end of the shard:
shard_content += index.to_bytes()
self._store[shard_key] = shard_content

def __delitem__(self, key) -> None:
# TODO not implemented yet, also delitems
# Deleting the "last" chunk in a shard needs to remove the whole shard
raise NotImplementedError("Deletion is not yet implemented")

def __iter__(self) -> Iterator[str]:
for shard_key in self._store.__iter__():
if any(shard_key.endswith(i) for i in (array_meta_key, group_meta_key, attrs_key)):
# Special keys such as ".zarray" are passed on as-is
yield shard_key
else:
# For each shard key in the wrapped store, all corresponding chunks are yielded.
# TODO: use partial read if available:
index = self.__get_index__(self._store[shard_key])
for chunk_tuple in self.__get_chunks_in_shard(shard_key):
if index.get_chunk_slice(chunk_tuple) is not None:
# TODO: if shard is in a group, prepend group-prefix to chunk
yield self._dimension_separator.join(map(str, chunk_tuple))

def __len__(self) -> int:
return sum(1 for _ in self.keys())


SHARDED_STORES = {
"indexed": IndexedShardedStore,
}
2 changes: 2 additions & 0 deletions zarr/_storage/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def _ensure_store(store: Any):


class Store(BaseStore):
# TODO: document methods which allow optimizations,
# e.g. delitems, setitems, getitems, listdir, …
"""Abstract store class used by implementations following the Zarr v2 spec.

Adds public `listdir`, `rename`, and `rmdir` methods on top of BaseStore.
Expand Down
45 changes: 37 additions & 8 deletions zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import operator
import re
from functools import reduce
from typing import Optional, Tuple

import numpy as np
from numcodecs.compat import ensure_bytes, ensure_ndarray

from collections.abc import MutableMapping
from zarr._storage.sharded_store import SHARDED_STORES

from zarr.attrs import Attributes
from zarr.codecs import AsType, get_codec
Expand Down Expand Up @@ -191,6 +193,9 @@ def __init__(
self._oindex = OIndex(self)
self._vindex = VIndex(self)

# the sharded store is only initialized when needed
self._cached_sharded_store = None

def _load_metadata(self):
"""(Re)load metadata from store."""
if self._synchronizer is None:
Expand All @@ -213,6 +218,8 @@ def _load_metadata_nosync(self):
self._meta = meta
self._shape = meta['shape']
self._chunks = meta['chunks']
self._shards = meta.get('shards')
self._shard_format = meta.get('shard_format')
self._dtype = meta['dtype']
self._fill_value = meta['fill_value']
self._order = meta['order']
Expand Down Expand Up @@ -262,9 +269,12 @@ def _flush_metadata_nosync(self):
filters_config = [f.get_config() for f in self._filters]
else:
filters_config = None
# Possible (unrelated) bug:
# should the dimension_separator also be included in this dict?
meta = dict(shape=self._shape, chunks=self._chunks, dtype=self._dtype,
compressor=compressor_config, fill_value=self._fill_value,
order=self._order, filters=filters_config)
order=self._order, filters=filters_config,
shards=self._shards, shard_format=self._shard_format)
mkey = self._key_prefix + array_meta_key
self._store[mkey] = self._store._metadata_class.encode_array_metadata(meta)

Expand Down Expand Up @@ -309,9 +319,19 @@ def read_only(self, value):
def chunk_store(self):
"""A MutableMapping providing the underlying storage for array chunks."""
if self._chunk_store is None:
return self._store
chunk_store = self._store
else:
chunk_store = self._chunk_store
if self._shards is None:
return chunk_store
else:
return self._chunk_store
if self._cached_sharded_store is None:
self._cached_sharded_store = SHARDED_STORES[self._shard_format](
chunk_store,
shards=self._shards,
dimension_separator=self._dimension_separator,
)
return self._cached_sharded_store

@property
def shape(self):
Expand All @@ -327,11 +347,17 @@ def shape(self, value):
self.resize(value)

@property
def chunks(self):
def chunks(self) -> Optional[Tuple[int, ...]]:
"""A tuple of integers describing the length of each dimension of a
chunk of the array."""
chunk of the array, or None."""
return self._chunks

@property
def shards(self):
"""A tuple of integers describing the number of chunks in each shard
of the array."""
return self._shards

@property
def dtype(self):
"""The NumPy data type."""
Expand Down Expand Up @@ -1708,7 +1734,7 @@ def _set_selection(self, indexer, value, fields=None):
check_array_shape('value', value, sel_shape)

# iterate over chunks in range
if not hasattr(self.store, "setitems") or self._synchronizer is not None \
if not hasattr(self.chunk_store, "setitems") or self._synchronizer is not None \
or any(map(lambda x: x == 0, self.shape)):
# iterative approach
for chunk_coords, chunk_selection, out_selection in indexer:
Expand Down Expand Up @@ -1904,6 +1930,7 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
and hasattr(self._compressor, "decode_partial")
and not fields
and self.dtype != object
# TODO: this should rather check for read_block or similar
and hasattr(self.chunk_store, "getitems")
):
partial_read_decode = True
Expand Down Expand Up @@ -1951,8 +1978,8 @@ def _chunk_setitems(self, lchunk_coords, lchunk_selection, values, fields=None):
self.chunk_store.setitems(to_store)

def _chunk_delitems(self, ckeys):
if hasattr(self.store, "delitems"):
self.store.delitems(ckeys)
if hasattr(self.chunk_store, "delitems"):
self.chunk_store.delitems(ckeys)
else: # pragma: no cover
# exempting this branch from coverage as there are no extant stores
# that will trigger this condition, but it's possible that they
Expand Down Expand Up @@ -2239,6 +2266,7 @@ def digest(self, hashname="sha1"):

h = hashlib.new(hashname)

# TODO: operate on shards here if available:
for i in itertools.product(*[range(s) for s in self.cdata_shape]):
h.update(self.chunk_store.get(self._chunk_key(i), b""))

Expand Down Expand Up @@ -2365,6 +2393,7 @@ def _resize_nosync(self, *args):
except KeyError:
# chunk not initialized
pass
# TODO: collect all chunks do delete and use _chunk_delitems

def append(self, data, axis=0):
"""Append `data` to `axis`.
Expand Down
Loading