Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton FP8 GEMM does not seem to work #21

Open
mgoin opened this issue May 2, 2024 · 10 comments
Open

Triton FP8 GEMM does not seem to work #21

mgoin opened this issue May 2, 2024 · 10 comments

Comments

@mgoin
Copy link

mgoin commented May 2, 2024

Hello @AdnanHoque , I am trying to recreate the results from the blog Accelerating Llama3 FP8 Inference with Triton Kernels. I haven't been able to get the splitk_gemm_fp8.py kernels to work properly as they seem to produce NaNs and Infs. I am using an H100 80GB (Driver 535.129.03, CUDA 12.2) with PyTorch 2.3.0. Do you have an example of the benchmark or accuracy eval used?

Here is the example and output demonstrating my issue.

Output:

python triton_fp8.py 
y_torch: tensor([[  19.7344,  135.1250,  -19.6719,  ..., -128.6250,  -62.2188,
          -49.6562]], device='cuda:0', dtype=torch.float16)
y_triton: tensor([[nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16)
y_fp16: tensor([[  19.7969,  138.8750,  -21.3438,  ..., -129.2500,  -59.4062,
          -51.9688]], device='cuda:0', dtype=torch.float16)
fp16 vs torch cos_sim: tensor(0.9990, device='cuda:0', dtype=torch.float16)
fp16 vs triton cos_sim: tensor(nan, device='cuda:0', dtype=torch.float16)

Script used:

import torch
import triton
import triton.language as tl
import time
import os
os.environ['ENABLE_TMA'] = '1'

@triton.jit
def grouped_launch(pid,
                m, n,
                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):

    grid_m = tl.cdiv(m, block_m)
    grid_n = tl.cdiv(n, block_n)

    width = group_m * grid_n
    group_id = pid // width
    group_size = tl.minimum(grid_m - group_id * group_m, group_m)

    pid_m = group_id * group_m + (pid % group_size)
    pid_n = (pid % width) // group_size

    return pid_m, pid_n


@triton.jit()
def col_major(pid,
              m, n,
              block_m: tl.constexpr, block_n: tl.constexpr):

    grid_m = tl.cdiv(m, block_m)

    pid_m = pid % grid_m
    pid_n = pid // grid_m

    return pid_m, pid_n


@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
            stride_am, stride_ak,
            stride_bk, stride_bn,
            stride_cm, stride_cn,
            m, n, k,
            block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
            split_k: tl.constexpr, group_m: tl.constexpr):

    pid = tl.program_id(0)
    pid_k = tl.program_id(1)
    grid_k = tl.cdiv(k, block_k*split_k)

    pid_m, pid_n = grouped_launch(pid,
                                  m, n,
                                  block_m, block_n, group_m)

    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)
    offs_k = pid_k*block_k + tl.arange(0, block_k)

    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)

    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)


    acc = tl.zeros((block_m, block_n), dtype=tl.float32)
    for k_ in range(0, grid_k):

        k_remaining = k - k_ * (block_k * split_k)

        a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)

        acc = tl.dot(a, b, acc, out_dtype=tl.float32)

        a_ptrs += block_k * split_k * stride_ak
        b_ptrs += block_k * split_k * stride_bk

    acc.to(tl.float16)

    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)

    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]

    tl.atomic_add(c_ptrs, acc, mask=mask)

def gemm_split_k(a, b):

    m, k = a.shape
    _, n = b.shape

    block_m = 64
    block_n = 64
    block_k = 512
    num_stages = 3
    num_warps = 8
    split_k = 4
    group_m = 8

    total_blocks_m = triton.cdiv(m, block_m)
    total_blocks_n = triton.cdiv(n, block_n)
    total_programs_mn = total_blocks_m * total_blocks_n
    total_programs_k = split_k

    grid = (total_programs_mn, total_programs_k)

    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
    k = gemm_split_k_kernel[grid](a, b, c,
                              a.stride(0), a.stride(1),
                              b.stride(0), b.stride(1),
                              c.stride(0), c.stride(1),
                              m, n, k,
                              block_m, block_n, block_k,
                              split_k, group_m, num_stages=num_stages, num_warps=num_warps)

    return c


