Skip to content

Commit

Permalink
[torch] torch.dequantize for per channel tensors to linalg (#2769)
Browse files Browse the repository at this point in the history
Support a lowering for dequantization for per channel tensors from
`torch` dialect to a linalg decomposition. Tested via a numerical
`torch` test.
  • Loading branch information
rsuderman authored Jan 26, 2024
1 parent 0aed231 commit 2ef2283
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 8 deletions.
53 changes: 53 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14465,6 +14465,33 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
}];
}

def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$scales,
AnyTorchTensorType:$zero_points,
Torch_IntType:$axis,
Torch_IntType:$dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -14560,6 +14587,32 @@ def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [
}];
}

def Torch_Aten_MakePerChannelQuantizedTensorOp : Torch_Op<"aten._make_per_channel_quantized_tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$scale,
AnyTorchTensorType:$zero_point,
Torch_IntType:$axis
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_MakePerChannelQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void Aten_MakePerChannelQuantizedTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
114 changes: 106 additions & 8 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto makeQTensor =
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
if (!makeQTensor) {
op->emitError(
op->emitWarning(
"unimplemented: dequantizing tensor of unknown scale / zero-point");
return nullptr;
}
Expand Down Expand Up @@ -2221,16 +2221,109 @@ class ConvertAtenIntReprOp : public OpConversionPattern<AtenIntReprOp> {
} // namespace

namespace {
class ConvertMakePerTensorQuantizedTensorOp
: public OpConversionPattern<Aten_MakePerTensorQuantizedTensorOp> {
class ConvertDequantizePerChannel
: public OpConversionPattern<AtenDequantizeSelfOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor,
matchAndRewrite(AtenDequantizeSelfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto loc = op.getLoc();
auto qoperand = op.getOperand();
auto make = qoperand.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>();
if (!make) {
llvm::errs() << "Did not find make per channel\n";
return rewriter.notifyMatchFailure(op, "did not find per channel qint");
}

auto converter = getTypeConverter();
auto operand = make.getOperand(0);
auto scale = make.getScale();
auto zeropoint = make.getZeroPoint();
auto axis = make.getAxis();

IntegerAttr axisAttr;
if (!matchPattern(axis, m_Constant(&axisAttr))) {
return failure();
}

auto operandDTy = operand.getType().cast<ValueTensorType>().getDtype();
auto zeropointDTy = zeropoint.getType().cast<ValueTensorType>().getDtype();
operand = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(operand.getType()), operand);
scale = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(scale.getType()), scale);
zeropoint = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);

auto resultType = converter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();

llvm::SmallVector<Value> dynSizes;
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
if (ShapedType::isDynamic(dim)) {
dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, operand, index));
}
}

llvm::SmallVector<utils::IteratorType> iterators(
resultType.getRank(), utils::IteratorType::parallel);
llvm::SmallVector<AffineMap> maps(
4, {rewriter.getMultiDimIdentityMap(resultType.getRank())});
auto broadcastMap = AffineMap::get(
resultType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(axisAttr.getInt())}, rewriter.getContext());
maps[1] = broadcastMap;
maps[2] = broadcastMap;

auto empty =
rewriter.create<tensor::EmptyOp>(op.getLoc(), resultType, dynSizes);
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, resultType, ValueRange{operand, scale, zeropoint},
ValueRange{empty}, maps, iterators,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value operand = args[0];
Value scale = args[1];
Value zeropoint = args[2];
if (operandDTy.isUnsignedInteger(8)) {
operand = b.create<arith::ExtUIOp>(loc, b.getI32Type(), operand);
} else if (operandDTy.isSignedInteger(8)) {
operand = b.create<arith::ExtSIOp>(loc, b.getI32Type(), operand);
}

if (zeropointDTy.isUnsignedInteger(8)) {
zeropoint =
b.create<arith::ExtUIOp>(loc, b.getI32Type(), zeropoint);
} else if (zeropointDTy.isSignedInteger(8)) {
zeropoint =
b.create<arith::ExtSIOp>(loc, b.getI32Type(), zeropoint);
}

Value sub = rewriter.create<arith::SubIOp>(loc, operand, zeropoint);
Value fp =
rewriter.create<arith::SIToFPOp>(loc, args[3].getType(), sub);
Value mul = rewriter.create<arith::MulFOp>(loc, fp, scale);
b.create<linalg::YieldOp>(loc, mul);
});
rewriter.replaceOp(op, linalgOp.getResults());
return success();
}
};
} // namespace

