From c4aacb247a2914ff3dd979db7e11e5d74267a6f9 Mon Sep 17 00:00:00 2001 From: Lukas Weidenholzer <17790923+LukeWeidenwalker@users.noreply.github.com> Date: Fri, 18 Aug 2023 12:04:49 +0200 Subject: [PATCH] fix: support axis keyword in array_contains (#152) * support axis keyword in array_contains * add testcase back in * add test for reduce * add assert to test --- .../process_implementations/arrays.py | 11 ++++---- tests/test_arrays.py | 26 +++++++++++++++++++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/openeo_processes_dask/process_implementations/arrays.py b/openeo_processes_dask/process_implementations/arrays.py index 58734514..3f8876e4 100644 --- a/openeo_processes_dask/process_implementations/arrays.py +++ b/openeo_processes_dask/process_implementations/arrays.py @@ -126,7 +126,7 @@ def array_concat(array1: ArrayLike, array2: ArrayLike) -> ArrayLike: return concat -def array_contains(data: ArrayLike, value: Any) -> bool: +def array_contains(data: ArrayLike, value: Any, axis=None) -> bool: # TODO: Contrary to the process spec, our implementation does interpret temporal strings before checking them here # This is somewhat implicit in how we currently parse parameters, so cannot be easily changed. @@ -135,15 +135,14 @@ def array_contains(data: ArrayLike, value: Any) -> bool: for dtype in valid_dtypes: if np.issubdtype(type(value), dtype): value_is_valid = True - if not value_is_valid: + if len(np.shape(data)) != 1 and axis is None: return False - - if len(np.shape(data)) != 1: + if not value_is_valid: return False if pd.isnull(value): - return np.isnan(data).any() + return np.isnan(data).any(axis=axis) else: - return np.isin(data, value).any() + return np.isin(data, value).any(axis=axis) def array_find( diff --git a/tests/test_arrays.py b/tests/test_arrays.py index 439572d1..e052547f 100644 --- a/tests/test_arrays.py +++ b/tests/test_arrays.py @@ -159,6 +159,16 @@ def test_array_contains(data, value, expected): assert dask_result == expected or dask_result.compute() == expected +def test_array_contains_axis(): + data = np.array([[4, 5, 6], [5, 7, 9]]) + + result_0 = array_contains(data, 5, axis=0) + np.testing.assert_array_equal(result_0, np.array([True, True, False])) + + result_1 = array_contains(data, 5, axis=1) + np.testing.assert_array_equal(result_1, np.array([True, True])) + + def test_array_contains_object_dtype(): assert not array_contains([{"a": "b"}, {"c": "d"}], {"a": "b"}) assert not array_contains(np.array([{"a": "b"}, {"c": "d"}]), {"a": "b"}) @@ -397,3 +407,19 @@ def test_reduce_dimension( ) assert output_cube.dims == ("x", "y", "t") xr.testing.assert_equal(output_cube, xr.ones_like(output_cube)) + + input_cube[0, 0, 0, 0] = 99999 + _process = partial( + process_registry["array_contains"].implementation, + data=ParameterReference(from_parameter="data"), + value=99999, + ) + output_cube = reduce_dimension(data=input_cube, reducer=_process, dimension="bands") + general_output_checks( + input_cube=input_cube, + output_cube=output_cube, + verify_attrs=False, + verify_crs=True, + ) + assert output_cube[0, 0, 0].data.compute().item() is True + assert not output_cube[slice(1, None), :, :].data.compute().any()