Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-41810: [C++] Support cast kernel from (dense or sparse) union to (large) string #41827

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
56 changes: 56 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_cast_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,60 @@ void AddBinaryToFixedSizeBinaryCast(CastFunction* func) {
AddBinaryToFixedSizeBinaryCast<FixedSizeBinaryType>(func);
}

// ----------------------------------------------------------------------
// Union to String

template <typename O>
struct UnionToStringCastFunctor {
using BuilderType = typename TypeTraits<O>::BuilderType;

static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
const ArraySpan& input = batch[0].array;
const auto& union_type = checked_cast<const UnionType&>(*input.type);
const auto type_ids = input.GetValues<int8_t>(1);
const auto& offsets = input.GetValues<int32_t>(2);

BuilderType builder(input.type->GetSharedPtr(), ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(input.length));

for (int64_t i = 0; i < input.length; ++i) {
Copy link
Collaborator

@ZhangHuiGui ZhangHuiGui May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's a better way to implement this by expand StringFormatter (include other nested types)? @felipecrv
In this way, we can unify it with other type's implementations and shield the logic of converting strings in the current file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review and the insightful suggestions.

That makes sense, and I'll consider applying that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this issue is related: #41831.

Copy link
Contributor

@felipecrv felipecrv Jun 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this way, we can unify it with other type's implementations and shield the logic of converting strings in the current file.

True, but that can prevent optimizations in the future. The approach of taking a scalar function and turning it into an array function by mapping —array::map(scalar_function: scalar -> scalar) -> array — is appealing but prevents vectorization techniques.

UPDATE: that's what we will do here because the set of unions and their parametrizations is infinite, but StringFormatter<MonthIntervalType> is not the way to go because it would have to switch on the type for every invocation of the formatter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A string formatter for unions should allocate a vector of string formatters that can do virtual dispatching (and deal with nesting themselves as well). StringFormatter<T> performs static dispatch which allows loop-specialization for the non-nested types. But for nested types we will need to setup a vector of VirtualStringFormatter (which is actually a tree) so that all the "switching on the type" happens at construction time (beginning of the loop) and invocations inside the loop are following the same function pointers from the vtables.

// in header
class VirtualStringFormatter {
  virtual ... = 0;
};

Result<std::unique_ptr<VirtualStringFormatter>> MakeFormatter(
  const std::shared_ptr<DataType>& type);

// in an anon namespace of the .cpp
// one sub-class per `Type::type`
class <T>StringFormatter : public VirtualStringFormatter {
}
// you can use templates to cover most cases delegating to StringFormatter<T>

class UnionStringFormatter : ...

This hierarchy would be similar to the builder class hierarchy.

if (input.IsNull(i)) {
RETURN_NOT_OK(builder.AppendNull());
continue;
}

const int8_t type_id = type_ids[i];
const auto& field = union_type.field(union_type.child_ids()[type_id]);
const ArraySpan& child_span = input.child_data[union_type.child_ids()[type_id]];

std::shared_ptr<Scalar> child_scalar;
auto child_index = union_type.mode() == UnionMode::DENSE ? offsets[i] : i;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a template parameter (a UnionMode::type) so we can specialize with if constexpr checks in the loop.

RETURN_NOT_OK(child_span.ToArray()->GetScalar(child_index).Value(&child_scalar));

std::string str = "union{" + field->name() + ": " + field->type()->ToString() +
" = " + child_scalar->ToString() + "}";
RETURN_NOT_OK(builder.Append(str));
}

std::shared_ptr<Array> output_array;
RETURN_NOT_OK(builder.Finish(&output_array));
out->value = output_array->data();
return Status::OK();
}
};

template <typename OutType>
void AddUnionToStringCast(CastFunction* func) {
auto out_ty = TypeTraits<OutType>::type_singleton();

DCHECK_OK(func->AddKernel(Type::DENSE_UNION, {InputType(Type::DENSE_UNION)}, out_ty,
UnionToStringCastFunctor<OutType>::Exec,
NullHandling::COMPUTED_NO_PREALLOCATE));
DCHECK_OK(func->AddKernel(Type::SPARSE_UNION, {InputType(Type::SPARSE_UNION)}, out_ty,
UnionToStringCastFunctor<OutType>::Exec,
NullHandling::COMPUTED_NO_PREALLOCATE));
}

} // namespace