def to_float8(x, dtype=torch.float8_e4m3fn):
    finfo = torch.finfo(dtype)
    scale = finfo.max / x.abs().max().clamp(min=1e-12)
    x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
    return x_scl_sat.to(dtype), scale.float().reciprocal()

dtype = torch.float16
qdtype = torch.float8_e4m3fn
m = 1
n = 8096
k = 8096

# create test inputs
x = torch.randn((m, k), dtype=dtype, device='cuda')
w = torch.randn((n, k), dtype=dtype, device='cuda')

x_fp8, x_inv_s = to_float8(x, dtype=qdtype)
w_fp8, w_inv_s = to_float8(w, dtype=qdtype)

y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
y_triton = gemm_split_k(x_fp8, w_fp8)
y_fp16 = torch.nn.functional.linear(x, w)

print("y_torch:", y_torch)
print("y_triton:", y_triton)
print("y_fp16:", y_fp16)

print("fp16 vs torch cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_torch.reshape(-1), dim=0))
print("fp16 vs triton cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_triton.reshape(-1), dim=0))
@AdnanHoque
Copy link
Contributor

Hello! Which Triton version are you using?

@AdnanHoque
Copy link
Contributor

Can you try to build triton from source? Latest main commit should be fine.

We expect that N and K dim should be multiples of the block size. Try with N=K=8192. I'll share an example usage script soon.

@mgoin
Copy link
Author

mgoin commented May 2, 2024

I am using torch==2.3.0 and triton==2.3.0, which is the default for the latest non-nightly torch. EDIT: I will try the nightly

I did mean to use 8192, not 8096, so thanks for that correction. However I still see NaNs after making that change.

I did another try doing M=N=K=512 and this did give some real numbers between the NaNs

Running with M=512, N=512, K=512
y_torch: tensor([[-15.2812, -14.4297,  -6.2031,  ...,   4.1797, -20.0000,   8.0625],
        [ 70.8125, -12.3281,  -6.3125,  ..., -20.4062, -11.0312,  13.5625],
        [-16.0781, -17.8906,  28.0469,  ..., -54.7500,   3.0391, -21.0781],
        ...,
        [-27.1250,  -7.9062, -10.9375,  ...,  29.9219,  30.6250,   1.7432],
        [ 15.5312, -29.6719,  15.0703,  ..., -41.6875, -36.7188,  41.8125],
        [  7.0078,  21.3906, -12.7578,  ..., -50.5938,  26.6094,  42.5625]],
       device='cuda:0', dtype=torch.float16)
y_triton: tensor([[   -inf,    -inf,    -inf,  ...,     inf,     inf,  11288.],
        [    inf,     inf,    -inf,  ...,   3748.,    -inf,     inf],
        [-32608., -47968., -40800.,  ..., -37184.,     inf, -47424.],
        ...,
        [ 33568.,    -inf,    -inf,  ...,    -inf, -54528., -45696.],
        [   -inf,     inf, -32992.,  ...,     inf,     inf,     inf],
        [   -inf,  34592.,     inf,  ...,     inf,  -4096., -24720.]],
       device='cuda:0', dtype=torch.float16)
y_fp16: tensor([[-15.3984, -14.9141,  -5.8555,  ...,   4.6094, -22.0469,   6.8867],
        [ 71.1875, -12.4531,  -6.8594,  ..., -19.3438, -11.3984,  13.5078],
        [-15.0234, -17.7656,  27.5625,  ..., -54.5938,   2.2539, -22.8594],
        ...,
        [-27.7031,  -7.2266, -11.0703,  ...,  30.2812,  29.4531,   1.6523],
        [ 14.7812, -30.0625,  14.6328,  ..., -42.6562, -35.9375,  43.0000],
        [  7.7227,  21.9375, -11.2109,  ..., -50.2812,  26.8750,  43.1875]],
       device='cuda:0', dtype=torch.float16)
fp16 vs torch cos_sim: tensor(0.9990, device='cuda:0', dtype=torch.float16)
fp16 vs triton cos_sim: tensor(nan, device='cuda:0', dtype=torch.float16)

@mgoin
Copy link
Author

mgoin commented May 2, 2024

Unfortunately I got the same NaN results using a fresh install of pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121, which resulted in these versions being installed

pytorch-triton           3.0.0+45fff310c8
torch                    2.4.0.dev20240502+cu121

@AdnanHoque
Copy link
Contributor

AdnanHoque commented May 2, 2024

Follow the instructions to build Triton from source here: https://github.com/openai/triton?tab=readme-ov-file#install-from-source

The results look pretty close on my end, with some margin of error expected because of the downcast.

import torch
import triton
import triton.language as tl
import time
import os
os.environ['ENABLE_TMA'] = '1'

@triton.jit
def grouped_launch(pid,
                m, n,
                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):
    
    grid_m = tl.cdiv(m, block_m)
    grid_n = tl.cdiv(n, block_n)

    width = group_m * grid_n
    group_id = pid // width
    group_size = tl.minimum(grid_m - group_id * group_m, group_m)

    pid_m = group_id * group_m + (pid % group_size)
    pid_n = (pid % width) // group_size

    return pid_m, pid_n



@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
            stride_am, stride_ak,
            stride_bk, stride_bn,
            stride_cm, stride_cn,
            m, n, k,
            block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
            split_k: tl.constexpr, group_m: tl.constexpr):
    
    pid = tl.program_id(0)
    pid_k = tl.program_id(1)
    grid_k = tl.cdiv(k, block_k*split_k)

    pid_m, pid_n = grouped_launch(pid,
                                  m, n,
                                  block_m, block_n, group_m)
    
    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)
    offs_k = pid_k*block_k + tl.arange(0, block_k)

    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)

    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)


    acc = tl.zeros((block_m, block_n), dtype=tl.float32)
    for k_ in range(0, grid_k):
        
        k_remaining = k - k_ * (block_k * split_k)

        a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)

        acc = tl.dot(a, b, acc, out_dtype=tl.float32)

        a_ptrs += block_k * split_k * stride_ak
        b_ptrs += block_k * split_k * stride_bk

    acc.to(tl.float16)

    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)
    
    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]
    
    tl.atomic_add(c_ptrs, acc, mask=mask)

