From a99a601f702256f727aed6499018afdfcac5be72 Mon Sep 17 00:00:00 2001 From: Zolisa Bleki Date: Tue, 17 Sep 2024 00:53:00 +0200 Subject: [PATCH] Add support for the `bool` data type. --- zarr/src/codecs/array_to_bytes.ml | 12 +++++++----- zarr/src/codecs/ebuffer.ml | 6 ++++++ zarr/src/codecs/ebuffer.mli | 2 ++ zarr/src/extensions.ml | 4 ++++ zarr/src/extensions.mli | 1 + zarr/src/metadata.ml | 3 +++ zarr/src/ndarray.ml | 2 ++ zarr/src/ndarray.mli | 1 + zarr/test/test_codecs.ml | 4 ++++ zarr/test/test_metadata.ml | 8 ++++++++ zarr/test/test_ndarray.ml | 2 ++ 11 files changed, 40 insertions(+), 5 deletions(-) diff --git a/zarr/src/codecs/array_to_bytes.ml b/zarr/src/codecs/array_to_bytes.ml index b05df74..f4f8131 100644 --- a/zarr/src/codecs/array_to_bytes.ml +++ b/zarr/src/codecs/array_to_bytes.ml @@ -18,11 +18,12 @@ module BytesCodec = struct let open (val endian_module t) in let buf = Bytes.create @@ Ndarray.byte_size x in match Ndarray.data_type x with - | Char-> Ndarray.iteri (set_char buf) x; Bytes.unsafe_to_string buf - | Uint8-> Ndarray.iteri (set_uint8 buf) x; Bytes.unsafe_to_string buf - | Int8-> Ndarray.iteri (set_int8 buf) x; Bytes.unsafe_to_string buf - | Int16-> Ndarray.iteri (set_int16 buf) x; Bytes.unsafe_to_string buf - | Uint16-> Ndarray.iteri (set_uint16 buf) x; Bytes.unsafe_to_string buf + | Char -> Ndarray.iteri (set_char buf) x; Bytes.unsafe_to_string buf + | Bool -> Ndarray.iteri (set_bool buf) x; Bytes.unsafe_to_string buf + | Uint8 -> Ndarray.iteri (set_uint8 buf) x; Bytes.unsafe_to_string buf + | Int8 -> Ndarray.iteri (set_int8 buf) x; Bytes.unsafe_to_string buf + | Int16 -> Ndarray.iteri (set_int16 buf) x; Bytes.unsafe_to_string buf + | Uint16 -> Ndarray.iteri (set_uint16 buf) x; Bytes.unsafe_to_string buf | Int32 -> Ndarray.iteri (set_int32 buf) x; Bytes.unsafe_to_string buf | Int64 -> Ndarray.iteri (set_int64 buf) x; Bytes.unsafe_to_string buf | Uint64 -> Ndarray.iteri (set_uint64 buf) x; Bytes.unsafe_to_string buf @@ -41,6 +42,7 @@ module BytesCodec = struct let buf = Bytes.unsafe_of_string str in match k, Ndarray.dtype_size k with | Char, _ -> Ndarray.init k shp @@ get_char buf + | Bool, _ -> Ndarray.init k shp @@ get_bool buf | Uint8, _ -> Ndarray.init k shp @@ get_int8 buf | Int8, _ -> Ndarray.init k shp @@ get_uint8 buf | Int16, s -> Ndarray.init k shp @@ fun i -> get_int16 buf (i*s) diff --git a/zarr/src/codecs/ebuffer.ml b/zarr/src/codecs/ebuffer.ml index 9e92176..7e2242b 100644 --- a/zarr/src/codecs/ebuffer.ml +++ b/zarr/src/codecs/ebuffer.ml @@ -1,5 +1,6 @@ module type S = sig val set_char : bytes -> int -> char -> unit + val set_bool : bytes -> int -> bool -> unit val set_int8 : bytes -> int -> int -> unit val set_uint8 : bytes -> int -> int -> unit val set_int16 : bytes -> int -> int -> unit @@ -15,6 +16,7 @@ module type S = sig val set_nativeint : bytes -> int -> nativeint -> unit val get_char : bytes -> int -> char + val get_bool : bytes -> int -> bool val get_int8 : bytes -> int -> int val get_uint8 : bytes -> int -> int val get_int16 : bytes -> int -> int @@ -34,6 +36,7 @@ module Little = struct let set_int8 = Bytes.set_int8 let set_uint8 = Bytes.set_uint8 let set_char buf i v = Char.code v |> set_uint8 buf i + let set_bool buf i v = Bool.to_int v |> set_uint8 buf i let set_int16 buf i v = Bytes.set_int16_le buf (2*i) v let set_uint16 buf i v = Bytes.set_uint16_le buf (2*i) v let set_int32 buf i v = Bytes.set_int32_le buf (4*i) v @@ -53,6 +56,7 @@ module Little = struct let get_int8 = Bytes.get_int8 let get_uint8 = Bytes.get_uint8 let get_char buf i = get_uint8 buf i |> Char.chr + let get_bool buf i = match get_uint8 buf i with | 0 -> false | _ -> true let get_int16 = Bytes.get_int16_le let get_uint16 = Bytes.get_uint16_le let get_int32 = Bytes.get_int32_le @@ -74,6 +78,7 @@ module Big = struct let set_int8 = Bytes.set_int8 let set_uint8 = Bytes.set_uint8 let set_char buf i v = Char.code v |> set_uint8 buf i + let set_bool buf i v = Bool.to_int v |> set_uint8 buf i let set_int16 buf i v = Bytes.set_int16_be buf (i * 2) v let set_uint16 buf i v = Bytes.set_uint16_be buf (i * 2) v let set_int32 buf i v = Bytes.set_int32_be buf (i * 4) v @@ -93,6 +98,7 @@ module Big = struct let get_int8 = Bytes.get_int8 let get_uint8 = Bytes.get_uint8 let get_char buf i = get_uint8 buf i |> Char.chr + let get_bool buf i = match get_uint8 buf i with | 0 -> false | _ -> true let get_int16 = Bytes.get_int16_be let get_uint16 = Bytes.get_uint16_be let get_int32 = Bytes.get_int32_be diff --git a/zarr/src/codecs/ebuffer.mli b/zarr/src/codecs/ebuffer.mli index e89d4dd..65cdd77 100644 --- a/zarr/src/codecs/ebuffer.mli +++ b/zarr/src/codecs/ebuffer.mli @@ -1,5 +1,6 @@ module type S = sig val set_char : bytes -> int -> char -> unit + val set_bool : bytes -> int -> bool -> unit val set_int8 : bytes -> int -> int -> unit val set_uint8 : bytes -> int -> int -> unit val set_int16 : bytes -> int -> int -> unit @@ -15,6 +16,7 @@ module type S = sig val set_nativeint : bytes -> int -> nativeint -> unit val get_char : bytes -> int -> char + val get_bool : bytes -> int -> bool val get_int8 : bytes -> int -> int val get_uint8 : bytes -> int -> int val get_int16 : bytes -> int -> int diff --git a/zarr/src/extensions.ml b/zarr/src/extensions.ml index 052aecc..6b94877 100644 --- a/zarr/src/extensions.ml +++ b/zarr/src/extensions.ml @@ -105,6 +105,7 @@ end module Datatype = struct type t = | Char + | Bool | Int8 | Uint8 | Int16 @@ -123,6 +124,7 @@ module Datatype = struct let of_kind : type a. a Ndarray.dtype -> t = function | Ndarray.Char -> Char + | Ndarray.Bool -> Bool | Ndarray.Int8 -> Int8 | Ndarray.Uint8 -> Uint8 | Ndarray.Int16 -> Int16 @@ -139,6 +141,7 @@ module Datatype = struct let to_yojson = function | Char -> `String "char" + | Bool -> `String "bool" | Int8 -> `String "int8" | Uint8 -> `String "uint8" | Int16 -> `String "int16" @@ -155,6 +158,7 @@ module Datatype = struct let of_yojson = function | `String "char" -> Ok Char + | `String "bool" -> Ok Bool | `String "int8" -> Ok Int8 | `String "uint8" -> Ok Uint8 | `String "int16" -> Ok Int16 diff --git a/zarr/src/extensions.mli b/zarr/src/extensions.mli index a8b4be7..390727f 100644 --- a/zarr/src/extensions.mli +++ b/zarr/src/extensions.mli @@ -23,6 +23,7 @@ module Datatype : sig type t = | Char + | Bool | Int8 | Uint8 | Int16 diff --git a/zarr/src/metadata.ml b/zarr/src/metadata.ml index 8e9463f..7f23b08 100644 --- a/zarr/src/metadata.ml +++ b/zarr/src/metadata.ml @@ -23,6 +23,7 @@ module FillValue = struct = fun kind a -> match kind with | Ndarray.Char -> Char a + | Ndarray.Bool -> Bool a | Ndarray.Int8 -> Int (Stdint.Uint64.of_int a) | Ndarray.Uint8 -> Int (Stdint.Uint64.of_int a) | Ndarray.Int16 -> Int (Stdint.Uint64.of_int a) @@ -299,6 +300,7 @@ module Array = struct = fun t kind -> match kind, t.data_type with | Ndarray.Char, Datatype.Char + | Ndarray.Bool, Datatype.Bool | Ndarray.Int8, Datatype.Int8 | Ndarray.Uint8, Datatype.Uint8 | Ndarray.Int16, Datatype.Int16 @@ -319,6 +321,7 @@ module Array = struct = fun t kind -> match kind, t.fill_value with | Ndarray.Char, FillValue.Char c -> c + | Ndarray.Bool, FillValue.Bool b -> b | Ndarray.Int8, FillValue.Int i -> Stdint.Uint64.to_int i | Ndarray.Uint8, FillValue.Int i -> Stdint.Uint64.to_int i | Ndarray.Int16, FillValue.Int i -> Stdint.Uint64.to_int i diff --git a/zarr/src/ndarray.ml b/zarr/src/ndarray.ml index 894c3c8..8fbbefb 100644 --- a/zarr/src/ndarray.ml +++ b/zarr/src/ndarray.ml @@ -1,5 +1,6 @@ type _ dtype = | Char : char dtype + | Bool : bool dtype | Int8 : int dtype | Uint8 : int dtype | Int16 : int dtype @@ -22,6 +23,7 @@ type 'a t = let dtype_size : type a. a dtype -> int = function | Char -> 1 + | Bool -> 1 | Int8 -> 1 | Uint8 -> 1 | Int16 -> 2 diff --git a/zarr/src/ndarray.mli b/zarr/src/ndarray.mli index e8d1f5a..024541b 100644 --- a/zarr/src/ndarray.mli +++ b/zarr/src/ndarray.mli @@ -1,6 +1,7 @@ (** Supported data types for a Zarr array. *) type _ dtype = | Char : char dtype + | Bool : bool dtype | Int8 : int dtype | Uint8 : int dtype | Int16 : int dtype diff --git a/zarr/test/test_codecs.ml b/zarr/test/test_codecs.ml index 321dfce..47434a8 100644 --- a/zarr/test/test_codecs.ml +++ b/zarr/test/test_codecs.ml @@ -381,6 +381,10 @@ let tests = [ (* test encoding/decoding of Char *) bytes_encode_decode {shape; kind = Ndarray.Char} '?'; + (* test encoding/decoding of Bool *) + bytes_encode_decode {shape; kind = Ndarray.Bool} false; + bytes_encode_decode {shape; kind = Ndarray.Bool} true; + (* test encoding/decoding of int8 *) bytes_encode_decode {shape; kind = Ndarray.Int8} 0; diff --git a/zarr/test/test_metadata.ml b/zarr/test/test_metadata.ml index 1da55b9..e84677b 100644 --- a/zarr/test/test_metadata.ml +++ b/zarr/test/test_metadata.ml @@ -586,6 +586,14 @@ let array = [ let chunks = [|5; 2; 6|] in let dimension_names = [Some "x"; None; Some "z"] in + (* tests using bool data type. *) + test_array_metadata + ~shape + ~chunks + Ndarray.Bool + Ndarray.Float32 + false; + (* tests using char data type. *) test_array_metadata ~shape diff --git a/zarr/test/test_ndarray.ml b/zarr/test/test_ndarray.ml index 238d191..b6b5bc4 100644 --- a/zarr/test/test_ndarray.ml +++ b/zarr/test/test_ndarray.ml @@ -27,6 +27,8 @@ let tests = [ let shape = [|2; 5; 3|] in run_test {shape; kind = M.Char} '?' 1; + run_test {shape; kind = M.Bool} false 1; + run_test {shape; kind = M.Int8} 0 1; run_test {shape; kind = M.Uint8} 0 1;