Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding Prototype #40

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions sharding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import os
import shutil

import zarrita


shutil.rmtree("sharding_test.zr3", ignore_errors=True)
h = zarrita.create_hierarchy("sharding_test.zr3")
a = h.create_array(
path="testarray",
shape=(20, 3),
dtype="float64",
chunk_shape=(3, 2),
sharding={"chunks_per_shard": (2, 2)},
)

a[:10, :] = 42
a[15, 1] = 389
a[19, 2] = 1
a[0, 1] = -4.2

assert a.store._chunks_per_shard == (2, 2)
assert a[15, 1] == 389
assert a[19, 2] == 1
assert a[0, 1] == -4.2
assert a[0, 0] == 42

array_json = a.store["meta/root/testarray.array.json"].decode()

print(array_json)
# {
# "shape": [
# 20,
# 3
# ],
# "data_type": "<f8",
# "chunk_grid": {
# "type": "regular",
# "chunk_shape": [
# 3,
# 2
# ],
# "separator": "/"
# },
# "chunk_memory_layout": "C",
# "fill_value": null,
# "extensions": [],
# "attributes": {},
# "sharding": {
# "chunks_per_shard": [
# 2,
# 2
# ],
# "format": "indexed"
# }
# }

assert json.loads(array_json)["sharding"]["chunks_per_shard"] == [2, 2]

print("ONDISK")
for root, dirs, files in os.walk("sharding_test.zr3"):
dirs.sort()
if len(files) > 0:
print(" ", root.ljust(40), *sorted(files))
print("UNDERLYING STORE", sorted(i.rsplit("c")[-1] for i in a.store._store if i.startswith("data")))
print("STORE", sorted(i.rsplit("c")[-1] for i in a.store if i.startswith("data")))
# ONDISK
# sharding_test.zr3 zarr.json
# sharding_test.zr3/data/root/testarray/c0 0
# sharding_test.zr3/data/root/testarray/c1 0
# sharding_test.zr3/data/root/testarray/c2 0
# sharding_test.zr3/data/root/testarray/c3 0
# sharding_test.zr3/meta/root testarray.array.json
# UNDERLYING STORE ['0/0', '1/0', '2/0', '3/0']
# STORE ['0/0', '0/1', '1/0', '1/1', '2/0', '2/1', '3/0', '3/1', '5/0', '6/1']

index_bytes = a.store._store["data/root/testarray/c0/0"][-2*2*16:]
print("INDEX 0.0", [int.from_bytes(index_bytes[i:i+8], byteorder="little") for i in range(0, len(index_bytes), 8)])
# INDEX 0.0 [0, 48, 48, 48, 96, 48, 144, 48]


a_reopened = zarrita.get_hierarchy("sharding_test.zr3").get_array("testarray")
assert a_reopened.store._chunks_per_shard == (2, 2)
assert a_reopened[15, 1] == 389
assert a_reopened[19, 2] == 1
assert a_reopened[0, 1] == -4.2
assert a_reopened[0, 0] == 42
212 changes: 206 additions & 6 deletions zarrita.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import json
import numbers
import itertools
import functools
import math
import re
from collections.abc import Mapping, MutableMapping
from typing import Iterator, Union, Optional, Tuple, Any, List, Dict, NamedTuple
from typing import Iterator, Union, Optional, Tuple, Any, List, Dict, NamedTuple, Iterable, Type

# third-party dependencies

Expand Down Expand Up @@ -170,6 +171,18 @@ def _check_compressor(compressor: Optional[Codec]) -> None:
assert compressor is None or isinstance(compressor, Codec)


def _check_sharding(sharding: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if sharding is None:
return None
if "format" not in sharding:
sharding["format"] = "indexed"
assert sharding["format"] in SHARDED_STORES, (
f"Shard format {sharding['format']} is not supported, "
+ f"use one of {list(SHARDED_STORES)}"
)
return sharding


def _encode_codec_metadata(codec: Codec) -> Optional[Mapping]:
if codec is None:
return None
Expand Down Expand Up @@ -265,7 +278,8 @@ def create_array(self,
chunk_separator: str = "/",
compressor: Optional[Codec] = None,
fill_value: Any = None,
attrs: Optional[Mapping] = None) -> Array:
attrs: Optional[Mapping] = None,
sharding: Optional[Dict[str, Any]] = None) -> Array:

