Skip to content

Commit

Permalink
Support brevitas custom op (#2320)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 authored and dan-garvey committed Jul 19, 2023
1 parent c9add6b commit 11da3c5
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createFinalizingBackendTypeConversionPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createUnpackTorchTensorPass();

std::unique_ptr<OperationPass<ModuleOp>>
createVerifyLinalgOnTensorsBackendContractPass();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def FinalizingBackendTypeConversion
}];
}

def UnpackTorchTensor : Pass<"torch-unpack-torch-tensor", "func::FuncOp"> {
let summary = "Unpack Int4 Torch Tensor";
let constructor = "mlir::torch::TorchConversion::createUnpackTorchTensorPass()";
}

def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
Expand Down
162 changes: 162 additions & 0 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,166 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
};
} // namespace

namespace {
class ConvertCustomQuantizedMatmulOp : public OpConversionPattern<OperatorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(OperatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getName().str() != "brevitas.matmul_rhs_group_quant") {
return failure();
}
Location loc = op->getLoc();
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
return failure();
}

// get inputs: lhs, q_rhs, scales, zps
Value lhs = adaptor.getOperands()[0];
auto lhsType = lhs.getType().cast<RankedTensorType>();
if (!lhsType) {
return failure();
}
auto lhsShape = lhsType.getShape();
int lhs_reduct_dim_size = lhsShape.back();

Value q_rhs = adaptor.getOperands()[1];
auto rhsType = q_rhs.getType().cast<RankedTensorType>();
if (!rhsType) {
return failure();
}
auto rhsShape = rhsType.getShape();
int rhs_reduct_dim_size = rhsShape.back();
Type rhs_elementType = rhsType.getElementType();

Value scales = adaptor.getOperands()[2];
Value zps = adaptor.getOperands()[3];
Value unpacked_type_width = adaptor.getOperands()[4];
Value group_size = adaptor.getOperands()[5];

auto getConstantIntegerFromDefiningOp = [](Value operand,
int &extractedInt) {
auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(operand.getDefiningOp());
if (!castOp) {
return failure();
}
auto constOp =
dyn_cast<Torch::ConstantIntOp>(castOp.getOperand(0).getDefiningOp());
if (!constOp) {
return failure();
}
extractedInt = constOp.getValue();
return success();
};

int gs;
if (failed(getConstantIntegerFromDefiningOp(group_size, gs))) {
return failure();
}
int unpackedBitWidth;
if (failed(getConstantIntegerFromDefiningOp(unpacked_type_width, unpackedBitWidth))) {
return failure();
}
if (unpackedBitWidth != rhs_elementType.getIntOrFloatBitWidth()) {
return failure();
}

// get outputs
Type newResultType = getTypeConverter()->convertType(op.getType(0));
auto resultType = newResultType.cast<RankedTensorType>();
if (!resultType) {
return failure();
}
auto resultShape = resultType.getShape();
Type elementType = resultType.getElementType();

// expand lhs
std::vector<int64_t> lhs_expandedShape = {lhsShape[0], lhsShape[1],
lhs_reduct_dim_size / gs, gs};
RankedTensorType lhs_expandedType = RankedTensorType::get(lhs_expandedShape, elementType);
SmallVector<ReassociationIndices, 4> lhs_reassociation = {{0}, {1}, {2, 3}};
Value expanded_lhs = rewriter.create<tensor::ExpandShapeOp>(
loc, lhs_expandedType, lhs, lhs_reassociation);

// expand rhs
std::vector<int64_t> expandedShape = {rhsShape[0], rhs_reduct_dim_size/gs, gs};
RankedTensorType expandedType = RankedTensorType::get(expandedShape, rhs_elementType);
SmallVector<ReassociationIndices, 4> reassociation = {{0}, {1, 2}};
Value expanded_rhs = rewriter.create<tensor::ExpandShapeOp>(
loc, expandedType, q_rhs, reassociation);
Value cst_0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));

Value dq_empty = rewriter.create<tensor::EmptyOp>(
loc, expandedShape, elementType);
SmallVector<Value> dynDims;
for (int i = 0; i < lhsType.getRank(); i++) {
if (lhsType.isDynamicDim(i)) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, lhs, i));
}
}
Value empty = rewriter.create<tensor::EmptyOp>(
loc, resultShape, elementType, dynDims);
Value output = rewriter.create<linalg::FillOp>(
loc, cst_0, empty).getResult(0);

AffineExpr d0, d1, d2, d3, d4;
bindDims(getContext(), d0, d1, d2, d3, d4);
auto c0 = rewriter.getAffineConstantExpr(0);
auto map = AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext());
auto map1 = AffineMap::get(3, 0, {d0, d1, c0}, rewriter.getContext());
auto map2 = AffineMap::get(5, 0, {d0, d1, d3, d4}, rewriter.getContext());
auto map3 = AffineMap::get(5, 0, {d2, d3, d4}, rewriter.getContext());
auto map4 = AffineMap::get(5, 0, {d0, d1, d2}, rewriter.getContext());
SmallVector<AffineMap, 4> dq_indexingMaps = {map, map1, map1, map};
SmallVector<AffineMap, 4> mat_indexingMaps = {map2, map3, map4};

SmallVector<utils::IteratorType> dq_iteratorTypes(3, utils::IteratorType::parallel);
SmallVector<utils::IteratorType> mat_iteratorTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel,
utils::IteratorType::parallel, utils::IteratorType::reduction,
utils::IteratorType::reduction
};

