Skip to content

Commit

Permalink
[BUG] Fix bug with map_groups UDFs that return more than 1 output row…
Browse files Browse the repository at this point in the history
… for empty partitions (#2532)
  • Loading branch information
jaychia authored Jul 18, 2024
1 parent afcfecd commit 6835cff
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{Expr, ExprRef};
pub struct PythonUDF {
func: partial_udf::PartialUDF,
num_expressions: usize,
return_dtype: DataType,
pub return_dtype: DataType,
}

pub fn udf(
Expand Down
16 changes: 11 additions & 5 deletions src/daft-table/src/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,16 @@ impl Table {
.collect::<DaftResult<Vec<_>>>()?;

// Take fast path short circuit if there is only 1 group
let (groupkeys_table, grouped_col) = if groupvals_indices.len() <= 1 {
let (groupkeys_table, grouped_col) = if groupvals_indices.is_empty() {
let empty_groupkeys_table = Table::empty(Some(groupby_table.schema.clone()))?;
let empty_udf_output_col = Series::empty(
evaluated_inputs
.first()
.map_or_else(|| "output", |s| s.name()),
&udf.return_dtype,
);
(empty_groupkeys_table, empty_udf_output_col)
} else if groupvals_indices.len() == 1 {
let grouped_col = udf.call_udf(evaluated_inputs.as_slice())?;
let groupkeys_table = {
let indices_as_series = UInt64Array::from(("", groupkey_indices)).into_series();
Expand Down Expand Up @@ -156,10 +165,7 @@ impl Table {
};

// Broadcast either the keys or the grouped_cols, depending on which is unit-length
let final_len = [groupkeys_table.len(), grouped_col.len()]
.into_iter()
.find(|&l| l != 1)
.unwrap_or(1);
let final_len = grouped_col.len();
let final_columns = [&groupkeys_table.columns[..], &[grouped_col]].concat();
let final_schema = Schema::new(final_columns.iter().map(|s| s.field().clone()).collect())?;
Self::new_with_broadcast(final_schema, final_columns, final_len)
Expand Down
5 changes: 3 additions & 2 deletions tests/dataframe/test_map_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def udf(a, b):


@pytest.mark.parametrize("repartition_nparts", [1, 2, 3])
def test_map_groups_more_than_one_output_row(make_df, repartition_nparts):
@pytest.mark.parametrize("output_when_empty", [[], [1], [1, 2]])
def test_map_groups_more_than_one_output_row(make_df, repartition_nparts, output_when_empty):
daft_df = make_df(
{
"group": [1, 2],
Expand All @@ -50,7 +51,7 @@ def test_map_groups_more_than_one_output_row(make_df, repartition_nparts):
def udf(a):
a = a.to_pylist()
if len(a) == 0:
return []
return output_when_empty
return [a[0]] * 3

daft_df = daft_df.groupby("group").map_groups(udf(daft_df["a"])).sort("group", desc=False)
Expand Down

0 comments on commit 6835cff

Please sign in to comment.