# sanity checks
path = _check_path(path)
Expand All @@ -274,6 +288,7 @@ def create_array(self,
chunk_shape = _check_chunk_shape(chunk_shape, shape)
_check_compressor(compressor)
attrs = _check_attrs(attrs)
sharding = _check_sharding(sharding)

# encode data type
if dtype == np.bool_:
Expand All @@ -297,6 +312,8 @@ def create_array(self,
)
if compressor is not None:
meta["compressor"] = _encode_codec_metadata(compressor)
if sharding is not None:
meta["sharding"] = sharding

# serialise and store metadata document
meta_doc = _json_encode_object(meta)
Expand All @@ -307,7 +324,8 @@ def create_array(self,
array = Array(store=self.store, path=path, owner=self,
shape=shape, dtype=dtype, chunk_shape=chunk_shape,
chunk_separator=chunk_separator, compressor=compressor,
fill_value=fill_value, attrs=attrs)
fill_value=fill_value, attrs=attrs,
sharding=sharding)

return array

Expand Down Expand Up @@ -341,12 +359,13 @@ def get_array(self, path: str) -> Array:
if spec["must_understand"]:
raise NotImplementedError(spec)
attrs = meta["attributes"]
sharding = meta.get("sharding", None)

# instantiate array
a = Array(store=self.store, path=path, owner=self, shape=shape,
dtype=dtype, chunk_shape=chunk_shape,
chunk_separator=chunk_separator, compressor=compressor,
fill_value=fill_value, attrs=attrs)
fill_value=fill_value, attrs=attrs, sharding=sharding)

return a

Expand Down Expand Up @@ -587,7 +606,15 @@ def __init__(self,
chunk_separator: str,
compressor: Optional[Codec],
fill_value: Any = None,
attrs: Optional[Mapping] = None):
attrs: Optional[Mapping] = None,
sharding: Optional[Dict[str, Any]] = None,
):
if sharding is not None:
store = SHARDED_STORES[sharding["format"]]( # type: ignore
store=store,
chunk_separator=chunk_separator,
**sharding,
)
super().__init__(store=store, path=path, owner=owner)
self.shape = shape
self.dtype = dtype
Expand Down Expand Up @@ -771,7 +798,7 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value):
encoded_chunk_data = self._encode_chunk(chunk)

# store
self.store[chunk_key] = encoded_chunk_data
self.store[chunk_key] = encoded_chunk_data.tobytes()

def _encode_chunk(self, chunk):

Expand Down Expand Up @@ -1146,3 +1173,176 @@ def __repr__(self) -> str:
if isinstance(protocol, tuple):
protocol = protocol[-1]
return f"{protocol}://{self.root}"


MAX_UINT_64 = 2 ** 64 - 1


def _is_data_key(key: str) -> bool:
return key.startswith("data/root")


class _ShardIndex(NamedTuple):
store: "IndexedShardedStore"
offsets_and_lengths: np.ndarray # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)

def __localize_chunk__(self, chunk: Tuple[int, ...]) -> Tuple[int, ...]:
return tuple(chunk_i % shard_i for chunk_i, shard_i in zip(chunk, self.store._chunks_per_shard))

def get_chunk_slice(self, chunk: Tuple[int, ...]) -> Optional[slice]:
localized_chunk = self.__localize_chunk__(chunk)
chunk_start, chunk_len = self.offsets_and_lengths[localized_chunk]
if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64):
return None
else:
return slice(chunk_start, chunk_start + chunk_len)

def set_chunk_slice(self, chunk: Tuple[int, ...], chunk_slice: Optional[slice]) -> None:
localized_chunk = self.__localize_chunk__(chunk)
if chunk_slice is None:
self.offsets_and_lengths[localized_chunk] = (MAX_UINT_64, MAX_UINT_64)
else:
self.offsets_and_lengths[localized_chunk] = (
chunk_slice.start,
chunk_slice.stop - chunk_slice.start
)

def to_bytes(self) -> bytes:
return self.offsets_and_lengths.tobytes(order='C')

@classmethod
def from_bytes(
cls, buffer: Union[bytes, bytearray], store: "IndexedShardedStore"
) -> "_ShardIndex":
return cls(
store=store,
offsets_and_lengths=np.frombuffer(
bytearray(buffer), dtype="<u8"
).reshape(*store._chunks_per_shard, 2, order="C")
)

@classmethod
def create_empty(cls, store: "IndexedShardedStore"):
# reserving 2*64bit per chunk for offset and length:
return cls.from_bytes(
MAX_UINT_64.to_bytes(8, byteorder="little") * (2 * store._num_chunks_per_shard),
store=store
)