def gemm_split_k(a, b):

    m, k = a.shape
    _, n = b.shape
    
    block_m = 64
    block_n = 64
    block_k = 512
    num_stages = 3
    num_warps = 8
    split_k = 4
    group_m = 8

    total_blocks_m = triton.cdiv(m, block_m)
    total_blocks_n = triton.cdiv(n, block_n)
    total_programs_mn = total_blocks_m * total_blocks_n
    total_programs_k = split_k
    
    grid = (total_programs_mn, total_programs_k)
    
    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
    k = gemm_split_k_kernel[grid](a, b, c,
                              a.stride(0), a.stride(1),
                              b.stride(0), b.stride(1),
                              c.stride(0), c.stride(1),
                              m, n, k,
                              block_m, block_n, block_k,
                              split_k, group_m, num_stages=num_stages, num_warps=num_warps)

    return c


if __name__ == '__main__':
    
    torch.cuda.manual_seed(0)
    
    m = 16
    k = 8192
    n = 8192

    a = torch.randn((m, k), device="cuda", dtype=torch.float16)
    b = torch.randn((k, n), device="cuda", dtype=torch.float16)
    a = a.to(torch.float8_e4m3fn)
    
    # pre-transpose b for efficiency.
    b = b.T
    b = b.to(torch.float8_e4m3fn)

    triton_output = gemm_split_k(a, b)
    torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
    print(f"triton_output_with_fp8_inputs={triton_output}")
    print(f"torch_output={torch_output}")
   
>>>

triton_output_with_fp8_inputs=tensor([[  -4.2812,   31.1094,  -75.3750,  ...,  -32.0000,   57.0625,
          -37.2812],
        [  63.8125,  -54.1250,  -37.6875,  ...,   54.9375,   -5.5312,
           24.8906],
        [  10.1250,  117.5000,    9.4688,  ...,  -15.3906,   89.5000,
            8.0781],
        ...,
        [  17.4688,   58.4375, -118.2500,  ...,  -53.0625,  143.2500,
          -62.5000],
        [  61.7188,  101.3125,   54.8750,  ...,  100.0000,   -2.4785,
          -69.1250],
        [ -84.3750,  -44.5312,  -86.8750,  ...,  -57.8750,  -95.7500,
           71.3125]], device='cuda:0', dtype=torch.float16)
