Skip to content

Commit

Permalink
fixes for transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
normanrz committed Oct 9, 2023
1 parent 6bb951e commit 482cb29
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 24 deletions.
3 changes: 2 additions & 1 deletion tests/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def test_convert_to_v3_array(store: Store):
assert a3.metadata.chunk_key_encoding.configuration.separator == "/"
assert a3.metadata.attributes["hello"] == "world"
assert any(
isinstance(c, TransposeCodec) and c.configuration.order == "F"
isinstance(c, TransposeCodec)
and c.order == tuple(a.metadata.ndim - x - 1 for x in range(a.metadata.ndim))
for c in a3.codec_pipeline.codecs
)
assert any(
Expand Down
101 changes: 101 additions & 0 deletions tests/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,95 @@ def test_order_implicit(
assert read_data.flags["C_CONTIGUOUS"]


@pytest.mark.parametrize("input_order", ["F", "C"])
@pytest.mark.parametrize("runtime_write_order", ["F", "C"])
@pytest.mark.parametrize("runtime_read_order", ["F", "C"])
@pytest.mark.parametrize("with_sharding", [True, False])
@pytest.mark.asyncio
async def test_transpose(
store: Store,
input_order: Literal["F", "C"],
runtime_write_order: Literal["F", "C"],
runtime_read_order: Literal["F", "C"],
with_sharding: bool,
):
data = np.arange(0, 256, dtype="uint16").reshape((1, 32, 8), order=input_order)

codecs_: List[CodecMetadata] = (
[
codecs.sharding_codec(
(1, 16, 8),
codecs=[codecs.transpose_codec((2, 1, 0)), codecs.bytes_codec()],
)
]
if with_sharding
else [codecs.transpose_codec((2, 1, 0)), codecs.bytes_codec()]
)

a = await Array.create_async(
store / "transpose",
shape=data.shape,
chunk_shape=(1, 32, 8),
dtype=data.dtype,
fill_value=0,
chunk_key_encoding=("v2", "."),
codecs=codecs_,
runtime_configuration=runtime_configuration(runtime_write_order),
)

await a.async_[:, :].set(data)
read_data = await a.async_[:, :].get()
assert np.array_equal(data, read_data)

a = await Array.open_async(
store / "transpose",
runtime_configuration=runtime_configuration(runtime_read_order),
)
read_data = await a.async_[:, :].get()
assert np.array_equal(data, read_data)

if runtime_read_order == "F":
assert read_data.flags["F_CONTIGUOUS"]
assert not read_data.flags["C_CONTIGUOUS"]
else:
assert not read_data.flags["F_CONTIGUOUS"]
assert read_data.flags["C_CONTIGUOUS"]

if not with_sharding:
# Compare with zarr-python
z = zarr.create(
shape=data.shape,
chunks=(1, 32, 8),
dtype="<u2",
order="F",
compressor=None,
fill_value=1,
store="testdata/transpose_zarr",
)
z[:, :] = data
assert await store.get_async("transpose/0.0") == await store.get_async(
"transpose_zarr/0.0"
)


def test_transpose_invalid(
store: Store,
):
data = np.arange(0, 256, dtype="uint16").reshape((1, 32, 8))

for order in [(1, 0), (3, 2, 1), (3, 3, 1)]:
with pytest.raises(AssertionError):
Array.create(
store / "transpose_invalid",
shape=data.shape,
chunk_shape=(1, 32, 8),
dtype=data.dtype,
fill_value=0,
chunk_key_encoding=("v2", "."),
codecs=[codecs.transpose_codec(order), codecs.bytes_codec()],
)


def test_open(store: Store):
a = Array.create(
store / "open",
Expand Down Expand Up @@ -774,6 +863,18 @@ def test_invalid_metadata(store: Store):
],
)

with pytest.raises(AssertionError):
Array.create(
store / "invalid",
shape=(16, 16),
chunk_shape=(16, 16),
dtype=np.dtype("uint8"),
fill_value=0,
codecs=[
codecs.transpose_codec("F"),
],
)

with pytest.raises(AssertionError):
Array.create(
store / "invalid",
Expand Down
86 changes: 63 additions & 23 deletions zarrita/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class Codec(ABC):
def compute_encoded_size(self, input_byte_length: int) -> int:
pass

def resolve_metadata(self) -> CoreArrayMetadata:
return self.array_metadata


class ArrayArrayCodec(Codec):
@abstractmethod
Expand Down Expand Up @@ -112,24 +115,28 @@ def from_metadata(
if codec_metadata.name == "endian":
codec_metadata = evolve(codec_metadata, name="bytes") # type: ignore

codec: Codec
if codec_metadata.name == "blosc":
out.append(BloscCodec.from_metadata(codec_metadata, array_metadata))
codec = BloscCodec.from_metadata(codec_metadata, array_metadata)
elif codec_metadata.name == "gzip":
out.append(GzipCodec.from_metadata(codec_metadata, array_metadata))
codec = GzipCodec.from_metadata(codec_metadata, array_metadata)
elif codec_metadata.name == "zstd":
out.append(ZstdCodec.from_metadata(codec_metadata, array_metadata))
codec = ZstdCodec.from_metadata(codec_metadata, array_metadata)
elif codec_metadata.name == "transpose":
out.append(TransposeCodec.from_metadata(codec_metadata, array_metadata))
codec = TransposeCodec.from_metadata(codec_metadata, array_metadata)
elif codec_metadata.name == "bytes":
out.append(BytesCodec.from_metadata(codec_metadata, array_metadata))
codec = BytesCodec.from_metadata(codec_metadata, array_metadata)
elif codec_metadata.name == "crc32c":
out.append(Crc32cCodec.from_metadata(codec_metadata, array_metadata))
codec = Crc32cCodec.from_metadata(codec_metadata, array_metadata)
elif codec_metadata.name == "sharding_indexed":
from zarrita.sharding import ShardingCodec

out.append(ShardingCodec.from_metadata(codec_metadata, array_metadata))
codec = ShardingCodec.from_metadata(codec_metadata, array_metadata)
else:
raise RuntimeError(f"Unsupported codec: {codec_metadata}")

out.append(codec)
array_metadata = codec.resolve_metadata()
CodecPipeline._validate_codecs(out, array_metadata)
return cls(out)

Expand Down Expand Up @@ -361,42 +368,75 @@ def compute_encoded_size(self, input_byte_length: int) -> int:
@frozen
class TransposeCodec(ArrayArrayCodec):
array_metadata: CoreArrayMetadata
configuration: TransposeCodecConfigurationMetadata
order: Tuple[int, ...]
is_fixed_size = True

@classmethod
def from_metadata(
cls, codec_metadata: TransposeCodecMetadata, array_metadata: CoreArrayMetadata
) -> TransposeCodec:
configuration = codec_metadata.configuration
if configuration.order == "F":
order = tuple(
array_metadata.ndim - x - 1 for x in range(array_metadata.ndim)
)

elif configuration.order == "C":
order = tuple(range(array_metadata.ndim))

else:
assert len(configuration.order) == array_metadata.ndim, (
"The `order` tuple needs have as many entries as "
+ f"there are dimensions in the array. Got: {configuration.order}"
)
assert len(configuration.order) == len(set(configuration.order)), (
"There must not be duplicates in the `order` tuple. "
+ f"Got: {configuration.order}"
)
assert all(0 <= x < array_metadata.ndim for x in configuration.order), (
"All entries in the `order` tuple must be between 0 and "
+ f"the number of dimensions in the array. Got: {configuration.order}"
)
order = tuple(configuration.order)

return cls(
array_metadata=array_metadata,
configuration=codec_metadata.configuration,
order=order,
)

def resolve_metadata(self) -> CoreArrayMetadata:
from zarrita.metadata import CoreArrayMetadata

return CoreArrayMetadata(
shape=tuple(
self.array_metadata.shape[self.order[i]]
for i in range(self.array_metadata.ndim)
),
chunk_shape=tuple(
self.array_metadata.chunk_shape[self.order[i]]
for i in range(self.array_metadata.ndim)
),
data_type=self.array_metadata.data_type,
fill_value=self.array_metadata.fill_value,
runtime_configuration=self.array_metadata.runtime_configuration,
)

async def decode(
self,
chunk_array: np.ndarray,
) -> np.ndarray:
new_order = self.configuration.order
chunk_array = chunk_array.view(np.dtype(self.array_metadata.data_type.value))
if isinstance(new_order, tuple):
chunk_array = chunk_array.transpose(new_order)
elif new_order == "F":
chunk_array = chunk_array.ravel().reshape(
self.array_metadata.chunk_shape, order="F"
)
inverse_order = [0 for _ in range(self.array_metadata.ndim)]
for x, i in enumerate(self.order):
inverse_order[x] = i
chunk_array = chunk_array.transpose(inverse_order)
return chunk_array

async def encode(
self,
chunk_array: np.ndarray,
) -> Optional[np.ndarray]:
new_order = self.configuration.order
if isinstance(new_order, tuple):
chunk_array = chunk_array.transpose(new_order)
elif new_order == "F":
chunk_array = chunk_array.T
return chunk_array.reshape(-1, order="C")
chunk_array = chunk_array.transpose(self.order)
return chunk_array

def compute_encoded_size(self, input_byte_length: int) -> int:
return input_byte_length
Expand Down
12 changes: 12 additions & 0 deletions zarrita/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ class CoreArrayMetadata:
def dtype(self) -> np.dtype:
return np.dtype(self.data_type.value)

@property
def ndim(self) -> int:
return len(self.shape)


@frozen
class ArrayMetadata:
Expand All @@ -273,6 +277,10 @@ class ArrayMetadata:
def dtype(self) -> np.dtype:
return np.dtype(self.data_type.value)

@property
def ndim(self) -> int:
return len(self.shape)

def get_core_metadata(
self, runtime_configuration: RuntimeConfiguration
) -> CoreArrayMetadata:
Expand Down Expand Up @@ -316,6 +324,10 @@ class ArrayV2Metadata:
compressor: Optional[Dict[str, Any]] = None
zarr_format: Literal[2] = 2

@property
def ndim(self) -> int:
return len(self.shape)

def to_bytes(self) -> bytes:
def _json_convert(o):
if isinstance(o, np.dtype):
Expand Down

0 comments on commit 482cb29

Please sign in to comment.