Skip to content

Commit

Permalink
optimize the marcro generator related items
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Sep 26, 2024
1 parent 0fb9535 commit 2c93dad
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 50 deletions.
5 changes: 4 additions & 1 deletion bitblas/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@
get_ldmatrix_offset, # noqa: F401
)

from .macro_generator import TensorCorePTXMacroGenerator # noqa: F401
from .macro_generator import (
TensorCoreIntrinEmitter, # noqa: F401
TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401
)
52 changes: 15 additions & 37 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
lift = convert


class TensorCorePTXMacroGenerator(object):
class TensorCoreIntrinEmitter(object):
"""
To eliminate Python syntax within TIR Macro.
"""
Expand Down Expand Up @@ -116,9 +116,8 @@ def _initialize_transform_kind(self, transform_kind_a, transform_kind_b):

assert transform_kind_b in [0, 3], "Currently only support 0 and 3"

@staticmethod
@T.macro
def LDMATRIX_A(
def _warp_ldmatrix_a(
inst,
A_local_buf,
A_shared_buf,
Expand All @@ -143,9 +142,8 @@ def LDMATRIX_A(
get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, inst.a_transposed),
)

@staticmethod
@T.macro
def LDMATRIX_B(
def _warp_ldmatrix_b(
inst,
B_local_buf,
B_shared_buf,
Expand Down Expand Up @@ -173,9 +171,8 @@ def LDMATRIX_B(
get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, inst.b_transposed),
)

@staticmethod
@T.macro
def MMA(inst, A_local_buf, B_local_buf, C_local_buf):
def _warp_mma(inst, A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
T.ptx_mma(
inst.accum_dtype,
Expand Down Expand Up @@ -216,9 +213,8 @@ def MMA(inst, A_local_buf, B_local_buf, C_local_buf):
# MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@staticmethod
@T.macro
def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings):
def _warp_stmatrix(inst, C_local_buf, C_shared_buf, thread_bindings):
tx = thread_bindings % inst.WARP_SIZE
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps
tz = (thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)) % inst.block_col_warps
Expand All @@ -231,38 +227,20 @@ def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings):
col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) +
j * inst.local_size_out + local_id]

# Allow GEMM from shared memory to local memory
@staticmethod
@T.macro
def GEMM_SS(inst, A_shared_buf, B_shared_buf, C_local_buf, thread_bindings):
# TODO(lei): alloc_buffer within the macro is not supported yet.
A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size_a),
inst.a_dtype,
scope="local")
B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size_b),
inst.b_dtype,
scope="local")
for ki in T.serial(0, (inst.chunk // inst.micro_size_k)):
inst.LDMATRIX_A(
inst,
A_local_buf,
A_shared_buf,
ki,
thread_bindings=thread_bindings,
)
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
return self._warp_ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk)

inst.LDMATRIX_B(
inst,
B_local_buf,
B_shared_buf,
ki,
thread_bindings=thread_bindings,
)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
return self._warp_ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk)

inst.MMA(inst, A_local_buf, B_local_buf, C_local_buf)
def mma(self, A_local_buf, B_local_buf, C_local_buf):
return self._warp_mma(self, A_local_buf, B_local_buf, C_local_buf)

def stmatrix(self, C_local_buf, C_shared_buf, thread_bindings):
return self._warp_stmatrix(self, C_local_buf, C_shared_buf, thread_bindings)


class TensorCorePTXMacroGeneratorWithLadderTransform(object):
class TensorCoreIntrinEmitterWithLadderTransform(object):
"""
To eliminate Python syntax within TIR Macro.
"""
Expand Down
23 changes: 11 additions & 12 deletions testing/python/tilelang/test_tilelang_macro_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tvm import tl as TL
import tvm.tl.language as T
from bitblas.tl.utils import get_swizzle_layout
from bitblas.tl.macro_generator import TensorCorePTXMacroGenerator
from bitblas.tl.macro_generator import TensorCoreIntrinEmitter


def make_swizzle_layout(shared_buf):
Expand Down Expand Up @@ -41,9 +41,10 @@ def tl_matmul(
"int8",
], "Currently only float16 and int8 are supported"
assert dtypeC in [
"float16",
"float32",
"int32",
], "Currently only float32 and int32 are supported"
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

Expand Down Expand Up @@ -83,7 +84,7 @@ def tl_matmul(
warp_cols = warp_col_tiles // micro_size_y

# MMA Wrapper to Auto Generate Code for MMA
ptx_macro_generator = TensorCorePTXMacroGenerator(
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=dtypeAB,
b_dtype=dtypeAB,
accum_dtype=accum_dtype,
Expand Down Expand Up @@ -112,6 +113,7 @@ def main(
C_local = T.alloc_fragment((warp_rows * warp_cols * local_size),
accum_dtype,
scope="local")

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
Expand All @@ -137,29 +139,26 @@ def main(
for ki in T.serial(0, (block_K // micro_size_k)):

# Load A into fragment
ptx_macro_generator.LDMATRIX_A(
ptx_macro_generator,
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)

# Load B into fragment
ptx_macro_generator.LDMATRIX_B(
ptx_macro_generator,
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)

# Perform Matrix Multiplication
ptx_macro_generator.MMA(ptx_macro_generator, A_local, B_local, C_local)
mma_emitter.mma(A_local, B_local, C_local)

# Perform STMatrix
ptx_macro_generator.STMATRIX(
ptx_macro_generator,
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
Expand Down Expand Up @@ -202,8 +201,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, dtypeC, accum_dtype):


def test_assert_tl_matmul_correctness():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float32", "float32")
assert_tl_matmul_correctness(32, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")


if __name__ == "__main__":
Expand Down

0 comments on commit 2c93dad

Please sign in to comment.