From 96fbe653fb7b66e3bdd9bd78988eeb26ccf5c29d Mon Sep 17 00:00:00 2001 From: Zolisa Bleki Date: Mon, 1 Jul 2024 16:52:21 +0200 Subject: [PATCH] Add unit tests for `ArrayMetadata` module. This also improves the correctness of serializing/deserializing the ArrayMetadata type. it also removes unnecessary functions exposed in the module signature. --- lib/codecs/array_to_array.ml | 3 +- lib/codecs/array_to_array.mli | 3 + lib/codecs/array_to_bytes.ml | 21 ++++-- lib/codecs/array_to_bytes.mli | 3 + lib/codecs/bytes_to_bytes.ml | 2 + lib/codecs/bytes_to_bytes.mli | 3 + lib/codecs/codecs.ml | 4 ++ lib/codecs/codecs.mli | 4 ++ lib/metadata.ml | 30 +++++---- lib/metadata.mli | 49 +++++++------- lib/storage/interface.ml | 26 +++++--- test/dune | 6 +- test/test_metadata.ml | 121 +++++++++++++++++++++++++++++++++- test/test_zarr.ml | 2 +- 14 files changed, 222 insertions(+), 55 deletions(-) diff --git a/lib/codecs/array_to_array.ml b/lib/codecs/array_to_array.ml index 5962880..9a27635 100644 --- a/lib/codecs/array_to_array.ml +++ b/lib/codecs/array_to_array.ml @@ -1,9 +1,10 @@ module Ndarray = Owl.Dense.Ndarray.Generic -type dimension_order = int array +type dimension_order = int array [@@deriving show] type array_to_array = | Transpose of dimension_order + [@@deriving show] type error = [ `Invalid_transpose_order of dimension_order * string ] diff --git a/lib/codecs/array_to_array.mli b/lib/codecs/array_to_array.mli index b7c395d..2d93127 100644 --- a/lib/codecs/array_to_array.mli +++ b/lib/codecs/array_to_array.mli @@ -8,6 +8,9 @@ type array_to_array = type error = [ `Invalid_transpose_order of dimension_order * string ] +val pp_array_to_array : Format.formatter -> array_to_array -> unit +val show_array_to_array : array_to_array -> string + module ArrayToArray : sig val parse : ('a, 'b) Util.array_repr -> diff --git a/lib/codecs/array_to_bytes.ml b/lib/codecs/array_to_bytes.ml index 9e9d398..9f6ab90 100644 --- a/lib/codecs/array_to_bytes.ml +++ b/lib/codecs/array_to_bytes.ml @@ -4,13 +4,20 @@ open Util.Result_syntax module Ndarray = Owl.Dense.Ndarray.Generic -type endianness = Little | Big +type endianness = + | Little + | Big + [@@deriving show] -type loc = Start | End +type loc = + | Start + | End + [@@deriving show] type array_to_bytes = | Bytes of endianness | ShardingIndexed of shard_config + [@@deriving show] and shard_config = {chunk_shape : int array @@ -18,11 +25,11 @@ and shard_config = ;index_codecs : chain ;index_location : loc} -and chain = { - a2a: array_to_array list; - a2b: array_to_bytes; - b2b: bytes_to_bytes list; -} +and chain = + {a2a: array_to_array list + ;a2b: array_to_bytes + ;b2b: bytes_to_bytes list} + [@@deriving show] type error = [ `Bytes_encode_error of string diff --git a/lib/codecs/array_to_bytes.mli b/lib/codecs/array_to_bytes.mli index 67b7962..c86749c 100644 --- a/lib/codecs/array_to_bytes.mli +++ b/lib/codecs/array_to_bytes.mli @@ -20,6 +20,9 @@ and chain = { b2b: Bytes_to_bytes.bytes_to_bytes list; } +val pp_chain : Format.formatter -> chain -> unit +val show_chain : chain -> string + type error = [ `Bytes_encode_error of string | `Bytes_decode_error of string diff --git a/lib/codecs/bytes_to_bytes.ml b/lib/codecs/bytes_to_bytes.ml index 67bfdd5..8ce1b41 100644 --- a/lib/codecs/bytes_to_bytes.ml +++ b/lib/codecs/bytes_to_bytes.ml @@ -2,10 +2,12 @@ module Ndarray = Owl.Dense.Ndarray.Generic type compression_level = | L0 | L1 | L2 | L3 | L4 | L5 | L6 | L7 | L8 | L9 + [@@deriving show] type bytes_to_bytes = | Crc32c | Gzip of compression_level + [@@deriving show] type error = [ `Gzip of Ezgzip.error ] diff --git a/lib/codecs/bytes_to_bytes.mli b/lib/codecs/bytes_to_bytes.mli index fa5a0f5..bb9e568 100644 --- a/lib/codecs/bytes_to_bytes.mli +++ b/lib/codecs/bytes_to_bytes.mli @@ -10,6 +10,9 @@ type bytes_to_bytes = type error = [ `Gzip of Ezgzip.error ] +val pp_bytes_to_bytes : Format.formatter -> bytes_to_bytes -> unit +val show_bytes_to_bytes : bytes_to_bytes -> string + module BytesToBytes : sig val compute_encoded_size : int -> bytes_to_bytes -> int val encode : bytes_to_bytes -> string -> (string, [> error]) result diff --git a/lib/codecs/codecs.ml b/lib/codecs/codecs.ml index 5fcc5d4..6aac8f7 100644 --- a/lib/codecs/codecs.ml +++ b/lib/codecs/codecs.ml @@ -8,6 +8,10 @@ module Ndarray = Owl.Dense.Ndarray.Generic module Chain = struct type t = chain + let pp = pp_chain + + let show = show_chain + let create repr {a2a; a2b; b2b} = List.fold_left (fun acc c -> diff --git a/lib/codecs/codecs.mli b/lib/codecs/codecs.mli index bdc30ca..d59bd64 100644 --- a/lib/codecs/codecs.mli +++ b/lib/codecs/codecs.mli @@ -58,4 +58,8 @@ module Chain : sig val of_yojson : Yojson.Safe.t -> (t, string) result val to_yojson : t -> Yojson.Safe.t + + val pp : Format.formatter -> t -> unit + + val show : t -> string end diff --git a/lib/metadata.ml b/lib/metadata.ml index 08512e1..a24f2c8 100644 --- a/lib/metadata.ml +++ b/lib/metadata.ml @@ -15,6 +15,8 @@ module FillValue = struct | BFComplex of Complex.t | BBComplex of Complex.t + let equal x y = x = y + let of_kind : type a b. (a, b) Bigarray.kind -> a -> t = fun kind a -> @@ -109,14 +111,16 @@ module ArrayMetadata = struct ;fill_value : FillValue.t ;chunk_grid : Extensions.RegularGrid.t ;chunk_key_encoding : Extensions.ChunkKeyEncoding.t - ;attributes : Yojson.Safe.t option [@yojson.option] - ;dimension_names : string option list option [@yojson.option] - ;storage_transformers : Yojson.Safe.t Util.ext_point list option [@yojson.option]} - [@@deriving yojson] + ;attributes : Yojson.Safe.t [@default `Null] + ;dimension_names : string option list [@default []] + ;storage_transformers : Yojson.Safe.t Util.ExtPoint.t list [@default []]} + [@@deriving yojson, eq] let create ?(sep=Extensions.Slash) ?(codecs=Codecs.Chain.default) + ?(dimension_names=[]) + ?(attributes=`Null) ~shape kind fv @@ -130,23 +134,23 @@ module ArrayMetadata = struct ;chunk_key_encoding = Extensions.ChunkKeyEncoding.create sep ;zarr_format = 3 ;node_type = "array" - ;storage_transformers = None - ;dimension_names = None - ;attributes = None} + ;attributes + ;dimension_names + ;storage_transformers = []} let shape t = t.shape let codecs t = t.codecs - let dtype t = t.data_type - - let fill_value t = t.fill_value + let data_type t = + Yojson.Safe.to_string @@ + Extensions.Datatype.to_yojson t.data_type let ndim t = Array.length @@ shape t let dimension_names t = t.dimension_names - let attributes t : Yojson.Safe.t option = t.attributes + let attributes t = t.attributes let chunk_shape t = Extensions.RegularGrid.chunk_shape t.chunk_grid @@ -171,8 +175,8 @@ module ArrayMetadata = struct of_yojson @@ Yojson.Safe.from_string b >>? fun s -> `Json_decode_error s - let update_attributes attrs t = - {t with attributes = Some attrs} + let update_attributes t attrs = + {t with attributes = attrs} let update_shape t shape = {t with shape} diff --git a/lib/metadata.mli b/lib/metadata.mli index b0feb80..d3142c6 100644 --- a/lib/metadata.mli +++ b/lib/metadata.mli @@ -35,7 +35,9 @@ module ArrayMetadata : sig val create : ?sep:Extensions.separator -> - ?codecs:Codecs.Chain.t -> + ?codecs:Codecs.Chain.t -> + ?dimension_names:string option list -> + ?attributes:Yojson.Safe.t -> shape:int array -> ('a, 'b) Bigarray.kind -> 'a -> @@ -50,40 +52,36 @@ module ArrayMetadata : sig val decode : string -> (t, [> error]) result (** [decode s] decodes a bytes string [s] into a {!ArrayMetadata.t} - type, and returns an {!error} error if the decoding process fails. *) + type, and returns an error if the decoding process fails. *) val shape : t -> int array (** [shape t] returns the shape of the zarr array represented by metadata type [t]. *) - val fill_value : t -> FillValue.t - (** [fill_value t] returns the fill value of the zarra array represented by [t]. *) - val ndim : t -> int (** [ndim t] returns the number of dimension in a Zarr array. *) val chunk_shape : t -> int array (** [chunk_shape t] returns the shape a chunk in this zarr array. *) - val dtype : t -> Extensions.Datatype.t - (** [dtype t] returns the data type as specified in the array metadata. *) + val data_type : t -> string + (** [data_type t] returns the data type as specified in the array metadata.*) val is_valid_kind : t -> ('a, 'b) Bigarray.kind -> bool - (** [is_valid_kind t kind] checks if [kind] is a valid {!Bigarray.kind} that + (** [is_valid_kind t kind] checks if [kind] is a valid Bigarray kind that matches the data type of the zarr array represented by this metadata type. *) val fillvalue_of_kind : t -> ('a, 'b) Bigarray.kind -> 'a (** [fillvalue_of_kind t kind] returns the fill value of uninitialized - chunks in this zarr array given [kind]. - - @raises [Failure] if the kind is not compatible with this array's fill value. *) + chunks in this zarr array given [kind]. Raises Failure if the kind + is not compatible with this array's fill value. *) - val attributes : t -> Yojson.Safe.t option + val attributes : t -> Yojson.Safe.t (** [attributes t] Returns a Yojson type containing user attributes assigned to the zarr array represented by [t]. *) - val dimension_names : t -> string option list option - (** [dimension_name t] returns a list of dimension names, if any are - defined in the array's JSON metadata document. *) + val dimension_names : t -> string option list + (** [dimension_name t] returns a list of dimension names. If none are + defined then an empty list is returned. *) val codecs : t -> Codecs.Chain.t (** [codecs t] Returns a type representing the chain of codecs applied @@ -105,20 +103,25 @@ module ArrayMetadata : sig val chunk_key : t -> int array -> string (** [chunk_key t idx] returns a key encoding of a the chunk index [idx]. *) - val update_attributes : Yojson.Safe.t -> t -> t - (** [update_attributes json t] returns a new metadata type with an updated + val update_attributes : t -> Yojson.Safe.t -> t + (** [update_attributes t json] returns a new metadata type with an updated attribute field containing contents in [json] *) val update_shape : t -> int array -> t (** [update_shape t new_shp] returns a new metadata type containing shape [new_shp]. *) + val equal : t -> t -> bool + (** [equal a b] returns true if [a] [b] are equal array metadata documents + and false otherwise. *) + val of_yojson : Yojson.Safe.t -> (t, string) result - (** [of_yojson json] converts a {!Yojson.Safe.t} object into a {!ArrayMetadata.t} + (** [of_yojson json] converts a [Yojson.Safe.t] object into a {!ArrayMetadata.t} and returns an error message upon failure. *) val to_yojson : t -> Yojson.Safe.t - (** [to_yojson t] serializes an array metadata type into a {!Yojson.Safe.t} object. *) + (** [to_yojson t] serializes an array metadata type into a [Yojson.Safe.t] + object. *) end module GroupMetadata : sig @@ -135,19 +138,19 @@ module GroupMetadata : sig (** [encode t] returns a byte string representing a JSON Zarr group metadata. *) val decode : string -> (t, [> error]) result - (** [decode s] decodes a bytes string [s] into a {!GroupMetadata.t} - type, and returns an {!Metadata.error} error if the decoding process fails. *) + (** [decode s] decodes a bytes string [s] into a {!t} type, and returns + an error if the decoding process fails. *) val update_attributes : t -> Yojson.Safe.t -> t (** [update_attributes t json] returns a new metadata type with an updated attribute field containing contents in [json]. *) val of_yojson : Yojson.Safe.t -> (t, string) result - (** [of_yojson json] converts a {!Yojson.Safe.t} object into a {!GroupMetadata.t} + (** [of_yojson json] converts a [Yojson.Safe.t] object into a {!GroupMetadata.t} and returns an error message upon failure. *) val to_yojson : t -> Yojson.Safe.t - (** [to_yojson t] serializes a group metadata type into a {!Yojson.Safe.t} object. *) + (** [to_yojson t] serializes a group metadata type into a [Yojson.Safe.t] object. *) val show : t -> string (** [show t] pretty-prints the contents of the group metadata type t. *) diff --git a/lib/storage/interface.ml b/lib/storage/interface.ml index dde6ed1..267d191 100644 --- a/lib/storage/interface.ml +++ b/lib/storage/interface.ml @@ -38,6 +38,8 @@ module type S = sig val create_array : ?sep:Extensions.separator -> + ?dimension_names:string option list -> + ?attributes:Yojson.Safe.t -> ?codecs:Codecs.chain -> shape:int array -> chunks:int array -> @@ -100,19 +102,27 @@ module Make (M : STORE) : S with type t = M.t = struct | Error _ -> create_group t n) @@ Node.ancestors node let create_array - ?(sep=Extensions.Slash) ?codecs ~shape ~chunks kind fill_value node t = + ?(sep=Extensions.Slash) + ?(dimension_names=[]) + ?(attributes=`Null) + ?codecs + ~shape + ~chunks + kind + fill_value + node + t + = let open Util in - let repr = - {kind - ;fill_value - ;shape = chunks} - in + let repr = {kind; fill_value; shape = chunks} in (match codecs with | Some c -> Codecs.Chain.create repr c | None -> Ok Codecs.Chain.default) - >>= fun codecs' -> + >>= fun codecs -> let meta = - AM.create ~sep ~codecs:codecs' ~shape kind fill_value chunks in + AM.create + ~sep ~codecs ~dimension_names ~attributes ~shape kind fill_value chunks + in set t (Node.to_metakey node) (AM.encode meta); Ok (make_implicit_groups_explicit t node) diff --git a/test/dune b/test/dune index 75a5792..62b6cb6 100644 --- a/test/dune +++ b/test/dune @@ -1,3 +1,7 @@ (test (name test_zarr) - (libraries ounit2 zarr)) + (libraries + zarr + ounit2) + (preprocess + (pps ppx_deriving.show))) diff --git a/test/test_metadata.ml b/test/test_metadata.ml index 715a712..b74d75a 100644 --- a/test/test_metadata.ml +++ b/test/test_metadata.ml @@ -13,7 +13,7 @@ let group = [ | Ok v -> assert_equal ~printer:GroupMetadata.show meta v; | Error _ -> - assert_bool "Decoding well formed metadata should not fail" false); + assert_failure "Decoding well formed metadata should not fail"); assert_bool "" (Result.is_error @@ GroupMetadata.decode {|{"bad_json":0}|}); let meta' = @@ -25,3 +25,122 @@ let group = [ in assert_equal expected @@ GroupMetadata.encode meta') ] + +let array = [ + +"array metadata" >:: (fun _ -> + let shape = [|10; 10; 10|] in + let chunks = [|5; 2; 6|] in + let grid_shape = [|2; 5; 2|] in + let dimension_names = [Some "x"; None; Some "z"] in + + let meta = + ArrayMetadata.create + ~shape ~dimension_names Bigarray.Float32 32.0 chunks + in + (match ArrayMetadata.encode meta |> ArrayMetadata.decode with + | Ok v -> + assert_bool "should not fail" @@ ArrayMetadata.equal v meta; + | Error _ -> + assert_failure "Decoding well formed metadata should not fail"); + + assert_bool + "" (Result.is_error @@ ArrayMetadata.decode {|{"bad_json":0}|}); + + let show_int_array = [%show: int array] in + assert_equal + ~printer:show_int_array shape @@ ArrayMetadata.shape meta; + + assert_equal + ~printer:Codecs.Chain.show + Codecs.Chain.default @@ + ArrayMetadata.codecs meta; + + assert_equal + ~printer:string_of_int + (Array.length shape) + (ArrayMetadata.ndim meta); + + assert_equal + ~printer:show_int_array + chunks @@ + ArrayMetadata.chunk_shape meta; + + assert_equal + ~printer:show_int_array + grid_shape @@ + ArrayMetadata.grid_shape meta shape; + + let show_int_array_tuple = + [%show: int array * int array] + in + assert_equal + ~printer:show_int_array_tuple + ([|1; 3; 1|], [|3; 1; 0|]) @@ + ArrayMetadata.index_coord_pair meta [|8; 7; 6|]; + + assert_equal + ~printer:show_int_array_tuple + ([|2; 5; 1|], [|0; 0; 4|]) @@ + ArrayMetadata.index_coord_pair meta [|10; 10; 10|]; + + assert_equal + ~printer:Fun.id + "c/2/5/1" @@ + ArrayMetadata.chunk_key meta [|2; 5; 1|]; + + let indices = + [[|0; 0; 0|]; [|0; 0; 1|]; [|0; 1; 0|]; [|0; 1; 1|] + ;[|1; 0; 0|]; [|1; 0; 1|]; [|1; 1; 0|]; [|1; 1; 1|]] + in + assert_equal + ~printer:[%show: int array list] + indices @@ + ArrayMetadata.chunk_indices meta [|10; 4; 10|]; + + assert_equal + ~printer:Fun.id + {|"float32"|} @@ + ArrayMetadata.data_type meta; + + assert_equal + ~printer:[%show: string option list] + dimension_names @@ + ArrayMetadata.dimension_names meta; + + assert_equal + ~printer:Yojson.Safe.show + `Null @@ + ArrayMetadata.attributes meta; + + let attrs = `Assoc [("questions", `String "answer")] in + assert_equal + ~printer:Yojson.Safe.show + attrs + ArrayMetadata.(attributes @@ update_attributes meta attrs); + + let new_shape = [|20; 10; 6|] in + assert_equal + ~printer:show_int_array + new_shape @@ + ArrayMetadata.(shape @@ update_shape meta new_shape); + + assert_bool + "" @@ ArrayMetadata.is_valid_kind meta Bigarray.Float32; + + assert_bool + "Float32 is the only valid kind for this metadata" + (not @@ ArrayMetadata.is_valid_kind meta Bigarray.Int8_signed); + + assert_equal + ~printer:string_of_float + 32. @@ + ArrayMetadata.fillvalue_of_kind meta Bigarray.Float32; + + assert_raises + ~msg:"Wrong kind used to extract fill value." + (Failure "kind is not compatible with node's fill value.") + (fun () -> ArrayMetadata.fillvalue_of_kind meta Bigarray.Complex32)) +] + +let tests = group @ array diff --git a/test/test_zarr.ml b/test/test_zarr.ml index 3e65564..a63c48b 100644 --- a/test/test_zarr.ml +++ b/test/test_zarr.ml @@ -5,6 +5,6 @@ let () = let suite = "Run All tests" >::: Test_node.tests @ Test_indexing.tests @ - Test_metadata.group + Test_metadata.tests in run_test_tt_main suite