Skip to content

Commit

Permalink
add some instances for conv/gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
carlushuang committed Jul 25, 2023
1 parent 85e1f71 commit 258a5d4
Showing 1 changed file with 133 additions and 35 deletions.
168 changes: 133 additions & 35 deletions python/aitemplate/utils/mk_ck_lib/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o
conv.GroupTileDesc(1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2),
conv.GroupTileDesc(1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1),
conv.GroupTileDesc(1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2),
conv.GroupTileDesc(1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1),
conv.GroupTileDesc(1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2),
conv.GroupTileDesc(1, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1),
conv.GroupTileDesc(1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1),
conv.GroupTileDesc(1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2),
conv.GroupTileDesc(1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2),
conv.GroupTileDesc(1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2),
conv.GroupTileDesc(1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1),
conv.GroupTileDesc(1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2),
conv.GroupTileDesc(1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1),
conv.GroupTileDesc(1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1),
conv.GroupTileDesc(1, 128, 32, 32, 64, 8, 8, 16, 16, 1, 2),
conv.GroupTileDesc(1, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2),

]

c_block_descriptions = [
Expand All @@ -68,17 +82,31 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o
conv.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8),
conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8),
conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8),

conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8),
conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8),
conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8),
conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8),
conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8),
conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8),
conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8),
conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8),
conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8),
conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 4),
conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 4),
conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 4),
conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 4),
]

block_descriptions = []
for t in tile_descriptions:
block_transfer = -1
if t.block_size == 256:
block_transfer = [4, 64, 1]
block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1]
if t.block_size == 128:
block_transfer = [4, 32, 1]
block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1]
if t.block_size == 64:
block_transfer = [4, 16, 1]
block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1]
assert (
block_transfer != -1
and "Cannot determine block_transfer_size with block_size "
Expand Down Expand Up @@ -150,6 +178,21 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o
conv.BlockTransferDesc([4, 4, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([4, 2, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([4, 2, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),

conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 16, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 16, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 16, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 16, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 8, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),

conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
conv.BlockTransferDesc([2, 32, 4], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1),
Expand Down Expand Up @@ -541,25 +584,39 @@ def CreateGemmRCROperator(manifest):
gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 64, 32, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(64, 32, 64, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 32, 32, 64, 8, 8, 32, 32, 1, 1),
gemm.TileDesc(256, 128, 64, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(256, 64, 128, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(128, 128, 64, 64, 8, 8, 32, 32, 2, 2),
gemm.TileDesc(128, 64, 128, 64, 8, 8, 32, 32, 2, 2),
gemm.TileDesc(128, 128, 32, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(128, 32, 128, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 16, 16, 64, 8, 8, 16, 16, 1, 1),
gemm.TileDesc(64, 16, 16, 32, 8, 8, 16, 16, 1, 1),
gemm.TileDesc(128, 32, 32, 64, 8, 8, 16, 16, 1, 2),
gemm.TileDesc(256, 64, 64, 64, 8, 8, 16, 16, 2, 2),
]

block_descriptions = []
c_block_descriptions = []
for t in tile_descriptions:
block_transfer = -1
c_block_transfer = -1
vec_c = 4 if t.m_per_xdl == 16 and t.n_per_xdl == 16 else 8
if t.block_size == 256:
block_transfer = [4, 64, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8)
block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], vec_c)
if t.block_size == 128:
block_transfer = [4, 32, 1]
block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1]
if t.n_per_block == 128:
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8)
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], vec_c)
else:
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8)
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], vec_c)
if t.block_size == 64:
block_transfer = [4, 16, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8)
block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], vec_c)

assert (
block_transfer != -1
Expand Down Expand Up @@ -653,25 +710,39 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op):
gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 64, 32, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(64, 32, 64, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 32, 32, 64, 8, 8, 32, 32, 1, 1),
gemm.TileDesc(256, 128, 64, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(256, 64, 128, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(128, 128, 64, 64, 8, 8, 32, 32, 2, 2),
gemm.TileDesc(128, 64, 128, 64, 8, 8, 32, 32, 2, 2),
gemm.TileDesc(128, 128, 32, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(128, 32, 128, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 16, 16, 64, 8, 8, 16, 16, 1, 1),
gemm.TileDesc(64, 16, 16, 32, 8, 8, 16, 16, 1, 1),
gemm.TileDesc(128, 32, 32, 64, 8, 8, 16, 16, 1, 2),
gemm.TileDesc(256, 64, 64, 64, 8, 8, 16, 16, 2, 2),
]

block_descriptions = []
c_block_descriptions = []
for t in tile_descriptions:
block_transfer = -1
c_block_transfer = -1
vec_c = 4 if t.m_per_xdl == 16 and t.n_per_xdl == 16 else 8
if t.block_size == 256:
block_transfer = [4, 64, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8)
block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], vec_c)
if t.block_size == 128:
block_transfer = [4, 32, 1]
block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1]
if t.n_per_block == 128:
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8)
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], vec_c)
else:
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8)
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], vec_c)
if t.block_size == 64:
block_transfer = [4, 16, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8)
block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], vec_c)

