Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Aug 10, 2024
1 parent 9ec34ad commit 68d8007
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 20 deletions.
12 changes: 8 additions & 4 deletions python/cudf/cudf/core/column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import itertools
import sys
import warnings
from collections import abc
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Mapping
Expand Down Expand Up @@ -606,7 +607,7 @@ def _pad_key(self, key: Any, pad_value="") -> Any:
return key + (pad_value,) * (self.nlevels - len(key))

def rename_levels(
self, mapper: Mapping[Any, Any] | Callable, level: int | None
self, mapper: Mapping[Any, Any] | Callable, level: int | None = None
) -> ColumnAccessor:
"""
Rename the specified levels of the given ColumnAccessor
Expand Down Expand Up @@ -649,10 +650,13 @@ def rename_column(x):
return x

if level is None:
raise NotImplementedError(
"Renaming columns with a MultiIndex and level=None is"
"not supported"
warnings.warn(
"Renaming columns with MultiIndex assuming level=0. "
"Specify the level keyword argument to rename using "
"a different level eg. df.rename(..., level=1)",
UserWarning,
)
level = 0
new_col_names = (rename_column(k) for k in self.keys())

else:
Expand Down
39 changes: 23 additions & 16 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def _groupby(self):
)

@_performance_tracking
def agg(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
"""
Apply aggregation(s) to the groups.
Expand Down Expand Up @@ -648,11 +648,10 @@ def agg(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
raise NotImplementedError(
"Passing args to func is currently not supported."
)
if kwargs:
raise NotImplementedError(
"Passing kwargs to func is currently not supported."
)
column_names, columns, normalized_aggs = self._normalize_aggs(func)

column_names, columns, normalized_aggs = self._normalize_aggs(
func, **kwargs
)
orig_dtypes = tuple(c.dtype for c in columns)

# Note: When there are no key columns, the below produces
Expand Down Expand Up @@ -1266,11 +1265,11 @@ def _grouped(self, *, include_groups: bool = True):
return (group_names, offsets, grouped_keys, grouped_values)

def _normalize_aggs(
self, aggs: MultiColumnAggType
self, aggs: MultiColumnAggType, **kwargs
) -> tuple[Iterable[Any], tuple[ColumnBase, ...], list[list[AggType]]]:
"""
Normalize aggs to a list of list of aggregations, where `out[i]`
is a list of aggregations for column `self.obj[i]`. We support three
is a list of aggregations for column `self.obj[i]`. We support four
different form of `aggs` input here:
- A single agg, such as "sum". This agg is applied to all value
columns.
Expand All @@ -1279,18 +1278,26 @@ def _normalize_aggs(
- A mapping of column name to aggs, such as
{"a": ["sum"], "b": ["mean"]}, the aggs are applied to specified
column.
- Pairs of column name and agg tuples passed as kwargs
eg. col1=("a", "sum"), col2=("b", "prod"). The output column names are
the keys. The aggs are applied to the corresponding column in the tuple.
Each agg can be string or lambda functions.
"""

aggs_per_column: Iterable[AggType | Iterable[AggType]]
if isinstance(aggs, dict):
column_names, aggs_per_column = aggs.keys(), aggs.values()
columns = tuple(self.obj._data[col] for col in column_names)
else:
values = self.grouping.values
column_names = values._column_names
columns = values._columns
aggs_per_column = (aggs,) * len(columns)
if aggs:
if isinstance(aggs, dict):
column_names, aggs_per_column = aggs.keys(), aggs.values()
columns = tuple(self.obj._data[col] for col in column_names)
else:
values = self.grouping.values
column_names = values._column_names
columns = values._columns
aggs_per_column = (aggs,) * len(columns)
elif not aggs and kwargs:
column_names, aggs_per_column = kwargs.keys(), kwargs.values()
columns = tuple(self.obj._data[x[1][0]] for x in kwargs.items())
aggs_per_column = [x[1] for x in kwargs.values()]

# is_list_like performs type narrowing but type-checkers don't
# know it. One could add a TypeGuard annotation to
Expand Down
30 changes: 30 additions & 0 deletions python/cudf/cudf/tests/groupby/test_agg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
import numpy as np
import pandas as pd
import pytest

import cudf
Expand All @@ -26,3 +27,32 @@ def test_series_agg(attr):
pd_agg = getattr(pdf.groupby(["a"])["a"], attr)("count")

assert agg.ndim == pd_agg.ndim


@pytest.mark.parametrize("func", ["sum", "prod", "mean", "count"])
@pytest.mark.parametrize("attr", ["agg", "aggregate"])
def test_dataframe_agg(attr, func):
df = cudf.DataFrame({"a": [1, 2, 1, 2], "b": [0, 0, 0, 0]})
pdf = df.to_pandas()

agg = getattr(df.groupby("a"), attr)(func)
pd_agg = getattr(pdf.groupby(["a"]), attr)(func)

pd.testing.assert_frame_equal(agg.to_pandas(), pd_agg)

agg = getattr(df.groupby("a"), attr)({"b": func})
pd_agg = getattr(pdf.groupby(["a"]), attr)({"b": func})

pd.testing.assert_frame_equal(agg.to_pandas(), pd_agg)

agg = getattr(df.groupby("a"), attr)([func])
pd_agg = getattr(pdf.groupby(["a"]), attr)([func])

pd.testing.assert_frame_equal(agg.to_pandas(), pd_agg)

agg = getattr(df.groupby("a"), attr)(foo=("b", func), bar=("a", func))
pd_agg = getattr(pdf.groupby(["a"]), attr)(
foo=("b", func), bar=("a", func)
)

pd.testing.assert_frame_equal(agg.to_pandas(), pd_agg)
4 changes: 4 additions & 0 deletions python/cudf/cudf/tests/test_column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ def test_replace_level_values_MultiColumn():
got = ca.rename_levels(mapper={"a": "f"}, level=0)
check_ca_equal(expect, got)

with pytest.raises(UserWarning):
got = ca.rename_levels(mapper={"a": "f"})
check_ca_equal(expect, got)


def test_clear_nrows_empty_before():
ca = ColumnAccessor({})
Expand Down

0 comments on commit 68d8007

Please sign in to comment.