Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Buddy GPU] Buddy GPU Gemm #395

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions Pipelines/BufferizeOpt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "Pipelines/BufferizeOpt.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "Utils/GemmCodegenUtils.h"
#include "Utils/PipelineUtils.h"

using namespace mlir;

void mlir::buddy::createBufferizeOptPipeline(OpPassManager &pm,
const BuddyBufferizeOptOptions &options) {
mlir::buddy::invokeOpPassPipelineBuilder(
[&](OpPassManager &pm) {
// OneShotBufferization not implement bufferize on funcOp's arguments on default
bufferization::OneShotBufferizationOptions bufferizeOptions;
bufferizeOptions.bufferizeFunctionBoundaries = true;
// bufferizeOptions.allowReturnAllocsFromLoops
pm.addNestedPass<func::FuncOp>(bufferization::createEmptyTensorEliminationPass());
pm.addPass(bufferization::createOneShotBufferizePass(bufferizeOptions));
pm.addNestedPass<func::FuncOp>(memref::createFoldMemRefAliasOpsPass());
addCleanUpPassPipeline(pm);
}, pm);
}

void mlir::buddy::registerBufferizeOptPassPipeline() {
PassPipelineRegistration<BuddyBufferizeOptOptions>(
"bufferize-opt",
"bufferize opt lowering tensor to memref",
createBufferizeOptPipeline
);
}
15 changes: 15 additions & 0 deletions Pipelines/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
add_subdirectory(GPU)

add_mlir_library(BuddyPipelines
LinalgTensorOpt.cpp
BufferizeOpt.cpp
LinalgMemrefOpt.cpp

LINK_LIBS PUBLIC
MLIRIR
BuddyGPUPipelines
BuddyTransformPasses
BuddyGPUPasses
BuddyGemmCodegenUtils
BuddyPipelineUtils
)
11 changes: 11 additions & 0 deletions Pipelines/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
add_mlir_library(BuddyGPUPipelines
GemmCodegenTransform.cpp

LINK_LIBS PUBLIC
MLIRIR
MLIRPDLDialect
MLIRTransformDialect
MLIRTransforms
BuddyTransformPasses
BuddyGemmCodegenUtils
)
266 changes: 266 additions & 0 deletions Pipelines/GPU/GemmCodegenTransform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
#include "Transform/Transforms/TransformInsertion.h"
#include "Pipelines/GPU/GemmCodegenTransform.h"
#include "Utils/GemmCodegenUtils.h"
#include "Utils/PipelineUtils.h"

#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/SmallSet.h"

#include <optional>

using namespace mlir;
using namespace mlir::buddy;

