Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Properly broadcast list arithmetic #18858

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions crates/polars-core/src/series/arithmetic/list_borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool {
}
}

fn broadcast_list(lhs: &ListChunked, rhs: &Series) -> PolarsResult<(ListChunked, Series)> {
let out = match (lhs.len(), rhs.len()) {
(1, _) => (lhs.new_from_index(0, rhs.len()), rhs.clone()),
(_, 1) => {
// Numeric scalars will be broadcasted implicitly without intermediate allocation.
if rhs.dtype().is_numeric() {
(lhs.clone(), rhs.clone())
} else {
(lhs.clone(), rhs.new_from_index(0, lhs.len()))
}
},
(a, b) if a == b => (lhs.clone(), rhs.clone()),
_ => {
polars_bail!(InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", lhs.dtype(), rhs.dtype())
},
};
Ok(out)
}

impl ListChunked {
/// Helper function for NumOpsDispatchInner implementation for ListChunked.
///
Expand All @@ -63,16 +82,17 @@ impl ListChunked {
op: &dyn Fn(&Series, &Series) -> PolarsResult<Series>,
has_nulls: Option<bool>,
) -> PolarsResult<Series> {
let (lhs, rhs) = broadcast_list(self, rhs)?;
polars_ensure!(
self.len() == rhs.len(),
lhs.len() == rhs.len(),
InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}",
self.len(),
lhs.len(),
rhs.len()
);

let mut has_nulls = has_nulls.unwrap_or(false);
if !has_nulls {
for chunk in self.chunks().iter() {
for chunk in lhs.chunks().iter() {
if does_list_have_nulls(chunk) {
has_nulls = true;
break;
Expand All @@ -92,11 +112,11 @@ impl ListChunked {
// values Arrow arrays. Given nulls, the two values arrays might not
// line up the way we expect.
let mut result = AnonymousListBuilder::new(
self.name().clone(),
self.len(),
Some(self.inner_dtype().clone()),
lhs.name().clone(),
lhs.len(),
Some(lhs.inner_dtype().clone()),
);
let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| {
let combined = lhs.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| {
let (Some(a_owner), Some(b_owner)) = (a, b) else {
// Operations with nulls always result in nulls:
return Ok(None);
Expand Down Expand Up @@ -131,12 +151,14 @@ impl ListChunked {
}
return Ok(result.finish().into());
}
let l_rechunked = self.clone().rechunk().into_series();
let l_rechunked = lhs.clone().rechunk().into_series();
let l_leaf_array = l_rechunked.get_leaf_array();
let r_leaf_array = rhs.rechunk().get_leaf_array();
polars_ensure!(
lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]),
InvalidOperation: "can only do arithmetic operations on lists of the same size"
InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}",
&l_leaf_array.chunks()[0].len(),
&r_leaf_array.chunks()[0].len()
);

let result = op(&l_leaf_array, &r_leaf_array)?;
Expand All @@ -151,7 +173,7 @@ impl ListChunked {

unsafe {
let mut result =
ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0);
ListChunked::new_with_dims(lhs.field.clone(), vec![result_chunk], 0, 0);
result.compute_len();
Ok(result.into())
}
Expand Down
10 changes: 0 additions & 10 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,16 +688,6 @@ def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any])


def test_list_arithmetic_error_cases() -> None:
# Different series length:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think some of these tests are necessary? E.g. pl.Series([[1, 2], [3, 4]]) + pl.Series([[1, 1], [2, 2], [3, 4]]) should complain.

(Separately I am honestly not super-excited about semantics of single-item Series working this way, but I guess that's how literals work? So if that design decision has already been made, oh well.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or rather, some variations of existing tests should still be there, by tweaking existing assertions instead of deleting them.

with pytest.raises(
InvalidOperationError, match="Series of the same size; got 1 and 2"
):
_ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], [3, 4]])
with pytest.raises(
InvalidOperationError, match="Series of the same size; got 1 and 2"
):
_ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], None])

# Different list length:
with pytest.raises(InvalidOperationError, match="lists of the same size"):
_ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1]])
Expand Down
26 changes: 26 additions & 0 deletions py-polars/tests/unit/operations/arithmetic/test_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import polars as pl


def test_literal_broadcast_list() -> None:
Copy link
Contributor

@itamarst itamarst Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should also test eager operations in addition to lazy operations, e.g. df.get_column("A") + lit, it sometimes hits different code paths for each when doing arithmetic, so you get different bugs.

I would also suggest adding a test where the literal is a different type than the list, e.g. a Int64, and again testing both lazy and eager (casting definitely works differently).

df = pl.DataFrame({"A": [[0.1, 0.2], [0.3, 0.4]]})

lit = pl.lit([3.0, 5.0])
assert df.select(
mul=pl.all() * lit,
div=pl.all() / lit,
add=pl.all() + lit,
sub=pl.all() - lit,
div_=lit / pl.all(),
add_=lit + pl.all(),
sub_=lit - pl.all(),
mul_=lit * pl.all(),
).to_dict(as_series=False) == {
"mul": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],
"div": [[0.03333333333333333, 0.04], [0.09999999999999999, 0.08]],
"add": [[3.1, 5.2], [3.3, 5.4]],
"sub": [[-2.9, -4.8], [-2.7, -4.6]],
"div_": [[30.0, 25.0], [10.0, 12.5]],
"add_": [[3.1, 5.2], [3.3, 5.4]],
"sub_": [[2.9, 4.8], [2.7, 4.6]],
"mul_": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],
}
Loading