namespace {

template <typename OpTy>
class ConvertCastEquivalentOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;

LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = this->getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
converter->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
Expand Down Expand Up @@ -2283,6 +2376,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
target.addIllegalOp<TensorStaticInfoCastOp>();
patterns.add<ConvertAtenIntReprOp>(typeConverter, context);
target.addIllegalOp<AtenIntReprOp>();
patterns.add<ConvertMakePerTensorQuantizedTensorOp>(typeConverter, context);
patterns.add<ConvertCastEquivalentOp<Aten_MakePerChannelQuantizedTensorOp>>(
typeConverter, context);
target.addIllegalOp<Aten_MakePerChannelQuantizedTensorOp>();
patterns.add<ConvertCastEquivalentOp<Aten_MakePerTensorQuantizedTensorOp>>(
typeConverter, context);
target.addIllegalOp<Aten_MakePerTensorQuantizedTensorOp>();
patterns.add<ConvertDequantizePerChannel>(typeConverter, context);
}
32 changes: 32 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6549,6 +6549,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand All @@ -6565,6 +6569,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12632,6 +12640,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n"
" return %arg4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" return %arg3 : !torch.int\n"
" }\n"
Expand Down Expand Up @@ -12664,6 +12675,27 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int) -> !torch.int {\n"
" %int14 = torch.constant.int 14\n"
" %int12 = torch.constant.int 12\n"
" %int1 = torch.constant.int 1\n"
" %int13 = torch.constant.int 13\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %int13 : !torch.int\n"
" } else {\n"
" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int12 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int14 : !torch.int\n"
" }\n"
" torch.prim.If.yield %4 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int) -> !torch.int {\n"
" %int14 = torch.constant.int 14\n"
" %int12 = torch.constant.int 12\n"
Expand Down
20 changes: 20 additions & 0 deletions projects/ltc/csrc/base_lazy_backend/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<torch::lazy::Shape>
compute_shape__make_per_channel_quantized_tensor(const at::Tensor &self,
const at::Tensor &scale,
const at::Tensor &zero_point,
int64_t axis) {
if (self.scalar_type() == at::kChar)
return {Shape(at::kQInt8, self.sizes().vec())};
if (self.scalar_type() == at::kByte)
return {Shape(at::kQUInt8, self.sizes().vec())};
if (self.scalar_type() == at::kInt)
return {Shape(at::kQInt32, self.sizes().vec())};
assert(false);
}

std::vector<torch::lazy::Shape> compute_shape__make_per_tensor_quantized_tensor(
const at::Tensor &self, double scale, int64_t zero_point) {
if (self.scalar_type() == at::kChar)
Expand Down Expand Up @@ -75,6 +89,12 @@ std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
return {Shape(at::kBool, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_quantize_per_channel(
const at::Tensor &self, const at::Tensor &scales,
const at::Tensor &zero_points, int64_t axis, at::ScalarType dtype) {
return {Shape(dtype, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
const at::Tensor& self, at::IntArrayRef kernel_size,
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@
"GroupNormNoWeightAndBiasModule_basic",

# Dynamo does not support tracing quantized tensors
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"AtenMmQuint8_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ def aten〇clamp_max〡shape(self: List[int], max: float) -> List[int]:
def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇quantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]:
return upstream_shape_functions.unary(self)

Expand All @@ -263,6 +266,9 @@ def aten〇dequantize〇tensor〡shape(qtensor: List[int]) -> List[int]:
def aten〇int_repr〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇_make_per_channel_quantized_tensor〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: float, zero_point: int) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -4280,6 +4286,9 @@ def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int
return a_dtype


def aten〇quantize_per_channel〡dtype(self_rank_dtype: Tuple[int, int], scales_rank_dtype: Tuple[int, int], zero_points_rank_dtype: Tuple[int, int], axis: int, dtype: int) -> int:
return dtype

def aten〇quantize_per_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, dtype: int) -> int:
return dtype

Expand All @@ -4297,6 +4306,14 @@ def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.int8
return torch.int32

def aten〇_make_per_channel_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int) -> int:
self_rank, self_dtype = self_rank_dtype
if (self_dtype == torch.uint8):
return torch.quint8
if (self_dtype == torch.int8):
return torch.qint8
return torch.qint32

def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int) -> int:
self_rank, self_dtype = self_rank_dtype
if (self_dtype == torch.uint8):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -820,10 +820,12 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")

# quantized ops
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)")
emit("aten::dequantize.self : (Tensor) -> (Tensor)")
emit("aten::dequantize.tensor : (Tensor) -> (Tensor)")
emit("aten::int_repr : (Tensor) -> (Tensor)")
emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)")
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")

# ==========================================================================
Expand Down
27 changes: 27 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4328,6 +4328,33 @@ def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils):

# ==============================================================================

class ElementwiseDequantizePerChannelModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 4], torch.int8, True),
([4], torch.int8, True),
([4], torch.float, True),
])
def forward(self, x, zeropoint, scale):
qx = torch._make_per_channel_quantized_tensor(x, scale, zeropoint, axis=1)
qx = torch.dequantize(qx)
return qx

@register_test_case(module_factory=lambda: ElementwiseDequantizePerChannelModule())
def ElementwiseDequantizePerChannelModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
tu.randint(4, low=-128, high=127).to(torch.int8),
tu.rand(4)
)

# ==============================================================================

class GluStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 2ef2283

Please sign in to comment.