torch_output=tensor([[  -4.2656,   31.2656,  -75.2500,  ...,  -32.0312,   57.0938,
          -37.3438],
        [  63.7188,  -54.1875,  -37.6875,  ...,   55.0625,   -5.5078,
           24.8750],
        [  10.1250,  117.3125,    9.5469,  ...,  -15.3125,   89.4375,
            8.1172],
        ...,
        [  17.7188,   58.4688, -118.2500,  ...,  -53.1562,  143.2500,
          -62.5000],
        [  61.7812,  101.3750,   54.9062,  ...,  100.0000,   -2.4609,
          -69.1250],
        [ -84.5000,  -44.5312,  -86.8125,  ...,  -58.0000,  -95.9375,
           71.2500]], device='cuda:0', dtype=torch.float16)

@mgoin
Copy link
Author

mgoin commented May 2, 2024

Using your script I can get the same correct result on triton==2.3.0, so it seems building triton from source is not important. Hopefully it is just an issue with input strides/transposing, looking into this.

@mgoin
Copy link
Author

mgoin commented May 2, 2024

It seems that the issue was just the scaling. If I replace my scaling function to_float8() used for generating per-tensor scales (needed for torch._scaled_mm and generally for better accuracy) with just .to(torch.float8_e4m3fn) like in your script, then I get proper output. Thanks for the help in debugging this.

It would be great if you could add support for per-tensor scaling as it is common and what we are supporting in vLLM.

For other folks who may stumble across this, here is my updated output and script.

Output:

Running with M=16, N=8192, K=8192
y_torch: tensor([[  19.3906,  -43.7188,   30.7188,  ...,  -63.3125,   32.9062,
           16.8750],
        [ -24.6406,   16.6562,   23.1875,  ...,   29.7500,   33.2188,
          210.7500],
        [ -59.0000,  -50.8438, -147.7500,  ...,   95.1875,  -36.2188,
         -110.2500],
        ...,
        [  89.1875,   85.7500, -121.5625,  ...,   18.7656,    5.5312,
         -128.7500],
        [ -38.2812,  173.8750, -144.7500,  ...,  -53.8750,   14.0078,
           62.0000],
        [ 174.0000,  -58.0000,  -57.7812,  ...,  111.6250,  -75.3750,
         -135.0000]], device='cuda:0', dtype=torch.float16)
y_triton: tensor([[  25.6875,  -47.6562,   19.9844,  ...,  -55.6250,   28.8438,
           17.3594],
        [ -25.2812,   15.5938,   16.7188,  ...,   26.0625,   29.9531,
          209.2500],
        [ -59.9062,  -54.6250, -148.1250,  ...,  102.0000,  -32.7500,
         -114.4375],
        ...,
        [  93.0625,   82.5000, -122.0000,  ...,   24.6250,    6.4375,
         -128.7500],
        [ -42.8125,  168.5000, -140.8750,  ...,  -52.7500,   11.7188,
           63.4375],
        [ 181.7500,  -54.2500,  -52.5625,  ...,  110.1875,  -80.5000,
         -138.6250]], device='cuda:0', dtype=torch.float16)
y_fp16: tensor([[  24.0469,  -46.9375,   27.9219,  ...,  -61.7188,   28.3750,
           19.2344],
        [ -25.3594,   14.8750,   21.5625,  ...,   23.1094,   30.6406,
          208.5000],
        [ -60.1250,  -55.7500, -146.5000,  ...,   97.6250,  -36.5625,
         -110.6250],
        ...,
        [  86.7500,   81.5625, -122.1875,  ...,   18.4062,    5.9336,
         -129.6250],
        [ -41.8750,  166.7500, -142.5000,  ...,  -51.7188,   13.3203,
           61.6875],
        [ 179.1250,  -54.0938,  -56.9688,  ...,  112.3125,  -79.8125,
         -139.5000]], device='cuda:0', dtype=torch.float16)
