From 6835cff92cbad32f4116e09aa17813101bd7e48c Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Wed, 17 Jul 2024 19:17:56 -0700 Subject: [PATCH] [BUG] Fix bug with map_groups UDFs that return more than 1 output row for empty partitions (#2532) --- src/daft-dsl/src/functions/python/mod.rs | 2 +- src/daft-table/src/ops/agg.rs | 16 +++++++++++----- tests/dataframe/test_map_groups.py | 5 +++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index b0b510ba25..71f00a2037 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -11,7 +11,7 @@ use crate::{Expr, ExprRef}; pub struct PythonUDF { func: partial_udf::PartialUDF, num_expressions: usize, - return_dtype: DataType, + pub return_dtype: DataType, } pub fn udf( diff --git a/src/daft-table/src/ops/agg.rs b/src/daft-table/src/ops/agg.rs index bff7e07910..1ae0ff0ee5 100644 --- a/src/daft-table/src/ops/agg.rs +++ b/src/daft-table/src/ops/agg.rs @@ -96,7 +96,16 @@ impl Table { .collect::>>()?; // Take fast path short circuit if there is only 1 group - let (groupkeys_table, grouped_col) = if groupvals_indices.len() <= 1 { + let (groupkeys_table, grouped_col) = if groupvals_indices.is_empty() { + let empty_groupkeys_table = Table::empty(Some(groupby_table.schema.clone()))?; + let empty_udf_output_col = Series::empty( + evaluated_inputs + .first() + .map_or_else(|| "output", |s| s.name()), + &udf.return_dtype, + ); + (empty_groupkeys_table, empty_udf_output_col) + } else if groupvals_indices.len() == 1 { let grouped_col = udf.call_udf(evaluated_inputs.as_slice())?; let groupkeys_table = { let indices_as_series = UInt64Array::from(("", groupkey_indices)).into_series(); @@ -156,10 +165,7 @@ impl Table { }; // Broadcast either the keys or the grouped_cols, depending on which is unit-length - let final_len = [groupkeys_table.len(), grouped_col.len()] - .into_iter() - .find(|&l| l != 1) - .unwrap_or(1); + let final_len = grouped_col.len(); let final_columns = [&groupkeys_table.columns[..], &[grouped_col]].concat(); let final_schema = Schema::new(final_columns.iter().map(|s| s.field().clone()).collect())?; Self::new_with_broadcast(final_schema, final_columns, final_len) diff --git a/tests/dataframe/test_map_groups.py b/tests/dataframe/test_map_groups.py index 0bcedf8043..0d5fe7ef86 100644 --- a/tests/dataframe/test_map_groups.py +++ b/tests/dataframe/test_map_groups.py @@ -37,7 +37,8 @@ def udf(a, b): @pytest.mark.parametrize("repartition_nparts", [1, 2, 3]) -def test_map_groups_more_than_one_output_row(make_df, repartition_nparts): +@pytest.mark.parametrize("output_when_empty", [[], [1], [1, 2]]) +def test_map_groups_more_than_one_output_row(make_df, repartition_nparts, output_when_empty): daft_df = make_df( { "group": [1, 2], @@ -50,7 +51,7 @@ def test_map_groups_more_than_one_output_row(make_df, repartition_nparts): def udf(a): a = a.to_pylist() if len(a) == 0: - return [] + return output_when_empty return [a[0]] * 3 daft_df = daft_df.groupby("group").map_groups(udf(daft_df["a"])).sort("group", desc=False)