Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-35166: [C++] Increase precision of decimals in aggregate functions #44184

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

khwilson
Copy link
Contributor

@khwilson khwilson commented Sep 23, 2024

Rationale for this change

As documented in #35166, when Arrow performs a sum, product, or mean of an array of type decimalX(P, S), it returns a scalar of type decimalX(P, S). This is true even if the aggregate does not fit in the specified precision. For instance, a sum of two decimal128(1, 0)'s such as 1 + 9 is a decimal128(2, 0). But (in Python):

import pyarrow as pa
from decimal import Decimal

arr = pa.array([Decimal("1"), Decimal("9")], type=pa.decimal128(1, 0))
assert arr.sum().type == pa.decimal128(1, 0)

This is recognized in the rules for binary addition and multiplication of decimals (see footnote 1 in this section), but this does not apply to array aggregates.

In #35166 I did a bit of research following a question from @westonpace , and it seems that there's no standard approach to this across DBMS's, but a common solution is to set the precision of the result of a sum to the maximum possible precision of the underlying type. That is, a sum of decimal128(1, 0)'s becomes a decimal128(38, 0).

However, products and means differ further. For instance, in both instances, duckdb converts a decimal to a double, which makes sense as the precision of the product of an array of decimals would likely be huge, e.g., an array of size N with precision 2 decimals would have precision at least 2^N.

This PR implements the minimum possible change: replace all return types of a product, sum, or mean aggregate of decimal128(P, S) to decimal128(38, S) and decimal256(P, S) to decimal256(76, S).

Please note, this PR is not done (see the checklist below), as it would be good to get feedback on the following before going through the whole checklist:

  • Is this the correct change to make (especially for products and means); and
  • The implementation relies on overriding "out types," which is how the current mean implementation works, but perhaps there's a better way to approach this.

What changes are included in this PR?

  • Update C++ kernels to support the change
  • Update docs to reflect the change
  • Fix tests in the languages that depend on the C++ engine
  • Determine if there are other languages which do not depend on the C++ engine which should also be updated

Are these changes tested?

They are tested in the following implementations:

  • Python
  • C++
  • Java Java does not currently implement any dependent tests

Are there any user-facing changes?

Yes. This changes the return type of a scalar aggregate of decimals.

This PR includes breaking changes to public APIs.

Specifically, the return type of a scalar aggregate of decimals changes. This is unlikely to break downstream applications as the underlying data has not changed, but if an application relies on the (incorrect!) type information for some reason, it would break.

Copy link

⚠️ GitHub issue #35166 has been automatically assigned in GitHub to PR creator.

@khwilson
Copy link
Contributor Author

@zeroshade I noticed in #43957 that you were adding in Decimal32/64 types, which I think will have the same problem that this PR addresses. I was curious if you might have interest in reviewing this PR?

@zeroshade
Copy link
Member

@khwilson Sure thing, i'll try to take a look at this in the next day or so

@khwilson
Copy link
Contributor Author

khwilson commented Oct 8, 2024

Hi @zeroshade just checking in! Thanks again for taking a look

Copy link
Member

@mapleFU mapleFU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is interesting, however, before doing that, do you think a user-side "cast" is ok?
Like:

cast(origin to decimal(large-enough)) then avg

cpp/src/arrow/compute/kernels/codegen_internal.h Outdated Show resolved Hide resolved
Comment on lines 93 to 94
return Status::TypeError(
"A call to MaxPrecisionDecimalType was made with a non-DecimalType");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind point out the type here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified to just call the common WidenDecimalToMaxPrecision function which is just the identity on non-Decimal types

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, TypeError and add the type here. Since we've supported decimal32 and decimal64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@@ -55,6 +55,7 @@ using internal::BinaryBitBlockCounter;
using internal::BitBlockCount;
using internal::BitmapReader;
using internal::checked_cast;
using internal::checked_pointer_cast;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why in header?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following the pattern of the internal::checked_cast on the line above. But I'm not wedded to including it.

@github-actions github-actions bot added awaiting committer review Awaiting committer review and removed awaiting review Awaiting review labels Oct 9, 2024
@khwilson
Copy link
Contributor Author

khwilson commented Oct 9, 2024

Thanks for the review!

By a user-side cast, do you mean that users should essentially do:

select avg(cast(blah as decimal(big-precision)))

instead of

select avg(blah)

or do you mean that this code should "inject" a cast on the "user" side?

If you mean putting the cast onto the user, then I would think you'd want to add an error if the answer can't fit into the default precision, but that seems like it would be more disruptive (and out of step with how other systems handle decimal aggregates).

If you mean "injecting" the cast on the user side, would that end up creating a copy of the array?

@zeroshade
Copy link
Member

@khwilson Hey sorry for the delay from me here, I've been traveling a lot lately for work and have been at ASF Community Over Code this week. I promise i'll get to this soon. In the meantime, you're in the very capable hands of @mapleFU

@khwilson
Copy link
Contributor Author

khwilson commented Oct 9, 2024

No problem! Hope your travels were fun!

@mapleFU
Copy link
Member

mapleFU commented Oct 11, 2024

Generally this method is ok for me, but I'm not so familiar with the "common solutions" here. I'll dive into Presto/ClickHouse to see the common pattern here

@khwilson
Copy link
Contributor Author

I enumerated several here: #35166 (comment)

Clickhouse for instance just ignores precision.

@mapleFU
Copy link
Member

mapleFU commented Oct 11, 2024

