Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pow with inexact base and integer exponent, refactor in-place binary operation type support #1814

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,9 @@ def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):


def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
# if the kind of result is different from
# the kind of input, use the default data
# we use default dtype for the resulting kind.
# This guarantees alignment of reciprocal and
# divide output types.
# if the kind of result is different from the kind of input, we use the
# default floating-point dtype for the resulting kind. This guarantees
# alignment of reciprocal and divide output types.
if buf_dt.kind != arg_dtype.kind:
default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
if res_dt == default_dt:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,53 @@ template <typename argT,
unsigned int n_vecs>
class add_inplace_contig_kernel;

/* @brief Types supported by in-place add */
template <typename argTy, typename resTy> struct AddInplaceTypePairSupport
{
/* value if true a kernel for <argTy, resTy> must be instantiated */
static constexpr bool is_defined = std::disjunction< // disjunction is
// C++17 feature,
// supported by
// DPC++ input bool
td_ns::TypePairDefinedEntry<argTy, bool, resTy, bool>,
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
td_ns::TypePairDefinedEntry<argTy, sycl::half, resTy, sycl::half>,
td_ns::TypePairDefinedEntry<argTy, float, resTy, float>,
td_ns::TypePairDefinedEntry<argTy, double, resTy, double>,
td_ns::TypePairDefinedEntry<argTy,
std::complex<float>,
resTy,
std::complex<float>>,
td_ns::TypePairDefinedEntry<argTy,
std::complex<double>,
resTy,
std::complex<double>>,
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};

template <typename fnT, typename argT, typename resT>
struct AddInplaceTypeMapFactory
{
/*! @brief get typeid for output type of x += y */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
if constexpr (AddInplaceTypePairSupport<argT, resT>::is_defined) {
return td_ns::GetTypeid<resT>{}.get();
}
else {
return td_ns::GetTypeid<void>{}.get();
}
}
};

