Skip to content

Commit

Permalink
apacheGH-20213: [C++] Implement cast to/from halffloat (apache#40067)
Browse files Browse the repository at this point in the history
### Rationale for this change

### What changes are included in this PR?

This PR implements casting to and from float16 types using the vendored float16 library included in arrow at `cpp/arrrow/util/float16.*`.

### Are these changes tested?

Unit tests are included in this PR.

### Are there any user-facing changes?

In that casts to and from float16 will now work, yes.

* Closes: apache#20213

### TODO

- [x] Add casts to/from float64.
- [x] String <-> float16 casts.
- [x] Integer <-> float16 casts.
- [x] Tests.
- [x] Update https://github.com/apache/arrow/blob/main/docs/source/status.rst about half float.
- [x] Rebase.
- [x] Run clang format over this PR.
* GitHub Issue: apache#20213

Authored-by: Clif Houck <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
  • Loading branch information
ClifHouck authored Apr 4, 2024
1 parent b99b00d commit 72d20ad
Show file tree
Hide file tree
Showing 15 changed files with 325 additions and 40 deletions.
2 changes: 1 addition & 1 deletion c_glib/test/test-half-float-scalar.rb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_equal
end

def test_to_s
assert_equal("[\n #{@half_float}\n]", @scalar.to_s)
assert_equal("1.0009765625", @scalar.to_s)
end

def test_value
Expand Down
30 changes: 30 additions & 0 deletions cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/bitmap_reader.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/float16.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
Expand All @@ -59,6 +60,7 @@ using internal::BitmapReader;
using internal::BitmapUInt64Reader;
using internal::checked_cast;
using internal::OptionalBitmapEquals;
using util::Float16;

// ----------------------------------------------------------------------
// Public method implementations
Expand Down Expand Up @@ -95,6 +97,30 @@ struct FloatingEquality {
const T epsilon;
};

// For half-float equality.
template <typename Flags>
struct FloatingEquality<uint16_t, Flags> {
explicit FloatingEquality(const EqualOptions& options)
: epsilon(static_cast<float>(options.atol())) {}

bool operator()(uint16_t x, uint16_t y) const {
Float16 f_x = Float16::FromBits(x);
Float16 f_y = Float16::FromBits(y);
if (x == y) {
return Flags::signed_zeros_equal || (f_x.signbit() == f_y.signbit());
}
if (Flags::nans_equal && f_x.is_nan() && f_y.is_nan()) {
return true;
}
if (Flags::approximate && (fabs(f_x.ToFloat() - f_y.ToFloat()) <= epsilon)) {
return true;
}
return false;
}

const float epsilon;
};

template <typename T, typename Visitor>
struct FloatingEqualityDispatcher {
const EqualOptions& options;
Expand Down Expand Up @@ -259,6 +285,8 @@ class RangeDataEqualsImpl {

Status Visit(const DoubleType& type) { return CompareFloating(type); }

Status Visit(const HalfFloatType& type) { return CompareFloating(type); }

// Also matches StringType
Status Visit(const BinaryType& type) { return CompareBinary(type); }

Expand Down Expand Up @@ -863,6 +891,8 @@ class ScalarEqualsVisitor {

Status Visit(const DoubleScalar& left) { return CompareFloating(left); }

Status Visit(const HalfFloatScalar& left) { return CompareFloating(left); }

template <typename T>
enable_if_t<std::is_base_of<BaseBinaryScalar, T>::value, Status> Visit(const T& left) {
const auto& right = checked_cast<const BaseBinaryScalar&>(right_);
Expand Down
70 changes: 70 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
#include "arrow/compute/cast_internal.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/extension_type.h"
#include "arrow/type_traits.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/float16.h"

namespace arrow {

using arrow::util::Float16;
using internal::checked_cast;
using internal::PrimitiveScalarBase;

Expand All @@ -47,6 +50,42 @@ struct CastPrimitive {
}
};

// Converting floating types to half float.
template <typename InType>
struct CastPrimitive<HalfFloatType, InType, enable_if_physical_floating_point<InType>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
using InT = typename InType::c_type;
const InT* in_values = arr.GetValues<InT>(1);
uint16_t* out_values = out->GetValues<uint16_t>(1);
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = Float16(*in_values++).bits();
}
}
};

// Converting from half float to other floating types.
template <>
struct CastPrimitive<FloatType, HalfFloatType, enable_if_t<true>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
const uint16_t* in_values = arr.GetValues<uint16_t>(1);
float* out_values = out->GetValues<float>(1);
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = Float16::FromBits(*in_values++).ToFloat();
}
}
};

template <>
struct CastPrimitive<DoubleType, HalfFloatType, enable_if_t<true>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
const uint16_t* in_values = arr.GetValues<uint16_t>(1);
double* out_values = out->GetValues<double>(1);
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = Float16::FromBits(*in_values++).ToDouble();
}
}
};

template <typename OutType, typename InType>
struct CastPrimitive<OutType, InType, enable_if_t<std::is_same<OutType, InType>::value>> {
// memcpy output
Expand All @@ -56,6 +95,33 @@ struct CastPrimitive<OutType, InType, enable_if_t<std::is_same<OutType, InType>:
}
};