namespace {

void createAddGemmCodegenLoweringConfigTransformImpl(
OpPassManager &pm, const std::string &anchor, const std::string &prefix,
ArrayRef<int64_t> tileConfig, ArrayRef<int64_t> workGroup, int64_t stages) {

SmallVector<int64_t> vecTileConfig{tileConfig};
SmallVector<int64_t> vecWorkGroup{workGroup};

TransformInsertionConfig config;
config.funcAnchor = anchor;
config.matchPrefix = prefix;
// transform operation takes effect needed to have this op
config.opFilter = [=](Operation *op){
if (isLinalgMatmul(op)) {
return true;
}
return false;
};

// pdlV is a handle of op
config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, Value pdlV) {
auto tileConfigAttrs = b.getAttr<ArrayAttr>(llvm::to_vector(
llvm::map_range(vecTileConfig, [&](int64_t i) -> Attribute {
return b.getI64IntegerAttr(i);
})));
auto workgroupAttrs = b.getAttr<ArrayAttr>(llvm::to_vector(
llvm::map_range(vecWorkGroup, [&](int64_t i) -> Attribute {
return b.getI64IntegerAttr(i);
})));
auto stagesAttr = b.getI64IntegerAttr(stages);

auto func = b.create<transform::GetParentOp>(
pdlV.getType(), pdlV,
/* isolated_from_above */ true,
/* allow_empty_results */ false,
/* op_name */ b.getStringAttr(func::FuncOp::getOperationName()),
/* deduplicate */ false,
/* nth_parent */ 1);

Value tileConfigValue_M = b.create<transform::ParamConstantOp>(
transform::ParamType::get(b.getContext(), mlir::IntegerType::get(b.getContext(), 64)),
/* value */ tileConfigAttrs.getValue()[0]
);

Value tileConfigValue_N = b.create<transform::ParamConstantOp>(
transform::ParamType::get(b.getContext(), mlir::IntegerType::get(b.getContext(), 64)),
/* value */ tileConfigAttrs.getValue()[1]
);

Value tileConfigValue_K = b.create<transform::ParamConstantOp>(
// /* type */ pdl::AttributeType::get(b.getContext()),
transform::ParamType::get(b.getContext(), mlir::IntegerType::get(b.getContext(), 64)),
/* value */ tileConfigAttrs.getValue()[2]
);

Value workGroupValue_X = b.create<transform::ParamConstantOp>(
transform::ParamType::get(b.getContext(), mlir::IntegerType::get(b.getContext(), 64)),
/* value */ workgroupAttrs.getValue()[0]
);

Value workGroupValue_Y = b.create<transform::ParamConstantOp>(
transform::ParamType::get(b.getContext(), mlir::IntegerType::get(b.getContext(), 64)),
/* value */ workgroupAttrs.getValue()[1]
);

Value workGroupValue_Z = b.create<transform::ParamConstantOp>(
transform::ParamType::get(b.getContext(), mlir::IntegerType::get(b.getContext(), 64)),
/* value */ workgroupAttrs.getValue()[2]
);

Value stagesValue = b.create<transform::ParamConstantOp>(
transform::ParamType::get(b.getContext(), mlir::IntegerType::get(b.getContext(), 64)),
/* value */ stagesAttr
);

b.create<transform::AnnotateOp>(func, getGemmTileMConfigAttrName(),
tileConfigValue_M);
b.create<transform::AnnotateOp>(func, getGemmTileNConfigAttrName(),
tileConfigValue_N);
b.create<transform::AnnotateOp>(func, getGemmTileKConfigAttrName(),
tileConfigValue_K);
b.create<transform::AnnotateOp>(func, getGemmBlockXSizeAttrName(),
workGroupValue_X);
b.create<transform::AnnotateOp>(func, getGemmBlockYSizeAttrName(),
workGroupValue_Y);
b.create<transform::AnnotateOp>(func, getGemmBlockZSizeAttrName(),
workGroupValue_Z);
b.create<transform::AnnotateOp>(func, getGemmPipelineStageAttrName(),
stagesValue);
};

pm.addPass(createGenericTransformInsertionPass(config));
}

} // namespace

void mlir::buddy::createGemmTileConfigInsertTransform(
OpPassManager &pm, const GPUGemmCodegenConfigOptions &options) {
invokeOpPassPipelineBuilder(
createAddGemmCodegenLoweringConfigTransformImpl, pm,
options.funcAnchor, options.annotatePrefix, options.tileConfig,
options.workGroup, options.stages);
}

