Skip to content

Commit

Permalink
Backport PR #2046 - Fixing FP16 conversions. (#2222)
Browse files Browse the repository at this point in the history
* Do not rely on conversions between float and extended floating point types (#2046)

The issue we have is that our tests rely extensively on those conversions which makes it incredibly painfull to test

* Fix including `<complex>` when bad CUDA bfloat/half macros are used. (#2226)

* Add <complex> test for bad macros being defined

* Fix <complex> failing upon inclusion when bad macros are defined

* Rather use explicit specializations and some evil hackery to get the complex interop to work

* Fix typos

* Inline everything

* Move workarounds together

* Use conversion functions instead of explicit specializations

* Drop unneeded conversions

---------

Co-authored-by: Michael Schellenberger Costa <[email protected]>

---------

Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
wmaxey and miscco authored Aug 14, 2024
1 parent 1251f54 commit c67b1c3
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 42 deletions.
87 changes: 76 additions & 11 deletions libcudacxx/include/cuda/std/__complex/nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ struct __libcpp_complex_overload_traits<__nv_bfloat16, false, false>
typedef complex<__nv_bfloat16> _ComplexType;
};

// This is a workaround against the user defining macros __CUDA_NO_BFLOAT16_CONVERSIONS__ __CUDA_NO_BFLOAT16_OPERATORS__
template <>
struct __complex_can_implicitly_construct<__nv_bfloat16, float> : true_type
{};

template <>
struct __complex_can_implicitly_construct<__nv_bfloat16, double> : true_type
{};

template <>
struct __complex_can_implicitly_construct<float, __nv_bfloat16> : true_type
{};

template <>
struct __complex_can_implicitly_construct<double, __nv_bfloat16> : true_type
{};

template <class _Tp>
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const _Tp& __value) noexcept
{
return __value;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const float& __value) noexcept
{
return __float2bfloat16(__value);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const double& __value) noexcept
{
return __double2bfloat16(__value);
}

template <>
class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__nv_bfloat16>
{
Expand All @@ -80,14 +113,14 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__

template <class _Up, __enable_if_t<__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_bfloat16(__c.real()), __convert_to_bfloat16(__c.imag()))
{}

template <class _Up,
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
__enable_if_t<_CCCL_TRAIT(is_constructible, value_type, _Up), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_bfloat16(__c.real()), __convert_to_bfloat16(__c.imag()))
{}

_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
Expand All @@ -100,8 +133,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
template <class _Up>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Up>& __c)
{
__repr_.x = __c.real();
__repr_.y = __c.imag();
__repr_.x = __convert_to_bfloat16(__c.real());
__repr_.y = __convert_to_bfloat16(__c.imag());
return *this;
}

Expand Down Expand Up @@ -155,24 +188,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__