template <typename argTy, typename resTy>
sycl::event
add_inplace_contig_impl(sycl::queue &exec_q,
Expand All @@ -457,9 +504,7 @@ template <typename fnT, typename T1, typename T2> struct AddInplaceContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -497,9 +542,7 @@ struct AddInplaceStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -544,8 +587,7 @@ struct AddInplaceRowMatrixBroadcastFactory
{
fnT get()
{
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (!std::is_same_v<resT, T2>) {
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,44 @@ template <typename argT,
unsigned int n_vecs>
class bitwise_and_inplace_contig_kernel;

/* @brief Types supported by in-place bitwise AND */
template <typename argTy, typename resTy>
struct BitwiseAndInplaceTypePairSupport
{
/* value if true a kernel for <argTy, resTy> must be instantiated */
static constexpr bool is_defined = std::disjunction< // disjunction is
// C++17 feature,
// supported by
// DPC++ input bool
td_ns::TypePairDefinedEntry<argTy, bool, resTy, bool>,
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};

template <typename fnT, typename argT, typename resT>
struct BitwiseAndInplaceTypeMapFactory
{
/*! @brief get typeid for output type of x &= y */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
if constexpr (BitwiseAndInplaceTypePairSupport<argT, resT>::is_defined)
{
return td_ns::GetTypeid<resT>{}.get();
}
else {
return td_ns::GetTypeid<void>{}.get();
}
}
};

template <typename argTy, typename resTy>
sycl::event
bitwise_and_inplace_contig_impl(sycl::queue &exec_q,
Expand All @@ -343,10 +381,7 @@ struct BitwiseAndInplaceContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename BitwiseAndOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!BitwiseAndInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -385,10 +420,7 @@ struct BitwiseAndInplaceStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename BitwiseAndOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!BitwiseAndInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,44 @@ template <typename argT,
unsigned int n_vecs>
class bitwise_left_shift_inplace_contig_kernel;

/* @brief Types supported by in-place bitwise left shift */
template <typename argTy, typename resTy>
struct BitwiseLeftShiftInplaceTypePairSupport
{
/* value if true a kernel for <argTy, resTy> must be instantiated */
static constexpr bool is_defined = std::disjunction< // disjunction is
// C++17 feature,
// supported by
// DPC++ input bool
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};

template <typename fnT, typename argT, typename resT>
struct BitwiseLeftShiftInplaceTypeMapFactory
{
/*! @brief get typeid for output type of x <<= y */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
if constexpr (BitwiseLeftShiftInplaceTypePairSupport<argT,
resT>::is_defined)
{
return td_ns::GetTypeid<resT>{}.get();
}
else {
return td_ns::GetTypeid<void>{}.get();
}
}
};

template <typename argTy, typename resTy>
sycl::event bitwise_left_shift_inplace_contig_impl(
sycl::queue &exec_q,
Expand All @@ -357,9 +395,8 @@ struct BitwiseLeftShiftInplaceContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename BitwiseLeftShiftOutputType<
T1, T2>::value_type,
void>)
if constexpr (!BitwiseLeftShiftInplaceTypePairSupport<T1,
T2>::is_defined)
{
fnT fn = nullptr;
return fn;
Expand Down Expand Up @@ -399,9 +436,8 @@ struct BitwiseLeftShiftInplaceStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename BitwiseLeftShiftOutputType<
T1, T2>::value_type,
void>)
if constexpr (!BitwiseLeftShiftInplaceTypePairSupport<T1,
T2>::is_defined)
{
fnT fn = nullptr;
return fn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,42 @@ template <typename argT,
unsigned int n_vecs>
class bitwise_or_inplace_contig_kernel;

/* @brief Types supported by in-place bitwise OR */
template <typename argTy, typename resTy> struct BitwiseOrInplaceTypePairSupport
{
/* value if true a kernel for <argTy, resTy> must be instantiated */
static constexpr bool is_defined = std::disjunction< // disjunction is
// C++17 feature,
// supported by
// DPC++ input bool
td_ns::TypePairDefinedEntry<argTy, bool, resTy, bool>,
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};

template <typename fnT, typename argT, typename resT>
struct BitwiseOrInplaceTypeMapFactory
{
/*! @brief get typeid for output type of x |= y */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
if constexpr (BitwiseOrInplaceTypePairSupport<argT, resT>::is_defined) {
return td_ns::GetTypeid<resT>{}.get();
}
else {
return td_ns::GetTypeid<void>{}.get();
}
}
};

template <typename argTy, typename resTy>
sycl::event
bitwise_or_inplace_contig_impl(sycl::queue &exec_q,
Expand All @@ -339,10 +375,7 @@ struct BitwiseOrInplaceContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename BitwiseOrOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!BitwiseOrInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -381,10 +414,7 @@ struct BitwiseOrInplaceStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename BitwiseOrOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!BitwiseOrInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,44 @@ template <typename argT,
unsigned int n_vecs>
class bitwise_right_shift_inplace_contig_kernel;

/* @brief Types supported by in-place bitwise right shift */
template <typename argTy, typename resTy>
struct BitwiseRightShiftInplaceTypePairSupport
{
/* value if true a kernel for <argTy, resTy> must be instantiated */
static constexpr bool is_defined = std::disjunction< // disjunction is
// C++17 feature,
// supported by
// DPC++ input bool
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};

template <typename fnT, typename argT, typename resT>
struct BitwiseRightShiftInplaceTypeMapFactory
{
/*! @brief get typeid for output type of x >>= y */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
if constexpr (BitwiseRightShiftInplaceTypePairSupport<argT,
resT>::is_defined)
{
return td_ns::GetTypeid<resT>{}.get();
}
else {
return td_ns::GetTypeid<void>{}.get();
}
}
};

template <typename argTy, typename resTy>
sycl::event bitwise_right_shift_inplace_contig_impl(
sycl::queue &exec_q,
Expand All @@ -361,9 +399,8 @@ struct BitwiseRightShiftInplaceContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename BitwiseRightShiftOutputType<
T1, T2>::value_type,
void>)
if constexpr (!BitwiseRightShiftInplaceTypePairSupport<T1,
T2>::is_defined)
{
fnT fn = nullptr;
return fn;
Expand Down Expand Up @@ -403,9 +440,8 @@ struct BitwiseRightShiftInplaceStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename BitwiseRightShiftOutputType<
T1, T2>::value_type,
void>)
if constexpr (!BitwiseRightShiftInplaceTypePairSupport<T1,
T2>::is_defined)
{
fnT fn = nullptr;
return fn;
Expand Down
Loading
Loading