Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds index_location parameter to sharding codec #13

Merged
merged 4 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions tests/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from zarrita import Array, Group, LocalStore, Store, codecs, runtime_configuration
from zarrita.indexing import morton_order_iter
from zarrita.metadata import CodecMetadata
from zarrita.metadata import CodecMetadata, ShardingCodecIndexLocation


@fixture
Expand All @@ -32,7 +32,12 @@ def store() -> Iterator[Store]:
pass


def test_sharding(store: Store, l4_sample_data: np.ndarray):
@pytest.mark.parametrize(
"index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]
)
def test_sharding(
store: Store, l4_sample_data: np.ndarray, index_location: ShardingCodecIndexLocation
):
data = l4_sample_data

a = Array.create(
Expand All @@ -49,6 +54,7 @@ def test_sharding(store: Store, l4_sample_data: np.ndarray):
codecs.bytes_codec(),
codecs.blosc_codec(typesize=data.dtype.itemsize, cname="lz4"),
],
index_location=index_location,
)
],
)
Expand All @@ -60,7 +66,12 @@ def test_sharding(store: Store, l4_sample_data: np.ndarray):
assert np.array_equal(data, read_data)


def test_sharding_partial(store: Store, l4_sample_data: np.ndarray):
@pytest.mark.parametrize(
"index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]
)
def test_sharding_partial(
store: Store, l4_sample_data: np.ndarray, index_location: ShardingCodecIndexLocation
):
data = l4_sample_data

a = Array.create(
Expand All @@ -77,6 +88,7 @@ def test_sharding_partial(store: Store, l4_sample_data: np.ndarray):
codecs.bytes_codec(),
codecs.blosc_codec(typesize=data.dtype.itemsize, cname="lz4"),
],
index_location=index_location,
)
],
)
Expand All @@ -91,7 +103,12 @@ def test_sharding_partial(store: Store, l4_sample_data: np.ndarray):
assert np.array_equal(data, read_data)


def test_sharding_partial_read(store: Store, l4_sample_data: np.ndarray):
@pytest.mark.parametrize(
"index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]
)
def test_sharding_partial_read(
store: Store, l4_sample_data: np.ndarray, index_location: ShardingCodecIndexLocation
):
data = l4_sample_data

a = Array.create(
Expand All @@ -108,6 +125,7 @@ def test_sharding_partial_read(store: Store, l4_sample_data: np.ndarray):
codecs.bytes_codec(),
codecs.blosc_codec(typesize=data.dtype.itemsize, cname="lz4"),
],
index_location=index_location,
)
],
)
Expand All @@ -116,7 +134,12 @@ def test_sharding_partial_read(store: Store, l4_sample_data: np.ndarray):
assert np.all(read_data == 1)


def test_sharding_partial_overwrite(store: Store, l4_sample_data: np.ndarray):
@pytest.mark.parametrize(
"index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]
)
def test_sharding_partial_overwrite(
store: Store, l4_sample_data: np.ndarray, index_location: ShardingCodecIndexLocation
):
data = l4_sample_data[:10, :10, :10]

a = Array.create(
Expand All @@ -133,6 +156,7 @@ def test_sharding_partial_overwrite(store: Store, l4_sample_data: np.ndarray):
codecs.bytes_codec(),
codecs.blosc_codec(typesize=data.dtype.itemsize, cname="lz4"),
],
index_location=index_location,
)
],
)
Expand All @@ -148,7 +172,20 @@ def test_sharding_partial_overwrite(store: Store, l4_sample_data: np.ndarray):
assert np.array_equal(data, read_data)


def test_nested_sharding(store: Store, l4_sample_data: np.ndarray):
@pytest.mark.parametrize(
"outer_index_location",
[ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end],
)
@pytest.mark.parametrize(
"inner_index_location",
[ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end],
)
def test_nested_sharding(
store: Store,
l4_sample_data: np.ndarray,
outer_index_location: ShardingCodecIndexLocation,
inner_index_location: ShardingCodecIndexLocation,
):
data = l4_sample_data

a = Array.create(
Expand All @@ -160,7 +197,12 @@ def test_nested_sharding(store: Store, l4_sample_data: np.ndarray):
codecs=[
codecs.sharding_codec(
(32, 32, 32),
[codecs.sharding_codec((16, 16, 16))],
[
codecs.sharding_codec(
(16, 16, 16), index_location=inner_index_location
)
],
index_location=outer_index_location,
)
],
)
Expand Down
4 changes: 3 additions & 1 deletion zarrita/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
GzipCodecConfigurationMetadata,
GzipCodecMetadata,
ShardingCodecConfigurationMetadata,
ShardingCodecIndexLocation,
ShardingCodecMetadata,
TransposeCodecConfigurationMetadata,
TransposeCodecMetadata,
Expand Down Expand Up @@ -595,11 +596,12 @@ def sharding_codec(
chunk_shape: Tuple[int, ...],
codecs: Optional[List[CodecMetadata]] = None,
index_codecs: Optional[List[CodecMetadata]] = None,
index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end,
) -> ShardingCodecMetadata:
codecs = codecs or [bytes_codec()]
index_codecs = index_codecs or [bytes_codec(), crc32c_codec()]
return ShardingCodecMetadata(
configuration=ShardingCodecConfigurationMetadata(
chunk_shape, codecs, index_codecs
chunk_shape, codecs, index_codecs, index_location
)
)
8 changes: 8 additions & 0 deletions zarrita/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,17 @@ class Crc32cCodecMetadata:
name: Literal["crc32c"] = "crc32c"


class ShardingCodecIndexLocation(Enum):
start = "start"
end = "end"


@frozen
class ShardingCodecConfigurationMetadata:
chunk_shape: ChunkCoords
codecs: List["CodecMetadata"]
index_codecs: List["CodecMetadata"]
index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end


@frozen
Expand Down Expand Up @@ -296,6 +302,8 @@ def to_bytes(self) -> bytes:
def _json_convert(o):
if isinstance(o, DataType):
return o.name
if isinstance(o, ShardingCodecIndexLocation):
return o.name
raise TypeError

return json.dumps(
Expand Down
71 changes: 56 additions & 15 deletions zarrita/sharding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from __future__ import annotations

from typing import Iterator, List, Mapping, NamedTuple, Optional, Set, Tuple
from typing import (
Awaitable,
Callable,
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Set,
Tuple,
)

import numpy as np
from attrs import frozen
Expand All @@ -23,6 +33,7 @@
CoreArrayMetadata,
DataType,
ShardingCodecConfigurationMetadata,
ShardingCodecIndexLocation,
ShardingCodecMetadata,
)
from zarrita.store import StorePath
Expand All @@ -49,7 +60,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> Optional[Tuple[int, int]
if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64):
return None
else:
return (int(chunk_start), int(chunk_start + chunk_len))
return (int(chunk_start), int(chunk_start) + int(chunk_len))

def set_chunk_slice(
self, chunk_coords: ChunkCoords, chunk_slice: Optional[slice]
Expand Down Expand Up @@ -101,11 +112,15 @@ class _ShardProxy(Mapping):

@classmethod
async def from_bytes(cls, buf: BytesLike, codec: ShardingCodec) -> _ShardProxy:
shard_index_size = codec._shard_index_size()
obj = cls()
obj.buf = memoryview(buf)
obj.index = await codec._decode_shard_index(
obj.buf[-codec._shard_index_size() :]
)
if codec.configuration.index_location == ShardingCodecIndexLocation.start:
shard_index_bytes = obj.buf[:shard_index_size]
else:
shard_index_bytes = obj.buf[-shard_index_size:]

obj.index = await codec._decode_shard_index(shard_index_bytes)
return obj

@classmethod
Expand Down Expand Up @@ -152,7 +167,10 @@ def merge_with_morton_order(
return obj

@classmethod
def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardBuilder:
def create_empty(
cls,
chunks_per_shard: ChunkCoords,
) -> _ShardBuilder:
obj = cls()
obj.buf = bytearray()
obj.index = _ShardIndex.create_empty(chunks_per_shard)
Expand All @@ -166,9 +184,23 @@ def append(self, chunk_coords: ChunkCoords, value: BytesLike):
chunk_coords, slice(chunk_start, chunk_start + chunk_length)
)

def finalize(self, index_bytes: BytesLike) -> BytesLike:
self.buf.extend(index_bytes)
return self.buf
async def finalize(
self,
index_location: ShardingCodecIndexLocation,
index_encoder: Callable[[_ShardIndex], Awaitable[BytesLike]],
) -> BytesLike:
index_bytes = await index_encoder(self.index)
if index_location == ShardingCodecIndexLocation.start:
self.index.offsets_and_lengths[..., 0] += len(index_bytes)
index_bytes = await index_encoder(
self.index
) # encode again with corrected offsets
out_buf = bytearray(index_bytes)
out_buf.extend(self.buf)
else:
out_buf = self.buf
out_buf.extend(index_bytes)
return out_buf


@frozen
Expand Down Expand Up @@ -402,8 +434,8 @@ async def _write_chunk(
if chunk_bytes is not None:
shard_builder.append(chunk_coords, chunk_bytes)

return shard_builder.finalize(
await self._encode_shard_index(shard_builder.index)
return await shard_builder.finalize(
self.configuration.index_location, self._encode_shard_index
)

async def encode_partial(
Expand Down Expand Up @@ -486,15 +518,19 @@ async def _write_chunk(
tombstones.add(chunk_coords)

shard_builder = _ShardBuilder.merge_with_morton_order(
self.chunks_per_shard, tombstones, new_shard_builder, old_shard_dict
self.chunks_per_shard,
tombstones,
new_shard_builder,
old_shard_dict,
)

if shard_builder.index.is_all_empty():
await store_path.delete_async()
else:
await store_path.set_async(
shard_builder.finalize(
await self._encode_shard_index(shard_builder.index)
await shard_builder.finalize(
self.configuration.index_location,
self._encode_shard_index,
)
)

Expand All @@ -520,7 +556,12 @@ def _shard_index_size(self) -> int:
async def _load_shard_index_maybe(
self, store_path: StorePath
) -> Optional[_ShardIndex]:
index_bytes = await store_path.get_async((-self._shard_index_size(), None))
shard_index_size = self._shard_index_size()
if self.configuration.index_location == ShardingCodecIndexLocation.start:
index_bytes = await store_path.get_async((0, shard_index_size))
else:
index_bytes = await store_path.get_async((-shard_index_size, None))

if index_bytes is not None:
return await self._decode_shard_index(index_bytes)
return None
Expand Down