namespace {

// TODO: Epilogue
void createGemmTileTransformImpl(OpPassManager &pm,
const std::string &anchor,
const std::string &prefix) {
TransformInsertionConfig config;
config.funcAnchor = anchor;
config.matchPrefix = prefix;
config.opFilter = [=](Operation *op){
if (isLinalgMatmul(op)) {
return true;
}
return false;
};
config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, Value pdlV) {
func::FuncOp funcOp = op->getParentOfType<func::FuncOp>();
linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);

SmallVector<int64_t, 3> tileConfig = getGemmTileSize(funcOp).value();
SmallVector<int64_t, 3> workGroup = getGemmBlockSize(funcOp).value();
int64_t stages = getGemmPipelineStages(funcOp).value();

bool hasEpilogue = false;

auto func = b.create<transform::GetParentOp>(
pdlV.getType(), pdlV,
/* isolated_from_above */ false,
/* allow_empty_results */ false,
/* op_name */ b.getStringAttr(func::FuncOp::getOperationName()),
/* deduplicate */ false,
/* nth_parent */ 1);

auto linalgFillType = transform::OperationType::get(
b.getContext(), linalg::FillOp::getOperationName()
);
auto linalgFillOp = b.create<transform::MatchOp>(
/* resultTypes */ linalgFillType,
/* target */ func,
/* opNames */ linalg::FillOp::getOperationName()
);

SmallVector<int64_t> mappingIdx;
bool isBMM = linalgOp.getNumParallelLoops() == 3;
if (isBMM) {
// 2 -> blockIdx.z 1 -> blockIdx.y 0->blockIdx.x
mappingIdx = {2, 1, 0};
} else {
// 1 -> blockIdx.y 0 -> blockIdx.x
mappingIdx = {1, 0};
}

// get GPU BlockIdx mapping
auto mapping = llvm::to_vector(llvm::map_range(
mappingIdx,
[](int64_t i){return static_cast<gpu::MappingId>(i);
}));
auto mappingAttrs = llvm::to_vector(llvm::map_range(
mapping,
[&](gpu::MappingId dim) -> Attribute {
return gpu::GPUBlockMappingAttr::get(b.getContext(), dim);
}));

SmallVector<int64_t> parallelTileSizes;
if (isBMM) {
parallelTileSizes = {1, tileConfig[0], tileConfig[1]};
} else {
parallelTileSizes = {tileConfig[0], tileConfig[1]};
}

// tile DimM and DimN and each tile dispathes to block
Value tiledMatmulOp;
if (hasEpilogue) {
// TODO
} else {
transform::TileUsingForallOp tiledResultOp =
b.create<transform::TileUsingForallOp>(
/* target */ pdlV,
/* staticTileSizes */ parallelTileSizes,
/* ctor tag */ transform::TileSizesSpec(),
/* mapping */ b.getArrayAttr(mappingAttrs)
);

if (linalgFillOp) {
b.create<transform::FuseIntoContainingOp>(
/* producerOp */ linalgFillOp,
/* containingOp */ tiledResultOp.getForallOp()
);
}
tiledMatmulOp = tiledResultOp.getTiledOp();
}

// only tile DimK of the matmul which is dispatched to each block
SmallVector<int64_t> reduceTileSize;
if (isBMM) {
reduceTileSize = {0, 0, 0, tileConfig[2]};
} else {
reduceTileSize = {0, 0, tileConfig[2]};
}

auto tiledKMatmulOp =
b.create<transform::TileUsingForOp>(
/* target */ tiledMatmulOp,
/* staticTileSizes */ reduceTileSize
);

// for k in K steps tileConfig[2]
auto forLoops = tiledKMatmulOp.getLoops();
// tiledmatmul computes at (BM, BN, tileConfig[2])
auto kMatmulOp = tiledKMatmulOp.getTiledLinalgOp();

if (!forLoops.empty()) {
b.create<transform::AnnotateOp>(forLoops[0], getMatmulKMainLoopMarker(),
Value());
} else {
b.create<transform::AnnotateOp>(kMatmulOp, getMatmulKMainLoopMarker(),
Value());
}

// Value mmaLevel = b.create<transform::ParamConstantOp>(
// /* type */ transform::ParamType::get(b.getContext(), b.getStringAttr()),
// /* value */ b.getStringAttr("Threadblock")
// );

// b.create<transform::AnnotateOp>(kMatmulOp, getLinalgMMALevelAttrName(),
// mmaLevel);
b.create<transform::AnnotateOp>(kMatmulOp, getMMAPatternAttrName(),
Value());
};
pm.addPass(createGenericTransformInsertionPass(config));
}
} // namespace

void mlir::buddy::createGemmTileTransform(OpPassManager &pm,
const GPUGemmGeneralOptions &options) {
invokeOpPassPipelineBuilder(
createGemmTileTransformImpl, pm,
options.funcAnchor, options.annotatePrefix);
}
44 changes: 44 additions & 0 deletions Pipelines/LinalgMemrefOpt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "Pipelines/LinalgMemrefOpt.h"
#include "GPU/Transforms/GPUDistributeToWarp.h"
#include "GPU/Transforms/RemoveReduntantLoops.h"
#include "GPU/Transforms/TensorCoreVectorization.h"
#include "Linalg/Transforms/LinalgPromotion.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "Utils/PipelineUtils.h"
#include "mlir/Transforms/Passes.h"
#include <string>

using namespace mlir;

namespace {

void addGemmLinalgMemrefOptPipeline(OpPassManager &pm) {
// TODO : use funcAnchor to nest the specific matmul func
pm.addNestedPass<func::FuncOp>(createLinalgPromotionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createGPUDistributeToWarpPass());
pm.addNestedPass<func::FuncOp>(createRemoveReduntantLoops());
pm.addNestedPass<func::FuncOp>(createTensorCoreVectorizationPass());
}

void createLinalgMemrefOptPipelineImpl(OpPassManager &pm,
const std::string target) {
addGemmLinalgMemrefOptPipeline(pm);
}

}

void mlir::buddy::createLinalgMemrefOptPipeline(OpPassManager &pm,
const LinalgMemrefOptPipelineOptions &options) {
invokeOpPassPipelineBuilder(createLinalgMemrefOptPipelineImpl, pm, options.target);
}

void mlir::buddy::registerLinalgMemrefOptPipeline() {
PassPipelineRegistration<LinalgMemrefOptPipelineOptions>(
"linalg-memref-opt", "Linalg Opt Pipeline with Memref",
createLinalgMemrefOptPipeline);
}
Loading