Skip to content

Commit

Permalink
apacheGH-34909: [C++] Avoid mean overflow on large integer inputs (ap…
Browse files Browse the repository at this point in the history
…ache#37243)

### Rationale for this change

The `mean` aggregate function would overflow if the input array's sum is larger than int64_max.

### What changes are included in this PR?

Store intermediate sum in double instead of int64, so that it won't overflow.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

No.

* Closes: apache#34909

Authored-by: Jin Shang <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
js8544 authored Aug 30, 2023
1 parent c9012a0 commit e0ee82c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
57 changes: 38 additions & 19 deletions cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@
#pragma once

#include <cmath>
#include <type_traits>
#include <utility>

#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/kernels/aggregate_internal.h"
#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/align_util.h"
#include "arrow/util/bit_block_counter.h"
#include "arrow/util/decimal.h"

namespace arrow {
namespace compute {
namespace internal {
namespace arrow::compute::internal {

void AddBasicAggKernels(KernelInit init,
const std::vector<std::shared_ptr<DataType>>& types,
Expand All @@ -58,16 +59,17 @@ void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func);
// ----------------------------------------------------------------------
// Sum implementation

template <typename ArrowType, SimdLevel::type SimdLevel>
template <typename ArrowType, SimdLevel::type SimdLevel,
typename ResultType = typename FindAccumulatorType<ArrowType>::Type>
struct SumImpl : public ScalarAggregator {
using ThisType = SumImpl<ArrowType, SimdLevel>;
using ThisType = SumImpl<ArrowType, SimdLevel, ResultType>;
using CType = typename TypeTraits<ArrowType>::CType;
using SumType = typename FindAccumulatorType<ArrowType>::Type;
using SumType = ResultType;
using SumCType = typename TypeTraits<SumType>::CType;
using OutputType = typename TypeTraits<SumType>::ScalarType;

SumImpl(std::shared_ptr<DataType> out_type, const ScalarAggregateOptions& options_)
: out_type(out_type), options(options_) {}
SumImpl(std::shared_ptr<DataType> out_type, ScalarAggregateOptions options_)
: out_type(std::move(out_type)), options(std::move(options_)) {}

Status Consume(KernelContext*, const ExecSpan& batch) override {
if (batch[0].is_array()) {
Expand Down Expand Up @@ -169,14 +171,19 @@ struct NullSumImpl : public NullImpl<ArrowType> {
}
};

template <typename ArrowType, SimdLevel::type SimdLevel, typename Enable = void>
struct MeanImpl;

template <typename ArrowType, SimdLevel::type SimdLevel>
struct MeanImpl : public SumImpl<ArrowType, SimdLevel> {
struct MeanImpl<ArrowType, SimdLevel, enable_if_decimal<ArrowType>>
: public SumImpl<ArrowType, SimdLevel> {
using SumImpl<ArrowType, SimdLevel>::SumImpl;
using SumImpl<ArrowType, SimdLevel>::options;
using SumCType = typename SumImpl<ArrowType, SimdLevel>::SumCType;
using OutputType = typename SumImpl<ArrowType, SimdLevel>::OutputType;

template <typename T = ArrowType>
enable_if_decimal<T, Status> FinalizeImpl(Datum* out) {
using SumCType = typename SumImpl<ArrowType, SimdLevel>::SumCType;
using OutputType = typename SumImpl<ArrowType, SimdLevel>::OutputType;
Status FinalizeImpl(Datum* out) {
if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count) || (this->count == 0)) {
out->value = std::make_shared<OutputType>(this->out_type);
Expand All @@ -196,20 +203,34 @@ struct MeanImpl : public SumImpl<ArrowType, SimdLevel> {
}
return Status::OK();
}

Status Finalize(KernelContext*, Datum* out) override { return FinalizeImpl(out); }
};

template <typename ArrowType, SimdLevel::type SimdLevel>
struct MeanImpl<ArrowType, SimdLevel,
std::enable_if_t<!is_decimal_type<ArrowType>::value>>
// Override the ResultType of SumImpl because we need to use double for intermediate
// sum to prevent integer overflows
: public SumImpl<ArrowType, SimdLevel, DoubleType> {
using SumImpl<ArrowType, SimdLevel, DoubleType>::SumImpl;
using SumImpl<ArrowType, SimdLevel, DoubleType>::options;

template <typename T = ArrowType>
enable_if_t<!is_decimal_type<T>::value, Status> FinalizeImpl(Datum* out) {
Status FinalizeImpl(Datum* out) {
if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
out->value = std::make_shared<DoubleScalar>();
} else {
const double mean = static_cast<double>(this->sum) / this->count;
static_assert(std::is_same_v<decltype(this->sum), double>,
"SumCType must be double for numeric inputs");
const double mean = this->sum / this->count;
out->value = std::make_shared<DoubleScalar>(mean);
}
return Status::OK();
}
Status Finalize(KernelContext*, Datum* out) override { return FinalizeImpl(out); }

using SumImpl<ArrowType, SimdLevel>::options;
Status Finalize(KernelContext*, Datum* out) override { return FinalizeImpl(out); }
};

template <template <typename> class KernelClass>
Expand Down Expand Up @@ -1012,6 +1033,4 @@ struct MinMaxInitState {
}
};

} // namespace internal
} // namespace compute
} // namespace arrow
} // namespace arrow::compute::internal
14 changes: 14 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,20 @@ TYPED_TEST(TestNumericMeanKernel, ScalarAggregateOptions) {
ResultWith(Datum(MakeNullScalar(float64()))));
}

TEST(TestNumericMeanKernel, Overflow) {
// will overflow if intermediate sum is int64_t
EXPECT_THAT(
Mean(ArrayFromJSON(
int64(), "[9223372036854775805, 9223372036854775806, 9223372036854775807]")),
ResultWith(ScalarFromJSON(float64(), "9223372036854775806")));

// will overflow if intermediate sum is uint64_t
EXPECT_THAT(
Mean(ArrayFromJSON(
uint64(), "[9223372036854775805, 9223372036854775806, 9223372036854775807]")),
ResultWith(ScalarFromJSON(float64(), "9223372036854775806")));
}

template <typename ArrowType>
class TestRandomNumericMeanKernel : public ::testing::Test {};

Expand Down

0 comments on commit e0ee82c

Please sign in to comment.