Skip to content

Commit

Permalink
Better alignment check + error
Browse files Browse the repository at this point in the history
xref #191
  • Loading branch information
dcherian committed Nov 27, 2022
1 parent 27a4e9a commit 9b8e27a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
8 changes: 5 additions & 3 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,9 +1465,11 @@ def _assert_by_is_aligned(shape, by):
for idx, b in enumerate(by):
if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)):
raise ValueError(
"`array` and `by` arrays must be aligned "
"i.e. array.shape[-by.ndim :] == by.shape. "
"for every array in `by`."
"`array` and `by` arrays must be 'aligned' "
"so that such that by_ is broadcastable to array.shape[-by.ndim:] "
"for every array `by_` in `by`. "
"Either array.shape[-by_.ndim :] == by_.shape or the only differences "
"should be size-1 dimensions in by_."
f"Received array of shape {shape} but "
f"array {idx} in `by` has shape {b.shape}."
)
Expand Down
8 changes: 8 additions & 0 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,14 @@ def xarray_reduce(

# broadcast to make sure grouper dimensions are present in the array.
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)

try:
xr.align(ds, *by_da, join="exact")
except ValueError as e:
raise ValueError(
"Object being grouped must be exactly aligned with every array in `by`."
) from e

ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]

if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,12 @@ def test_mixed_grouping(chunk):
assert (r.sel(v1=[3, 4, 5]) == 0).all().data


def test_alignment_error():
da = xr.DataArray(np.arange(10), dims="x", coords={"x": np.arange(10)})
with pytest.raises(ValueError):
xarray_reduce(da, da.x.sel(x=slice(5)), func="count")


@pytest.mark.parametrize("add_nan", [True, False])
@pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
Expand Down

0 comments on commit 9b8e27a

Please sign in to comment.