Skip to content

Commit

Permalink
Try cleaning up some expected_groups logic (#175)
Browse files Browse the repository at this point in the history
* Try cleaning up some expected_groups logic

* Fix _extract_unknown_groups

* Fixes
  • Loading branch information
dcherian authored Nov 27, 2022
1 parent a70c5dd commit 27a4e9a
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,7 @@ def subset_to_blocks(
return dask.array.Array(graph, name, chunks, meta=array)


def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]:
import dask.array
from dask.highlevelgraph import HighLevelGraph

Expand All @@ -1180,7 +1180,7 @@ def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
dask.array.Array(
HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]),
groups_token,
chunks=group_chunks,
chunks=((np.nan,),),
meta=np.array([], dtype=dtype),
),
)
Expand Down Expand Up @@ -1293,14 +1293,7 @@ def dask_groupby_agg(
name=f"{name}-chunk-{token}",
)

if expected_groups is None:
if is_duck_dask_array(by_input):
expected_groups = None
else:
expected_groups = _get_expected_groups(by_input, sort=sort)
group_chunks: tuple[tuple[Union[int, float], ...]] = (
(len(expected_groups),) if expected_groups is not None else (np.nan,),
)
group_chunks: tuple[tuple[Union[int, float], ...]]

if method in ["map-reduce", "cohorts"]:
combine: Callable[..., IntermediateDict]
Expand Down Expand Up @@ -1333,13 +1326,13 @@ def dask_groupby_agg(
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
)
if is_duck_dask_array(by_input) and expected_groups is None:
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
groups = _extract_unknown_groups(reduced, dtype=by.dtype)
group_chunks = ((np.nan,),)
else:
if expected_groups is None:
expected_groups_ = _get_expected_groups(by_input, sort=sort)
else:
expected_groups_ = expected_groups
groups = (expected_groups_.to_numpy(),)
expected_groups = _get_expected_groups(by_input, sort=sort)
groups = (expected_groups.to_numpy(),)
group_chunks = ((len(expected_groups),),)

elif method == "cohorts":
chunks_cohorts = find_group_cohorts(
Expand Down

0 comments on commit 27a4e9a

Please sign in to comment.