From 258a5d412ad5407efa7b029ebdba1b730ca8fd7e Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 25 Jul 2023 11:55:02 -0400 Subject: [PATCH] add some instances for conv/gemm --- .../aitemplate/utils/mk_ck_lib/generator.py | 168 ++++++++++++++---- 1 file changed, 133 insertions(+), 35 deletions(-) diff --git a/python/aitemplate/utils/mk_ck_lib/generator.py b/python/aitemplate/utils/mk_ck_lib/generator.py index e8f89f666..02364dd6f 100644 --- a/python/aitemplate/utils/mk_ck_lib/generator.py +++ b/python/aitemplate/utils/mk_ck_lib/generator.py @@ -53,6 +53,20 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o conv.GroupTileDesc(1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2), conv.GroupTileDesc(1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1), conv.GroupTileDesc(1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + conv.GroupTileDesc(1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1), + conv.GroupTileDesc(1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2), + conv.GroupTileDesc(1, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1), + conv.GroupTileDesc(1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1), + conv.GroupTileDesc(1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2), + conv.GroupTileDesc(1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2), + conv.GroupTileDesc(1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2), + conv.GroupTileDesc(1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1), + conv.GroupTileDesc(1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2), + conv.GroupTileDesc(1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1), + conv.GroupTileDesc(1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1), + conv.GroupTileDesc(1, 128, 32, 32, 64, 8, 8, 16, 16, 1, 2), + conv.GroupTileDesc(1, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2), + ] c_block_descriptions = [ @@ -68,17 +82,31 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o conv.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8), conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 4), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 4), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 4), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 4), ] block_descriptions = [] for t in tile_descriptions: block_transfer = -1 if t.block_size == 256: - block_transfer = [4, 64, 1] + block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1] if t.block_size == 128: - block_transfer = [4, 32, 1] + block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1] if t.block_size == 64: - block_transfer = [4, 16, 1] + block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1] assert ( block_transfer != -1 and "Cannot determine block_transfer_size with block_size " @@ -150,6 +178,21 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o conv.BlockTransferDesc([4, 4, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), conv.BlockTransferDesc([4, 2, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), conv.BlockTransferDesc([4, 2, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + + conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 16, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 16, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 16, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 16, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), @@ -541,6 +584,19 @@ def CreateGemmRCROperator(manifest): gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 64, 32, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(64, 32, 64, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 32, 32, 64, 8, 8, 32, 32, 1, 1), + gemm.TileDesc(256, 128, 64, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(128, 128, 64, 64, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 64, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 32, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(128, 32, 128, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 16, 16, 64, 8, 8, 16, 16, 1, 1), + gemm.TileDesc(64, 16, 16, 32, 8, 8, 16, 16, 1, 1), + gemm.TileDesc(128, 32, 32, 64, 8, 8, 16, 16, 1, 2), + gemm.TileDesc(256, 64, 64, 64, 8, 8, 16, 16, 2, 2), ] block_descriptions = [] @@ -548,18 +604,19 @@ def CreateGemmRCROperator(manifest): for t in tile_descriptions: block_transfer = -1 c_block_transfer = -1 + vec_c = 4 if t.m_per_xdl == 16 and t.n_per_xdl == 16 else 8 if t.block_size == 256: - block_transfer = [4, 64, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) + block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], vec_c) if t.block_size == 128: - block_transfer = [4, 32, 1] + block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1] if t.n_per_block == 128: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], vec_c) else: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], vec_c) if t.block_size == 64: - block_transfer = [4, 16, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8) + block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], vec_c) assert ( block_transfer != -1 @@ -653,6 +710,19 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op): gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 64, 32, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(64, 32, 64, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 32, 32, 64, 8, 8, 32, 32, 1, 1), + gemm.TileDesc(256, 128, 64, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(128, 128, 64, 64, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 64, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 32, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(128, 32, 128, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 16, 16, 64, 8, 8, 16, 16, 1, 1), + gemm.TileDesc(64, 16, 16, 32, 8, 8, 16, 16, 1, 1), + gemm.TileDesc(128, 32, 32, 64, 8, 8, 16, 16, 1, 2), + gemm.TileDesc(256, 64, 64, 64, 8, 8, 16, 16, 2, 2), ] block_descriptions = [] @@ -660,18 +730,19 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op): for t in tile_descriptions: block_transfer = -1 c_block_transfer = -1 + vec_c = 4 if t.m_per_xdl == 16 and t.n_per_xdl == 16 else 8 if t.block_size == 256: - block_transfer = [4, 64, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) + block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], vec_c) if t.block_size == 128: - block_transfer = [4, 32, 1] + block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1] if t.n_per_block == 128: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], vec_c) else: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], vec_c) if t.block_size == 64: - block_transfer = [4, 16, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8) + block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], vec_c) assert ( block_transfer != -1 @@ -723,10 +794,11 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op): for tile_desc, block_desc, c_block_desc in zip( tile_descriptions, block_descriptions, c_block_descriptions ): + vec_c_scale = 4 if tile_desc.m_per_xdl == 16 and tile_desc.n_per_xdl == 16 else 8 c_block_desc = copy.deepcopy(c_block_desc) c_block_desc.scalar_per_vector = 1 - c_block_desc.m_n_block_wave_per_xdl[1] //= 8 - c_block_desc.m_n_block_wave_per_xdl[-1] *= 8 + c_block_desc.m_n_block_wave_per_xdl[1] //= vec_c_scale + c_block_desc.m_n_block_wave_per_xdl[-1] *= vec_c_scale new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, @@ -1098,6 +1170,19 @@ def CreateGemmRCRm2n3PermOperator(manifest, c_element_op): gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 64, 32, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(64, 32, 64, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 32, 32, 64, 8, 8, 32, 32, 1, 1), + gemm.TileDesc(256, 128, 64, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(128, 128, 64, 64, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 64, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 32, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(128, 32, 128, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 16, 16, 64, 8, 8, 16, 16, 1, 1), + gemm.TileDesc(64, 16, 16, 32, 8, 8, 16, 16, 1, 1), + gemm.TileDesc(128, 32, 32, 64, 8, 8, 16, 16, 1, 2), + gemm.TileDesc(256, 64, 64, 64, 8, 8, 16, 16, 2, 2), ] block_descriptions = [] @@ -1105,19 +1190,19 @@ def CreateGemmRCRm2n3PermOperator(manifest, c_element_op): for t in tile_descriptions: block_transfer = -1 c_block_transfer = -1 + vec_c = 4 if t.m_per_xdl == 16 and t.n_per_xdl == 16 else 8 if t.block_size == 256: - block_transfer = [4, 64, 1] - # TODO:figure out the last dimension - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) + block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], vec_c) if t.block_size == 128: - block_transfer = [4, 32, 1] + block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1] if t.n_per_block == 128: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], vec_c) else: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], vec_c) if t.block_size == 64: - block_transfer = [4, 16, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8) + block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], vec_c) assert ( block_transfer != -1 @@ -1190,6 +1275,19 @@ def CreateGemmRCRm3n2PermOperator(manifest, c_element_op): gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 64, 32, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(64, 32, 64, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 32, 32, 64, 8, 8, 32, 32, 1, 1), + gemm.TileDesc(256, 128, 64, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(128, 128, 64, 64, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 64, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 32, 64, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(128, 32, 128, 64, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 16, 16, 64, 8, 8, 16, 16, 1, 1), + gemm.TileDesc(64, 16, 16, 32, 8, 8, 16, 16, 1, 1), + gemm.TileDesc(128, 32, 32, 64, 8, 8, 16, 16, 1, 2), + gemm.TileDesc(256, 64, 64, 64, 8, 8, 16, 16, 2, 2), ] block_descriptions = [] @@ -1197,19 +1295,19 @@ def CreateGemmRCRm3n2PermOperator(manifest, c_element_op): for t in tile_descriptions: block_transfer = -1 c_block_transfer = -1 + vec_c = 4 if t.m_per_xdl == 16 and t.n_per_xdl == 16 else 8 if t.block_size == 256: - block_transfer = [4, 64, 1] - # TODO:figure out the last dimension - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 1) + block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], vec_c) if t.block_size == 128: - block_transfer = [4, 32, 1] + block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1] if t.n_per_block == 128: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 1) + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], vec_c) else: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 1) + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], vec_c) if t.block_size == 64: - block_transfer = [4, 16, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 1) + block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], vec_c) assert ( block_transfer != -1