std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts() {
Expand All @@ -528,6 +582,7 @@ std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts() {
AddDecimalToStringCasts<StringType>(cast_string.get());
AddTemporalToStringCasts<StringType>(cast_string.get());
AddBinaryToBinaryCast<StringType>(cast_string.get());
AddUnionToStringCast<StringType>(cast_string.get());

auto cast_large_string =
std::make_shared<CastFunction>("cast_large_string", Type::LARGE_STRING);
Expand All @@ -536,6 +591,7 @@ std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts() {
AddDecimalToStringCasts<LargeStringType>(cast_large_string.get());
AddTemporalToStringCasts<LargeStringType>(cast_large_string.get());
AddBinaryToBinaryCast<LargeStringType>(cast_large_string.get());
AddUnionToStringCast<LargeStringType>(cast_large_string.get());

auto cast_fsb =
std::make_shared<CastFunction>("cast_fixed_size_binary", Type::FIXED_SIZE_BINARY);
Expand Down
121 changes: 88 additions & 33 deletions cpp/src/arrow/scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1855,11 +1855,22 @@ class TestUnionScalar : public ::testing::Test {
}

void TestCast() {
// Cast() function doesn't support casting union to string, use Scalar::CastTo()
// instead.
ASSERT_OK_AND_ASSIGN(auto casted, union_alpha_->CastTo(utf8()));
ASSERT_TRUE(casted->Equals(StringScalar(R"(union{string: string = alpha})")))
<< casted->ToString();
std::vector<std::pair<std::shared_ptr<Scalar>, std::string>> test_cases = {
{union_alpha_, R"(union{string: string = alpha})"},
{union_beta_, R"(union{string: string = beta})"},
{union_two_, R"(union{number: uint64 = 2})"},
{union_three_, R"(union{number: uint64 = 3})"},
{union_other_two_, R"(union{other_number: uint64 = 2})"},
{union_string_null_, "null"},
{union_number_null_, "null"}};

for (const auto& out_ty : {utf8(), large_utf8()}) {
for (const auto& [scalar, expected] : test_cases) {
ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, out_ty));
ASSERT_EQ(casted.scalar()->ToString(), expected)
<< "Failed to cast " << scalar->ToString() << " to " << expected;
}
}
}

protected:
Expand All @@ -1882,41 +1893,85 @@ TYPED_TEST(TestUnionScalar, MakeNullScalar) { this->TestMakeNullScalar(); }

TYPED_TEST(TestUnionScalar, Cast) { this->TestCast(); }

class TestSparseUnionScalar : public TestUnionScalar<SparseUnionType> {};
class TestSparseUnionScalar : public TestUnionScalar<SparseUnionType> {
void SetUp() override {
TestUnionScalar::SetUp();

TEST_F(TestSparseUnionScalar, GetScalar) {
ArrayVector children{ArrayFromJSON(utf8(), R"(["alpha", "", "beta", null, "gamma"])"),
ArrayFromJSON(uint64(), "[1, 2, 11, 22, null]"),
ArrayFromJSON(uint64(), "[100, 101, 102, 103, 104]")};
children = {ArrayFromJSON(utf8(), R"(["alpha", "", "beta", null, "gamma"])"),
ArrayFromJSON(uint64(), "[1, 2, 11, 22, null]"),
ArrayFromJSON(uint64(), "[100, 101, 102, 103, 104]")};

type_ids = ArrayFromJSON(int8(), "[3, 42, 3, 3, 42]");
arr = std::make_shared<SparseUnionArray>(type_, 5, children,
type_ids->data()->buffers[1]);
ASSERT_OK(arr->ValidateFull());
}

auto type_ids = ArrayFromJSON(int8(), "[3, 42, 3, 3, 42]");
SparseUnionArray arr(type_, 5, children, type_ids->data()->buffers[1]);
ASSERT_OK(arr.ValidateFull());
protected:
ArrayVector children;
std::shared_ptr<Array> type_ids;
std::shared_ptr<SparseUnionArray> arr;
};

CheckGetValidUnionScalar(arr, 0, *union_alpha_, *alpha_);
CheckGetValidUnionScalar(arr, 1, *union_two_, *two_);
CheckGetValidUnionScalar(arr, 2, *union_beta_, *beta_);
CheckGetNullUnionScalar(arr, 3);
CheckGetNullUnionScalar(arr, 4);
TEST_F(TestSparseUnionScalar, GetScalar) {
CheckGetValidUnionScalar(*arr, 0, *union_alpha_, *alpha_);
CheckGetValidUnionScalar(*arr, 1, *union_two_, *two_);
CheckGetValidUnionScalar(*arr, 2, *union_beta_, *beta_);
CheckGetNullUnionScalar(*arr, 3);
CheckGetNullUnionScalar(*arr, 4);
}

class TestDenseUnionScalar : public TestUnionScalar<DenseUnionType> {};
TEST_F(TestSparseUnionScalar, CastToString) {
for (const auto& out_ty : {utf8(), large_utf8()}) {
auto expected = ArrayFromJSON(out_ty, R"(["union{string: string = alpha}",
"union{number: uint64 = 2}",
"union{string: string = beta}",
null,
null])");
ASSERT_OK_AND_ASSIGN(auto casted, Cast(*arr, out_ty));
ASSERT_TRUE(casted->Equals(*expected));
}
}

class TestDenseUnionScalar : public TestUnionScalar<DenseUnionType> {
void SetUp() override {
TestUnionScalar::SetUp();

children = {ArrayFromJSON(utf8(), R"(["alpha", "beta", null])"),
ArrayFromJSON(uint64(), "[2, 3]"), ArrayFromJSON(uint64(), "[]")};

type_ids = ArrayFromJSON(int8(), "[3, 42, 3, 3, 42]");
offsets = ArrayFromJSON(int32(), "[0, 0, 1, 2, 1]");
arr = std::make_shared<DenseUnionArray>(
type_, 5, children, type_ids->data()->buffers[1], offsets->data()->buffers[1]);
ASSERT_OK(arr->ValidateFull());
}

protected:
ArrayVector children;
std::shared_ptr<Array> type_ids;
std::shared_ptr<Array> offsets;
std::shared_ptr<DenseUnionArray> arr;
};

TEST_F(TestDenseUnionScalar, GetScalar) {
ArrayVector children{ArrayFromJSON(utf8(), R"(["alpha", "beta", null])"),
ArrayFromJSON(uint64(), "[2, 3]"), ArrayFromJSON(uint64(), "[]")};

auto type_ids = ArrayFromJSON(int8(), "[3, 42, 3, 3, 42]");
auto offsets = ArrayFromJSON(int32(), "[0, 0, 1, 2, 1]");
DenseUnionArray arr(type_, 5, children, type_ids->data()->buffers[1],
offsets->data()->buffers[1]);
ASSERT_OK(arr.ValidateFull());

CheckGetValidUnionScalar(arr, 0, *union_alpha_, *alpha_);
CheckGetValidUnionScalar(arr, 1, *union_two_, *two_);
CheckGetValidUnionScalar(arr, 2, *union_beta_, *beta_);
CheckGetNullUnionScalar(arr, 3);
CheckGetValidUnionScalar(arr, 4, *union_three_, *three_);
CheckGetValidUnionScalar(*arr, 0, *union_alpha_, *alpha_);
CheckGetValidUnionScalar(*arr, 1, *union_two_, *two_);
CheckGetValidUnionScalar(*arr, 2, *union_beta_, *beta_);
CheckGetNullUnionScalar(*arr, 3);
CheckGetValidUnionScalar(*arr, 4, *union_three_, *three_);
}

TEST_F(TestDenseUnionScalar, CastToString) {
for (const auto& out_ty : {utf8(), large_utf8()}) {
auto expected = ArrayFromJSON(out_ty, R"(["union{string: string = alpha}",
"union{number: uint64 = 2}",
"union{string: string = beta}",
null,
"union{number: uint64 = 3}"])");
ASSERT_OK_AND_ASSIGN(auto casted, Cast(*arr, out_ty));
ASSERT_TRUE(casted->Equals(*expected));
}
}

template <typename RunEndType>
Expand Down
Loading