From 07be198af8befbe674be842622fcd6e0cb37ff58 Mon Sep 17 00:00:00 2001 From: XNNPACK Team Date: Wed, 16 Oct 2024 14:09:12 -0700 Subject: [PATCH] Put back missing packing optimization. PiperOrigin-RevId: 686636654 --- src/packing.cc | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/packing.cc b/src/packing.cc index 39f81b97dbe..9ccf25ddd05 100644 --- a/src/packing.cc +++ b/src/packing.cc @@ -1025,21 +1025,32 @@ void xnn_pack_f32_gemm_gio_w( const size_t nr_block_size = min(nc - nr_block_start, nr); copy_bias(b, nr_block_start, nr_block_size, packed_weights); packed_weights += nr; + if (sr == 1 && kr == 1) { + for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start++) { + const size_t kc_idx = round_down_po2(kr_block_start, skr); - for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr); kr_block_start += kr) { - for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) { + if (kc_idx < kc) { + memcpy(packed_weights, &k[kc_idx * k_stride + nr_block_start], nr_block_size * sizeof(float)); + } + packed_weights += nr; + } + packed_weights = (float*) ((uintptr_t) packed_weights + extra_bytes); + } else { + for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr); kr_block_start += kr) { + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) { const size_t kc_begin = round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & (skr - 1)); - for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) { const size_t kc_idx = kc_begin + kr_block_offset; - if (kc_idx < kc) { - packed_weights[kr_block_offset] = k[kc_idx * k_stride + nr_block_start + nr_block_offset]; + if (kc_idx < kc) { + packed_weights[kr_block_offset] = k[kc_idx * k_stride + nr_block_start + nr_block_offset]; + } } + packed_weights += kr; } - packed_weights += kr; + packed_weights += (nr - nr_block_size) * kr; } - packed_weights += (nr - nr_block_size) * kr; + packed_weights = (float*) ((uintptr_t) packed_weights + extra_bytes); } - packed_weights = (float*) ((uintptr_t) packed_weights + extra_bytes); } k += nc * kc; if XNN_UNPREDICTABLE(b != nullptr) {