class IndexedShardedStore(Store):
"""This class should not be used directly,
but is added to an Array as a wrapper when needed automatically."""

def __init__(
self,
store: Store,
chunk_separator: str,
chunks_per_shard: Iterable[int],
**kwargs: Any,
) -> None:
self._store = store
self._num_chunks_per_shard = functools.reduce(lambda x, y: x*y, chunks_per_shard, 1)
self._chunk_separator = chunk_separator
assert all(isinstance(s, int) for s in chunks_per_shard)
self._chunks_per_shard = tuple(chunks_per_shard)

def _key_to_shard(
self, chunk_key: str
) -> Tuple[str, Tuple[int, ...]]:
prefix, _, chunk_string = chunk_key.rpartition("c")
chunk_subkeys = tuple(map(int, chunk_string.split(self._chunk_separator)))
shard_key_tuple = (
subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._chunks_per_shard)
)
shard_key = prefix + "c" + self._chunk_separator.join(map(str, shard_key_tuple))
return shard_key, chunk_subkeys

def _get_index(self, buffer: Union[bytes, bytearray]) -> _ShardIndex:
# At the end of each shard 2*64bit per chunk for offset and length define the index:
return _ShardIndex.from_bytes(buffer[-16 * self._num_chunks_per_shard:], self)

def _get_chunks_in_shard(self, shard_key: str) -> Iterator[Tuple[int, ...]]:
_, _, chunk_string = shard_key.rpartition("c")
shard_key_tuple = tuple(map(int, chunk_string.split(self._chunk_separator)))
for chunk_offset in itertools.product(*(range(i) for i in self._chunks_per_shard)):
yield tuple(
shard_key_i * shards_i + offset_i
for shard_key_i, offset_i, shards_i
in zip(shard_key_tuple, chunk_offset, self._chunks_per_shard)
)

def __getitem__(self, key: str, default: Optional[bytes] = None) -> bytes:
if _is_data_key(key):
shard_key, chunk_subkeys = self._key_to_shard(key)
full_shard_value = self._store[shard_key]
index = self._get_index(full_shard_value)
chunk_slice = index.get_chunk_slice(chunk_subkeys)
if chunk_slice is not None:
return full_shard_value[chunk_slice]
else:
if default is not None:
return default
raise KeyError(key)
else:
return self._store.__getitem__(key, default)

def __setitem__(self, key: str, value: bytes) -> None:
if _is_data_key(key):
shard_key, chunk_subkeys = self._key_to_shard(key)
chunks_to_read = set(self._get_chunks_in_shard(shard_key))
chunks_to_read.remove(chunk_subkeys)
new_content = {chunk_subkeys: value}
try:
full_shard_value = self._store[shard_key]
except KeyError:
index = _ShardIndex.create_empty(self)
else:
index = self._get_index(full_shard_value)
for chunk_to_read in chunks_to_read:
chunk_slice = index.get_chunk_slice(chunk_to_read)
if chunk_slice is not None:
new_content[chunk_to_read] = full_shard_value[chunk_slice]

shard_content = b""
for chunk_subkeys, chunk_content in new_content.items():
chunk_slice = slice(len(shard_content), len(shard_content) + len(chunk_content))
index.set_chunk_slice(chunk_subkeys, chunk_slice)
shard_content += chunk_content
# Appending the index at the end of the shard:
shard_content += index.to_bytes()
self._store[shard_key] = shard_content
else:
self._store[key] = value

def _shard_key_to_original_keys(self, key: str) -> Iterator[str]:
if not _is_data_key(key):
# Special keys such as meta-keys are passed on as-is
yield key
else:
index = self._get_index(self._store[key])
prefix, _, _ = key.rpartition("c")
for chunk_tuple in self._get_chunks_in_shard(key):
if index.get_chunk_slice(chunk_tuple) is not None:
yield prefix + "c" + self._chunk_separator.join(map(str, chunk_tuple))

def __iter__(self) -> Iterator[str]:
for key in self._store:
yield from self._shard_key_to_original_keys(key)

def list_prefix(self, prefix: str) -> List[str]:
if _is_data_key(prefix):
# Needs translation of the prefix to shard_key
raise NotImplementedError
return self._store.list_prefix(prefix)

def list_dir(self, prefix: str) -> ListDirResult:
if _is_data_key(prefix):
# Needs translation of the prefix to shard_key
raise NotImplementedError
return self._store.list_dir(prefix)


SHARDED_STORES: Dict[str, Type[Store]] = {
"indexed": IndexedShardedStore,
}