From 4df5916b184300ad569b711ff880fa7f260040c0 Mon Sep 17 00:00:00 2001 From: Zolisa Bleki Date: Thu, 3 Oct 2024 07:55:20 +0200 Subject: [PATCH] Add support for Zstd `bytes->bytes` Codec. --- .github/workflows/build-and-test.yml | 1 + README.md | 2 +- zarr-sync/test/test_sync.ml | 2 +- zarr.opam | 3 ++ zarr.opam.template | 3 ++ zarr/src/codecs/array_to_bytes.ml | 2 +- zarr/src/codecs/bytes_to_bytes.ml | 40 +++++++++++++++ zarr/src/codecs/bytes_to_bytes.mli | 1 + zarr/src/codecs/codecs.ml | 1 + zarr/src/codecs/codecs.mli | 2 + zarr/src/codecs/codecs_intf.ml | 10 +++- zarr/src/dune | 1 + zarr/test/test_codecs.ml | 73 +++++++++++++++++++++++----- 13 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 zarr.opam.template diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 6256ff6..2608ab6 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -44,6 +44,7 @@ jobs: - name: setup run: | opam install --deps-only --with-test --with-doc --yes zarr + opam install bytesrw conf-zlib conf-zstd --yes opam install lwt --yes opam exec -- dune build zarr zarr-sync zarr-lwt diff --git a/README.md b/README.md index d784c1a..fdb1d93 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ assert (Ndarray.equal x' y);; ```ocaml let config = {chunk_shape = [|5; 3; 5|] - ;codecs = [`Transpose [|2; 0; 1|]; `Bytes LE; `Gzip L5] + ;codecs = [`Transpose [|2; 0; 1|]; `Bytes LE; `Zstd (0, true)] ;index_codecs = [`Bytes BE; `Crc32c] ;index_location = Start};; diff --git a/zarr-sync/test/test_sync.ml b/zarr-sync/test/test_sync.ml index 61656a2..a1b2635 100644 --- a/zarr-sync/test/test_sync.ml +++ b/zarr-sync/test/test_sync.ml @@ -45,7 +45,7 @@ let test_storage {chunk_shape = [|2; 5; 5|] ;index_location = End ;index_codecs = [`Bytes LE; `Crc32c] - ;codecs = [`Transpose [|2; 0; 1|]; `Bytes BE; `Gzip L5]} in + ;codecs = [`Transpose [|2; 0; 1|]; `Bytes BE; `Zstd (0, false)]} in let cfg2 = {chunk_shape = [|2; 5; 5|] ;index_location = Start diff --git a/zarr.opam b/zarr.opam index 7f49394..e38c4ff 100644 --- a/zarr.opam +++ b/zarr.opam @@ -40,3 +40,6 @@ build: [ ] ] dev-repo: "git+https://github.com/zoj613/zarr-ml.git" +pin-depends: [ + ["bytesrw.dev" "git+https://erratique.ch/repos/bytesrw.git"] +] diff --git a/zarr.opam.template b/zarr.opam.template new file mode 100644 index 0000000..4e82fcf --- /dev/null +++ b/zarr.opam.template @@ -0,0 +1,3 @@ +pin-depends: [ + ["bytesrw.dev" "git+https://erratique.ch/repos/bytesrw.git"] +] diff --git a/zarr/src/codecs/array_to_bytes.ml b/zarr/src/codecs/array_to_bytes.ml index f4f8131..f71163e 100644 --- a/zarr/src/codecs/array_to_bytes.ml +++ b/zarr/src/codecs/array_to_bytes.ml @@ -375,7 +375,7 @@ end = struct let* l = acc in match c with | `Crc32c -> Ok (`Crc32c :: l) - | `Gzip _ -> Error msg) ic.b2b (Ok []) + | `Gzip _ | `Zstd _ -> Error msg) ic.b2b (Ok []) in let+ a2b = match ic.a2b with | `Bytes e -> Ok (`Bytes e) diff --git a/zarr/src/codecs/bytes_to_bytes.ml b/zarr/src/codecs/bytes_to_bytes.ml index bfe8709..b90545d 100644 --- a/zarr/src/codecs/bytes_to_bytes.ml +++ b/zarr/src/codecs/bytes_to_bytes.ml @@ -1,4 +1,5 @@ open Codecs_intf +open Bytesrw (* https://zarr-specs.readthedocs.io/en/latest/v3/codecs/gzip/v1.0.html *) module GzipCodec = struct @@ -52,27 +53,66 @@ module Crc32cCodec = struct Ok `Crc32c end +(* https://github.com/zarr-developers/zarr-specs/pull/256 *) +module ZstdCodec = struct + let min_clevel = -131072 and max_clevel = 22 + + let parse_clevel l = + if l < min_clevel || max_clevel < l then (raise Invalid_zstd_level) + + let encode clevel checksum x = + let params = Bytesrw_zstd.Cctx_params.make ~checksum ~clevel () in + Bytes.Reader.to_string @@ + Bytesrw_zstd.compress_reads ~params () @@ Bytes.Reader.of_string x + + let decode x = + let params = Bytesrw_zstd.Dctx_params.default in + Bytes.Reader.to_string @@ + Bytesrw_zstd.decompress_reads ~params () @@ Bytes.Reader.of_string x + + let to_yojson l c = + `Assoc + [("name", `String "zstd") + ;("configuration", `Assoc [("level", `Int l); ("checksum", `Bool c)])] + + let of_yojson x = + match Yojson.Safe.Util.(member "configuration" x |> to_assoc) with + | [("level", `Int l); ("checksum", `Bool c)] -> + begin match parse_clevel l with + | exception Invalid_zstd_level -> Error "Invalid_zstd_level" + | () -> Result.ok @@ `Zstd (l, c) end + | _ -> Error "Invalid Zstd configuration." +end + module BytesToBytes = struct let encoded_size : int -> fixed_bytestobytes -> int = fun input_size -> function | `Crc32c -> Crc32cCodec.encoded_size input_size + let parse = function + | `Gzip _ | `Crc32c -> () + | `Zstd (l, _) -> ZstdCodec.parse_clevel l + let encode x = function | `Gzip l -> GzipCodec.encode l x | `Crc32c -> Crc32cCodec.encode x + | `Zstd (l, c) -> ZstdCodec.encode l c x let decode t x = match t with | `Gzip _ -> GzipCodec.decode x | `Crc32c -> Crc32cCodec.decode x + | `Zstd _ -> ZstdCodec.decode x let to_yojson = function | `Gzip l -> GzipCodec.to_yojson l | `Crc32c -> Crc32cCodec.to_yojson + | `Zstd (l, c) -> ZstdCodec.to_yojson l c let of_yojson x = match Util.get_name x with | "gzip" -> GzipCodec.of_yojson x | "crc32c" -> Crc32cCodec.of_yojson x + | "zstd" -> ZstdCodec.of_yojson x | s -> Error (Printf.sprintf "codec %s is not supported." s) end diff --git a/zarr/src/codecs/bytes_to_bytes.mli b/zarr/src/codecs/bytes_to_bytes.mli index 41d7f39..ccefd60 100644 --- a/zarr/src/codecs/bytes_to_bytes.mli +++ b/zarr/src/codecs/bytes_to_bytes.mli @@ -1,6 +1,7 @@ open Codecs_intf module BytesToBytes : sig + val parse : bytestobytes -> unit val encoded_size : int -> fixed_bytestobytes -> int val encode : string -> bytestobytes -> string val decode : bytestobytes -> string -> string diff --git a/zarr/src/codecs/codecs.ml b/zarr/src/codecs/codecs.ml index e7c052d..28dd9c6 100644 --- a/zarr/src/codecs/codecs.ml +++ b/zarr/src/codecs/codecs.ml @@ -76,6 +76,7 @@ module Chain = struct | x :: _ as xs -> ArrayToArray.parse x shape; List.fold_left ArrayToArray.encoded_repr shape xs); + List.fold_left (fun _ v -> BytesToBytes.parse v) () b2b; {a2a; a2b; b2b} let encode t x = diff --git a/zarr/src/codecs/codecs.mli b/zarr/src/codecs/codecs.mli index 4967286..6e91d9c 100644 --- a/zarr/src/codecs/codecs.mli +++ b/zarr/src/codecs/codecs.mli @@ -19,6 +19,8 @@ module Chain : sig if [c] contains more than one bytes->bytes codec. @raise Invalid_transpose_order if [c] contains a transpose codec with invalid order array. + @raise Invalid_zstd_level + if [c] contains a Zstd codec whose compression level is invalid. @raise Invalid_sharding_chunk_shape if [c] contains a shardingindexed codec with an incorrect inner chunk shape. *) diff --git a/zarr/src/codecs/codecs_intf.ml b/zarr/src/codecs/codecs_intf.ml index 70b63e7..c8d8ff0 100644 --- a/zarr/src/codecs/codecs_intf.ml +++ b/zarr/src/codecs/codecs_intf.ml @@ -2,6 +2,7 @@ exception Array_to_bytes_invariant exception Invalid_transpose_order exception Invalid_sharding_chunk_shape exception Invalid_codec_ordering +exception Invalid_zstd_level type arraytoarray = [ `Transpose of int array ] @@ -13,7 +14,8 @@ type fixed_bytestobytes = [ `Crc32c ] type variable_bytestobytes = - [ `Gzip of compression_level ] + [ `Gzip of compression_level + | `Zstd of int * bool ] type bytestobytes = [ fixed_bytestobytes | variable_bytestobytes ] @@ -62,6 +64,9 @@ module type Interface = sig (** raised when a codec chain has incorrect ordering of codecs. i.e if the ordering is not [arraytoarray list -> 1 arraytobytes -> bytestobytes list]. *) + exception Invalid_zstd_level + (** raised when a codec chain contains a Zstd codec with an incorrect compression value.*) + (** The type of [array -> array] codecs. *) type arraytoarray = [ `Transpose of int array ] @@ -78,7 +83,8 @@ module type Interface = sig (** A type representing [bytes -> bytes] codecs that produce variable sized encoded strings. *) type variable_bytestobytes = - [ `Gzip of compression_level ] + [ `Gzip of compression_level + | `Zstd of int * bool ] (** The type of [bytes -> bytes] codecs. *) type bytestobytes = diff --git a/zarr/src/dune b/zarr/src/dune index 0d038b1..88301b0 100644 --- a/zarr/src/dune +++ b/zarr/src/dune @@ -4,6 +4,7 @@ (libraries yojson ezgzip + bytesrw.zstd stdint checkseum) (ocamlopt_flags diff --git a/zarr/test/test_codecs.ml b/zarr/test/test_codecs.ml index 47434a8..969b13f 100644 --- a/zarr/test/test_codecs.ml +++ b/zarr/test/test_codecs.ml @@ -238,19 +238,22 @@ let tests = [ [{"name": "bytes", "configuration": {"endian": "big"}}]}}]|} ~msg:"Must be exactly one array->bytes codec."; (* test violation of index_codec invariant when it contains variable-sized codecs. *) - decode_chain - ~shape:[|5; 5; 5|] - ~str:{|[ - {"name": "sharding_indexed", - "configuration": - {"index_location": "start", - "chunk_shape": [5, 5, 5], - "index_codecs": - [{"name": "bytes", "configuration": {"endian": "big"}}, - {"name": "gzip", "configuration": {"level": 1}}], - "codecs": - [{"name": "bytes", "configuration": {"endian": "big"}}]}}]|} - ~msg:"Must be exactly one array->bytes codec."; + List.iter + (fun c -> + decode_chain + ~shape:[|5; 5; 5|] + ~str:(Format.sprintf {|[ + {"name": "sharding_indexed", + "configuration": + {"index_location": "start", + "chunk_shape": [5, 5, 5], + "index_codecs": + [{"name": "bytes", "configuration": {"endian": "big"}}, %s], + "codecs": + [{"name": "bytes", "configuration": {"endian": "big"}}]}}]|} c) + ~msg:"Must be exactly one array->bytes codec.") + [{|{"name": "zstd", "configuration": {"level": 0, "checksum": false}}|} + ;{|{"name": "gzip", "configuration": {"level": 1}}|}]; let shape = [|10; 15; 10|] in let kind = Ndarray.Float64 in @@ -365,6 +368,50 @@ let tests = [ assert_equal arr @@ Chain.decode c {shape; kind} encoded) [L0; L1; L2; L3; L4; L5; L6; L7; L8; L9]) ; + +"test zstd codec" >:: (fun _ -> + (* test wrong compression level *) + List.iter + (fun l -> + decode_chain + ~shape:[||] + ~str:(Format.sprintf {|[{"name": "bytes", "configuration": {"endian": "little"}}, + {"name": "zstd", "configuration": {"level": %d, "checksum": false}}]|} l) + ~msg:"zstd codec is unsupported or has invalid configuration.") + [50; -500_000]; + (* test incorrect configuration *) + decode_chain + ~shape:[||] + ~str:{|[{"name": "bytes", "configuration": {"endian": "little"}}, + {"name": "zstd", "configuration": {"something": -1}}]|} + ~msg:"zstd codec is unsupported or has invalid configuration."; + + (* test correct deserialization of zstd compression level *) + let shape = [|10; 15; 10|] in + List.iter + (fun level -> + let str = + Format.sprintf + {|[{"name": "bytes", "configuration": {"endian": "little"}}, + {"name": "zstd", "configuration": {"level": %d, "checksum": false}}]|} level + in + let r = Chain.of_yojson shape @@ Yojson.Safe.from_string str in + assert_bool "Encoding this chain should not fail" @@ Result.is_ok r) + [-131072; 0]; + + (* test encoding/decoding for various compression levels *) + let kind = Ndarray.Int in + let fill_value = Int.max_int in + let arr = Ndarray.create kind shape fill_value in + let chain = [`Bytes LE] in + List.iter + (fun (level, checksum) -> + let c = Chain.create shape @@ chain @ [`Zstd (level, checksum)] in + let encoded = Chain.encode c arr in + assert_equal arr @@ Chain.decode c {shape; kind} encoded) + [(-131072, false); (-131072, true); (0, false); (0, true)]) +; + "test bytes codec" >:: (fun _ -> let shape = [|2; 2; 2|] in (* test decoding of chain with invalid endianness name *)