From 6cff6f2c0f78f804dabe1bdda8c8aa76da95ec64 Mon Sep 17 00:00:00 2001 From: ZhangHuiGui Date: Tue, 26 Mar 2024 22:26:20 +0800 Subject: [PATCH] fix comments --- cpp/src/arrow/acero/asof_join_node.cc | 2 +- cpp/src/arrow/compute/key_hash.cc | 63 +++++++++++++++++--------- cpp/src/arrow/compute/key_hash.h | 23 +--------- cpp/src/arrow/compute/key_hash_test.cc | 14 ++++-- cpp/src/arrow/compute/row/grouper.cc | 4 +- cpp/src/arrow/compute/util.cc | 18 ++++---- cpp/src/arrow/compute/util.h | 5 +- 7 files changed, 69 insertions(+), 60 deletions(-) diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index cf0d475c1d770..e5211b63d2c53 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -436,7 +436,7 @@ class KeyHasher { ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i, length); } // write directly to the cache - Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i); + DCHECK_OK(Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i)); } DEBUG_SYNC(node_, "key hasher ", index_, " got hashes ", compute::internal::GenericToString(hashes_), DEBUG_MANIP(std::endl)); diff --git a/cpp/src/arrow/compute/key_hash.cc b/cpp/src/arrow/compute/key_hash.cc index f8fcc1be82cad..493df15c8fcd1 100644 --- a/cpp/src/arrow/compute/key_hash.cc +++ b/cpp/src/arrow/compute/key_hash.cc @@ -378,20 +378,28 @@ void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t } } -void Hashing32::HashMultiColumn(const std::vector& cols, +Status Hashing32::HashMultiColumn(const std::vector& cols, LightContext* ctx, uint32_t* hashes) { uint32_t num_rows = static_cast(cols[0].length()); constexpr uint32_t max_batch_size = util::MiniBatch::kMiniBatchLength; - const uint32_t alloc_batch_size = std::min(num_rows, max_batch_size); - const int64_t estimate_alloc_size = EstimateBatchStackSize(alloc_batch_size); + const auto alloc_batch_size = std::min(num_rows, max_batch_size); - util::TempVectorStack temp_stack; + // pre calculate alloc size in TempVectorStack for hash_temp_buf, null_hash_temp_buf + // and null_indices_buf + const auto alloc_hash_temp_buf = + util::TempVectorStack::EstimateAllocSize(alloc_batch_size * sizeof(uint32_t)); + const auto alloc_for_null_indices_buf = + util::TempVectorStack::EstimateAllocSize(alloc_batch_size * sizeof(uint16_t)); + const auto alloc_size = alloc_hash_temp_buf * 2 + alloc_for_null_indices_buf; + + std::shared_ptr temp_stack(nullptr); if (!ctx->stack) { - ARROW_CHECK_OK(temp_stack.Init(default_memory_pool(), estimate_alloc_size)); - ctx->stack = &temp_stack; + temp_stack = std::make_shared(); + RETURN_NOT_OK(temp_stack->Init(default_memory_pool(), alloc_size)); + ctx->stack = temp_stack.get(); } else { - ctx->stack->CheckAllocSizeValid(estimate_alloc_size); + RETURN_NOT_OK(ctx->stack->CheckAllocOverflow(alloc_size)); } auto hash_temp_buf = util::TempVectorHolder(ctx->stack, alloc_batch_size); @@ -471,6 +479,11 @@ void Hashing32::HashMultiColumn(const std::vector& cols, first_row += batch_size_next; } + + if (temp_stack) { + ctx->stack = nullptr; + } + return Status::OK(); } Status Hashing32::HashBatch(const ExecBatch& key_batch, uint32_t* hashes, @@ -483,9 +496,7 @@ Status Hashing32::HashBatch(const ExecBatch& key_batch, uint32_t* hashes, LightContext ctx; ctx.hardware_flags = hardware_flags; ctx.stack = temp_stack; - - HashMultiColumn(column_arrays, &ctx, hashes); - return Status::OK(); + return HashMultiColumn(column_arrays, &ctx, hashes); } inline uint64_t Hashing64::Avalanche(uint64_t acc) { @@ -833,20 +844,27 @@ void Hashing64::HashFixed(bool combine_hashes, uint32_t num_keys, uint64_t key_l } } -void Hashing64::HashMultiColumn(const std::vector& cols, - LightContext* ctx, uint64_t* hashes) { +Status Hashing64::HashMultiColumn(const std::vector& cols, + LightContext* ctx, uint64_t* hashes) { uint32_t num_rows = static_cast(cols[0].length()); constexpr uint32_t max_batch_size = util::MiniBatch::kMiniBatchLength; - const uint32_t alloc_batch_size = std::min(num_rows, max_batch_size); - const uint64_t estimate_alloc_size = EstimateBatchStackSize(alloc_batch_size); + const auto alloc_batch_size = std::min(num_rows, max_batch_size); - util::TempVectorStack temp_stack; + // pre calculate alloc size in TempVectorStack for null_indices_buf, null_hash_temp_buf + const auto alloc_for_null_hash_temp_buf = + util::TempVectorStack::EstimateAllocSize(alloc_batch_size * sizeof(uint64_t)); + const auto alloc_for_null_indices_buf = + util::TempVectorStack::EstimateAllocSize(alloc_batch_size * sizeof(uint16_t)); + const auto alloc_size = alloc_for_null_hash_temp_buf + alloc_for_null_indices_buf; + + std::shared_ptr temp_stack(nullptr); if (!ctx->stack) { - ARROW_CHECK_OK(temp_stack.Init(default_memory_pool(), estimate_alloc_size)); - ctx->stack = &temp_stack; + temp_stack = std::make_shared(); + RETURN_NOT_OK(temp_stack->Init(default_memory_pool(), alloc_size)); + ctx->stack = temp_stack.get(); } else { - ctx->stack->CheckAllocSizeValid(estimate_alloc_size); + RETURN_NOT_OK(ctx->stack->CheckAllocOverflow(alloc_size)); } auto null_indices_buf = util::TempVectorHolder(ctx->stack, alloc_batch_size); @@ -920,6 +938,11 @@ void Hashing64::HashMultiColumn(const std::vector& cols, first_row += batch_size_next; } + + if (temp_stack) { + ctx->stack = nullptr; + } + return Status::OK(); } Status Hashing64::HashBatch(const ExecBatch& key_batch, uint64_t* hashes, @@ -932,9 +955,7 @@ Status Hashing64::HashBatch(const ExecBatch& key_batch, uint64_t* hashes, LightContext ctx; ctx.hardware_flags = hardware_flags; ctx.stack = temp_stack; - - HashMultiColumn(column_arrays, &ctx, hashes); - return Status::OK(); + return HashMultiColumn(column_arrays, &ctx, hashes); } } // namespace compute diff --git a/cpp/src/arrow/compute/key_hash.h b/cpp/src/arrow/compute/key_hash.h index dcb3f867980f9..be4809c6992f1 100644 --- a/cpp/src/arrow/compute/key_hash.h +++ b/cpp/src/arrow/compute/key_hash.h @@ -45,7 +45,7 @@ class ARROW_EXPORT Hashing32 { friend void TestBloomSmall(BloomFilterBuildStrategy, int64_t, int, bool, bool); public: - static void HashMultiColumn(const std::vector& cols, LightContext* ctx, + static Status HashMultiColumn(const std::vector& cols, LightContext* ctx, uint32_t* out_hash); static Status HashBatch(const ExecBatch& key_batch, uint32_t* hashes, @@ -158,7 +158,7 @@ class ARROW_EXPORT Hashing64 { friend void TestBloomSmall(BloomFilterBuildStrategy, int64_t, int, bool, bool); public: - static void HashMultiColumn(const std::vector& cols, LightContext* ctx, + static Status HashMultiColumn(const std::vector& cols, LightContext* ctx, uint64_t* hashes); static Status HashBatch(const ExecBatch& key_batch, uint64_t* hashes, @@ -219,24 +219,5 @@ class ARROW_EXPORT Hashing64 { const uint8_t* keys, uint64_t* hashes); }; -template -static int64_t EstimateBatchStackSize(int32_t batch_size) { - if (sizeof(T) == sizeof(uint32_t)) { - const int64_t alloc_for_hash_temp_buf = - util::TempVectorStack::EstimateAllocSize(batch_size * sizeof(uint32_t)); - const int64_t alloc_for_null_hash_temp_buf = alloc_for_hash_temp_buf; - const int64_t alloc_for_null_indices_buf = - util::TempVectorStack::EstimateAllocSize(batch_size * sizeof(uint16_t)); - return alloc_for_hash_temp_buf + alloc_for_null_hash_temp_buf + - alloc_for_null_indices_buf; - } else { - const int64_t alloc_for_null_hash_temp_buf = - util::TempVectorStack::EstimateAllocSize(batch_size * sizeof(uint64_t)); - const int64_t alloc_for_null_indices_buf = - util::TempVectorStack::EstimateAllocSize(batch_size * sizeof(uint16_t)); - return alloc_for_null_hash_temp_buf + alloc_for_null_indices_buf; - } -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/key_hash_test.cc b/cpp/src/arrow/compute/key_hash_test.cc index 5584d4497396d..ee60e9c984b31 100644 --- a/cpp/src/arrow/compute/key_hash_test.cc +++ b/cpp/src/arrow/compute/key_hash_test.cc @@ -313,7 +313,7 @@ TEST(VectorHash, FixedLengthTailByteSafety) { TEST(HashBatch, AllocTempStackAsNeeded) { auto arr = arrow::ArrayFromJSON(arrow::int32(), "[9,2,6]"); - const int32_t batch_size = static_cast(arr->length()); + const auto batch_size = static_cast(arr->length()); arrow::compute::ExecBatch exec_batch({arr}, batch_size); auto ctx = arrow::compute::default_exec_context(); std::vector temp_column_arrays; @@ -324,11 +324,17 @@ TEST(HashBatch, AllocTempStackAsNeeded) { exec_batch, h1.data(), temp_column_arrays, ctx->cpu_info()->hardware_flags(), nullptr, 0, batch_size)); - // alloc stack as HashBatch needed. util::TempVectorStack stack; - ASSERT_OK( - stack.Init(default_memory_pool(), EstimateBatchStackSize(batch_size))); std::vector h2(batch_size); + + // alloc stack overflow in HashBatch + ASSERT_OK(stack.Init(default_memory_pool(), batch_size)); + ASSERT_NOT_OK(arrow::compute::Hashing32::HashBatch( + exec_batch, h2.data(), temp_column_arrays, ctx->cpu_info()->hardware_flags(), + &stack, 0, batch_size)); + + // alloc stack normally in HashBatch + ASSERT_OK(stack.Init(default_memory_pool(), 1024)); ASSERT_OK(arrow::compute::Hashing32::HashBatch( exec_batch, h2.data(), temp_column_arrays, ctx->cpu_info()->hardware_flags(), &stack, 0, batch_size)); diff --git a/cpp/src/arrow/compute/row/grouper.cc b/cpp/src/arrow/compute/row/grouper.cc index 5e23eda16fda2..d5ba6f9fa402b 100644 --- a/cpp/src/arrow/compute/row/grouper.cc +++ b/cpp/src/arrow/compute/row/grouper.cc @@ -680,8 +680,8 @@ struct GrouperFastImpl : public Grouper { encoder_.PrepareEncodeSelected(start_row, batch_size_next, cols_); // Compute hash - Hashing32::HashMultiColumn(encoder_.batch_all_cols(), &encode_ctx_, - minibatch_hashes_.data()); + RETURN_NOT_OK(Hashing32::HashMultiColumn(encoder_.batch_all_cols(), &encode_ctx_, + minibatch_hashes_.data())); // Map auto match_bitvector = diff --git a/cpp/src/arrow/compute/util.cc b/cpp/src/arrow/compute/util.cc index 078292e7827f5..199e0ddace45c 100644 --- a/cpp/src/arrow/compute/util.cc +++ b/cpp/src/arrow/compute/util.cc @@ -32,10 +32,10 @@ using internal::CpuInfo; namespace util { void TempVectorStack::alloc(uint32_t num_bytes, uint8_t** data, int* id) { - int64_t new_top = top_ + PaddedAllocationSize(num_bytes) + 2 * sizeof(uint64_t); - // Stack overflow check (see GH-39582). + const auto estimate_size = EstimateAllocSize(num_bytes); // XXX cannot return a regular Status because most consumers do not either. - CheckAllocSizeValid(new_top); + ARROW_DCHECK_OK(CheckAllocOverflow(estimate_size)); + int64_t new_top = top_ + estimate_size; *data = buffer_->mutable_data() + top_ + sizeof(uint64_t); // We set 8 bytes before the beginning of the allocated range and // 8 bytes after the end to check for stack overflow (which would @@ -58,11 +58,13 @@ void TempVectorStack::release(int id, uint32_t num_bytes) { --num_vectors_; } -void TempVectorStack::CheckAllocSizeValid(int64_t estimate_alloc_size) { - ARROW_DCHECK_LE(estimate_alloc_size, buffer_size_) - << "TempVectorStack alloc overflow." - "(Actual " - << buffer_size_ << "Bytes, expect " << estimate_alloc_size << "Bytes)"; +Status TempVectorStack::CheckAllocOverflow(int64_t alloc_size) { + // Stack overflow check (see GH-39582). + if ((alloc_size + top_) > buffer_size_) { + return Status::Invalid("TempVectorStack alloc overflow. (Actual ", buffer_size_, + "Bytes, expect ", alloc_size, "Bytes)"); + } + return Status::OK(); } namespace bit_util { diff --git a/cpp/src/arrow/compute/util.h b/cpp/src/arrow/compute/util.h index c2334db032884..fab86f44dcd34 100644 --- a/cpp/src/arrow/compute/util.h +++ b/cpp/src/arrow/compute/util.h @@ -89,7 +89,7 @@ class ARROW_EXPORT TempVectorStack { Status Init(MemoryPool* pool, int64_t size) { num_vectors_ = 0; top_ = 0; - buffer_size_ = PaddedAllocationSize(size) + 2 * sizeof(uint64_t); + buffer_size_ = EstimateAllocSize(size); ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(size, pool)); // Ensure later operations don't accidentally read uninitialized memory. std::memset(buffer->mutable_data(), 0xFF, size); @@ -101,8 +101,7 @@ class ARROW_EXPORT TempVectorStack { return PaddedAllocationSize(size) + 2 * sizeof(uint64_t); } - int64_t StackBufferSize() const { return buffer_size_; } - void CheckAllocSizeValid(int64_t estimate_alloc_size); + Status CheckAllocOverflow(int64_t alloc_size); private: static int64_t PaddedAllocationSize(int64_t num_bytes) {