fp16 vs torch cos_sim: tensor(0.9995, device='cuda:0', dtype=torch.float16)
fp16 vs triton cos_sim: tensor(0.9990, device='cuda:0', dtype=torch.float16)

Script:

import torch
import triton
import triton.language as tl
import time
import os
os.environ['ENABLE_TMA'] = '1'

@triton.jit
def grouped_launch(pid,
                m, n,
                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):

    grid_m = tl.cdiv(m, block_m)
    grid_n = tl.cdiv(n, block_n)

    width = group_m * grid_n
    group_id = pid // width
    group_size = tl.minimum(grid_m - group_id * group_m, group_m)

    pid_m = group_id * group_m + (pid % group_size)
    pid_n = (pid % width) // group_size

    return pid_m, pid_n


@triton.jit()
def col_major(pid,
              m, n,
              block_m: tl.constexpr, block_n: tl.constexpr):

    grid_m = tl.cdiv(m, block_m)

    pid_m = pid % grid_m
    pid_n = pid // grid_m

    return pid_m, pid_n


@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
            stride_am, stride_ak,
            stride_bk, stride_bn,
            stride_cm, stride_cn,
            m, n, k,
            block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
            split_k: tl.constexpr, group_m: tl.constexpr):

    pid = tl.program_id(0)
    pid_k = tl.program_id(1)
    grid_k = tl.cdiv(k, block_k*split_k)

    pid_m, pid_n = grouped_launch(pid,
                                  m, n,
                                  block_m, block_n, group_m)

    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)
    offs_k = pid_k*block_k + tl.arange(0, block_k)

    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)

    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)


    acc = tl.zeros((block_m, block_n), dtype=tl.float32)
    for k_ in range(0, grid_k):

        k_remaining = k - k_ * (block_k * split_k)

        a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)

        acc = tl.dot(a, b, acc, out_dtype=tl.float32)

        a_ptrs += block_k * split_k * stride_ak
        b_ptrs += block_k * split_k * stride_bk

    acc.to(tl.float16)

    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)

    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]

    tl.atomic_add(c_ptrs, acc, mask=mask)

def gemm_split_k(a, b):

    m, k = a.shape
    _, n = b.shape

    block_m = 64
    block_n = 64
    block_k = 512
    num_stages = 3
    num_warps = 8
    split_k = 4
    group_m = 8

    total_blocks_m = triton.cdiv(m, block_m)
    total_blocks_n = triton.cdiv(n, block_n)
    total_programs_mn = total_blocks_m * total_blocks_n
    total_programs_k = split_k

    grid = (total_programs_mn, total_programs_k)

    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
    k = gemm_split_k_kernel[grid](a, b, c,
                              a.stride(0), a.stride(1),
                              b.stride(0), b.stride(1),
                              c.stride(0), c.stride(1),
                              m, n, k,
                              block_m, block_n, block_k,
                              split_k, group_m, num_stages=num_stages, num_warps=num_warps)

    return c


def to_float8(x, dtype=torch.float8_e4m3fn):
    finfo = torch.finfo(dtype)
    scale = finfo.max / x.abs().max().clamp(min=1e-12)
    x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
    return x_scl_sat.to(dtype), scale.float().reciprocal()

dtype = torch.float16
qdtype = torch.float8_e4m3fn
m = 16
n = 8192
k = 8192

print(f"Running with M={m}, N={n}, K={k}")

# create test inputs
x = torch.randn((m, k), dtype=dtype, device='cuda')
w = torch.randn((k, n), dtype=dtype, device='cuda')

x_fp8_scaled, x_inv_s = to_float8(x, dtype=qdtype)
w_fp8_scaled, w_inv_s = to_float8(w, dtype=qdtype)

x_fp8 = x.to(qdtype)
w_fp8 = w.T.to(qdtype)