assert (
block_transfer != -1
Expand Down Expand Up @@ -723,10 +794,11 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op):
for tile_desc, block_desc, c_block_desc in zip(
tile_descriptions, block_descriptions, c_block_descriptions
):
vec_c_scale = 4 if tile_desc.m_per_xdl == 16 and tile_desc.n_per_xdl == 16 else 8
c_block_desc = copy.deepcopy(c_block_desc)
c_block_desc.scalar_per_vector = 1
c_block_desc.m_n_block_wave_per_xdl[1] //= 8
c_block_desc.m_n_block_wave_per_xdl[-1] *= 8
c_block_desc.m_n_block_wave_per_xdl[1] //= vec_c_scale
c_block_desc.m_n_block_wave_per_xdl[-1] *= vec_c_scale
new_operation = gemm.GemmOperation(
operation_kind=operation_kind,
extra_kind=c_element_op,
Expand Down Expand Up @@ -1098,26 +1170,39 @@ def CreateGemmRCRm2n3PermOperator(manifest, c_element_op):
gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 64, 32, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(64, 32, 64, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 32, 32, 64, 8, 8, 32, 32, 1, 1),
gemm.TileDesc(256, 128, 64, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(256, 64, 128, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(128, 128, 64, 64, 8, 8, 32, 32, 2, 2),
gemm.TileDesc(128, 64, 128, 64, 8, 8, 32, 32, 2, 2),
gemm.TileDesc(128, 128, 32, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(128, 32, 128, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 16, 16, 64, 8, 8, 16, 16, 1, 1),
gemm.TileDesc(64, 16, 16, 32, 8, 8, 16, 16, 1, 1),
gemm.TileDesc(128, 32, 32, 64, 8, 8, 16, 16, 1, 2),
gemm.TileDesc(256, 64, 64, 64, 8, 8, 16, 16, 2, 2),
]

block_descriptions = []
c_block_descriptions = []
for t in tile_descriptions:
block_transfer = -1
c_block_transfer = -1
vec_c = 4 if t.m_per_xdl == 16 and t.n_per_xdl == 16 else 8
if t.block_size == 256:
block_transfer = [4, 64, 1]
# TODO:figure out the last dimension
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8)
block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], vec_c)
if t.block_size == 128:
block_transfer = [4, 32, 1]
block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1]
if t.n_per_block == 128:
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8)
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], vec_c)
else:
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8)
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], vec_c)
if t.block_size == 64:
block_transfer = [4, 16, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8)
block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], vec_c)

assert (
block_transfer != -1
Expand Down Expand Up @@ -1190,26 +1275,39 @@ def CreateGemmRCRm3n2PermOperator(manifest, c_element_op):
gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 64, 32, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(64, 32, 64, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 32, 32, 64, 8, 8, 32, 32, 1, 1),
gemm.TileDesc(256, 128, 64, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(256, 64, 128, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(128, 128, 64, 64, 8, 8, 32, 32, 2, 2),
gemm.TileDesc(128, 64, 128, 64, 8, 8, 32, 32, 2, 2),
gemm.TileDesc(128, 128, 32, 64, 8, 8, 32, 32, 2, 1),
gemm.TileDesc(128, 32, 128, 64, 8, 8, 32, 32, 1, 2),
gemm.TileDesc(64, 16, 16, 64, 8, 8, 16, 16, 1, 1),
gemm.TileDesc(64, 16, 16, 32, 8, 8, 16, 16, 1, 1),
gemm.TileDesc(128, 32, 32, 64, 8, 8, 16, 16, 1, 2),
gemm.TileDesc(256, 64, 64, 64, 8, 8, 16, 16, 2, 2),
]

block_descriptions = []
c_block_descriptions = []
for t in tile_descriptions:
block_transfer = -1
c_block_transfer = -1
vec_c = 4 if t.m_per_xdl == 16 and t.n_per_xdl == 16 else 8
if t.block_size == 256:
block_transfer = [4, 64, 1]
# TODO:figure out the last dimension
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 1)
block_transfer = [8, 32, 1] if t.k_per_block == 64 else [4, 64, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], vec_c)
if t.block_size == 128:
block_transfer = [4, 32, 1]
block_transfer = [8, 16, 1] if t.k_per_block == 64 else [4, 32, 1]
if t.n_per_block == 128:
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 1)
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], vec_c)
else:
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 1)
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], vec_c)
if t.block_size == 64:
block_transfer = [4, 16, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 1)
block_transfer = [8, 8, 1] if t.k_per_block == 64 else [4, 16, 1]
c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], vec_c)

assert (
block_transfer != -1
Expand Down

0 comments on commit 258a5d4

Please sign in to comment.