Would you mind making this Ready for review?

@khwilson
Copy link
Contributor Author

Sure!

This is an initial pass whereby a scalar aggregate of a Decimal type
increases its precision to the maximum. That is, a sum of an
array of decimal128(3, 2)'s becomes a decimal128(38, 2).

Previously, the exact decimal type was preserved (e.g., a sum of
decimal128(3, 2)'s was a decimal128(3, 2)) *regardless* of whether
that was enough precision to capture the full decimal value.
@khwilson khwilson force-pushed the increase-precision-of-decimals branch from 89f1ae9 to 01e53b3 Compare October 13, 2024 17:03
@khwilson khwilson marked this pull request as ready for review October 13, 2024 20:23
@khwilson
Copy link
Contributor Author

@mapleFU I believe this is done now. Some notes on the diff:

  • The hash aggregates had to be updated as well (missed them in the first pass)
  • I've also added in Decimal32/64 support the basic aggregates (sum, product, mean, min/max, index). However, there's quite a lot of missing support for these types still in compute (most notably in casts)
  • Docs are updated to reflect the change

And a note that quite a few tests are failing for what appears to be the same reason as #41390. Happy to address them if you'd like.

Copy link
Member

@mapleFU mapleFU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also cc @pitrou @bkietz

cpp/src/arrow/compute/kernels/codegen_internal.h Outdated Show resolved Hide resolved
@pitrou
Copy link
Member

pitrou commented Oct 15, 2024

I'm lukewarm about the approach here. Silently casting to the max precision discards metadata about the input; it also risks producing errors further down the line (if e.g. the max precision is deemed too large for other operations). It also doesn't automatically eliminate any potential overflow, for example:

>>> a = pa.array([789.3] * 20).cast(pa.decimal128(38, 35))
>>> a
<pyarrow.lib.Decimal128Array object at 0x7f0f103ca7a0>
[
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440
]
>>> pc.sum(a)
<pyarrow.Decimal128Scalar: Decimal('-1228.11834604692408266343214451664848480')>

We should instead check that the result of an aggregate fits into the resulting Decimal type, while overflows currently pass silently:

>>> a = pa.array([123., 456., 789.]).cast(pa.decimal128(4, 1))
>>> a
<pyarrow.lib.Decimal128Array object at 0x7f0ed06261a0>
[
  123.0,
  456.0,
  789.0
]
>>> pc.sum(a)
<pyarrow.Decimal128Scalar: Decimal('1368.0')>
>>> pc.sum(a).validate(full=True)
Traceback (most recent call last):
  ...
ArrowInvalid: Decimal value 13680 does not fit in precision of decimal128(4, 1)

@khwilson
Copy link
Contributor Author

Two problems with just validating afterward: First, I'd expect in reasonable cases for the validation to fail. A sum of 1m decimals of approximately the same size you'd expect to have 6 more digits of precision. I assume this is why all the DBMSs I looked at increase the precision by default.

Second, just checking for overflow doesn't solve the underlying problem. Consider:

a = pa.array([789.3] * 18).cast(pa.decimal128(38, 35))
print(pc.sum(a))
pc.sum(a).validate(full=True)  # passes

In duckdb, they implement an intermediate check to make sure that there's not an internal overflow:

tab = pa.Table.from_pydict({"a": a})
duckdb.query("select sum(a) from tab")
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
# duckdb.duckdb.OutOfRangeException: Out of Range Error: Overflow in HUGEINT addition: 
# 157859999999999990905052982270717620880 + 78929999999999995452526491135358810440

Notably, this lack of overflow checking also applies to integer sums in arrow:

>>> pa.array([9223372036854775800] * 2, type=pa.int64())
<pyarrow.lib.Int64Array object at 0x10c1d8b80>
[
  9223372036854775800,
  9223372036854775800
]
>>> pc.sum(pa.array([9223372036854775800] * 2, type=pa.int64()))
<pyarrow.Int64Scalar: -16>
>>> pc.sum(pa.array([9223372036854775800] * 2, type=pa.int64())).validate(full=True)

@pitrou
Copy link
Member

pitrou commented Oct 15, 2024

Two problems with just validating afterward: First, I'd expect in reasonable cases for the validation to fail. A sum of 1m decimals of approximately the same size you'd expect to have 6 more digits of precision.

It depends obviously if all decimals are of the same sign, and what their actual magnitude is.

Second, just checking for overflow doesn't solve the underlying problem.

In the example above, I used a validate call simply to show that the result was indeed erroneous. I didn't mean we should actually call validation afterwards. We should instead check for overflow at each individual aggregation step (for each add or multiply, for example). This is required even if we were to bump the result's precision to the max.

@pitrou
Copy link
Member

pitrou commented Oct 15, 2024

Notably, this lack of overflow checking also applies to integer sums in arrow:

Yes, and there's already a bug open for it: #37090

@khwilson
Copy link
Contributor Author

Nice! I'm excited for the checked variants of sum and product!

With the integer overflow example, I only meant to point out that the compute module currently allows overflows, so I think it would be unexpected for sum to complain about an overflow only if the underlying type was a decimal. But if the goal with #37090 is to replace sum with a checked version, then the solution of erroring makes a lot of sense, and I'd be happy to implement it when #37536 gets merged. :-)

Still, I do think that users would find it unexpected to get an error if the sum fit in the underlying storage since this is how all the databases I've used (and the four I surveyed in #35166) have operated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants