Skip to content

Commit

Permalink
Replaced GlobalAveragePooling with static_mean.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684367627
  • Loading branch information
Misha Gutman authored and xnnpack-bot committed Oct 10, 2024
1 parent b6b97f6 commit db7fd55
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/subgraph.c
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node*
case xnn_node_type_leaky_relu:
case xnn_node_type_static_mean:
if (subgraph->values[node->inputs[0]].shape.num_dims == 4) {
return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
} else {
xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
xnn_node_type_to_string(node->type));
Expand Down
30 changes: 26 additions & 4 deletions src/subgraph/global-average-pooling.c
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,18 @@ enum xnn_status xnn_define_global_average_pooling_1d(
uint32_t output_id,
uint32_t flags)
{
return define_global_average_pooling_nd(
subgraph, xnn_node_type_global_average_pooling_1d, output_min, output_max, input_id, output_id, flags);
const struct xnn_value* input_value = &subgraph->values[input_id];
size_t reduction_axes[XNN_MAX_TENSOR_DIMS];

if (input_value->layout == xnn_layout_type_nchw) {
reduction_axes[0] = input_value->shape.num_dims - 1;
} else {
reduction_axes[0] = input_value->shape.num_dims - 2;
}

return xnn_define_static_mean(
subgraph, 1, reduction_axes, input_id,
output_id, flags);
}

enum xnn_status xnn_define_global_average_pooling_2d(
Expand All @@ -437,6 +447,18 @@ enum xnn_status xnn_define_global_average_pooling_2d(
uint32_t output_id,
uint32_t flags)
{
return define_global_average_pooling_nd(
subgraph, xnn_node_type_global_average_pooling_2d, output_min, output_max, input_id, output_id, flags);
const struct xnn_value* input_value = &subgraph->values[input_id];
size_t reduction_axes[XNN_MAX_TENSOR_DIMS];

if (input_value->layout == xnn_layout_type_nchw) {
reduction_axes[0] = input_value->shape.num_dims - 2;
reduction_axes[1] = input_value->shape.num_dims - 1;
} else {
reduction_axes[0] = input_value->shape.num_dims - 3;
reduction_axes[1] = input_value->shape.num_dims - 2;
}

return xnn_define_static_mean(
subgraph, 2, reduction_axes, input_id,
output_id, flags);
}
20 changes: 6 additions & 14 deletions test/global-average-pooling-1d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ TEST_F(GlobalAveragePooling1DTestQS8, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_1d);
ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
ASSERT_EQ(node->activation.output_min, output_min);
ASSERT_EQ(node->activation.output_max, output_max);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -160,10 +157,7 @@ TEST_F(GlobalAveragePooling1DTestQU8, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_1d);
ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
ASSERT_EQ(node->activation.output_min, output_min);
ASSERT_EQ(node->activation.output_max, output_max);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -199,10 +193,7 @@ TEST_F(GlobalAveragePooling1DTestF16, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_1d);
ASSERT_EQ(node->compute_type, xnn_compute_type_fp16);
ASSERT_EQ(node->activation.output_min, output_min);
ASSERT_EQ(node->activation.output_max, output_max);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -238,10 +229,7 @@ TEST_F(GlobalAveragePooling1DTestF32, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_1d);
ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
ASSERT_EQ(node->activation.output_min, output_min);
ASSERT_EQ(node->activation.output_max, output_max);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -468,7 +456,9 @@ TEST_F(GlobalAveragePooling1DTestF16, matches_operator_api)
ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));

ASSERT_EQ(subgraph_output, operator_output);
for (size_t i = 0; i < subgraph_output.size(); ++i) {
ASSERT_NEAR(subgraph_output[i], operator_output[i], 1e-3);
}
}

TEST_F(GlobalAveragePooling1DTestF32, matches_operator_api)
Expand Down Expand Up @@ -538,7 +528,9 @@ TEST_F(GlobalAveragePooling1DTestF32, matches_operator_api)
ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));

ASSERT_EQ(subgraph_output, operator_output);
for (size_t i = 0; i < subgraph_output.size(); ++i) {
ASSERT_NEAR(subgraph_output[i], operator_output[i], 1e-6);
}
}

TEST_F(GlobalAveragePooling1DTestF32, reshape_output_no_keep_dims)
Expand Down
20 changes: 6 additions & 14 deletions test/global-average-pooling-2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ TEST_F(GlobalAveragePooling2DTestQS8, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_2d);
ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
ASSERT_EQ(node->activation.output_min, output_min);
ASSERT_EQ(node->activation.output_max, output_max);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -158,10 +155,7 @@ TEST_F(GlobalAveragePooling2DTestQU8, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_2d);
ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
ASSERT_EQ(node->activation.output_min, output_min);
ASSERT_EQ(node->activation.output_max, output_max);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -197,10 +191,7 @@ TEST_F(GlobalAveragePooling2DTestF16, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_2d);
ASSERT_EQ(node->compute_type, xnn_compute_type_fp16);
ASSERT_EQ(node->activation.output_min, output_min);
ASSERT_EQ(node->activation.output_max, output_max);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -236,10 +227,7 @@ TEST_F(GlobalAveragePooling2DTestF32, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_global_average_pooling_2d);
ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
ASSERT_EQ(node->activation.output_min, output_min);
ASSERT_EQ(node->activation.output_max, output_max);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -466,7 +454,9 @@ TEST_F(GlobalAveragePooling2DTestF16, matches_operator_api)
ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));

ASSERT_EQ(subgraph_output, operator_output);
for (size_t i = 0; i < subgraph_output.size(); ++i) {
ASSERT_NEAR(subgraph_output[i], operator_output[i], 1e-3);
}
}

TEST_F(GlobalAveragePooling2DTestF32, matches_operator_api)
Expand Down Expand Up @@ -536,7 +526,9 @@ TEST_F(GlobalAveragePooling2DTestF32, matches_operator_api)
ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));

ASSERT_EQ(subgraph_output, operator_output);
for (size_t i = 0; i < subgraph_output.size(); ++i) {
ASSERT_NEAR(subgraph_output[i], operator_output[i], 1e-6);
}
}

TEST_F(GlobalAveragePooling2DTestF32, reshape_output_no_keep_dims)
Expand Down

0 comments on commit db7fd55

Please sign in to comment.