From 3a3f3830ea6f4c467ae1408948c13cded4ba3893 Mon Sep 17 00:00:00 2001 From: cmdlineluser <99486669+cmdlineluser@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:04:55 +0100 Subject: [PATCH 1/3] debugging --- .../src/series/arithmetic/list_borrowed.rs | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index 1628780d7b0e..05d85e9f9ba5 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -53,6 +53,25 @@ fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { } } +fn broadcast_list(lhs: &ListChunked, rhs: &Series) -> PolarsResult<(ListChunked, Series)> { + let out = match (lhs.len(), rhs.len()) { + (1, _) => (lhs.new_from_index(0, rhs.len()), rhs.clone()), + (_, 1) => { + // Numeric scalars will be broadcasted implicitly without intermediate allocation. + if rhs.dtype().is_numeric() { + (lhs.clone(), rhs.clone()) + } else { + (lhs.clone(), rhs.new_from_index(0, lhs.len())) + } + }, + (a, b) if a == b => (lhs.clone(), rhs.clone()), + _ => { + polars_bail!(InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", lhs.dtype(), rhs.dtype()) + }, + }; + Ok(out) +} + impl ListChunked { /// Helper function for NumOpsDispatchInner implementation for ListChunked. /// @@ -63,6 +82,7 @@ impl ListChunked { op: &dyn Fn(&Series, &Series) -> PolarsResult, has_nulls: Option, ) -> PolarsResult { + let (lhs, rhs) = broadcast_list(self, rhs)?; polars_ensure!( self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", From 006fa33eb69aec9bb13c140c57c6ff70c99a71ff Mon Sep 17 00:00:00 2001 From: cmdlineluser <99486669+cmdlineluser@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:41:11 +0100 Subject: [PATCH 2/3] fix: Properly broadcast array arithmetic --- .../src/series/arithmetic/list_borrowed.rs | 22 ++++++++++--------- .../operations/arithmetic/test_arithmetic.py | 10 --------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index 05d85e9f9ba5..d44e181d961b 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -84,15 +84,15 @@ impl ListChunked { ) -> PolarsResult { let (lhs, rhs) = broadcast_list(self, rhs)?; polars_ensure!( - self.len() == rhs.len(), + lhs.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", - self.len(), + lhs.len(), rhs.len() ); let mut has_nulls = has_nulls.unwrap_or(false); if !has_nulls { - for chunk in self.chunks().iter() { + for chunk in lhs.chunks().iter() { if does_list_have_nulls(chunk) { has_nulls = true; break; @@ -112,11 +112,11 @@ impl ListChunked { // values Arrow arrays. Given nulls, the two values arrays might not // line up the way we expect. let mut result = AnonymousListBuilder::new( - self.name().clone(), - self.len(), - Some(self.inner_dtype().clone()), + lhs.name().clone(), + lhs.len(), + Some(lhs.inner_dtype().clone()), ); - let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { + let combined = lhs.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { let (Some(a_owner), Some(b_owner)) = (a, b) else { // Operations with nulls always result in nulls: return Ok(None); @@ -151,12 +151,14 @@ impl ListChunked { } return Ok(result.finish().into()); } - let l_rechunked = self.clone().rechunk().into_series(); + let l_rechunked = lhs.clone().rechunk().into_series(); let l_leaf_array = l_rechunked.get_leaf_array(); let r_leaf_array = rhs.rechunk().get_leaf_array(); polars_ensure!( lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]), - InvalidOperation: "can only do arithmetic operations on lists of the same size" + InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", + &l_leaf_array.chunks()[0].len(), + &r_leaf_array.chunks()[0].len() ); let result = op(&l_leaf_array, &r_leaf_array)?; @@ -171,7 +173,7 @@ impl ListChunked { unsafe { let mut result = - ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0); + ListChunked::new_with_dims(lhs.field.clone(), vec![result_chunk], 0, 0); result.compute_len(); Ok(result.into()) } diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 360def065ca1..aa2c3bb9099e 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -688,16 +688,6 @@ def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) def test_list_arithmetic_error_cases() -> None: - # Different series length: - with pytest.raises( - InvalidOperationError, match="Series of the same size; got 1 and 2" - ): - _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) - with pytest.raises( - InvalidOperationError, match="Series of the same size; got 1 and 2" - ): - _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], None]) - # Different list length: with pytest.raises(InvalidOperationError, match="lists of the same size"): _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1]]) From bc8029ea5e03d5e100b15dec323ef6ec46d702f2 Mon Sep 17 00:00:00 2001 From: cmdlineluser <99486669+cmdlineluser@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:42:55 +0100 Subject: [PATCH 3/3] add test --- .../unit/operations/arithmetic/test_list.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 py-polars/tests/unit/operations/arithmetic/test_list.py diff --git a/py-polars/tests/unit/operations/arithmetic/test_list.py b/py-polars/tests/unit/operations/arithmetic/test_list.py new file mode 100644 index 000000000000..07e105fb43f6 --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/test_list.py @@ -0,0 +1,26 @@ +import polars as pl + + +def test_literal_broadcast_list() -> None: + df = pl.DataFrame({"A": [[0.1, 0.2], [0.3, 0.4]]}) + + lit = pl.lit([3.0, 5.0]) + assert df.select( + mul=pl.all() * lit, + div=pl.all() / lit, + add=pl.all() + lit, + sub=pl.all() - lit, + div_=lit / pl.all(), + add_=lit + pl.all(), + sub_=lit - pl.all(), + mul_=lit * pl.all(), + ).to_dict(as_series=False) == { + "mul": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]], + "div": [[0.03333333333333333, 0.04], [0.09999999999999999, 0.08]], + "add": [[3.1, 5.2], [3.3, 5.4]], + "sub": [[-2.9, -4.8], [-2.7, -4.6]], + "div_": [[30.0, 25.0], [10.0, 12.5]], + "add_": [[3.1, 5.2], [3.3, 5.4]], + "sub_": [[2.9, 4.8], [2.7, 4.6]], + "mul_": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]], + }