// Cast int to half float
template <typename InType>
struct CastPrimitive<HalfFloatType, InType, enable_if_integer<InType>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
using InT = typename InType::c_type;
const InT* in_values = arr.GetValues<InT>(1);
uint16_t* out_values = out->GetValues<uint16_t>(1);
for (int64_t i = 0; i < arr.length; ++i) {
float temp = static_cast<float>(*in_values++);
*out_values++ = Float16(temp).bits();
}
}
};

// Cast half float to int
template <typename OutType>
struct CastPrimitive<OutType, HalfFloatType, enable_if_integer<OutType>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
using OutT = typename OutType::c_type;
const uint16_t* in_values = arr.GetValues<uint16_t>(1);
OutT* out_values = out->GetValues<OutT>(1);
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = static_cast<OutT>(Float16::FromBits(*in_values++).ToFloat());
}
}
};

template <typename InType>
void CastNumberImpl(Type::type out_type, const ArraySpan& input, ArraySpan* out) {
switch (out_type) {
Expand All @@ -79,6 +145,8 @@ void CastNumberImpl(Type::type out_type, const ArraySpan& input, ArraySpan* out)
return CastPrimitive<FloatType, InType>::Exec(input, out);
case Type::DOUBLE:
return CastPrimitive<DoubleType, InType>::Exec(input, out);
case Type::HALF_FLOAT:
return CastPrimitive<HalfFloatType, InType>::Exec(input, out);
default:
break;
}
Expand Down Expand Up @@ -109,6 +177,8 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type,
return CastNumberImpl<FloatType>(out_type, input, out);
case Type::DOUBLE:
return CastNumberImpl<DoubleType>(out_type, input, out);
case Type::HALF_FLOAT:
return CastNumberImpl<HalfFloatType>(out_type, input, out);
default:
DCHECK(false);
break;
Expand Down
103 changes: 87 additions & 16 deletions cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/scalar.h"
#include "arrow/util/bit_block_counter.h"
#include "arrow/util/float16.h"
#include "arrow/util/int_util.h"
#include "arrow/util/value_parsing.h"

Expand All @@ -34,6 +35,7 @@ using internal::IntegersCanFit;
using internal::OptionalBitBlockCounter;
using internal::ParseValue;
using internal::PrimitiveScalarBase;
using util::Float16;

namespace compute {
namespace internal {
Expand All @@ -56,18 +58,37 @@ Status CastFloatingToFloating(KernelContext*, const ExecSpan& batch, ExecResult*

// ----------------------------------------------------------------------
// Implement fast safe floating point to integer cast
//
template <typename InType, typename OutType, typename InT = typename InType::c_type,
typename OutT = typename OutType::c_type>
struct WasTruncated {
static bool Check(OutT out_val, InT in_val) {
return static_cast<InT>(out_val) != in_val;
}

static bool CheckMaybeNull(OutT out_val, InT in_val, bool is_valid) {
return is_valid && static_cast<InT>(out_val) != in_val;
}
};

// Half float to int
template <typename OutType>
struct WasTruncated<HalfFloatType, OutType> {
using OutT = typename OutType::c_type;
static bool Check(OutT out_val, uint16_t in_val) {
return static_cast<float>(out_val) != Float16::FromBits(in_val).ToFloat();
}

static bool CheckMaybeNull(OutT out_val, uint16_t in_val, bool is_valid) {
return is_valid && static_cast<float>(out_val) != Float16::FromBits(in_val).ToFloat();
}
};

// InType is a floating point type we are planning to cast to integer
template <typename InType, typename OutType, typename InT = typename InType::c_type,
typename OutT = typename OutType::c_type>
ARROW_DISABLE_UBSAN("float-cast-overflow")
Status CheckFloatTruncation(const ArraySpan& input, const ArraySpan& output) {
auto WasTruncated = [&](OutT out_val, InT in_val) -> bool {
return static_cast<InT>(out_val) != in_val;
};
auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) -> bool {
return is_valid && static_cast<InT>(out_val) != in_val;
};
auto GetErrorMessage = [&](InT val) {
return Status::Invalid("Float value ", val, " was truncated converting to ",
*output.type);
Expand All @@ -86,26 +107,28 @@ Status CheckFloatTruncation(const ArraySpan& input, const ArraySpan& output) {
if (block.popcount == block.length) {
// Fast path: branchless
for (int64_t i = 0; i < block.length; ++i) {
block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]);
block_out_of_bounds |=
WasTruncated<InType, OutType>::Check(out_data[i], in_data[i]);
}
} else if (block.popcount > 0) {
// Indices have nulls, must only boundscheck non-null values
for (int64_t i = 0; i < block.length; ++i) {
block_out_of_bounds |= WasTruncatedMaybeNull(
block_out_of_bounds |= WasTruncated<InType, OutType>::CheckMaybeNull(
out_data[i], in_data[i], bit_util::GetBit(bitmap, offset_position + i));
}
}
if (ARROW_PREDICT_FALSE(block_out_of_bounds)) {
if (input.GetNullCount() > 0) {
for (int64_t i = 0; i < block.length; ++i) {
if (WasTruncatedMaybeNull(out_data[i], in_data[i],
bit_util::GetBit(bitmap, offset_position + i))) {
if (WasTruncated<InType, OutType>::CheckMaybeNull(
out_data[i], in_data[i],
bit_util::GetBit(bitmap, offset_position + i))) {
return GetErrorMessage(in_data[i]);
}
}
} else {
for (int64_t i = 0; i < block.length; ++i) {
if (WasTruncated(out_data[i], in_data[i])) {
if (WasTruncated<InType, OutType>::Check(out_data[i], in_data[i])) {
return GetErrorMessage(in_data[i]);
}
}
Expand Down Expand Up @@ -151,6 +174,9 @@ Status CheckFloatToIntTruncation(const ExecValue& input, const ExecResult& outpu
return CheckFloatToIntTruncationImpl<FloatType>(input.array, *output.array_span());
case Type::DOUBLE:
return CheckFloatToIntTruncationImpl<DoubleType>(input.array, *output.array_span());
case Type::HALF_FLOAT:
return CheckFloatToIntTruncationImpl<HalfFloatType>(input.array,
*output.array_span());
default:
break;
}
Expand Down Expand Up @@ -293,6 +319,15 @@ struct CastFunctor<
}
};

template <>
struct CastFunctor<HalfFloatType, StringType, enable_if_t<true>> {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return applicator::ScalarUnaryNotNull<HalfFloatType, StringType,
ParseString<HalfFloatType>>::Exec(ctx, batch,
out);
}
};

// ----------------------------------------------------------------------
// Decimal to integer

Expand Down Expand Up @@ -689,6 +724,10 @@ std::shared_ptr<CastFunction> GetCastToInteger(std::string name) {
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToInteger));
}

// Cast from half-float
DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {InputType(Type::HALF_FLOAT)}, out_ty,
CastFloatingToInteger));

// From other numbers to integer
AddCommonNumberCasts<OutType>(out_ty, func.get());

Expand All @@ -715,6 +754,10 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToFloating));
}

