Skip to content

Commit

Permalink
[CUTLASS] Add FP8 gemm kernels (#17408)
Browse files Browse the repository at this point in the history
This PR introduces the sm90a FP8 kernels from CUTLASS. These kernels
are helpful in the cases of small `M`, where cuBLAS has unoptimized
performance.
  • Loading branch information
MasterJH5574 authored Sep 25, 2024
1 parent 5648a8e commit 4e70e4a
Show file tree
Hide file tree
Showing 5 changed files with 349 additions and 15 deletions.
1 change: 1 addition & 0 deletions cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ if(USE_CUDA AND USE_CUTLASS)
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu)
endif()
if(TVM_CUTLASS_RUNTIME_SRCS)
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
Expand Down
6 changes: 4 additions & 2 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,13 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
&bias->data, sizeof(float*)));
}

if (scaleA != nullptr && scaleB != nullptr) {
if (scaleA != nullptr) {
auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&scaleA_data, sizeof(float*)));
}
if (scaleB != nullptr) {
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&scaleB_data, sizeof(float*)));
}
Expand Down
95 changes: 95 additions & 0 deletions src/runtime/contrib/cutlass/fp8_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <cuda_fp16.h>
#include <float.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include "../cublas/cublas_utils.h"
#include "gemm_runner.cuh"

#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

struct KernelTraitsM64 {
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _8, _1>;
};

namespace tvm {
namespace runtime {

template <typename ElementA, typename ElementB, typename ElementC>
void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray alpha,
NDArray out) {
// Workspace is used for storing device-side gemm arguments and cutlass internal workspace.
// Recommened size is 4MB.
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
CHECK_GE(x->ndim, 2);
CHECK_EQ(weight->ndim, 2);
CHECK_EQ(workspace->ndim, 1);
CHECK_GE(out->ndim, 2);
CHECK_EQ(alpha->dtype.code, kDLFloat);
CHECK_EQ(alpha->dtype.bits, 32);
CHECK_EQ(alpha->ndim, 1);
CHECK_EQ(alpha->shape[0], 1);
int64_t m = 1;
for (int i = 0; i < x->ndim - 1; ++i) {
m *= x->shape[i];
}
int64_t n = weight->shape[0];
CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight is supported now.";
int64_t k = x->shape[x->ndim - 1];
const float* beta = nullptr;
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
if (m <= 64) {
cutlass_gemm<KernelTraitsM64>(
static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<uint8_t*>(workspace->data), workspace->shape[0], m, n, k,
static_cast<float*>(alpha->data), beta, static_cast<ElementC*>(out->data), stream);
} else {
tvm::contrib::CuBlasLtThreadEntry* cublas_entry =
tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
tvm::contrib::CallCublasLt(cublas_entry->handle, stream, cublas_entry->matmul_pref_desc,
x.operator->(), weight.operator->(), nullptr, alpha.operator->(),
nullptr, out.operator->(), /*transa=*/false, /*transb=*/true,
cublas_entry->workspace_ptr, cublas_entry->workspace_size,
CUBLASLT_EPILOGUE_DEFAULT, std::nullopt);
}
}

TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16")
.set_body_typed(
tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e5m2_t, cutlass::half_t>);

TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16")
.set_body_typed(
tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e4m3_t, cutlass::half_t>);

TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16")
.set_body_typed(
tvm_cutlass_fp8_gemm<cutlass::float_e4m3_t, cutlass::float_e4m3_t, cutlass::half_t>);

} // namespace runtime
} // namespace tvm

#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
155 changes: 155 additions & 0 deletions src/runtime/contrib/cutlass/gemm_runner.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <fstream>
#include <iostream>
#include <sstream>
#include <variant>
#include <vector>

#include "../../cuda/cuda_common.h"

// clang-format off
#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on

#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
CHECK(error == cutlass::Status::kSuccess) \
<< "Got cutlass error: " << cutlassGetStatusString(error); \
}

using namespace cute;
using ProblemShape = Shape<int, int, int>; // <M, N, K>

template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC,
typename LayoutA = cutlass::layout::RowMajor,
typename LayoutB = cutlass::layout::ColumnMajor,
typename LayoutC = cutlass::layout::RowMajor>
struct CutlassGemmRunner {
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements
// (up to 16 bytes)

static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements
// (up to 16 bytes)

static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements
// (up to 16 bytes)

// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ScaleType = std::variant<ElementAccumulator, const ElementAccumulator*>;
using ArchTag =
cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = typename KernelTraits::TileShape;
using ClusterShape = typename KernelTraits::ClusterShape;
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue to launch

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;

void run_gemm(const ElementA* ptr_A, const ElementB* ptr_B, const ElementC* ptr_C,
ElementC* ptr_D, ProblemShape* problem_size, StrideA* stride_A, StrideB* stride_B,
StrideC* stride_C, StrideD* stride_D, uint8_t* workspace, int64_t workspace_size,
ScaleType alpha, ScaleType beta, cudaStream_t stream) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
*problem_size,
{ptr_A, *stride_A, ptr_B, *stride_B},
{{}, ptr_C, *stride_C, ptr_D, *stride_D},
// {epilogue_params, ptr_C, *stride_C, ptr_D, *stride_D},
hw_info};

ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type";
if (std::holds_alternative<ElementAccumulator>(alpha)) {
arguments.epilogue.thread.alpha = std::get<ElementAccumulator>(alpha);
arguments.epilogue.thread.beta = std::get<ElementAccumulator>(beta);
} else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
arguments.epilogue.thread.alpha_ptr = std::get<const ElementAccumulator*>(alpha);
arguments.epilogue.thread.beta_ptr = std::get<const ElementAccumulator*>(beta);
} else {
LOG(FATAL) << "Unsupported alpha and beta type";
throw;
}

Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
};

template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC>
void cutlass_gemm(ElementA* x, ElementB* weight, uint8_t* workspace, int64_t workspace_size,
int64_t m, int64_t n, int64_t k, std::variant<float, const float*> alpha,
std::variant<float, const float*> beta, ElementC* out, cudaStream_t stream) {
using Runner = CutlassGemmRunner<KernelTraits, ElementA, ElementB, ElementC>;
using StrideA = typename Runner::StrideA;
using StrideB = typename Runner::StrideB;
using StrideC = typename Runner::StrideC;

Runner runner;
StrideA stride_A = cute::make_stride(k, Int<1>{}, int64_t{0});
StrideB stride_B = cute::make_stride(k, Int<1>{}, int64_t{0});
StrideC stride_D = cute::make_stride(n, Int<1>{}, int64_t{0});
ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k)};
runner.run_gemm(x, weight, out, out, &problem_size, &stride_A, &stride_B, &stride_D, &stride_D,
workspace, workspace_size, alpha, beta, stream);
}
Loading

0 comments on commit 4e70e4a

Please sign in to comment.