diff --git a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc index bf67a474f31e2..b7c70c1a8bf24 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc @@ -894,18 +894,23 @@ Status ExtensionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult } // Transform filter to selection indices and then use Take. -Status FilterWithTakeExec(const ArrayKernelExec& take_exec, KernelContext* ctx, +Status FilterWithTakeExec(TakeKernelExec take_aaa_exec, KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - std::shared_ptr indices; + std::shared_ptr indices_data; RETURN_NOT_OK(GetTakeIndices(batch[1].array, FilterState::Get(ctx).null_selection_behavior, ctx->memory_pool()) - .Value(&indices)); + .Value(&indices_data)); + KernelContext take_ctx(*ctx); TakeState state{TakeOptions::NoBoundsCheck()}; take_ctx.SetState(&state); - ExecSpan take_batch({batch[0], ArraySpan(*indices)}, batch.length); - return take_exec(&take_ctx, take_batch, out); + + ValuesSpan values(batch[0].array); + std::shared_ptr out_data = out->array_data(); + RETURN_NOT_OK(take_aaa_exec(&take_ctx, values, *indices_data, &out_data)); + out->value = std::move(out_data); + return Status::OK(); } // Due to the special treatment with their Take kernels, we filter Struct and SparseUnion diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc index 04e725496c34f..4a2d5ece192b4 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc @@ -968,69 +968,80 @@ Status MapFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) namespace { -template -Status TakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { +template +Status TakeAAAExec(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices, + std::shared_ptr* out) { + DCHECK(!values.is_chunked()) + << "TakeAAAExec kernels can't be called with chunked array values"; if (TakeState::Get(ctx).boundscheck) { - RETURN_NOT_OK(CheckIndexBounds(batch[1].array, batch[0].length())); + RETURN_NOT_OK(CheckIndexBounds(indices, values.length())); } - Impl kernel(ctx, batch, /*output_length=*/batch[1].length(), out); + SelectionImpl kernel(ctx, values, indices, out); return kernel.ExecTake(); } } // namespace -Status VarBinaryTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status VarBinaryTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status LargeVarBinaryTakeExec(KernelContext* ctx, const ExecSpan& batch, - ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status LargeVarBinaryTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status ListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status ListTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status LargeListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status LargeListTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status ListViewTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status ListViewTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status LargeListViewTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status LargeListViewTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status FSLTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - const ArraySpan& values = batch[0].array; - +Status FSLTakeExec(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices, + std::shared_ptr* out) { // If a FixedSizeList wraps a fixed-width type we can, in some cases, use // FixedWidthTakeExec for a fixed-size list array. - if (util::IsFixedWidthLike(values, + if (util::IsFixedWidthLike(values.array(), /*force_null_count=*/true, /*exclude_bool_and_dictionary=*/true)) { - return FixedWidthTakeExec(ctx, batch, out); + return FixedWidthTakeExec(ctx, values, indices, out); } - return TakeExec(ctx, batch, out); + return TakeAAAExec(ctx, values, indices, out); } -Status DenseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec(ctx, batch, out); +Status DenseUnionTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec(ctx, values, indices, out); } -Status SparseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec(ctx, batch, out); +Status SparseUnionTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec(ctx, values, indices, out); } -Status StructTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec(ctx, batch, out); +Status StructTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec(ctx, values, indices, out); } -Status MapTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status MapTakeExec(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices, + std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } } // namespace compute::internal diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.h b/cpp/src/arrow/compute/kernels/vector_selection_internal.h index 99278f0046589..6a4a4fe868649 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.h @@ -100,6 +100,20 @@ class ValuesSpan { } }; +/// \brief Type for a single "array_take" kernel function. +/// +/// Instead of implementing both `ArrayKernelExec` and `ChunkedExec` typed +/// functions for each configurations of `array_take` parameters, we use +/// templates wrapping `TakeKernelExec` functions to expose exec functions +/// that can be registered in the kernel registry. +/// +/// A `TakeKernelExec` always returns a single array, which is the result of +/// taking values from a single array (AA->A) or multiple arrays (CA->A). The +/// wrappers take care of converting the output of a CA call to C or calling +/// the kernel multiple times to process a CC call. +using TakeKernelExec = Status (*)(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); + struct SelectionKernelData { SelectionKernelData(InputType value_type, InputType selection_type, ArrayKernelExec exec, @@ -149,19 +163,33 @@ Status FSLFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status DenseUnionFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status MapFilterExec(KernelContext*, const ExecSpan&, ExecResult*); -Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status LargeVarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status FixedWidthTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status FixedWidthTakeChunkedExec(KernelContext*, const ExecBatch&, Datum*); -Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status ListViewTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status LargeListViewTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status DenseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status SparseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status StructTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status MapTakeExec(KernelContext*, const ExecSpan&, ExecResult*); +// Take kernels compatible with the TakeKernelExec signature +Status VarBinaryTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status LargeVarBinaryTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status FixedWidthTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status FixedWidthTakeChunkedExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status ListTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status LargeListTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status ListViewTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status LargeListViewTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status FSLTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status DenseUnionTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status SparseUnionTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status StructTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status MapTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); } // namespace compute::internal } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc index b3623a1baaca9..e4cc24fcd5f34 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc @@ -604,8 +604,10 @@ Status TakeIndexDispatch(KernelContext* ctx, const ValuesSpan& values, } } -Status FixedWidthTakeExecImpl(KernelContext* ctx, const ValuesSpan& values, - const ArraySpan& indices, ArrayData* out_arr) { +} // namespace + +Status FixedWidthTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out_arr) { if (TakeState::Get(ctx).boundscheck) { RETURN_NOT_OK(CheckIndexBounds(indices, values.length())); } @@ -615,42 +617,43 @@ Status FixedWidthTakeExecImpl(KernelContext* ctx, const ValuesSpan& values, // allocating the validity bitmap altogether and save time and space. const bool allocate_validity = values.MayHaveNulls() || indices.MayHaveNulls(); RETURN_NOT_OK(util::internal::PreallocateFixedWidthArrayData( - ctx, indices.length, /*source=*/values.chunk0(), allocate_validity, out_arr)); + ctx, indices.length, /*source=*/values.chunk0(), allocate_validity, + out_arr->get())); switch (util::FixedWidthInBits(*values.type())) { case 0: DCHECK(values.type()->id() == Type::FIXED_SIZE_BINARY || values.type()->id() == Type::FIXED_SIZE_LIST); return TakeIndexDispatch>( - ctx, values, indices, out_arr); + ctx, values, indices, out_arr->get()); case 1: // Zero-initialize the data buffer for the output array when the bit-width is 1 // (e.g. Boolean array) to avoid having to ClearBit on every null element. // This might be profitable for other types as well, but we take the most // conservative approach for now. - memset(out_arr->buffers[1]->mutable_data(), 0, out_arr->buffers[1]->size()); + memset((*out_arr)->buffers[1]->mutable_data(), 0, (*out_arr)->buffers[1]->size()); return TakeIndexDispatch< FixedWidthTakeImpl, std::integral_constant, /*OutputIsZeroInitialized=*/ - std::true_type>(ctx, values, indices, out_arr); + std::true_type>(ctx, values, indices, out_arr->get()); case 8: return TakeIndexDispatch>( - ctx, values, indices, out_arr); + ctx, values, indices, out_arr->get()); case 16: return TakeIndexDispatch>( - ctx, values, indices, out_arr); + ctx, values, indices, out_arr->get()); case 32: return TakeIndexDispatch>( - ctx, values, indices, out_arr); + ctx, values, indices, out_arr->get()); case 64: return TakeIndexDispatch>( - ctx, values, indices, out_arr); + ctx, values, indices, out_arr->get()); case 128: // For INTERVAL_MONTH_DAY_NANO, DECIMAL128 return TakeIndexDispatch>( - ctx, values, indices, out_arr); + ctx, values, indices, out_arr->get()); case 256: // For DECIMAL256 return TakeIndexDispatch>( - ctx, values, indices, out_arr); + ctx, values, indices, out_arr->get()); } if (ARROW_PREDICT_TRUE(values.type()->id() == Type::FIXED_SIZE_BINARY || values.type()->id() == Type::FIXED_SIZE_LIST)) { @@ -660,67 +663,57 @@ Status FixedWidthTakeExecImpl(KernelContext* ctx, const ValuesSpan& values, return TakeIndexDispatch, /*OutputIsZeroInitialized=*/std::false_type, - /*WithFactor=*/std::true_type>(ctx, values, indices, out_arr, + /*WithFactor=*/std::true_type>(ctx, values, indices, + out_arr->get(), /*factor=*/byte_width); } return Status::NotImplemented("Unsupported primitive type for take: ", *values.type()); } -} // namespace - -Status FixedWidthTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - ValuesSpan values{batch[0].array}; - auto* out_arr = out->array_data().get(); - return FixedWidthTakeExecImpl(ctx, values, batch[1].array, out_arr); -} - -Status FixedWidthTakeChunkedExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ValuesSpan values{batch[0].chunked_array()}; - auto& indices = batch[1].array(); - auto* out_arr = out->mutable_array(); - return FixedWidthTakeExecImpl(ctx, values, *indices, out_arr); -} - namespace { // ---------------------------------------------------------------------- // Null take -Status NullTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { +Status NullTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { if (TakeState::Get(ctx).boundscheck) { - RETURN_NOT_OK(CheckIndexBounds(batch[1].array, batch[0].length())); + RETURN_NOT_OK(CheckIndexBounds(indices, values.length())); } // batch.length doesn't take into account the take indices - auto new_length = batch[1].array.length; - out->value = std::make_shared(new_length)->data(); + auto new_length = indices.length; + *out = NullArray{new_length}.data(); return Status::OK(); } // ---------------------------------------------------------------------- // Dictionary take -Status DictionaryTake(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - DictionaryArray values(batch[0].array.ToArrayData()); +Status DictionaryTake(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + DictionaryArray values_dict(values.array().ToArrayData()); Datum result; - RETURN_NOT_OK(Take(Datum(values.indices()), batch[1].array.ToArrayData(), + RETURN_NOT_OK(Take(Datum(values_dict.indices()), indices.ToArrayData(), TakeState::Get(ctx), ctx->exec_context()) .Value(&result)); - DictionaryArray taken_values(values.type(), result.make_array(), values.dictionary()); - out->value = taken_values.data(); + DictionaryArray taken_values(values_dict.type(), result.make_array(), + values_dict.dictionary()); + *out = taken_values.data(); return Status::OK(); } // ---------------------------------------------------------------------- // Extension take -Status ExtensionTake(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - ExtensionArray values(batch[0].array.ToArrayData()); +Status ExtensionTake(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + ExtensionArray values_data(values.array().ToArrayData()); Datum result; - RETURN_NOT_OK(Take(Datum(values.storage()), batch[1].array.ToArrayData(), + RETURN_NOT_OK(Take(Datum(values_data.storage()), indices.ToArrayData(), TakeState::Get(ctx), ctx->exec_context()) .Value(&result)); - ExtensionArray taken_values(values.type(), result.make_array()); - out->value = taken_values.data(); + ExtensionArray taken_values(values_data.type(), result.make_array()); + *out = taken_values.data(); return Status::OK(); } @@ -899,49 +892,51 @@ Result> ChunkedArrayAsArray( } } -Status CallAAAKernel(ArrayKernelExec take_aaa_exec, KernelContext* ctx, - std::shared_ptr values, - std::shared_ptr indices, Datum* out) { - int64_t batch_length = values->length; - std::vector args = {std::move(values), std::move(indices)}; - ExecBatch array_array_batch(std::move(args), batch_length); +Status CallXAAKernel(TakeKernelExec take_xaa_exec, KernelContext* ctx, + const ValuesSpan& values, const ArraySpan& indices, Datum* out) { DCHECK_EQ(out->kind(), Datum::ARRAY); - ExecSpan exec_span{array_array_batch}; - ExecResult result; - result.value = out->array(); - RETURN_NOT_OK(take_aaa_exec(ctx, exec_span, &result)); - DCHECK(result.is_array_data()); - out->value = result.array_data(); + auto out_arr = out->array(); + RETURN_NOT_OK(take_xaa_exec(ctx, values, indices, &out_arr)); + out->value = std::move(out_arr); return Status::OK(); } -Status CallCAAKernel(VectorKernel::ChunkedExec take_caa_exec, KernelContext* ctx, - std::shared_ptr values, - std::shared_ptr indices, Datum* out) { - int64_t batch_length = values->length(); - std::vector args = {std::move(values), std::move(indices)}; - ExecBatch chunked_array_array_batch(std::move(args), batch_length); - DCHECK_EQ(out->kind(), Datum::ARRAY); - return take_caa_exec(ctx, chunked_array_array_batch, out); -} - -Status TakeACCChunkedExec(ArrayKernelExec take_aaa_exec, KernelContext* ctx, +Status TakeACCChunkedExec(TakeKernelExec take_aaa_exec, KernelContext* ctx, const ExecBatch& batch, Datum* out) { - auto& values = batch.values[0].array(); + ValuesSpan values{*batch.values[0].array()}; auto& indices = batch.values[1].chunked_array(); auto num_chunks = indices->num_chunks(); std::vector> new_chunks(num_chunks); for (int i = 0; i < num_chunks; i++) { // Take with that indices chunk - auto& indices_chunk = indices->chunk(i)->data(); - Datum result = PrepareOutput(batch, values->length); - RETURN_NOT_OK(CallAAAKernel(take_aaa_exec, ctx, values, indices_chunk, &result)); + auto& indices_chunk_data = indices->chunk(i)->data(); + ArraySpan indices_chunk{*indices_chunk_data}; + Datum result = PrepareOutput(batch, values.length()); + RETURN_NOT_OK(CallXAAKernel(take_aaa_exec, ctx, values, indices_chunk, &result)); new_chunks[i] = MakeArray(result.array()); } - out->value = std::make_shared(std::move(new_chunks), values->type); + out->value = std::make_shared(std::move(new_chunks), + values.type()->GetSharedPtr()); return Status::OK(); } +Status ArrayTakeExec(TakeKernelExec take_aaa_exec, KernelContext* ctx, + const ExecSpan& span, ExecResult* out) { + ValuesSpan values{span[0].array}; + auto& indices = span[1].array; + std::shared_ptr out_arr = out->array_data(); + RETURN_NOT_OK(take_aaa_exec(ctx, values, indices, &out_arr)); + out->value = std::move(out_arr); + return Status::OK(); +} + +template +struct ArrayTakeExecFunctor { + static Status Exec(KernelContext* ctx, const ExecSpan& span, ExecResult* out) { + return ArrayTakeExec(kTakeAAAExec, ctx, span, out); + } +}; + /// \brief Generic (slower) VectorKernel::exec_chunked (`CA->C`, `CC->C`, and `AC->C`). /// /// This function concatenates the chunks of the values and then calls the `AA->A` take @@ -955,34 +950,34 @@ Status TakeACCChunkedExec(ArrayKernelExec take_aaa_exec, KernelContext* ctx, /// `AC->C` cases are trivially delegated to TakeACCChunkedExec without any concatenation. /// /// \param take_aaa_exec The `AA->A` take kernel to use. -Status GenericTakeChunkedExec(ArrayKernelExec take_aaa_exec, KernelContext* ctx, +Status GenericTakeChunkedExec(TakeKernelExec take_aaa_exec, KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& args = batch.values; if (args[0].kind() == Datum::CHUNKED_ARRAY) { const auto& values_chunked = args[0].chunked_array(); ARROW_ASSIGN_OR_RAISE(auto values_array, ChunkedArrayAsArray(values_chunked, ctx->memory_pool())); + const ValuesSpan values(*values_array->data()); if (args[1].kind() == Datum::ARRAY) { // CA->C - const auto& indices = args[1].array(); + const auto& indices_data = args[1].array(); + const ArraySpan indices(*indices_data); DCHECK_EQ(values_array->length(), batch.length); { // AA->A - RETURN_NOT_OK( - CallAAAKernel(take_aaa_exec, ctx, values_array->data(), indices, out)); + RETURN_NOT_OK(CallXAAKernel(take_aaa_exec, ctx, values, indices, out)); out->value = std::make_shared(MakeArray(out->array())); } return Status::OK(); } else if (args[1].kind() == Datum::CHUNKED_ARRAY) { // CC->C - const auto& indices = args[1].chunked_array(); + const auto& chunked_indices = args[1].chunked_array(); std::vector> new_chunks; - for (int i = 0; i < indices->num_chunks(); i++) { + for (int i = 0; i < chunked_indices->num_chunks(); i++) { // AA->A - const auto& indices_chunk = indices->chunk(i)->data(); + const ArraySpan indices_chunk{*chunked_indices->chunk(i)->data()}; Datum result = PrepareOutput(batch, values_array->length()); - RETURN_NOT_OK(CallAAAKernel(take_aaa_exec, ctx, values_array->data(), - indices_chunk, &result)); + RETURN_NOT_OK(CallXAAKernel(take_aaa_exec, ctx, values, indices_chunk, &result)); new_chunks.push_back(MakeArray(result.array())); } DCHECK(out->is_array()); @@ -1008,7 +1003,7 @@ Status GenericTakeChunkedExec(ArrayKernelExec take_aaa_exec, KernelContext* ctx, args[0].ToString(), ", indices=", args[1].ToString()); } -template +template struct GenericTakeChunkedExecFunctor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return GenericTakeChunkedExec(kTakeAAAExec, ctx, batch, out); @@ -1029,11 +1024,9 @@ struct GenericTakeChunkedExecFunctor { /// /// `AC->C` cases are trivially delegated to TakeACCChunkedExec. /// -/// \param take_aaa_exec The `AA->A` take kernel to use. -Status SpecialTakeChunkedExec(const ArrayKernelExec take_aaa_exec, - VectorKernel::ChunkedExec take_caa_exec, KernelContext* ctx, +/// \param take_xaa_exec The `AA->A` and `CA->A` take kernel to use. +Status SpecialTakeChunkedExec(TakeKernelExec take_xaa_exec, KernelContext* ctx, const ExecBatch& batch, Datum* out) { - Datum result = PrepareOutput(batch, batch.length); auto* pool = ctx->memory_pool(); const auto& args = batch.values; if (args[0].kind() == Datum::CHUNKED_ARRAY) { @@ -1048,14 +1041,15 @@ Status SpecialTakeChunkedExec(const ArrayKernelExec take_aaa_exec, if (args[1].kind() == Datum::ARRAY) { // CA->C - const auto& indices = args[1].array(); + ArraySpan indices(*args[1].array()); if (single_chunk) { // AA->A DCHECK_EQ(single_chunk->length(), batch.length); + ValuesSpan single_chunk_values(*single_chunk->data()); // If the ChunkedArray was cheaply converted to a single chunk, // we can use the AA->A take kernel directly. RETURN_NOT_OK( - CallAAAKernel(take_aaa_exec, ctx, single_chunk->data(), indices, out)); + CallXAAKernel(take_xaa_exec, ctx, single_chunk_values, indices, out)); out->value = std::make_shared(MakeArray(out->array())); return Status::OK(); } @@ -1063,24 +1057,29 @@ Status SpecialTakeChunkedExec(const ArrayKernelExec take_aaa_exec, // which has a more efficient implementation for this case. At this point, // that implementation doesn't have to care about empty or single-chunk // ChunkedArrays. - RETURN_NOT_OK(take_caa_exec(ctx, batch, &result)); - out->value = std::make_shared(MakeArray(result.array())); + ValuesSpan values(values_chunked); + auto out_arr = out->array(); + RETURN_NOT_OK(take_xaa_exec(ctx, values, indices, &out_arr)); + out->value = std::make_shared(MakeArray(std::move(out_arr))); return Status::OK(); } else { + Datum result; // CC->C const auto& indices = args[1].chunked_array(); std::vector> new_chunks; for (int i = 0; i < indices->num_chunks(); i++) { - const auto& indices_chunk = indices->chunk(i)->data(); + const ArraySpan indices_chunk{*indices->chunk(i)->data()}; result = PrepareOutput(batch, values_chunked->length()); if (single_chunk) { + ValuesSpan single_chunk_values(*single_chunk->data()); // If the ChunkedArray was cheaply converted to a single chunk, // we can use the AA->A take kernel directly. - RETURN_NOT_OK(CallAAAKernel(take_aaa_exec, ctx, single_chunk->data(), + RETURN_NOT_OK(CallXAAKernel(take_xaa_exec, ctx, single_chunk_values, indices_chunk, &result)); } else { + ValuesSpan values(values_chunked); RETURN_NOT_OK( - CallCAAKernel(take_caa_exec, ctx, values_chunked, indices_chunk, &result)); + CallXAAKernel(take_xaa_exec, ctx, values, indices_chunk, &result)); } new_chunks.push_back(MakeArray(result.array())); } @@ -1095,7 +1094,7 @@ Status SpecialTakeChunkedExec(const ArrayKernelExec take_aaa_exec, // everything is wired up correctly. if (args[1].kind() == Datum::CHUNKED_ARRAY) { // AC->C - return TakeACCChunkedExec(take_aaa_exec, ctx, batch, out); + return TakeACCChunkedExec(take_xaa_exec, ctx, batch, out); } else { DCHECK(false) << "Unexpected kind for array_take's exec_chunked kernel: values=" << args[0].ToString() << ", indices=" << args[1].ToString(); @@ -1107,10 +1106,10 @@ Status SpecialTakeChunkedExec(const ArrayKernelExec take_aaa_exec, args[0].ToString(), ", indices=", args[1].ToString()); } -template +template struct SpecialTakeChunkedExecFunctor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - return SpecialTakeChunkedExec(kTakeAAAExec, kTakeCAAExec, ctx, batch, out); + return SpecialTakeChunkedExec(kTakeXAAExec, ctx, batch, out); } }; @@ -1129,40 +1128,54 @@ void PopulateTakeKernels(std::vector* out) { auto take_indices = match::Integer(); *out = { - {InputType(match::Primitive()), take_indices, FixedWidthTakeExec, - SpecialTakeChunkedExecFunctor::Exec}, - {InputType(match::BinaryLike()), take_indices, VarBinaryTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(match::LargeBinaryLike()), take_indices, LargeVarBinaryTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(match::FixedSizeBinaryLike()), take_indices, FixedWidthTakeExec, - SpecialTakeChunkedExecFunctor::Exec}, - {InputType(null()), take_indices, NullTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::DICTIONARY), take_indices, DictionaryTake, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::EXTENSION), take_indices, ExtensionTake, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::LIST), take_indices, ListTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::LARGE_LIST), take_indices, LargeListTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::LIST_VIEW), take_indices, ListViewTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::LARGE_LIST_VIEW), take_indices, LargeListViewTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::FIXED_SIZE_LIST), take_indices, FSLTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::DENSE_UNION), take_indices, DenseUnionTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::SPARSE_UNION), take_indices, SparseUnionTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::STRUCT), take_indices, StructTakeExec, - GenericTakeChunkedExecFunctor::Exec}, - {InputType(Type::MAP), take_indices, MapTakeExec, - GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(match::Primitive()), take_indices, + ArrayTakeExecFunctor::Exec, + SpecialTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(match::BinaryLike()), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(match::LargeBinaryLike()), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(match::FixedSizeBinaryLike()), take_indices, + ArrayTakeExecFunctor::Exec, + SpecialTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(null()), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::DICTIONARY), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::EXTENSION), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::LIST), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::LARGE_LIST), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::LIST_VIEW), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::LARGE_LIST_VIEW), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::FIXED_SIZE_LIST), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::DENSE_UNION), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::SPARSE_UNION), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::STRUCT), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, + SelectionKernelData{InputType(Type::MAP), take_indices, + ArrayTakeExecFunctor::Exec, + GenericTakeChunkedExecFunctor::Exec}, }; }