y_torch, _ = torch._scaled_mm(x_fp8_scaled, w_fp8_scaled.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
y_triton = gemm_split_k(x_fp8, w_fp8)
y_fp16 = torch.nn.functional.linear(x, w)

print("y_torch:", y_torch)
print("y_triton:", y_triton)
print("y_fp16:", y_fp16)

print("fp16 vs torch cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_torch.reshape(-1), dim=0))
print("fp16 vs triton cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_triton.reshape(-1), dim=0))

@mgoin
Copy link
Author

mgoin commented May 2, 2024

While I have you here @AdnanHoque , could you possibly share your setup for tuning the triton kernel?

After carefully tuning the other relevant hyperparameters for our kernel such as tile sizes, number of warps and the number of pipeline stages to Llama3-70B problem sizes we were able to produce up to 1.94x speedup over the Triton base implementation.

@AdnanHoque
Copy link
Contributor

Hey thanks for the suggestion! Try this script for per tensor scale support. Thanks @cyang49 for getting to this so quickly. We'll push this into main soon.

We haven't done any performance analysis yet, but since the scaling is done in SRAM this shouldn't add too much overhead:

Code:

import torch
import triton
import triton.language as tl
import time
import os
os.environ['ENABLE_TMA'] = '1'

@triton.jit
def grouped_launch(pid,
                m, n,
                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):
    
    grid_m = tl.cdiv(m, block_m)
    grid_n = tl.cdiv(n, block_n)

    width = group_m * grid_n
    group_id = pid // width
    group_size = tl.minimum(grid_m - group_id * group_m, group_m)

    pid_m = group_id * group_m + (pid % group_size)
    pid_n = (pid % width) // group_size

    return pid_m, pid_n

@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
            stride_am, stride_ak,
            stride_bk, stride_bn,
            stride_cm, stride_cn,
            scale_a, scale_b,
            m, n, k,
            block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
            split_k: tl.constexpr, group_m: tl.constexpr):
    
    pid = tl.program_id(0)
    pid_k = tl.program_id(1)
    grid_k = tl.cdiv(k, block_k*split_k)

    pid_m, pid_n = grouped_launch(pid,
                                  m, n,
                                  block_m, block_n, group_m)

    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)
    offs_k = pid_k*block_k + tl.arange(0, block_k)

    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)

    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)


    acc = tl.zeros((block_m, block_n), dtype=tl.float32)
    for k_ in range(0, grid_k):
        
        k_remaining = k - k_ * (block_k * split_k)

        a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)

        acc = tl.dot(a, b, acc, out_dtype=tl.float32)

        a_ptrs += block_k * split_k * stride_ak
        b_ptrs += block_k * split_k * stride_bk
    
    acc = scale_a * scale_b * acc
    acc.to(tl.float16)

    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)
    
    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]
    
    tl.atomic_add(c_ptrs, acc, mask=mask)

def gemm_split_k(a, b, scale_a:float=1.0, scale_b:float=1.0):
    assert a.shape[1] == b.shape[0]
    m, k = a.shape
    _, n = b.shape
    
    block_m = 64
    block_n = 64
    block_k = 512
    num_stages = 3
    num_warps = 8
    split_k = 4
    group_m = 8

    total_blocks_m = triton.cdiv(m, block_m)
    total_blocks_n = triton.cdiv(n, block_n)
    total_programs_mn = total_blocks_m * total_blocks_n
    total_programs_k = split_k
    
    grid = (total_programs_mn, total_programs_k)

    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
    k = gemm_split_k_kernel[grid](a, b, c,
                              a.stride(0), a.stride(1),
                              b.stride(0), b.stride(1),
                              c.stride(0), c.stride(1),
                              scale_a, scale_b,                              
                              m, n, k,
                              block_m, block_n, block_k,
                              split_k, group_m, num_stages=num_stages, num_warps=num_warps)

    return c


def to_float8(x, dtype=torch.float8_e4m3fn):
    finfo = torch.finfo(dtype)
    scale = finfo.max / x.abs().max().clamp(min=1e-12)
    x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
    return x_scl_sat.to(dtype), scale.float().reciprocal()

dtype = torch.float16
qdtype = torch.float8_e4m3fn

torch.cuda.manual_seed(0)

m = 64
n = 4096
k = 4096

# create test inputs
x = torch.randn((m, k), dtype=dtype, device='cuda')
w = torch.randn((n, k), dtype=dtype, device='cuda')