_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
{
__repr_.x += __re;
__repr_.x = __hadd(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
{
__repr_.x -= __re;
__repr_.x = __hsub(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
{
__repr_.x *= __re;
__repr_.y *= __re;
__repr_.x = __hmul(__repr_.x, __re);
__repr_.y = __hmul(__repr_.y, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
{
__repr_.x /= __re;
__repr_.y /= __re;
__repr_.x = __hdiv(__repr_.x, __re);
__repr_.y = __hdiv(__repr_.y, __re);
return *this;
}

Expand All @@ -195,9 +228,41 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
}
};

template <> // complex<float>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>::complex(const complex<__nv_bfloat16>& __c)
: __re_(__bfloat162float(__c.real()))
, __im_(__bfloat162float(__c.imag()))
{}

template <> // complex<double>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>::complex(const complex<__nv_bfloat16>& __c)
: __re_(__bfloat162float(__c.real()))
, __im_(__bfloat162float(__c.imag()))
{}

template <> // complex<float>
template <> // complex<__nv_bfloat16>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>& complex<float>::operator=(const complex<__nv_bfloat16>& __c)
{
__re_ = __bfloat162float(__c.real());
__im_ = __bfloat162float(__c.imag());
return *this;
}

template <> // complex<double>
template <> // complex<__nv_bfloat16>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>& complex<double>::operator=(const complex<__nv_bfloat16>& __c)
{
__re_ = __bfloat162float(__c.real());
__im_ = __bfloat162float(__c.imag());
return *this;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 arg(__nv_bfloat16 __re)
{
return _CUDA_VSTD::atan2f(__nv_bfloat16(0), __re);
return _CUDA_VSTD::atan2(__int2bfloat16_rn(0), __re);
}

// We have performance issues with some trigonometric functions with __nv_bfloat16
Expand Down
87 changes: 76 additions & 11 deletions libcudacxx/include/cuda/std/__complex/nvfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,39 @@ struct __libcpp_complex_overload_traits<__half, false, false>
typedef complex<__half> _ComplexType;
};

// This is a workaround against the user defining macros __CUDA_NO_HALF_CONVERSIONS__ __CUDA_NO_HALF_OPERATORS__
template <>
struct __complex_can_implicitly_construct<__half, float> : true_type
{};

template <>
struct __complex_can_implicitly_construct<__half, double> : true_type
{};

template <>
struct __complex_can_implicitly_construct<float, __half> : true_type
{};

template <>
struct __complex_can_implicitly_construct<double, __half> : true_type
{};

template <class _Tp>
inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const _Tp& __value) noexcept
{
return __value;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const float& __value) noexcept
{
return __float2half(__value);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const double& __value) noexcept
{
return __double2half(__value);
}

template <>
class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
{
Expand All @@ -77,14 +110,14 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>

template <class _Up, __enable_if_t<__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_half(__c.real()), __convert_to_half(__c.imag()))
{}

template <class _Up,
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
__enable_if_t<_CCCL_TRAIT(is_constructible, value_type, _Up), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_half(__c.real()), __convert_to_half(__c.imag()))
{}

_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
Expand All @@ -97,8 +130,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
template <class _Up>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Up>& __c)
{
__repr_.x = __c.real();
__repr_.y = __c.imag();
__repr_.x = __convert_to_half(__c.real());
__repr_.y = __convert_to_half(__c.imag());
return *this;
}

Expand Down Expand Up @@ -152,24 +185,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>

_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
{
__repr_.x += __re;
__repr_.x = __hadd(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
{
__repr_.x -= __re;
__repr_.x = __hsub(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
{
__repr_.x *= __re;
__repr_.y *= __re;
__repr_.x = __hmul(__repr_.x, __re);
__repr_.y = __hmul(__repr_.y, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
{
__repr_.x /= __re;
__repr_.y /= __re;
__repr_.x = __hdiv(__repr_.x, __re);
__repr_.y = __hdiv(__repr_.y, __re);
return *this;
}

Expand All @@ -192,9 +225,41 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
}
};

template <> // complex<float>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>::complex(const complex<__half>& __c)
: __re_(__half2float(__c.real()))
, __im_(__half2float(__c.imag()))
{}

template <> // complex<double>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>::complex(const complex<__half>& __c)
: __re_(__half2float(__c.real()))
, __im_(__half2float(__c.imag()))
{}

template <> // complex<float>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>& complex<float>::operator=(const complex<__half>& __c)
{
__re_ = __half2float(__c.real());
__im_ = __half2float(__c.imag());
return *this;
}

template <> // complex<double>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>& complex<double>::operator=(const complex<__half>& __c)
{
__re_ = __half2float(__c.real());
__im_ = __half2float(__c.imag());
return *this;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half arg(__half __re)
{
return _CUDA_VSTD::atan2f(__half(0), __re);
return _CUDA_VSTD::atan2(__int2half_rn(0), __re);
}

// We have performance issues with some trigonometric functions with __half
Expand Down
20 changes: 10 additions & 10 deletions libcudacxx/include/cuda/std/__cuda/cmath_nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,47 @@ _LIBCUDACXX_BEGIN_NAMESPACE_STD
// trigonometric functions
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sin(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __nv_bfloat16(::sin(float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __float2bfloat16(::sin(__bfloat162float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sinh(__nv_bfloat16 __v)
{
return __nv_bfloat16(::sinh(float(__v)));
return __float2bfloat16(::sinh(__bfloat162float(__v)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cos(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __nv_bfloat16(::cos(float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __float2bfloat16(::cos(__bfloat162float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cosh(__nv_bfloat16 __v)
{
return __nv_bfloat16(::cosh(float(__v)));
return __float2bfloat16(::cosh(__bfloat162float(__v)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 exp(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __nv_bfloat16(::exp(float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __float2bfloat16(::exp(__bfloat162float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 hypot(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __nv_bfloat16(::hypot(float(__x), float(__y)));
return __float2bfloat16(::hypot(__bfloat162float(__x), __bfloat162float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 atan2(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __nv_bfloat16(::atan2(float(__x), float(__y)));
return __float2bfloat16(::atan2(__bfloat162float(__x), __bfloat162float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 log(__nv_bfloat16 __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __nv_bfloat16(::log(float(__x)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __float2bfloat16(::log(__bfloat162float(__x)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sqrt(__nv_bfloat16 __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __nv_bfloat16(::sqrt(float(__x)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2bfloat16(::sqrt(__bfloat162float(__x)));))
}

// floating point helper
Expand Down Expand Up @@ -123,7 +123,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY bool isfinite(__nv_bfloat16 __v)

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __constexpr_copysign(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept
{
return __nv_bfloat16(::copysignf(float(__x), float(__y)));
return __float2bfloat16(::copysignf(__bfloat162float(__x), __bfloat162float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 copysign(__nv_bfloat16 __x, __nv_bfloat16 __y)
Expand Down
Loading

0 comments on commit c67b1c3

Please sign in to comment.