Skip to content

Commit

Permalink
support cast from union to string
Browse files Browse the repository at this point in the history
  • Loading branch information
llama90 committed May 25, 2024
1 parent 283f66f commit 1966c31
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
56 changes: 56 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_cast_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,60 @@ void AddBinaryToFixedSizeBinaryCast(CastFunction* func) {
AddBinaryToFixedSizeBinaryCast<FixedSizeBinaryType>(func);
}

// ----------------------------------------------------------------------
// Union to String

template <typename O>
struct UnionToStringCastFunctor {
using BuilderType = typename TypeTraits<O>::BuilderType;

static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
const ArraySpan& input = batch[0].array;
const auto& union_type = checked_cast<const UnionType&>(*input.type);
const auto type_ids = input.GetValues<int8_t>(1);

BuilderType builder(input.type->GetSharedPtr(), ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(input.length));

for (int64_t i = 0; i < input.length; ++i) {
if (input.IsNull(i)) {
RETURN_NOT_OK(builder.AppendNull());
continue;
}

const int8_t type_id = type_ids[i];
const auto& field = union_type.field(union_type.child_ids()[type_id]);
const ArraySpan& child_span = input.child_data[union_type.child_ids()[type_id]];

std::shared_ptr<Scalar> child_scalar;
RETURN_NOT_OK(child_span.ToArray()->GetScalar(i).Value(&child_scalar));

std::stringstream ss;
ss << "union{" << field->name() << ": " << field->type()->ToString() << " = "
<< child_scalar->ToString() << "}";

RETURN_NOT_OK(builder.Append(ss.str()));
}

std::shared_ptr<Array> output_array;
RETURN_NOT_OK(builder.Finish(&output_array));
out->value = output_array->data();
return Status::OK();
}
};

template <typename OutType>
void AddUnionToStringCast(CastFunction* func) {
auto out_ty = TypeTraits<OutType>::type_singleton();

DCHECK_OK(func->AddKernel(Type::DENSE_UNION, {InputType(Type::DENSE_UNION)}, out_ty,
UnionToStringCastFunctor<OutType>::Exec,
NullHandling::COMPUTED_NO_PREALLOCATE));
DCHECK_OK(func->AddKernel(Type::SPARSE_UNION, {InputType(Type::SPARSE_UNION)}, out_ty,
UnionToStringCastFunctor<OutType>::Exec,
NullHandling::COMPUTED_NO_PREALLOCATE));
}

} // namespace

std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts() {
Expand All @@ -528,6 +582,7 @@ std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts() {
AddDecimalToStringCasts<StringType>(cast_string.get());
AddTemporalToStringCasts<StringType>(cast_string.get());
AddBinaryToBinaryCast<StringType>(cast_string.get());
AddUnionToStringCast<StringType>(cast_string.get());

auto cast_large_string =
std::make_shared<CastFunction>("cast_large_string", Type::LARGE_STRING);
Expand All @@ -536,6 +591,7 @@ std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts() {
AddDecimalToStringCasts<LargeStringType>(cast_large_string.get());
AddTemporalToStringCasts<LargeStringType>(cast_large_string.get());
AddBinaryToBinaryCast<LargeStringType>(cast_large_string.get());
AddUnionToStringCast<LargeStringType>(cast_large_string.get());

auto cast_fsb =
std::make_shared<CastFunction>("cast_fixed_size_binary", Type::FIXED_SIZE_BINARY);
Expand Down
20 changes: 15 additions & 5 deletions cpp/src/arrow/scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1855,11 +1855,21 @@ class TestUnionScalar : public ::testing::Test {
}

void TestCast() {
// Cast() function doesn't support casting union to string, use Scalar::CastTo()
// instead.
ASSERT_OK_AND_ASSIGN(auto casted, union_alpha_->CastTo(utf8()));
ASSERT_TRUE(casted->Equals(StringScalar(R"(union{string: string = alpha})")))
<< casted->ToString();
std::vector<std::pair<std::shared_ptr<Scalar>, std::string>> test_cases = {
{union_alpha_, R"(union{string: string = alpha})"},
{union_beta_, R"(union{string: string = beta})"},
{union_two_, R"(union{number: uint64 = 2})"},
{union_three_, R"(union{number: uint64 = 3})"},
{union_other_two_, R"(union{other_number: uint64 = 2})"},
{union_string_null_, "null"},
{union_number_null_, "null"}
};

for (const auto& [scalar, expected] : test_cases) {
ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8()));
ASSERT_EQ(casted.scalar()->ToString(), expected)
<< "Failed to cast " << scalar->ToString() << " to " << expected;
}
}

protected:
Expand Down

0 comments on commit 1966c31

Please sign in to comment.