Skip to content

Commit

Permalink
fix: support axis keyword in array_contains (#152)
Browse files Browse the repository at this point in the history
* support axis keyword in array_contains

* add testcase back in

* add test for reduce

* add assert to test
  • Loading branch information
LukeWeidenwalker authored Aug 18, 2023
1 parent 7490bc3 commit c4aacb2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
11 changes: 5 additions & 6 deletions openeo_processes_dask/process_implementations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def array_concat(array1: ArrayLike, array2: ArrayLike) -> ArrayLike:
return concat


def array_contains(data: ArrayLike, value: Any) -> bool:
def array_contains(data: ArrayLike, value: Any, axis=None) -> bool:
# TODO: Contrary to the process spec, our implementation does interpret temporal strings before checking them here
# This is somewhat implicit in how we currently parse parameters, so cannot be easily changed.

Expand All @@ -135,15 +135,14 @@ def array_contains(data: ArrayLike, value: Any) -> bool:
for dtype in valid_dtypes:
if np.issubdtype(type(value), dtype):
value_is_valid = True
if not value_is_valid:
if len(np.shape(data)) != 1 and axis is None:
return False

if len(np.shape(data)) != 1:
if not value_is_valid:
return False
if pd.isnull(value):
return np.isnan(data).any()
return np.isnan(data).any(axis=axis)
else:
return np.isin(data, value).any()
return np.isin(data, value).any(axis=axis)


def array_find(
Expand Down
26 changes: 26 additions & 0 deletions tests/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ def test_array_contains(data, value, expected):
assert dask_result == expected or dask_result.compute() == expected


def test_array_contains_axis():
data = np.array([[4, 5, 6], [5, 7, 9]])

result_0 = array_contains(data, 5, axis=0)
np.testing.assert_array_equal(result_0, np.array([True, True, False]))

result_1 = array_contains(data, 5, axis=1)
np.testing.assert_array_equal(result_1, np.array([True, True]))


def test_array_contains_object_dtype():
assert not array_contains([{"a": "b"}, {"c": "d"}], {"a": "b"})
assert not array_contains(np.array([{"a": "b"}, {"c": "d"}]), {"a": "b"})
Expand Down Expand Up @@ -397,3 +407,19 @@ def test_reduce_dimension(
)
assert output_cube.dims == ("x", "y", "t")
xr.testing.assert_equal(output_cube, xr.ones_like(output_cube))

input_cube[0, 0, 0, 0] = 99999
_process = partial(
process_registry["array_contains"].implementation,
data=ParameterReference(from_parameter="data"),
value=99999,
)
output_cube = reduce_dimension(data=input_cube, reducer=_process, dimension="bands")
general_output_checks(
input_cube=input_cube,
output_cube=output_cube,
verify_attrs=False,
verify_crs=True,
)
assert output_cube[0, 0, 0].data.compute().item() is True
assert not output_cube[slice(1, None), :, :].data.compute().any()

0 comments on commit c4aacb2

Please sign in to comment.