diff --git a/lib/util.ml b/lib/util.ml index 29e92e0..9aea033 100644 --- a/lib/util.ml +++ b/lib/util.ml @@ -24,9 +24,6 @@ module ArraySet = Set.Make (ComparableArray) module Arraytbl = Hashtbl.Make (HashableArray) module Result_syntax = struct - let ( let* ) = Result.bind - let ( let+ ) = Result.map - let ( >>= ) = Result.bind let ( >>| ) x f = (* infix map *) @@ -38,11 +35,6 @@ module Result_syntax = struct match x with | Ok _ as k -> k | Error e -> Error (f e) - - let ( and+ ) x y = (* product *) - match x, y with - | Ok a, Ok b -> Ok (a, b) - | Error e, Ok _ | Ok _, Error e | Error e, Error _ -> Error e end module Indexing = struct @@ -71,8 +63,11 @@ module Indexing = struct | _ -> failwith "Invalid slice index." let reformat_slice slice shape = - Owl_slicing.check_slice_definition - (Owl_slicing.sdarray_to_sdarray slice) shape + match slice with + | [||] -> [||] + | xs -> + Owl_slicing.check_slice_definition + (Owl_slicing.sdarray_to_sdarray xs) shape let coords_of_slice slice shape = (Array.map indices_of_slice @@ diff --git a/lib/util.mli b/lib/util.mli index 2231731..465f496 100644 --- a/lib/util.mli +++ b/lib/util.mli @@ -19,18 +19,12 @@ module ArraySet : sig include Set.S with type elt = int array end module Result_syntax : sig (** Result monad operator syntax. *) - val ( let* ) - : ('a, 'e) result -> ('a -> ('b, 'e) result ) -> ('b, 'e) result val ( >>= ) : ('a, 'e) result -> ('a -> ('b, 'e) result ) -> ('b, 'e) result - val ( let+ ) - : ('a -> 'b) -> ('a, 'e) result -> ('b, 'e) result val ( >>| ) : ('a, 'e) result -> ('a -> 'b) -> ('b, 'e) result val ( >>? ) : ('a, 'e) result -> ('e -> 'f) -> ('a, 'f) result - val ( and+ ) - : ('a, 'e) result -> ('b, 'e) result -> (('a * 'b), 'e) result end module Indexing : sig diff --git a/test/test_indexing.ml b/test/test_indexing.ml new file mode 100644 index 0000000..6cc06e1 --- /dev/null +++ b/test/test_indexing.ml @@ -0,0 +1,55 @@ +open OUnit2 +open Zarr + + +let tests = [ + +"slice from coords" >:: (fun _ -> + let coords = + [[|0; 1; 2; 3|] + ;[|9; 8; 7; 6|] + ;[|5; 4; 3; 2|]] + in + let expected = + Owl_types.[| + L [0; 9; 5]; + L [1; 8; 4]; + L [2; 7; 3]; + L [3; 6; 2] + |] + in + assert_equal expected @@ Indexing.slice_of_coords coords; + assert_equal [||] @@ Indexing.slice_of_coords []) +; +"coords from slice" >:: (fun _ -> + let shape = [|10; 10; 11|] in + let slice = + Owl_types.[|L [0; 9; 5]; I 1; R [3; 10; 3]|] + in + let excepted = + [|[|0; 1; 3|]; [|0; 1; 6|]; [|0; 1; 9|] + ;[|9; 1; 3|]; [|9; 1; 6|]; [|9; 1; 9|] + ;[|5; 1; 3|]; [|5; 1; 6|]; [|5; 1; 9|]|] + in + assert_equal excepted @@ Indexing.coords_of_slice slice shape; + assert_equal [|[||]|] @@ Indexing.coords_of_slice [||] shape) +; +"compute slice shape" >:: (fun _ -> + let shape = [|10; 10; 10|] in + let slice = + Owl_types.[|L [0; 9; 5]; I 1; R [2; 9; 1]|] + in + assert_equal [|3; 1; 8|] @@ Indexing.slice_shape slice shape; + assert_equal [||] @@ Indexing.slice_shape [||] shape) +; +"cartesian product" >:: (fun _ -> + let ll = [[1; 2]; [3; 8]; [9; 4]] in + let expected = + [[1; 3; 9]; [1; 3; 4]; [1; 8; 9]; [1; 8; 4] + ;[2; 3; 9]; [2; 3; 4]; [2; 8; 9]; [2; 8; 4]] + in + assert_equal expected @@ Indexing.cartesian_prod ll; + let ll = [['a'; 'b']; ['z'; 'o']] in + let expected = [['a'; 'z']; ['a'; 'o']; ['b'; 'z']; ['b'; 'o']] in + assert_equal expected @@ Indexing.cartesian_prod ll) +] diff --git a/test/test_zarr.ml b/test/test_zarr.ml index d939747..f287238 100644 --- a/test/test_zarr.ml +++ b/test/test_zarr.ml @@ -3,6 +3,7 @@ open OUnit2 let () = let suite = "Run All tests" >::: - Test_node.tests + Test_node.tests @ + Test_indexing.tests in run_test_tt_main suite