From faa2fd866a7ec4fe6adb8bc6675ed952ad72ee5b Mon Sep 17 00:00:00 2001 From: XNNPACK Team Date: Tue, 22 Oct 2024 09:30:18 -0700 Subject: [PATCH] Refactor binary-elementwise-nd.cc and binary.cc to reduce the number of test suites via type erasure. This should make the test easier to read and maintain. PiperOrigin-RevId: 688580770 --- src/operator-utils.c | 6 + src/operators/binary-elementwise-nd.c | 6 - src/xnnpack/log.h | 1 + test/BUILD.bazel | 5 +- test/binary-elementwise-nd.cc | 645 ++++++++++++++------------ test/binary.cc | 484 ++++++++++--------- 6 files changed, 593 insertions(+), 554 deletions(-) diff --git a/src/operator-utils.c b/src/operator-utils.c index 9c11701181b..1862d64b44d 100644 --- a/src/operator-utils.c +++ b/src/operator-utils.c @@ -183,3 +183,9 @@ enum xnn_operator_type xnn_reduce_operator_to_operator_type(enum xnn_reduce_oper return xnn_operator_type_invalid; } } + + +const char* xnn_binary_operator_to_string(enum xnn_binary_operator type) { + return xnn_operator_type_to_string( + xnn_binary_operator_to_operator_type(type)); +} diff --git a/src/operators/binary-elementwise-nd.c b/src/operators/binary-elementwise-nd.c index f8a77c842a4..cfe9d5dcd43 100644 --- a/src/operators/binary-elementwise-nd.c +++ b/src/operators/binary-elementwise-nd.c @@ -47,12 +47,6 @@ static uint32_t xnn_datatype_get_log2_element_size(enum xnn_datatype datatype) { } } -static const char* xnn_binary_operator_to_string( - enum xnn_binary_operator type) { - return xnn_operator_type_to_string( - xnn_binary_operator_to_operator_type(type)); -} - static const struct xnn_binary_elementwise_config* init_config( enum xnn_binary_operator type, enum xnn_datatype datatype, int* sign_b) { switch (type) { diff --git a/src/xnnpack/log.h b/src/xnnpack/log.h index e59f3a830cf..5a62c586f89 100644 --- a/src/xnnpack/log.h +++ b/src/xnnpack/log.h @@ -52,6 +52,7 @@ extern "C" { #endif const char* xnn_datatype_to_string(enum xnn_datatype type); +const char* xnn_binary_operator_to_string(enum xnn_binary_operator type); #ifdef __cplusplus } // extern "C" diff --git a/test/BUILD.bazel b/test/BUILD.bazel index a9e9f4f6965..215528540f8 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -1348,7 +1348,9 @@ xnnpack_unit_test( name = "binary_elementwise_nd_test", timeout = "long", srcs = ["binary-elementwise-nd.cc"], - deps = OPERATOR_TEST_DEPS, + deps = OPERATOR_TEST_DEPS + [ + "//:logging", + ], ) xnnpack_unit_test( @@ -1840,6 +1842,7 @@ xnnpack_unit_test( ":replicable_random_device", "//:XNNPACK", "//:buffer", + "//:logging", "//:math", "//:operators", "//:subgraph", diff --git a/test/binary-elementwise-nd.cc b/test/binary-elementwise-nd.cc index e4f26e82a26..f92acce505c 100644 --- a/test/binary-elementwise-nd.cc +++ b/test/binary-elementwise-nd.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -22,72 +23,95 @@ #include #include "xnnpack.h" -#include "xnnpack/math.h" #include "xnnpack/buffer.h" +#include "xnnpack/log.h" +#include "xnnpack/math.h" #include "replicable_random_device.h" +using ::testing::Combine; +using ::testing::Values; + enum class RunMode { kCreateReshapeRun, kEager, }; -template -xnn_datatype datatype_of() { - if (std::is_same::value) { - return xnn_datatype_quint8; - } else if (std::is_same::value) { - return xnn_datatype_qint8; - } else if (std::is_same::value) { - return xnn_datatype_fp16; - } else if (std::is_same::value) { - return xnn_datatype_fp32; - } else if (std::is_same::value) { - return xnn_datatype_int32; - } else { - XNN_UNREACHABLE; +double compute_tolerance(xnn_datatype datatype, double output_ref) { + switch (datatype) { + case xnn_datatype_fp16: + return std::abs(output_ref); + case xnn_datatype_fp32: + return 1.0e-6 * std::abs(output_ref); + default: + return 0.6; } } -template -double compute_tolerance(double output_ref) { - if (std::is_integral::value) { - return 0.6; - } else if (std::is_same::value) { - return 1.0e-3 * std::abs(output_ref); - } else { - return 1.0e-6 * std::abs(output_ref); +template +void randomize_buffer(xnn_datatype datatype, + xnnpack::ReplicableRandomDevice& rng, + std::uniform_real_distribution& dist, + Buffer& buf) { + switch (datatype) { + case xnn_datatype_quint8: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; + case xnn_datatype_qint8: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; + case xnn_datatype_int32: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; + case xnn_datatype_fp16: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; + case xnn_datatype_fp32: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; + default: + break; } } -class BinaryElementwiseOperatorTester { - public: - static std::string ToString(xnn_binary_operator operation_type) { - switch (operation_type) { - case xnn_binary_invalid: - return "Unknown"; - case xnn_binary_add: - return "Add"; - case xnn_binary_copysign: - return "CopySign"; - case xnn_binary_divide: - return "Divide"; - case xnn_binary_maximum: - return "Maximum"; - case xnn_binary_minimum: - return "Minimum"; - case xnn_binary_multiply: - return "Multiply"; - case xnn_binary_prelu: - return "Prelu"; - case xnn_binary_subtract: - return "Subtract"; - case xnn_binary_squared_difference: - return "SquaredDifference"; - default: - return "Unknown"; - } +std::string BinaryOperatorToString(xnn_binary_operator operation_type) { + switch (operation_type) { + case xnn_binary_invalid: + return "Unknown"; + case xnn_binary_add: + return "Add"; + case xnn_binary_copysign: + return "CopySign"; + case xnn_binary_divide: + return "Divide"; + case xnn_binary_maximum: + return "Maximum"; + case xnn_binary_minimum: + return "Minimum"; + case xnn_binary_multiply: + return "Multiply"; + case xnn_binary_prelu: + return "Prelu"; + case xnn_binary_subtract: + return "Subtract"; + case xnn_binary_squared_difference: + return "SquaredDifference"; + default: + return "Unknown"; } +} +class BinaryElementwiseOperatorTester { + public: double Compute(double a, double b) const { switch (operation_type()) { case xnn_binary_add: @@ -213,6 +237,13 @@ class BinaryElementwiseOperatorTester { xnn_binary_operator operation_type() const { return this->operation_type_; } + BinaryElementwiseOperatorTester& datatype(xnn_datatype datatype) { + this->datatype_ = datatype; + return *this; + } + + xnn_datatype datatype() const { return this->datatype_; } + BinaryElementwiseOperatorTester& iterations(size_t iterations) { this->iterations_ = iterations; return *this; @@ -220,18 +251,105 @@ class BinaryElementwiseOperatorTester { size_t iterations() const { return this->iterations_; } - template + // Some combinations aren't implemented. + bool SupportedBinaryNDTest() const { + switch (datatype()) { + case xnn_datatype_quint8: + case xnn_datatype_qint8: + switch (operation_type()) { + case xnn_binary_add: + case xnn_binary_multiply: + case xnn_binary_subtract: + return true; + default: + return false; + } + case xnn_datatype_int32: + switch (operation_type()) { + case xnn_binary_multiply: + return true; + default: + return false; + } + case xnn_datatype_fp16: + switch (operation_type()) { + case xnn_binary_add: + case xnn_binary_divide: + case xnn_binary_maximum: + case xnn_binary_minimum: + case xnn_binary_multiply: + case xnn_binary_subtract: + case xnn_binary_squared_difference: + return true; + default: + return false; + } + case xnn_datatype_fp32: + switch (operation_type()) { + case xnn_binary_add: + case xnn_binary_copysign: + case xnn_binary_divide: + case xnn_binary_maximum: + case xnn_binary_minimum: + case xnn_binary_multiply: + case xnn_binary_subtract: + case xnn_binary_squared_difference: + return true; + default: + return false; + } + default: + return false; + } + } + void Test(RunMode mode) { ASSERT_NE(operation_type(), xnn_binary_invalid); + ASSERT_NE(datatype(), xnn_datatype_invalid); xnnpack::ReplicableRandomDevice rng; - double input_min = std::is_integral::value - ? static_cast(std::numeric_limits::min()) - : 0.01; - double input_max = std::is_integral::value - ? static_cast(std::numeric_limits::max()) - : 1.0; - std::uniform_real_distribution dist(input_min, input_max); + double datatype_min, datatype_max, datatype_low; + size_t datatype_size; + switch (datatype()) { + case xnn_datatype_quint8: + datatype_low = std::numeric_limits::lowest(); + datatype_min = std::numeric_limits::min(); + datatype_max = std::numeric_limits::max(); + datatype_size = sizeof(uint8_t); + break; + case xnn_datatype_qint8: + datatype_low = std::numeric_limits::lowest(); + datatype_min = std::numeric_limits::min(); + datatype_max = std::numeric_limits::max(); + datatype_size = sizeof(int8_t); + break; + case xnn_datatype_int32: + datatype_low = std::numeric_limits::lowest(); + datatype_min = std::numeric_limits::min(); + datatype_max = std::numeric_limits::max(); + datatype_size = sizeof(int32_t); + break; + case xnn_datatype_fp16: + datatype_low = 0.0; // don't use std::numeric_limits here + datatype_min = 0.01; + datatype_max = 1.0; + datatype_size = sizeof(xnn_float16); + break; + case xnn_datatype_fp32: + datatype_low = std::numeric_limits::lowest(); + datatype_min = 0.01; + datatype_max = 1.0; + datatype_size = sizeof(float); + break; + default: + datatype_low = 0; + datatype_min = 0; + datatype_max = 0; + datatype_size = 0; + assert(false); + break; + } + std::uniform_real_distribution dist(datatype_min, datatype_max); // Compute generalized shapes. std::array input1_dims; @@ -267,26 +385,27 @@ class BinaryElementwiseOperatorTester { output_stride *= output_dims[i - 1]; } - xnn_datatype datatype = datatype_of(); xnn_quantization_params input1_quantization = {input1_zero_point(), input1_scale()}; xnn_quantization_params input2_quantization = {input2_zero_point(), input2_scale()}; xnn_quantization_params output_quantization = {output_zero_point(), output_scale()}; - xnnpack::Buffer input1(XNN_EXTRA_BYTES / sizeof(T) + num_input1_elements()); - xnnpack::Buffer input2(XNN_EXTRA_BYTES / sizeof(T) + num_input2_elements()); - xnnpack::Buffer output(num_output_elements); + xnnpack::Buffer input1(XNN_EXTRA_BYTES / sizeof(char) + + num_input1_elements() * datatype_size); + xnnpack::Buffer input2(XNN_EXTRA_BYTES / sizeof(char) + + num_input2_elements() * datatype_size); + xnnpack::Buffer output(num_output_elements * datatype_size); for (size_t iteration = 0; iteration < iterations(); iteration++) { - std::generate(input1.begin(), input1.end(), [&]() { return dist(rng); }); - std::generate(input2.begin(), input2.end(), [&]() { return dist(rng); }); + randomize_buffer(datatype(), rng, dist, input1); + randomize_buffer(datatype(), rng, dist, input2); if (mode == RunMode::kCreateReshapeRun) { // Create, setup, run, and destroy a binary elementwise operator. ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); xnn_operator_t binary_elementwise_op = nullptr; xnn_status status = xnn_create_binary_elementwise_nd( - operation_type(), datatype, &input1_quantization, + operation_type(), datatype(), &input1_quantization, &input2_quantization, &output_quantization, 0, &binary_elementwise_op); if (status == xnn_status_unsupported_hardware) { @@ -315,7 +434,7 @@ class BinaryElementwiseOperatorTester { } else if (mode == RunMode::kEager) { // Run a binary elementwise operator without creating it. xnn_status status = xnn_run_binary_elementwise_nd( - operation_type(), datatype, &input1_quantization, + operation_type(), datatype(), &input1_quantization, &input2_quantization, &output_quantization, 0, input1_dims.size(), input1_dims.data(), input2_dims.size(), input2_dims.data(), input1.data(), input2.data(), output.data(), @@ -335,31 +454,47 @@ class BinaryElementwiseOperatorTester { for (size_t l = 0; l < output_dims[3]; l++) { for (size_t m = 0; m < output_dims[4]; m++) { for (size_t n = 0; n < output_dims[5]; n++) { + const auto value_of = + [&](const xnnpack::Buffer& buf, + const std::array& + strides) -> double { + const size_t index = i * strides[0] + j * strides[1] + + k * strides[2] + l * strides[3] + + m * strides[4] + n * strides[5]; + const char* base = &buf[0]; + switch (datatype()) { + case xnn_datatype_quint8: + return reinterpret_cast(base)[index]; + case xnn_datatype_qint8: + return reinterpret_cast(base)[index]; + case xnn_datatype_int32: + return reinterpret_cast(base)[index]; + case xnn_datatype_fp16: + return reinterpret_cast( + base)[index]; + case xnn_datatype_fp32: + return reinterpret_cast(base)[index]; + default: + return std::nanf(""); + } + }; const double input1_value = input1_scale() * - (input1[i * input1_strides[0] + j * input1_strides[1] + - k * input1_strides[2] + l * input1_strides[3] + - m * input1_strides[4] + n * input1_strides[5]] - - input1_zero_point()); + (value_of(input1, input1_strides) - input1_zero_point()); const double input2_value = input2_scale() * - (input2[i * input2_strides[0] + j * input2_strides[1] + - k * input2_strides[2] + l * input2_strides[3] + - m * input2_strides[4] + n * input2_strides[5]] - - input2_zero_point()); + (value_of(input2, input2_strides) - input2_zero_point()); double output_ref = Compute(input1_value, input2_value) / output_scale() + output_zero_point(); - const size_t index = - i * output_strides[0] + j * output_strides[1] + - k * output_strides[2] + l * output_strides[3] + - m * output_strides[4] + n * output_strides[5]; - if (output_ref < std::numeric_limits::lowest() || - output_ref > std::numeric_limits::max()) { + if (output_ref < datatype_low || output_ref > datatype_max) { // This is expected to overflow. } else { - const double tolerance = compute_tolerance(output_ref); - ASSERT_NEAR(output[index], output_ref, tolerance) + const double tolerance = + compute_tolerance(datatype(), output_ref); + const double output_value = + value_of(output, output_strides); + ASSERT_NEAR(output_value, output_ref, tolerance) << "input1_value = " << input1_value << ", " << "input2_value = " << input2_value << ", " << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " @@ -390,6 +525,7 @@ class BinaryElementwiseOperatorTester { int32_t output_zero_point_{0}; float output_scale_{1.0f}; xnn_binary_operator operation_type_{xnn_binary_invalid}; + xnn_datatype datatype_{xnn_datatype_invalid}; size_t iterations_{3}; }; @@ -420,7 +556,6 @@ std::vector RandomBroadcast(Rng& rng, std::vector dims) { return dims; } -template void RunBinaryOpTester(RunMode run_mode, BinaryElementwiseOperatorTester& tester) { xnnpack::ReplicableRandomDevice rng; @@ -428,253 +563,169 @@ void RunBinaryOpTester(RunMode run_mode, std::vector output_shape = RandomShape(rng); tester.input1_shape(RandomBroadcast(rng, output_shape)) .input2_shape(RandomBroadcast(rng, output_shape)); - tester.Test(run_mode); + tester.Test(run_mode); } } -template -void BinaryNDTestImpl(const Params& params) { - RunMode mode = std::get<0>(params); - xnn_binary_operator op = std::get<1>(params); +struct Param { + using TupleT = std::tuple; + explicit Param(TupleT p) + : datatype(std::get<0>(p)), + run_mode(std::get<1>(p)), + binary_operator(std::get<2>(p)) {} + + std::string Name() const { + std::stringstream sstr; + if (run_mode == RunMode::kEager) { + sstr << "Eager_"; + } else { + sstr << "CreateReshapeRun_"; + } + sstr << xnn_datatype_to_string(datatype) << "_" + << xnn_binary_operator_to_string(binary_operator); + std::string s = sstr.str(); + // Test names must be alphanumeric with no spaces + std::replace(s.begin(), s.end(), ' ', '_'); + std::replace(s.begin(), s.end(), '(', '_'); + std::replace(s.begin(), s.end(), ')', '_'); + return s; + } + + xnn_datatype datatype; + RunMode run_mode; + xnn_binary_operator binary_operator; +}; + +class BinaryNDTest : public testing::TestWithParam {}; + +TEST_P(BinaryNDTest, op) { BinaryElementwiseOperatorTester tester; - tester.operation_type(op); - RunBinaryOpTester(mode, tester); + tester.operation_type(GetParam().binary_operator); + tester.datatype(GetParam().datatype); +#ifdef XNN_EXCLUDE_F16_TESTS + if (GetParam().datatype == xnn_datatype_fp16) { + GTEST_SKIP(); + } +#endif + if (!tester.SupportedBinaryNDTest()) { + GTEST_SKIP(); + } + RunBinaryOpTester(GetParam().run_mode, tester); } -template -class BinaryNDTest - : public testing::TestWithParam< - std::tuple> {}; - -using BinaryNDTestQS8 = BinaryNDTest; -using BinaryNDTestQU8 = BinaryNDTest; -#ifndef XNN_EXCLUDE_F16_TESTS -using BinaryNDTestF16 = BinaryNDTest; -#endif // XNN_EXCLUDE_F16_TESTS -using BinaryNDTestF32 = BinaryNDTest; -using BinaryNDTestS32 = BinaryNDTest; - -TEST_P(BinaryNDTestQS8, op) { BinaryNDTestImpl(GetParam()); } -TEST_P(BinaryNDTestQU8, op) { BinaryNDTestImpl(GetParam()); } -#ifndef XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryNDTestF16, op) { BinaryNDTestImpl(GetParam()); } -#endif // XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryNDTestF32, op) { BinaryNDTestImpl(GetParam()); } -TEST_P(BinaryNDTestS32, op) { BinaryNDTestImpl(GetParam()); } - -std::string ToString(const std::tuple& param) { - return BinaryElementwiseOperatorTester::ToString(std::get<1>(param)); +// We do the full Cartesian combination here, but some are inappropriate +// and will be skipped for certain combinations. +INSTANTIATE_TEST_SUITE_P( + BinaryNDTest, BinaryNDTest, + testing::ConvertGenerator(Combine( + Values(xnn_datatype_quint8, xnn_datatype_qint8, xnn_datatype_fp16, + xnn_datatype_fp32, xnn_datatype_int32), + Values(RunMode::kCreateReshapeRun, RunMode::kEager), + Values(xnn_binary_add, xnn_binary_copysign, xnn_binary_divide, + xnn_binary_maximum, xnn_binary_minimum, xnn_binary_multiply, + xnn_binary_prelu, xnn_binary_subtract, + xnn_binary_squared_difference))), + [](const auto& info) { return info.param.Name(); }); + +class QuantizedTest : public testing::TestWithParam {}; + +int32_t GetMin(xnn_datatype datatype) { + switch (datatype) { + case xnn_datatype_quint8: + return std::numeric_limits::min(); + case xnn_datatype_qint8: + return std::numeric_limits::min(); + default: + assert(false); + return 0; + } } -INSTANTIATE_TEST_SUITE_P( - CreateReshapeRun, BinaryNDTestQS8, - testing::Combine(testing::Values(RunMode::kCreateReshapeRun), - testing::Values(xnn_binary_add, xnn_binary_subtract, - xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P(Eager, BinaryNDTestQS8, - testing::Combine(testing::Values(RunMode::kEager), - testing::Values(xnn_binary_add, - xnn_binary_subtract, - xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P( - CreateReshapeRun, BinaryNDTestQU8, - testing::Combine(testing::Values(RunMode::kCreateReshapeRun), - testing::Values(xnn_binary_add, xnn_binary_subtract, - xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P(Eager, BinaryNDTestQU8, - testing::Combine(testing::Values(RunMode::kEager), - testing::Values(xnn_binary_add, - xnn_binary_subtract, - xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); -#ifndef XNN_EXCLUDE_F16_TESTS -INSTANTIATE_TEST_SUITE_P( - CreateReshapeRun, BinaryNDTestF16, - testing::Combine( - testing::Values(RunMode::kCreateReshapeRun), - testing::Values(xnn_binary_add, xnn_binary_divide, xnn_binary_maximum, - xnn_binary_minimum, xnn_binary_multiply, - xnn_binary_squared_difference, xnn_binary_subtract)), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P( - Eager, BinaryNDTestF16, - testing::Combine( - testing::Values(RunMode::kEager), - testing::Values(xnn_binary_add, xnn_binary_divide, xnn_binary_maximum, - xnn_binary_minimum, xnn_binary_multiply, - xnn_binary_squared_difference, xnn_binary_subtract)), - [](const auto& info) { return ToString(info.param); }); -#endif -INSTANTIATE_TEST_SUITE_P( - CreateReshapeRun, BinaryNDTestF32, - testing::Combine(testing::Values(RunMode::kCreateReshapeRun), - testing::Values(xnn_binary_add, xnn_binary_copysign, - xnn_binary_divide, xnn_binary_maximum, - xnn_binary_minimum, xnn_binary_multiply, - xnn_binary_subtract, - xnn_binary_squared_difference)), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P( - Eager, BinaryNDTestF32, - testing::Combine(testing::Values(RunMode::kEager), - testing::Values(xnn_binary_add, xnn_binary_divide, - xnn_binary_maximum, xnn_binary_minimum, - xnn_binary_multiply, xnn_binary_subtract, - xnn_binary_squared_difference)), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P( - CreateReshapeRun, BinaryNDTestS32, - testing::Combine(testing::Values(RunMode::kCreateReshapeRun), - testing::Values(xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P(Eager, BinaryNDTestS32, - testing::Combine(testing::Values(RunMode::kEager), - testing::Values(xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); - -template -void QuantizedTest_Input1Scale(Params params) { +int32_t GetMax(xnn_datatype datatype) { + switch (datatype) { + case xnn_datatype_quint8: + return std::numeric_limits::max(); + case xnn_datatype_qint8: + return std::numeric_limits::max(); + default: + assert(false); + return 0; + } +} + +TEST_P(QuantizedTest, input1_scale) { for (float input1_scale = 0.1f; input1_scale <= 10.0f; input1_scale *= 3.14f) { - RunBinaryOpTester(std::get<0>(params), - BinaryElementwiseOperatorTester() - .operation_type(std::get<1>(params)) - .input1_scale(input1_scale)); + RunBinaryOpTester(GetParam().run_mode, + BinaryElementwiseOperatorTester() + .operation_type(GetParam().binary_operator) + .datatype(GetParam().datatype) + .input1_scale(input1_scale)); } } -template -void QuantizedTest_Input1ZeroPoint(Params params) { - for (int32_t input1_zero_point = std::numeric_limits::min(); - input1_zero_point <= std::numeric_limits::max(); +TEST_P(QuantizedTest, input1_zero_point) { + for (int32_t input1_zero_point = GetMin(GetParam().datatype); + input1_zero_point <= GetMax(GetParam().datatype); input1_zero_point += 51) { - RunBinaryOpTester(std::get<0>(params), - BinaryElementwiseOperatorTester() - .operation_type(std::get<1>(params)) - .input1_zero_point(input1_zero_point)); + RunBinaryOpTester(GetParam().run_mode, + BinaryElementwiseOperatorTester() + .operation_type(GetParam().binary_operator) + .datatype(GetParam().datatype) + .input1_zero_point(input1_zero_point)); } } -template -void QuantizedTest_Input2Scale(Params params) { +TEST_P(QuantizedTest, input2_scale) { for (float input2_scale = 0.1f; input2_scale <= 10.0f; input2_scale *= 3.14f) { - RunBinaryOpTester(std::get<0>(params), - BinaryElementwiseOperatorTester() - .operation_type(std::get<1>(params)) - .input2_scale(input2_scale)); + RunBinaryOpTester(GetParam().run_mode, + BinaryElementwiseOperatorTester() + .operation_type(GetParam().binary_operator) + .datatype(GetParam().datatype) + .input2_scale(input2_scale)); } } -template -void QuantizedTest_Input2ZeroPoint(Params params) { - for (int32_t input2_zero_point = std::numeric_limits::min(); - input2_zero_point <= std::numeric_limits::max(); +TEST_P(QuantizedTest, input2_zero_point) { + for (int32_t input2_zero_point = GetMin(GetParam().datatype); + input2_zero_point <= GetMax(GetParam().datatype); input2_zero_point += 51) { - RunBinaryOpTester(std::get<0>(params), - BinaryElementwiseOperatorTester() - .operation_type(std::get<1>(params)) - .input2_zero_point(input2_zero_point)); + RunBinaryOpTester(GetParam().run_mode, + BinaryElementwiseOperatorTester() + .operation_type(GetParam().binary_operator) + .datatype(GetParam().datatype) + .input2_zero_point(input2_zero_point)); } } -template -void QuantizedTest_OutputScale(Params params) { +TEST_P(QuantizedTest, output_scale) { for (float output_scale = 0.1f; output_scale <= 10.0f; output_scale *= 3.14f) { - RunBinaryOpTester(std::get<0>(params), - BinaryElementwiseOperatorTester() - .operation_type(std::get<1>(params)) - .output_scale(output_scale)); + RunBinaryOpTester(GetParam().run_mode, + BinaryElementwiseOperatorTester() + .operation_type(GetParam().binary_operator) + .datatype(GetParam().datatype) + .output_scale(output_scale)); } } -template -void QuantizedTest_OutputZeroPoint(Params params) { - for (int32_t output_zero_point = std::numeric_limits::min(); - output_zero_point <= std::numeric_limits::max(); +TEST_P(QuantizedTest, output_zero_point) { + for (int32_t output_zero_point = GetMin(GetParam().datatype); + output_zero_point <= GetMax(GetParam().datatype); output_zero_point += 51) { - RunBinaryOpTester(std::get<0>(params), - BinaryElementwiseOperatorTester() - .operation_type(std::get<1>(params)) - .output_zero_point(output_zero_point)); + RunBinaryOpTester(GetParam().run_mode, + BinaryElementwiseOperatorTester() + .operation_type(GetParam().binary_operator) + .datatype(GetParam().datatype) + .output_zero_point(output_zero_point)); } } -template -class QuantizedTest - : public testing::TestWithParam> { -}; - -using QuantizedTestQS8 = QuantizedTest; - -TEST_P(QuantizedTestQS8, input1_scale) { - QuantizedTest_Input1Scale(GetParam()); -} -TEST_P(QuantizedTestQS8, input1_zero_point) { - QuantizedTest_Input1ZeroPoint(GetParam()); -} -TEST_P(QuantizedTestQS8, input2_scale) { - QuantizedTest_Input2Scale(GetParam()); -} -TEST_P(QuantizedTestQS8, input2_zero_point) { - QuantizedTest_Input2ZeroPoint(GetParam()); -} - -TEST_P(QuantizedTestQS8, output_scale) { - QuantizedTest_OutputScale(GetParam()); -} -TEST_P(QuantizedTestQS8, output_zero_point) { - QuantizedTest_OutputZeroPoint(GetParam()); -} - -INSTANTIATE_TEST_SUITE_P( - CreateReshapeRun, QuantizedTestQS8, - testing::Combine(testing::Values(RunMode::kCreateReshapeRun), - testing::Values(xnn_binary_add, xnn_binary_subtract, - xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P(Eager, QuantizedTestQS8, - testing::Combine(testing::Values(RunMode::kEager), - testing::Values(xnn_binary_add, - xnn_binary_subtract, - xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); - -using QuantizedTestQU8 = QuantizedTest; - -TEST_P(QuantizedTestQU8, input1_scale) { - QuantizedTest_Input1Scale(GetParam()); -} -TEST_P(QuantizedTestQU8, input1_zero_point) { - QuantizedTest_Input1ZeroPoint(GetParam()); -} -TEST_P(QuantizedTestQU8, input2_scale) { - QuantizedTest_Input2Scale(GetParam()); -} -TEST_P(QuantizedTestQU8, input2_zero_point) { - QuantizedTest_Input2ZeroPoint(GetParam()); -} - -TEST_P(QuantizedTestQU8, output_scale) { - QuantizedTest_OutputScale(GetParam()); -} -TEST_P(QuantizedTestQU8, output_zero_point) { - QuantizedTest_OutputZeroPoint(GetParam()); -} - INSTANTIATE_TEST_SUITE_P( - CreateReshapeRun, QuantizedTestQU8, - testing::Combine(testing::Values(RunMode::kCreateReshapeRun), - testing::Values(xnn_binary_add, xnn_binary_subtract, - xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P(Eager, QuantizedTestQU8, - testing::Combine(testing::Values(RunMode::kEager), - testing::Values(xnn_binary_add, - xnn_binary_subtract, - xnn_binary_multiply)), - [](const auto& info) { return ToString(info.param); }); + QuantizedTest, QuantizedTest, + testing::ConvertGenerator(Combine( + Values(xnn_datatype_quint8, xnn_datatype_qint8), + Values(RunMode::kCreateReshapeRun, RunMode::kEager), + Values(xnn_binary_add, xnn_binary_subtract, xnn_binary_multiply))), + [](const auto& info) { return info.param.Name(); }); diff --git a/test/binary.cc b/test/binary.cc index 7280dfd90f4..e70bf57f8f9 100644 --- a/test/binary.cc +++ b/test/binary.cc @@ -14,85 +14,23 @@ #include #include #include +#include +#include +#include #include #include #include #include "xnnpack.h" +#include "xnnpack/buffer.h" +#include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" -#include "xnnpack/buffer.h" #include "replicable_random_device.h" -template -class NumericLimits { - public: - static constexpr T min() { return std::numeric_limits::min(); } - static constexpr T max() { return std::numeric_limits::max(); } -}; - -template <> -class NumericLimits { - public: - static xnn_float16 min() { return -std::numeric_limits::infinity(); } - static xnn_float16 max() { return +std::numeric_limits::infinity(); } -}; - -template -struct UniformDistribution { - std::uniform_real_distribution dist{-10.0f, 10.0f}; - - template - T operator()(Generator& g) { - return dist(g); - } -}; - -template <> -struct UniformDistribution { - std::uniform_real_distribution dist{-10.0f, 10.0f}; - - template - xnn_float16 operator()(Generator& g) { - return dist(g); - } -}; - -template <> -struct UniformDistribution { - std::uniform_int_distribution dist{std::numeric_limits::min(), - std::numeric_limits::max()}; - - template - int8_t operator()(Generator& g) { - return dist(g); - } -}; - -template <> -struct UniformDistribution { - std::uniform_int_distribution dist{ - std::numeric_limits::min(), - std::numeric_limits::max()}; - - template - uint8_t operator()(Generator& g) { - return dist(g); - } -}; - -template <> -struct UniformDistribution { - std::uniform_int_distribution dist{ - std::numeric_limits::min(), - std::numeric_limits::max()}; - - template - int32_t operator()(Generator& g) { - return dist(g); - } -}; +using ::testing::Combine; +using ::testing::Values; template size_t RandomRank(Rng& rng) { @@ -112,11 +50,21 @@ std::vector RandomShape(Rng& rng) { return RandomShape(rng, RandomRank(rng)); } -template -xnn_quantization_params RandomQuantization(Rng& rng) { - if (std::is_same::value || std::is_same::value) { +template +xnn_quantization_params RandomQuantization(xnn_datatype datatype, Rng& rng) { + if (datatype == xnn_datatype_qint8) { + std::uniform_int_distribution dist{std::numeric_limits::min(), + std::numeric_limits::max()}; return { - static_cast(UniformDistribution()(rng)), + static_cast(dist(rng)), + std::uniform_real_distribution(0.1f, 5.0f)(rng), + }; + } else if (datatype == xnn_datatype_quint8) { + std::uniform_int_distribution dist{ + std::numeric_limits::min(), + std::numeric_limits::max()}; + return { + static_cast(dist(rng)), std::uniform_real_distribution(0.1f, 5.0f)(rng), }; } else { @@ -150,46 +98,39 @@ bool is_quantized(xnn_datatype t) { } } -static const char* binary_operator_to_string( - xnn_binary_operator operation_type) { - switch (operation_type) { - case xnn_binary_add: - return "Add"; - case xnn_binary_copysign: - return "CopySign"; - case xnn_binary_divide: - return "Divide"; - case xnn_binary_maximum: - return "Maximum"; - case xnn_binary_minimum: - return "Minimum"; - case xnn_binary_multiply: - return "Multiply"; - case xnn_binary_prelu: - return "Prelu"; - case xnn_binary_subtract: - return "Subtract"; - case xnn_binary_squared_difference: - return "SquaredDifference"; +template +void randomize_buffer(xnn_datatype datatype, + xnnpack::ReplicableRandomDevice& rng, + std::uniform_real_distribution& dist, + Buffer& buf) { + switch (datatype) { + case xnn_datatype_quint8: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; + case xnn_datatype_qint8: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; + case xnn_datatype_int32: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; + case xnn_datatype_fp16: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; + case xnn_datatype_fp32: + std::generate(reinterpret_cast(buf.begin()), + reinterpret_cast(buf.end()), + [&]() { return dist(rng); }); + break; default: - return "Unknown"; - } -} - -template -xnn_datatype datatype_of() { - if (std::is_same::value) { - return xnn_datatype_quint8; - } else if (std::is_same::value) { - return xnn_datatype_qint8; - } else if (std::is_same::value) { - return xnn_datatype_fp16; - } else if (std::is_same::value) { - return xnn_datatype_fp32; - } else if (std::is_same::value) { - return xnn_datatype_int32; - } else { - XNN_UNREACHABLE; + break; } } @@ -209,29 +150,7 @@ size_t xnn_datatype_size(xnn_datatype datatype) { } } -// TODO(dsharlet): We need a place to put helper functions like this. -// XNNPACK's built-in equivalent helpers are not implemented in release -// builds... -const char* datatype_to_string(xnn_datatype datatype) { - switch (datatype) { - case xnn_datatype_qint8: - return "qint8"; - case xnn_datatype_quint8: - return "quint8"; - case xnn_datatype_fp16: - return "fp16"; - case xnn_datatype_fp32: - return "fp32"; - case xnn_datatype_int32: - return "int32"; - default: - XNN_UNREACHABLE; - } -} - -template -void MatchesOperatorApi(xnn_binary_operator binary_op) { - xnn_datatype datatype = datatype_of(); +void MatchesOperatorApi(xnn_datatype datatype, xnn_binary_operator binary_op) { xnnpack::ReplicableRandomDevice rng; std::vector input0_dims = RandomShape(rng); @@ -285,24 +204,56 @@ void MatchesOperatorApi(xnn_binary_operator binary_op) { output_dims.erase(output_dims.begin()); } - xnnpack::Buffer input0(NumElements(input0_dims) + - XNN_EXTRA_BYTES / sizeof(T)); - xnnpack::Buffer input1(NumElements(input1_dims) + - XNN_EXTRA_BYTES / sizeof(T)); - xnnpack::Buffer operator_output( - NumElements(output_dims)); - xnnpack::Buffer subgraph_output( - NumElements(output_dims)); - UniformDistribution dist; - std::generate(input0.begin(), input0.end(), [&]() { return dist(rng); }); - std::generate(input1.begin(), input1.end(), [&]() { return dist(rng); }); + size_t datatype_size = xnn_datatype_size(datatype); + xnnpack::Buffer input0( + NumElements(input0_dims) * datatype_size + + XNN_EXTRA_BYTES / sizeof(char)); + xnnpack::Buffer input1( + NumElements(input1_dims) * datatype_size + + XNN_EXTRA_BYTES / sizeof(char)); + xnnpack::Buffer operator_output( + NumElements(output_dims) * datatype_size); + xnnpack::Buffer subgraph_output( + NumElements(output_dims) * datatype_size); + + double datatype_min, datatype_max; + switch (datatype) { + case xnn_datatype_quint8: + datatype_min = std::numeric_limits::min(); + datatype_max = std::numeric_limits::max(); + break; + case xnn_datatype_qint8: + datatype_min = std::numeric_limits::min(); + datatype_max = std::numeric_limits::max(); + break; + case xnn_datatype_int32: + datatype_min = std::numeric_limits::min(); + datatype_max = std::numeric_limits::max(); + break; + case xnn_datatype_fp16: + case xnn_datatype_fp32: + datatype_min = -10.0; + datatype_max = 10.0; + break; + default: + datatype_min = 0; + datatype_max = 0; + assert(false); + break; + } + std::uniform_real_distribution dist(datatype_min, datatype_max); + randomize_buffer(datatype, rng, dist, input0); + randomize_buffer(datatype, rng, dist, input1); ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); bool quantized = is_quantized(datatype); - xnn_quantization_params input0_quantization = RandomQuantization(rng); - xnn_quantization_params input1_quantization = RandomQuantization(rng); - xnn_quantization_params output_quantization = RandomQuantization(rng); + xnn_quantization_params input0_quantization = + RandomQuantization(datatype, rng); + xnn_quantization_params input1_quantization = + RandomQuantization(datatype, rng); + xnn_quantization_params output_quantization = + RandomQuantization(datatype, rng); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; @@ -815,118 +766,151 @@ void DegenerateDimension(xnn_datatype datatype, xnn_binary_operator binary_op) { xnn_status_success); } -template -class BinaryTest : public testing::TestWithParam {}; - -using BinaryTestQS8 = BinaryTest; -using BinaryTestQU8 = BinaryTest; -#ifndef XNN_EXCLUDE_F16_TESTS -using BinaryTestF16 = BinaryTest; -#endif // XNN_EXCLUDE_F16_TESTS -using BinaryTestF32 = BinaryTest; -using BinaryTestS32 = BinaryTest; +struct Param { + using TupleT = std::tuple; + explicit Param(TupleT p) + : datatype(std::get<0>(p)), binary_operator(std::get<1>(p)) {} + + std::string Name() const { + std::stringstream sstr; + sstr << xnn_datatype_to_string(datatype) << "_" + << xnn_binary_operator_to_string(binary_operator); + std::string s = sstr.str(); + // Test names must be alphanumeric with no spaces + std::replace(s.begin(), s.end(), ' ', '_'); + std::replace(s.begin(), s.end(), '(', '_'); + std::replace(s.begin(), s.end(), ')', '_'); + return s; + } -TEST_P(BinaryTestQS8, matches_operator_api) { - MatchesOperatorApi(GetParam()); -} -TEST_P(BinaryTestQU8, matches_operator_api) { - MatchesOperatorApi(GetParam()); -} -#ifndef XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF16, matches_operator_api) { - MatchesOperatorApi(GetParam()); -} -#endif // XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF32, matches_operator_api) { - MatchesOperatorApi(GetParam()); -} -TEST_P(BinaryTestS32, matches_operator_api) { - MatchesOperatorApi(GetParam()); -} + xnn_datatype datatype; + xnn_binary_operator binary_operator; +}; -#ifndef XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF16, reshape) { Reshape(xnn_datatype_fp16, GetParam()); } -#endif // XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF32, reshape) { Reshape(xnn_datatype_fp32, GetParam()); } -TEST_P(BinaryTestS32, reshape) { Reshape(xnn_datatype_int32, GetParam()); } +class BinaryTest : public testing::TestWithParam {}; -#ifndef XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF16, reshape_broadcast_dim0) { - ReshapeBroadcastDim0(xnn_datatype_fp16, GetParam()); -} -#endif // XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF32, reshape_broadcast_dim0) { - ReshapeBroadcastDim0(xnn_datatype_fp32, GetParam()); -} -TEST_P(BinaryTestS32, reshape_broadcast_dim0) { - ReshapeBroadcastDim0(xnn_datatype_int32, GetParam()); +// Some combinations aren't implemented. +bool SupportedBinaryTest(xnn_datatype datatype, xnn_binary_operator binary_op) { + switch (datatype) { + case xnn_datatype_quint8: + case xnn_datatype_qint8: + switch (binary_op) { + case xnn_binary_add: + case xnn_binary_multiply: + case xnn_binary_subtract: + return true; + default: + return false; + } + case xnn_datatype_int32: + switch (binary_op) { + case xnn_binary_multiply: + return true; + default: + return false; + } + case xnn_datatype_fp16: +#ifdef XNN_EXCLUDE_F16_TESTS + return false; +#else + switch (binary_op) { + case xnn_binary_add: + case xnn_binary_divide: + case xnn_binary_maximum: + case xnn_binary_minimum: + case xnn_binary_multiply: + case xnn_binary_prelu: + case xnn_binary_squared_difference: + case xnn_binary_subtract: + return true; + default: + return false; + } +#endif + case xnn_datatype_fp32: + switch (binary_op) { + case xnn_binary_add: + case xnn_binary_copysign: + case xnn_binary_divide: + case xnn_binary_maximum: + case xnn_binary_minimum: + case xnn_binary_multiply: + case xnn_binary_prelu: + case xnn_binary_subtract: + case xnn_binary_squared_difference: + return true; + default: + return false; + } + default: + return false; + } } -#ifndef XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF16, reshape_broadcast_1d) { - ReshapeBroadcast1D(xnn_datatype_fp16, GetParam()); -} -#endif // XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF32, reshape_broadcast_1d) { - ReshapeBroadcast1D(xnn_datatype_fp32, GetParam()); -} -TEST_P(BinaryTestS32, reshape_broadcast_1d) { - ReshapeBroadcast1D(xnn_datatype_int32, GetParam()); +TEST_P(BinaryTest, matches_operator_api) { + if (!SupportedBinaryTest(GetParam().datatype, GetParam().binary_operator)) { + GTEST_SKIP(); + } + MatchesOperatorApi(GetParam().datatype, GetParam().binary_operator); } -#ifndef XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF16, reshape_broadcast_2d) { - ReshapeBroadcast2D(xnn_datatype_fp16, GetParam()); -} -#endif // XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF32, reshape_broadcast_2d) { - ReshapeBroadcast2D(xnn_datatype_fp32, GetParam()); -} -TEST_P(BinaryTestS32, reshape_broadcast_2d) { - ReshapeBroadcast2D(xnn_datatype_int32, GetParam()); +TEST_P(BinaryTest, reshape) { + if (!SupportedBinaryTest(GetParam().datatype, GetParam().binary_operator)) { + GTEST_SKIP(); + } + if (is_quantized(GetParam().datatype)) { + GTEST_SKIP(); + } + Reshape(GetParam().datatype, GetParam().binary_operator); } -#ifndef XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF16, degenerate_dimension) { - DegenerateDimension(xnn_datatype_fp16, GetParam()); -} -#endif // XNN_EXCLUDE_F16_TESTS -TEST_P(BinaryTestF32, degenerate_dimension) { - DegenerateDimension(xnn_datatype_fp32, GetParam()); +TEST_P(BinaryTest, reshape_broadcast_dim0) { + if (!SupportedBinaryTest(GetParam().datatype, GetParam().binary_operator)) { + GTEST_SKIP(); + } + if (is_quantized(GetParam().datatype)) { + GTEST_SKIP(); + } + ReshapeBroadcastDim0(GetParam().datatype, GetParam().binary_operator); } -TEST_P(BinaryTestS32, degenerate_dimension) { - DegenerateDimension(xnn_datatype_int32, GetParam()); + +TEST_P(BinaryTest, reshape_broadcast_1d) { + if (!SupportedBinaryTest(GetParam().datatype, GetParam().binary_operator)) { + GTEST_SKIP(); + } + if (is_quantized(GetParam().datatype)) { + GTEST_SKIP(); + } + ReshapeBroadcast1D(GetParam().datatype, GetParam().binary_operator); } -std::string ToString(xnn_binary_operator op) { - return binary_operator_to_string(op); +TEST_P(BinaryTest, reshape_broadcast_2d) { + if (!SupportedBinaryTest(GetParam().datatype, GetParam().binary_operator)) { + GTEST_SKIP(); + } + if (is_quantized(GetParam().datatype)) { + GTEST_SKIP(); + } + ReshapeBroadcast2D(GetParam().datatype, GetParam().binary_operator); } -INSTANTIATE_TEST_SUITE_P(test, BinaryTestQS8, - testing::Values(xnn_binary_add, xnn_binary_subtract, - xnn_binary_multiply), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P(test, BinaryTestQU8, - testing::Values(xnn_binary_add, xnn_binary_subtract, - xnn_binary_multiply), - [](const auto& info) { return ToString(info.param); }); -#ifndef XNN_EXCLUDE_F16_TESTS -INSTANTIATE_TEST_SUITE_P(test, BinaryTestF16, - testing::Values(xnn_binary_add, xnn_binary_subtract, - xnn_binary_multiply, xnn_binary_divide, - xnn_binary_maximum, xnn_binary_minimum, - xnn_binary_squared_difference, - xnn_binary_prelu), - [](const auto& info) { return ToString(info.param); }); -#endif -INSTANTIATE_TEST_SUITE_P(test, BinaryTestF32, - testing::Values(xnn_binary_add, xnn_binary_subtract, - xnn_binary_multiply, xnn_binary_divide, - xnn_binary_maximum, xnn_binary_minimum, - xnn_binary_copysign, - xnn_binary_squared_difference, - xnn_binary_prelu), - [](const auto& info) { return ToString(info.param); }); -INSTANTIATE_TEST_SUITE_P(test, BinaryTestS32, - testing::Values(xnn_binary_multiply), - [](const auto& info) { return ToString(info.param); }); +TEST_P(BinaryTest, degenerate_dimension) { + if (!SupportedBinaryTest(GetParam().datatype, GetParam().binary_operator)) { + GTEST_SKIP(); + } + if (is_quantized(GetParam().datatype)) { + GTEST_SKIP(); + } + DegenerateDimension(GetParam().datatype, GetParam().binary_operator); +} + +INSTANTIATE_TEST_SUITE_P( + BinaryTest, BinaryTest, + testing::ConvertGenerator(Combine( + Values(xnn_datatype_quint8, xnn_datatype_qint8, xnn_datatype_fp16, + xnn_datatype_fp32, xnn_datatype_int32), + Values(xnn_binary_add, xnn_binary_subtract, xnn_binary_multiply, + xnn_binary_divide, xnn_binary_maximum, xnn_binary_minimum, + xnn_binary_copysign, xnn_binary_squared_difference, + xnn_binary_prelu))), + [](const auto& info) { return info.param.Name(); });