From db7fd55890779ce0fa7381d235bb55ac55fb3bbd Mon Sep 17 00:00:00 2001 From: Misha Gutman Date: Thu, 10 Oct 2024 02:47:59 -0700 Subject: [PATCH] Replaced GlobalAveragePooling with static_mean. PiperOrigin-RevId: 684367627 --- src/subgraph.c | 2 +- src/subgraph/global-average-pooling.c | 30 +++++++++++++++++++++++---- test/global-average-pooling-1d.cc | 20 ++++++------------ test/global-average-pooling-2d.cc | 20 ++++++------------ 4 files changed, 39 insertions(+), 33 deletions(-) diff --git a/src/subgraph.c b/src/subgraph.c index 919609bd7e7..af5c781d274 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -517,7 +517,7 @@ uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* case xnn_node_type_leaky_relu: case xnn_node_type_static_mean: if (subgraph->values[node->inputs[0]].shape.num_dims == 4) { - return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW; + return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC; } else { xnn_log_info("Node %s inputs shape is incompatible with sparse inference", xnn_node_type_to_string(node->type)); diff --git a/src/subgraph/global-average-pooling.c b/src/subgraph/global-average-pooling.c index 4674b099306..4d84c73980f 100644 --- a/src/subgraph/global-average-pooling.c +++ b/src/subgraph/global-average-pooling.c @@ -425,8 +425,18 @@ enum xnn_status xnn_define_global_average_pooling_1d( uint32_t output_id, uint32_t flags) { - return define_global_average_pooling_nd( - subgraph, xnn_node_type_global_average_pooling_1d, output_min, output_max, input_id, output_id, flags); + const struct xnn_value* input_value = &subgraph->values[input_id]; + size_t reduction_axes[XNN_MAX_TENSOR_DIMS]; + + if (input_value->layout == xnn_layout_type_nchw) { + reduction_axes[0] = input_value->shape.num_dims - 1; + } else { + reduction_axes[0] = input_value->shape.num_dims - 2; + } + + return xnn_define_static_mean( + subgraph, 1, reduction_axes, input_id, + output_id, flags); } enum xnn_status xnn_define_global_average_pooling_2d( @@ -437,6 +447,18 @@ enum xnn_status xnn_define_global_average_pooling_2d( uint32_t output_id, uint32_t flags) { - return define_global_average_pooling_nd( - subgraph, xnn_node_type_global_average_pooling_2d, output_min, output_max, input_id, output_id, flags); + const struct xnn_value* input_value = &subgraph->values[input_id]; + size_t reduction_axes[XNN_MAX_TENSOR_DIMS]; + + if (input_value->layout == xnn_layout_type_nchw) { + reduction_axes[0] = input_value->shape.num_dims - 2; + reduction_axes[1] = input_value->shape.num_dims - 1; + } else { + reduction_axes[0] = input_value->shape.num_dims - 3; + reduction_axes[1] = input_value->shape.num_dims - 2; + } + + return xnn_define_static_mean( + subgraph, 2, reduction_axes, input_id, + output_id, flags); } diff --git a/test/global-average-pooling-1d.cc b/test/global-average-pooling-1d.cc index 4a45f86851d..35c6629fc74 100644 --- a/test/global-average-pooling-1d.cc +++ b/test/global-average-pooling-1d.cc @@ -121,10 +121,7 @@ TEST_F(GlobalAveragePooling1DTestQS8, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_1d); ASSERT_EQ(node->compute_type, xnn_compute_type_qs8); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); ASSERT_EQ(node->num_inputs, 1); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->num_outputs, 1); @@ -160,10 +157,7 @@ TEST_F(GlobalAveragePooling1DTestQU8, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_1d); ASSERT_EQ(node->compute_type, xnn_compute_type_qu8); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); ASSERT_EQ(node->num_inputs, 1); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->num_outputs, 1); @@ -199,10 +193,7 @@ TEST_F(GlobalAveragePooling1DTestF16, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_1d); ASSERT_EQ(node->compute_type, xnn_compute_type_fp16); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); ASSERT_EQ(node->num_inputs, 1); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->num_outputs, 1); @@ -238,10 +229,7 @@ TEST_F(GlobalAveragePooling1DTestF32, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_1d); ASSERT_EQ(node->compute_type, xnn_compute_type_fp32); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); ASSERT_EQ(node->num_inputs, 1); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->num_outputs, 1); @@ -468,7 +456,9 @@ TEST_F(GlobalAveragePooling1DTestF16, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - ASSERT_EQ(subgraph_output, operator_output); + for (size_t i = 0; i < subgraph_output.size(); ++i) { + ASSERT_NEAR(subgraph_output[i], operator_output[i], 1e-3); + } } TEST_F(GlobalAveragePooling1DTestF32, matches_operator_api) @@ -538,7 +528,9 @@ TEST_F(GlobalAveragePooling1DTestF32, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - ASSERT_EQ(subgraph_output, operator_output); + for (size_t i = 0; i < subgraph_output.size(); ++i) { + ASSERT_NEAR(subgraph_output[i], operator_output[i], 1e-6); + } } TEST_F(GlobalAveragePooling1DTestF32, reshape_output_no_keep_dims) diff --git a/test/global-average-pooling-2d.cc b/test/global-average-pooling-2d.cc index 786fcf47a33..fd1509b88cc 100644 --- a/test/global-average-pooling-2d.cc +++ b/test/global-average-pooling-2d.cc @@ -119,10 +119,7 @@ TEST_F(GlobalAveragePooling2DTestQS8, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_2d); ASSERT_EQ(node->compute_type, xnn_compute_type_qs8); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); ASSERT_EQ(node->num_inputs, 1); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->num_outputs, 1); @@ -158,10 +155,7 @@ TEST_F(GlobalAveragePooling2DTestQU8, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_2d); ASSERT_EQ(node->compute_type, xnn_compute_type_qu8); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); ASSERT_EQ(node->num_inputs, 1); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->num_outputs, 1); @@ -197,10 +191,7 @@ TEST_F(GlobalAveragePooling2DTestF16, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_2d); ASSERT_EQ(node->compute_type, xnn_compute_type_fp16); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); ASSERT_EQ(node->num_inputs, 1); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->num_outputs, 1); @@ -236,10 +227,7 @@ TEST_F(GlobalAveragePooling2DTestF32, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_2d); ASSERT_EQ(node->compute_type, xnn_compute_type_fp32); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); ASSERT_EQ(node->num_inputs, 1); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->num_outputs, 1); @@ -466,7 +454,9 @@ TEST_F(GlobalAveragePooling2DTestF16, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - ASSERT_EQ(subgraph_output, operator_output); + for (size_t i = 0; i < subgraph_output.size(); ++i) { + ASSERT_NEAR(subgraph_output[i], operator_output[i], 1e-3); + } } TEST_F(GlobalAveragePooling2DTestF32, matches_operator_api) @@ -536,7 +526,9 @@ TEST_F(GlobalAveragePooling2DTestF32, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - ASSERT_EQ(subgraph_output, operator_output); + for (size_t i = 0; i < subgraph_output.size(); ++i) { + ASSERT_NEAR(subgraph_output[i], operator_output[i], 1e-6); + } } TEST_F(GlobalAveragePooling2DTestF32, reshape_output_no_keep_dims)