// From half-float to float/double
DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {InputType(Type::HALF_FLOAT)}, out_ty,
CastFloatingToFloating));

// From other numbers to floating point
AddCommonNumberCasts<OutType>(out_ty, func.get());

Expand All @@ -723,6 +766,7 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
CastFunctor<OutType, Decimal128Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
CastFunctor<OutType, Decimal256Type>::Exec));

return func;
}

Expand Down Expand Up @@ -795,6 +839,32 @@ std::shared_ptr<CastFunction> GetCastToDecimal256() {
return func;
}

std::shared_ptr<CastFunction> GetCastToHalfFloat() {
// HalfFloat is a bit brain-damaged for now
auto func = std::make_shared<CastFunction>("func", Type::HALF_FLOAT);
AddCommonCasts(Type::HALF_FLOAT, float16(), func.get());

// Casts from integer to floating point
for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty},
TypeTraits<HalfFloatType>::type_singleton(),
CastIntegerToFloating));
}

// Cast from other strings to half float.
for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
auto exec = GenerateVarBinaryBase<CastFunctor, HalfFloatType>(*in_ty);
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty},
TypeTraits<HalfFloatType>::type_singleton(), exec));
}

DCHECK_OK(func.get()->AddKernel(Type::FLOAT, {InputType(Type::FLOAT)}, float16(),
CastFloatingToFloating));
DCHECK_OK(func.get()->AddKernel(Type::DOUBLE, {InputType(Type::DOUBLE)}, float16(),
CastFloatingToFloating));
return func;
}

} // namespace

std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
Expand Down Expand Up @@ -830,13 +900,14 @@ std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
functions.push_back(GetCastToInteger<UInt64Type>("cast_uint64"));

// HalfFloat is a bit brain-damaged for now
auto cast_half_float =
std::make_shared<CastFunction>("cast_half_float", Type::HALF_FLOAT);
AddCommonCasts(Type::HALF_FLOAT, float16(), cast_half_float.get());
auto cast_half_float = GetCastToHalfFloat();
functions.push_back(cast_half_float);

functions.push_back(GetCastToFloating<FloatType>("cast_float"));
functions.push_back(GetCastToFloating<DoubleType>("cast_double"));
auto cast_float = GetCastToFloating<FloatType>("cast_float");
functions.push_back(cast_float);

auto cast_double = GetCastToFloating<DoubleType>("cast_double");
functions.push_back(cast_double);

functions.push_back(GetCastToDecimal128());
functions.push_back(GetCastToDecimal256());
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_cast_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ void AddNumberToStringCasts(CastFunction* func) {
GenerateNumeric<NumericToStringCastFunctor, OutType>(*in_ty),
NullHandling::COMPUTED_NO_PREALLOCATE));
}

DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {float16()}, out_ty,
NumericToStringCastFunctor<OutType, HalfFloatType>::Exec,
NullHandling::COMPUTED_NO_PREALLOCATE));
}

template <typename OutType>
Expand Down
Loading

0 comments on commit 72d20ad

Please sign in to comment.