x_fp8, x_inv_s = to_float8(x, dtype=qdtype)
w_fp8, w_inv_s = to_float8(w, dtype=qdtype)

y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
y_triton = gemm_split_k(x_fp8, w_fp8.t(), scale_a=x_inv_s.item(), scale_b=w_inv_s.item())
y_fp16 = torch.nn.functional.linear(x, w)

print("y_torch:", y_torch)
print("y_triton:", y_triton)
print("y_fp16:", y_fp16)

print("fp16 vs torch cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_torch.reshape(-1), dim=0))
print("fp16 vs triton cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_triton.reshape(-1), dim=0))
y_torch: tensor([[  51.3125,  -48.5312,  -18.3906,  ...,  -42.6562,  -54.5938,
         129.0000],
       [ -34.3125,  -60.8750,   25.5469,  ...,   53.0312,   77.7500,
         -21.8750],
       [  14.8750,   53.6875,  -19.1875,  ...,   -3.6992,   64.8750,
         102.0625],
       ...,
       [ 163.2500,  -10.9375,   33.8438,  ...,  -44.3438,   -1.5117,
          64.6250],
       [ -15.1562,   -2.1172,   14.7812,  ..., -122.5625,  -42.7500,
          29.0469],
       [   1.1836,   55.1875,   68.0000,  ..., -123.8125,   38.1250,
         -48.8750]], device='cuda:0', dtype=torch.float16)
y_triton: tensor([[  51.3125,  -48.5000,  -18.5000,  ...,  -42.6250,  -54.5625,
         128.8750],
       [ -34.3125,  -60.9062,   25.5625,  ...,   53.0312,   77.7500,
         -21.9062],
       [  14.8906,   53.6875,  -19.2031,  ...,   -3.7031,   64.8750,
         102.0625],
       ...,
       [ 163.2500,  -10.9141,   33.8750,  ...,  -44.3125,   -1.4844,
          64.6875],
       [ -15.1719,   -2.1543,   14.7812,  ..., -122.6250,  -42.7500,
          29.0781],
       [   1.2188,   55.2500,   68.0000,  ..., -123.7500,   38.1250,
         -48.8750]], device='cuda:0', dtype=torch.float16)
y_fp16: tensor([[  52.4688,  -52.3438,  -21.5625,  ...,  -41.5000,  -54.0312,
         127.6250],
       [ -36.9375,  -59.6250,   25.1406,  ...,   54.1250,   74.6875,
         -19.4062],
       [  13.5859,   50.5625,  -23.1875,  ...,   -3.5859,   67.1250,
         101.7500],
       ...,
       [ 160.6250,  -10.4297,   37.5000,  ...,  -41.0625,   -1.8691,
          69.6250],
       [ -19.2344,   -0.9331,   15.5234,  ..., -123.1250,  -40.7812,
          31.0625],
       [  -0.4199,   55.9062,   67.0625,  ..., -123.2500,   36.1875,
         -49.5625]], device='cuda:0', dtype=torch.float16)
fp16 vs torch cos_sim: tensor(0.9995, device='cuda:0', dtype=torch.float16)
fp16 vs triton cos_sim: tensor(0.9995, device='cuda:0', dtype=torch.float16)

@xTayEx
Copy link

xTayEx commented Aug 26, 2024

Hi @AdnanHoque, I do some benchmark based on the script you posted in #21 (comment), but can't reproduce the numbers mentioned in the blog https://pytorch.org/blog/accelerating-llama3/. The triton kernel is slower than torch._scaled_mm. I do the benchmark on an H800 machine. For m=1, n=k=8192, the triton kernel takes 0.00017833 s, while torch._scaled_mm only takes 3.337860107421875e-05 s. Below is my benchmark script. Is there anything wrong with my script?

import torch
import triton
import triton.language as tl
import time
import os
from time import time
os.environ['ENABLE_TMA'] = '1'

@triton.jit
def grouped_launch(pid,
                m, n,
                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):
    
    grid_m = tl.cdiv(m, block_m)
    grid_n = tl.cdiv(n, block_n)

    width = group_m * grid_n
    group_id = pid // width
    group_size = tl.minimum(grid_m - group_id * group_m, group_m)

    pid_m = group_id * group_m + (pid % group_size)
    pid_n = (pid % width) // group_size

    return pid_m, pid_n

@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
            stride_am, stride_ak,
            stride_bk, stride_bn,
            stride_cm, stride_cn,
            scale_a, scale_b,
            m, n, k,
            block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
            split_k: tl.constexpr, group_m: tl.constexpr):
    
    pid = tl.program_id(0)
    pid_k = tl.program_id(1)
    grid_k = tl.cdiv(k, block_k*split_k)

    pid_m, pid_n = grouped_launch(pid,
                                  m, n,
                                  block_m, block_n, group_m)

    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)
    offs_k = pid_k*block_k + tl.arange(0, block_k)

    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)

    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)


    acc = tl.zeros((block_m, block_n), dtype=tl.float32)
    for k_ in range(0, grid_k):
        
        k_remaining = k - k_ * (block_k * split_k)

        a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)

        acc = tl.dot(a, b, acc, out_dtype=tl.float32)

        a_ptrs += block_k * split_k * stride_ak
        b_ptrs += block_k * split_k * stride_bk
    
    acc = scale_a * scale_b * acc
    acc.to(tl.float16)

    offs_m = pid_m*block_m + tl.arange(0, block_m)
    offs_n = pid_n*block_n + tl.arange(0, block_n)
    
    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]
    
    tl.atomic_add(c_ptrs, acc, mask=mask)

