Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Enable UDF to handle arbitrary number of Daft series #1984

Merged
merged 3 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,32 @@ class PartialUDF:
bound_args: inspect.BoundArguments

def expressions(self) -> dict[str, Expression]:
return {key: val for key, val in self.bound_args.arguments.items() if isinstance(val, Expression)}
parsed_expressions = {}
for key, val in self.bound_args.arguments.items():
if isinstance(val, Expression):
parsed_expressions[key] = val
elif isinstance(val, tuple) and all(isinstance(x, Expression) for x in val):
gmweaver marked this conversation as resolved.
Show resolved Hide resolved
for idx, x in enumerate(val):
parsed_expressions[f"{key}{idx}"] = x
gmweaver marked this conversation as resolved.
Show resolved Hide resolved

return parsed_expressions

def arg_keys(self) -> list[str]:
parsed_arg_keys = []
for key, value in self.bound_args.arguments.items():
if key in self.bound_args.kwargs:
continue
if isinstance(value, tuple):
for idx, _ in enumerate(value):
parsed_arg_keys.append(f"{key}{idx}")
else:
parsed_arg_keys.append(key)

return parsed_arg_keys

def __call__(self, evaluated_expressions: list[Series]) -> PySeries:
kwarg_keys = list(self.bound_args.kwargs.keys())
arg_keys = [k for k in self.bound_args.arguments.keys() if k not in self.bound_args.kwargs.keys()]
arg_keys = self.arg_keys()
pyvalues = {key: val for key, val in self.bound_args.arguments.items() if not isinstance(val, Expression)}
expressions = self.expressions()
assert len(evaluated_expressions) == len(
Expand Down
32 changes: 32 additions & 0 deletions tests/expressions/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,35 @@ def single_arg_udf(x, y, z=10):
pass

assert single_arg_udf(col("x"), "y", z=20).__repr__() == "@udf[single_arg_udf](col('x'), 'y', z=20)"


def test_udf_arbitrary_number_of_args():
table = MicroPartition.from_pydict({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})

@udf(return_dtype=DataType.int64())
def add_cols_elementwise(*args):
return Series.from_pylist([sum(x) for x in zip(*[arg.to_pylist() for arg in args])])

expr = add_cols_elementwise(col("a"), col("b"), col("c"))
field = expr._to_field(table.schema())
assert field.name == "a"
assert field.dtype == DataType.int64()

result = table.eval_expression_list([expr])
assert result.to_pydict() == {"a": [3, 6, 9]}


def test_udf_arbitrary_number_of_args_with_kwargs():
table = MicroPartition.from_pydict({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})

@udf(return_dtype=DataType.int64())
def add_cols_elementwise(*args, multiplier: float):
return Series.from_pylist([multiplier * sum(x) for x in zip(*[arg.to_pylist() for arg in args])])

expr = add_cols_elementwise(col("a"), col("b"), col("c"), multiplier=2)
field = expr._to_field(table.schema())
assert field.name == "a"
assert field.dtype == DataType.int64()

result = table.eval_expression_list([expr])
assert result.to_pydict() == {"a": [6, 12, 18]}
Loading