Value dq_rhs =
rewriter
.create<linalg::GenericOp>(
loc, dq_empty.getType(),
ValueRange{expanded_rhs, scales, zps}, dq_empty,
/*indexingMaps=*/dq_indexingMaps,
/*iteratorTypes=*/dq_iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value w = args[0], scale = args[1], zeroPoint = args[2];
Value extw = b.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), w);
Value fp_extw = b.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), extw);
Value shifted = b.create<arith::SubFOp>(loc, fp_extw, zeroPoint);
Value dqw = b.create<arith::MulFOp>(loc, shifted, scale);
b.create<linalg::YieldOp>(loc, dqw);
})
.getResult(0);

Value quantMat =
rewriter
.create<linalg::GenericOp>(
loc, output.getType(),
ValueRange{expanded_lhs, dq_rhs}, output,
/*indexingMaps=*/mat_indexingMaps,
/*iteratorTypes=*/mat_iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value l = args[0], r = args[1], out = args[2];
Value pd = b.create<arith::MulFOp>(loc, l, r);
Value ac = b.create<arith::AddFOp>(loc, pd, out);
b.create<linalg::YieldOp>(loc, ac);
})
.getResult(0);

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, quantMat);
return success();
}
};
} // namespace

namespace {
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
public:
Expand Down Expand Up @@ -860,6 +1020,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
target.addIllegalOp<AtenMatmulOp>();
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
target.addIllegalOp<OperatorOp>();
patterns.add<ConvertCustomQuantizedMatmulOp>(typeConverter, context);
target.addIllegalOp<AtenBmmOp>();
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ static bool isValidTorchDtype(Type dtype) {
if (type.isSignless() && type.getWidth() == 1)
return true;
if (type.isSigned()) {
for (unsigned width : {8, 16, 32, 64}) {
for (unsigned width : {4, 8, 16, 32, 64}) {
if (type.getWidth() == width)
return true;
}
}
if (type.isUnsigned()) {
return type.getWidth() == 8;
return type.getWidth() == 8 || type.getWidth() == 4;
}
}
return false;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TorchConversion/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
BackendTypeConversion.cpp
BackendTypeConversionPasses.cpp
Passes.cpp
UnpackTensor.cpp
VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp
VerifyStablehloBackendContract.cpp
Expand Down
141 changes: 141 additions & 0 deletions lib/Dialect/TorchConversion/Transforms/UnpackTensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"

#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
class UnpackQuantizedMatmulWeights
: public OpRewritePattern<ValueTensorLiteralOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ValueTensorLiteralOp constOp,
PatternRewriter &rewriter) const override {
if (!constOp->hasOneUse())
return failure();

OpOperand *use = constOp.getResult().use_begin().getOperand();
auto op = dyn_cast<OperatorOp>(use->getOwner());
if (!op)
return failure();

if (use->getOperandNumber() != 1)
return failure();

if (op.getName().str() != "brevitas.matmul_rhs_group_quant") {
return failure();
}

Value rhs = op.getOperand(1);
Value bitWidth = op.getOperand(4);

auto getConstantIntegerFromDefiningOp = [](Value operand,
int &extractedInt) {
auto constOp = dyn_cast<Torch::ConstantIntOp>(operand.getDefiningOp());
if (!constOp) {
return failure();
}
extractedInt = constOp.getValue();
return success();
};
int unpackedBitWidth;
if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth)))
return failure();

auto rhsType = rhs.getType().dyn_cast<ValueTensorType>();
if (!rhsType)
return failure();

if (!rhsType.hasDtype())
return failure();

Type dType = rhsType.getDtype();
int dTypeWidth = dType.getIntOrFloatBitWidth();
if (dTypeWidth == unpackedBitWidth)
return failure();

if (!rhsType.hasSizes())
return failure();

SmallVector<int64_t> tensorShape(rhsType.getSizes());
if (tensorShape.back() == kUnknownSize)
return failure();
int packRatio = dTypeWidth / unpackedBitWidth;

tensorShape[tensorShape.size() - 1] *= packRatio;
Type unpackedElementType;
if (dType.isSignedInteger())
unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, true);
else
unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, false);
ValueTensorType newRhsType = ValueTensorType::get(
rewriter.getContext(), tensorShape, unpackedElementType);

auto elements = constOp.getValueAttr().dyn_cast<DenseIntElementsAttr>();
if (!elements)
return failure();

auto attrType = RankedTensorType::get(tensorShape, unpackedElementType);

// This is terrible but idk what else to do.
auto data = elements.getRawData();
std::vector<APInt> newData(data.size() * packRatio,
APInt(unpackedBitWidth, 0));
for (int i = 0, e = data.size(); i < e; ++i) {
auto el = data[i];
char mask = (1 << unpackedBitWidth) - 1;
for (int b = 0; b < packRatio; b++) {
newData[i * packRatio + b] =
APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b));
mask = mask << unpackedBitWidth;
}
}
rewriter.replaceOpWithNewOp<ValueTensorLiteralOp>(
constOp, newRhsType,
DenseElementsAttr::get(attrType, ArrayRef<APInt>(newData)));
return success();
}
};
} // namespace

namespace {
class UnpackTorchTensorPass
: public TorchConversion::UnpackTorchTensorBase<UnpackTorchTensorPass> {
using UnpackTorchTensorBase<UnpackTorchTensorPass>::UnpackTorchTensorBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect>();
registry.insert<Torch::TorchDialect>();
}

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<UnpackQuantizedMatmulWeights>(context);

if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::TorchConversion::createUnpackTorchTensorPass() {
return std::make_unique<UnpackTorchTensorPass>();
}

0 comments on commit 11da3c5

Please sign in to comment.