From 74874deed9b15ad936f31f09a9c88e094fa7c6c5 Mon Sep 17 00:00:00 2001 From: gmweaver Date: Wed, 6 Mar 2024 14:43:51 -0800 Subject: [PATCH 1/3] allow arbritrary number of args --- daft/udf.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/daft/udf.py b/daft/udf.py index b670c2ad03..a20b404a97 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -30,11 +30,33 @@ class PartialUDF: bound_args: inspect.BoundArguments def expressions(self) -> dict[str, Expression]: - return {key: val for key, val in self.bound_args.arguments.items() if isinstance(val, Expression)} + parsed_expressions = {} + + for key, val in self.bound_args.arguments.items(): + if isinstance(val, Expression): + parsed_expressions[key] = val + elif isinstance(val, tuple) and all(isinstance(x, Expression) for x in val): + for idx, x in enumerate(val): + parsed_expressions[f"{key}{idx}"] = x + return parsed_expressions + + def arg_keys(self) -> list[str]: + """Gets argument keys.""" + parsed_arg_keys = [] + for key, value in self.bound_args.arguments.items(): + if key in self.bound_args.kwargs: + continue + if isinstance(value, tuple): + for idx, _ in enumerate(value): + parsed_arg_keys.append(f"{key}{idx}") + else: + parsed_arg_keys.append(key) + + return parsed_arg_keys def __call__(self, evaluated_expressions: list[Series]) -> PySeries: kwarg_keys = list(self.bound_args.kwargs.keys()) - arg_keys = [k for k in self.bound_args.arguments.keys() if k not in self.bound_args.kwargs.keys()] + arg_keys = self.arg_keys() pyvalues = {key: val for key, val in self.bound_args.arguments.items() if not isinstance(val, Expression)} expressions = self.expressions() assert len(evaluated_expressions) == len( From 4a364f8aab883694e9970a34f0770ef1b8e473aa Mon Sep 17 00:00:00 2001 From: gmweaver Date: Wed, 6 Mar 2024 15:05:27 -0800 Subject: [PATCH 2/3] add tests --- daft/udf.py | 3 +-- tests/expressions/test_udf.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/daft/udf.py b/daft/udf.py index a20b404a97..6f34e38d57 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -31,17 +31,16 @@ class PartialUDF: def expressions(self) -> dict[str, Expression]: parsed_expressions = {} - for key, val in self.bound_args.arguments.items(): if isinstance(val, Expression): parsed_expressions[key] = val elif isinstance(val, tuple) and all(isinstance(x, Expression) for x in val): for idx, x in enumerate(val): parsed_expressions[f"{key}{idx}"] = x + return parsed_expressions def arg_keys(self) -> list[str]: - """Gets argument keys.""" parsed_arg_keys = [] for key, value in self.bound_args.arguments.items(): if key in self.bound_args.kwargs: diff --git a/tests/expressions/test_udf.py b/tests/expressions/test_udf.py index ff5463c2d1..8123003951 100644 --- a/tests/expressions/test_udf.py +++ b/tests/expressions/test_udf.py @@ -168,3 +168,35 @@ def single_arg_udf(x, y, z=10): pass assert single_arg_udf(col("x"), "y", z=20).__repr__() == "@udf[single_arg_udf](col('x'), 'y', z=20)" + + +def test_udf_arbitrary_number_of_args(): + table = MicroPartition.from_pydict({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) + + @udf(return_dtype=DataType.int64()) + def add_cols_elementwise(*args): + return Series.from_pylist([sum(x) for x in zip(*[arg.to_pylist() for arg in args])]) + + expr = add_cols_elementwise(col("a"), col("b"), col("c")) + field = expr._to_field(table.schema()) + assert field.name == "a" + assert field.dtype == DataType.int64() + + result = table.eval_expression_list([expr]) + assert result.to_pydict() == {"a": [3, 6, 9]} + + +def test_udf_arbitrary_number_of_args_with_kwargs(): + table = MicroPartition.from_pydict({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) + + @udf(return_dtype=DataType.int64()) + def add_cols_elementwise(*args, multiplier: float): + return Series.from_pylist([multiplier * sum(x) for x in zip(*[arg.to_pylist() for arg in args])]) + + expr = add_cols_elementwise(col("a"), col("b"), col("c"), multiplier=2) + field = expr._to_field(table.schema()) + assert field.name == "a" + assert field.dtype == DataType.int64() + + result = table.eval_expression_list([expr]) + assert result.to_pydict() == {"a": [6, 12, 18]} From 5d00bb0f0f931d5b05c18a2cdd1ec7040dae7b50 Mon Sep 17 00:00:00 2001 From: gmweaver Date: Wed, 6 Mar 2024 15:48:56 -0800 Subject: [PATCH 3/3] address comments --- daft/udf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/daft/udf.py b/daft/udf.py index 6f34e38d57..1566d24159 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -34,9 +34,10 @@ def expressions(self) -> dict[str, Expression]: for key, val in self.bound_args.arguments.items(): if isinstance(val, Expression): parsed_expressions[key] = val - elif isinstance(val, tuple) and all(isinstance(x, Expression) for x in val): + elif isinstance(val, tuple): for idx, x in enumerate(val): - parsed_expressions[f"{key}{idx}"] = x + if isinstance(x, Expression): + parsed_expressions[f"{key}-{idx}"] = x return parsed_expressions @@ -47,7 +48,7 @@ def arg_keys(self) -> list[str]: continue if isinstance(value, tuple): for idx, _ in enumerate(value): - parsed_arg_keys.append(f"{key}{idx}") + parsed_arg_keys.append(f"{key}-{idx}") else: parsed_arg_keys.append(key)