def gemm_split_k(a, b, scale_a:float=1.0, scale_b:float=1.0):
    assert a.shape[1] == b.shape[0]
    m, k = a.shape
    _, n = b.shape
    
    block_m = 64
    block_n = 64
    block_k = 512
    num_stages = 3
    num_warps = 8
    split_k = 4
    group_m = 8

    total_blocks_m = triton.cdiv(m, block_m)
    total_blocks_n = triton.cdiv(n, block_n)
    total_programs_mn = total_blocks_m * total_blocks_n
    total_programs_k = split_k
    
    grid = (total_programs_mn, total_programs_k)

    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
    k = gemm_split_k_kernel[grid](a, b, c,
                              a.stride(0), a.stride(1),
                              b.stride(0), b.stride(1),
                              c.stride(0), c.stride(1),
                              scale_a, scale_b,                              
                              m, n, k,
                              block_m, block_n, block_k,
                              split_k, group_m, num_stages=num_stages, num_warps=num_warps)

    return c


def to_float8(x, dtype=torch.float8_e4m3fn):
    finfo = torch.finfo(dtype)
    scale = finfo.max / x.abs().max().clamp(min=1e-12)
    x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
    return x_scl_sat.to(dtype), scale.float().reciprocal()

dtype = torch.float16
qdtype = torch.float8_e4m3fn

torch.cuda.manual_seed(0)

m = 1
n = 8192
k = 8192

# create test inputs
x = torch.randn((m, k), dtype=dtype, device='cuda')
w = torch.randn((n, k), dtype=dtype, device='cuda')

x_fp8, x_inv_s = to_float8(x, dtype=qdtype)
w_fp8, w_inv_s = to_float8(w, dtype=qdtype)

for _ in range(10):
    y_triton = gemm_split_k(x_fp8, w_fp8.t(), scale_a=x_inv_s.item(), scale_b=w_inv_s.item())
    y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)

torch_start_time = time()
y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
torch_end_time = time()
print(f"torch duration: {torch_end_time - torch_start_time}")

triton_start_time = time()
y_triton = gemm_split_k(x_fp8, w_fp8.t(), scale_a=x_inv_s.item(), scale_b=w_inv_s.item())
triton_end_time = time()
print(f"triton duration: {triton_end_time - triton_start_time}")

y_fp16 = torch.nn.functional.linear(x, w)

#print("y_torch:", y_torch)
#print("y_triton:", y_triton)
#print("y_fp16:", y_fp16)

#print("fp16 vs torch cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_torch.reshape(-1), dim=0))
#print("fp16 vs triton cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_triton.reshape(-1), dim=0))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants