diff --git a/tests/test_v3.py b/tests/test_v3.py index cc984ad..697b7c5 100644 --- a/tests/test_v3.py +++ b/tests/test_v3.py @@ -11,7 +11,7 @@ from zarrita import Array, Group, LocalStore, Store, codecs, runtime_configuration from zarrita.indexing import morton_order_iter -from zarrita.metadata import CodecMetadata +from zarrita.metadata import CodecMetadata, ShardingCodecIndexLocation @fixture @@ -32,7 +32,12 @@ def store() -> Iterator[Store]: pass -def test_sharding(store: Store, l4_sample_data: np.ndarray): +@pytest.mark.parametrize( + "index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end] +) +def test_sharding( + store: Store, l4_sample_data: np.ndarray, index_location: ShardingCodecIndexLocation +): data = l4_sample_data a = Array.create( @@ -49,6 +54,7 @@ def test_sharding(store: Store, l4_sample_data: np.ndarray): codecs.bytes_codec(), codecs.blosc_codec(typesize=data.dtype.itemsize, cname="lz4"), ], + index_location=index_location, ) ], ) @@ -60,7 +66,12 @@ def test_sharding(store: Store, l4_sample_data: np.ndarray): assert np.array_equal(data, read_data) -def test_sharding_partial(store: Store, l4_sample_data: np.ndarray): +@pytest.mark.parametrize( + "index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end] +) +def test_sharding_partial( + store: Store, l4_sample_data: np.ndarray, index_location: ShardingCodecIndexLocation +): data = l4_sample_data a = Array.create( @@ -77,6 +88,7 @@ def test_sharding_partial(store: Store, l4_sample_data: np.ndarray): codecs.bytes_codec(), codecs.blosc_codec(typesize=data.dtype.itemsize, cname="lz4"), ], + index_location=index_location, ) ], ) @@ -91,7 +103,12 @@ def test_sharding_partial(store: Store, l4_sample_data: np.ndarray): assert np.array_equal(data, read_data) -def test_sharding_partial_read(store: Store, l4_sample_data: np.ndarray): +@pytest.mark.parametrize( + "index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end] +) +def test_sharding_partial_read( + store: Store, l4_sample_data: np.ndarray, index_location: ShardingCodecIndexLocation +): data = l4_sample_data a = Array.create( @@ -108,6 +125,7 @@ def test_sharding_partial_read(store: Store, l4_sample_data: np.ndarray): codecs.bytes_codec(), codecs.blosc_codec(typesize=data.dtype.itemsize, cname="lz4"), ], + index_location=index_location, ) ], ) @@ -116,7 +134,12 @@ def test_sharding_partial_read(store: Store, l4_sample_data: np.ndarray): assert np.all(read_data == 1) -def test_sharding_partial_overwrite(store: Store, l4_sample_data: np.ndarray): +@pytest.mark.parametrize( + "index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end] +) +def test_sharding_partial_overwrite( + store: Store, l4_sample_data: np.ndarray, index_location: ShardingCodecIndexLocation +): data = l4_sample_data[:10, :10, :10] a = Array.create( @@ -133,6 +156,7 @@ def test_sharding_partial_overwrite(store: Store, l4_sample_data: np.ndarray): codecs.bytes_codec(), codecs.blosc_codec(typesize=data.dtype.itemsize, cname="lz4"), ], + index_location=index_location, ) ], ) @@ -148,7 +172,20 @@ def test_sharding_partial_overwrite(store: Store, l4_sample_data: np.ndarray): assert np.array_equal(data, read_data) -def test_nested_sharding(store: Store, l4_sample_data: np.ndarray): +@pytest.mark.parametrize( + "outer_index_location", + [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end], +) +@pytest.mark.parametrize( + "inner_index_location", + [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end], +) +def test_nested_sharding( + store: Store, + l4_sample_data: np.ndarray, + outer_index_location: ShardingCodecIndexLocation, + inner_index_location: ShardingCodecIndexLocation, +): data = l4_sample_data a = Array.create( @@ -160,7 +197,12 @@ def test_nested_sharding(store: Store, l4_sample_data: np.ndarray): codecs=[ codecs.sharding_codec( (32, 32, 32), - [codecs.sharding_codec((16, 16, 16))], + [ + codecs.sharding_codec( + (16, 16, 16), index_location=inner_index_location + ) + ], + index_location=outer_index_location, ) ], ) diff --git a/zarrita/codecs.py b/zarrita/codecs.py index 56f99e3..8eefb55 100644 --- a/zarrita/codecs.py +++ b/zarrita/codecs.py @@ -24,6 +24,7 @@ GzipCodecConfigurationMetadata, GzipCodecMetadata, ShardingCodecConfigurationMetadata, + ShardingCodecIndexLocation, ShardingCodecMetadata, TransposeCodecConfigurationMetadata, TransposeCodecMetadata, @@ -595,11 +596,12 @@ def sharding_codec( chunk_shape: Tuple[int, ...], codecs: Optional[List[CodecMetadata]] = None, index_codecs: Optional[List[CodecMetadata]] = None, + index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end, ) -> ShardingCodecMetadata: codecs = codecs or [bytes_codec()] index_codecs = index_codecs or [bytes_codec(), crc32c_codec()] return ShardingCodecMetadata( configuration=ShardingCodecConfigurationMetadata( - chunk_shape, codecs, index_codecs + chunk_shape, codecs, index_codecs, index_location ) ) diff --git a/zarrita/metadata.py b/zarrita/metadata.py index 45922e1..34eacdd 100644 --- a/zarrita/metadata.py +++ b/zarrita/metadata.py @@ -219,11 +219,17 @@ class Crc32cCodecMetadata: name: Literal["crc32c"] = "crc32c" +class ShardingCodecIndexLocation(Enum): + start = "start" + end = "end" + + @frozen class ShardingCodecConfigurationMetadata: chunk_shape: ChunkCoords codecs: List["CodecMetadata"] index_codecs: List["CodecMetadata"] + index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end @frozen @@ -296,6 +302,8 @@ def to_bytes(self) -> bytes: def _json_convert(o): if isinstance(o, DataType): return o.name + if isinstance(o, ShardingCodecIndexLocation): + return o.name raise TypeError return json.dumps( diff --git a/zarrita/sharding.py b/zarrita/sharding.py index 283f06e..4a9f7c1 100644 --- a/zarrita/sharding.py +++ b/zarrita/sharding.py @@ -1,6 +1,16 @@ from __future__ import annotations -from typing import Iterator, List, Mapping, NamedTuple, Optional, Set, Tuple +from typing import ( + Awaitable, + Callable, + Iterator, + List, + Mapping, + NamedTuple, + Optional, + Set, + Tuple, +) import numpy as np from attrs import frozen @@ -23,6 +33,7 @@ CoreArrayMetadata, DataType, ShardingCodecConfigurationMetadata, + ShardingCodecIndexLocation, ShardingCodecMetadata, ) from zarrita.store import StorePath @@ -49,7 +60,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> Optional[Tuple[int, int] if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64): return None else: - return (int(chunk_start), int(chunk_start + chunk_len)) + return (int(chunk_start), int(chunk_start) + int(chunk_len)) def set_chunk_slice( self, chunk_coords: ChunkCoords, chunk_slice: Optional[slice] @@ -101,11 +112,15 @@ class _ShardProxy(Mapping): @classmethod async def from_bytes(cls, buf: BytesLike, codec: ShardingCodec) -> _ShardProxy: + shard_index_size = codec._shard_index_size() obj = cls() obj.buf = memoryview(buf) - obj.index = await codec._decode_shard_index( - obj.buf[-codec._shard_index_size() :] - ) + if codec.configuration.index_location == ShardingCodecIndexLocation.start: + shard_index_bytes = obj.buf[:shard_index_size] + else: + shard_index_bytes = obj.buf[-shard_index_size:] + + obj.index = await codec._decode_shard_index(shard_index_bytes) return obj @classmethod @@ -152,7 +167,10 @@ def merge_with_morton_order( return obj @classmethod - def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardBuilder: + def create_empty( + cls, + chunks_per_shard: ChunkCoords, + ) -> _ShardBuilder: obj = cls() obj.buf = bytearray() obj.index = _ShardIndex.create_empty(chunks_per_shard) @@ -166,9 +184,23 @@ def append(self, chunk_coords: ChunkCoords, value: BytesLike): chunk_coords, slice(chunk_start, chunk_start + chunk_length) ) - def finalize(self, index_bytes: BytesLike) -> BytesLike: - self.buf.extend(index_bytes) - return self.buf + async def finalize( + self, + index_location: ShardingCodecIndexLocation, + index_encoder: Callable[[_ShardIndex], Awaitable[BytesLike]], + ) -> BytesLike: + index_bytes = await index_encoder(self.index) + if index_location == ShardingCodecIndexLocation.start: + self.index.offsets_and_lengths[..., 0] += len(index_bytes) + index_bytes = await index_encoder( + self.index + ) # encode again with corrected offsets + out_buf = bytearray(index_bytes) + out_buf.extend(self.buf) + else: + out_buf = self.buf + out_buf.extend(index_bytes) + return out_buf @frozen @@ -402,8 +434,8 @@ async def _write_chunk( if chunk_bytes is not None: shard_builder.append(chunk_coords, chunk_bytes) - return shard_builder.finalize( - await self._encode_shard_index(shard_builder.index) + return await shard_builder.finalize( + self.configuration.index_location, self._encode_shard_index ) async def encode_partial( @@ -486,15 +518,19 @@ async def _write_chunk( tombstones.add(chunk_coords) shard_builder = _ShardBuilder.merge_with_morton_order( - self.chunks_per_shard, tombstones, new_shard_builder, old_shard_dict + self.chunks_per_shard, + tombstones, + new_shard_builder, + old_shard_dict, ) if shard_builder.index.is_all_empty(): await store_path.delete_async() else: await store_path.set_async( - shard_builder.finalize( - await self._encode_shard_index(shard_builder.index) + await shard_builder.finalize( + self.configuration.index_location, + self._encode_shard_index, ) ) @@ -520,7 +556,12 @@ def _shard_index_size(self) -> int: async def _load_shard_index_maybe( self, store_path: StorePath ) -> Optional[_ShardIndex]: - index_bytes = await store_path.get_async((-self._shard_index_size(), None)) + shard_index_size = self._shard_index_size() + if self.configuration.index_location == ShardingCodecIndexLocation.start: + index_bytes = await store_path.get_async((0, shard_index_size)) + else: + index_bytes = await store_path.get_async((-shard_index_size, None)) + if index_bytes is not None: return await self._decode_shard_index(index_bytes) return None