Skip to content

Commit

Permalink
Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)
Browse files Browse the repository at this point in the history
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For
`aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization
Patterns as follow:

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.mul.int %1, %const-6
```

Will be replaced by

`torch.aten.mul.int %input, %const-30`


And 

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.floordiv.int %1, %const-5
```
Will directly return `%input`


This PR also relaxes the `float` type constraint in TorchToTosa for the
`AtenRsubScalarOp` conversion.



To test:

`cmake --build build --target check-torch-mlir-all`
  • Loading branch information
zezhang authored Sep 3, 2024
1 parent 70de04a commit b3942ff
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 7 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15078,6 +15078,7 @@ def Torch_AtenFloordivIntOp : Torch_Op<"aten.floordiv.int", [
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [
Expand Down Expand Up @@ -15226,6 +15227,7 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
Expand Down
4 changes: 0 additions & 4 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1823,10 +1823,6 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Rsub");

if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");

Value otherTensor, alphaTensor;

if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
Expand Down
77 changes: 77 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3434,6 +3434,44 @@ OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
}

void AtenFloordivIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenFloordivIntOp op, PatternRewriter &rewriter) {
int64_t lhs, rhs;
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
if (lConstant && rConstant)
return failure();
if (lConstant || rConstant) {
int64_t firstConstant = lConstant ? lhs : rhs;
Value firstOperand = lConstant ? op.getB() : op.getA();
if (firstOperand.getDefiningOp() &&
firstOperand.getDefiningOp<AtenMulIntOp>()) {
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
int64_t prevLhs, prevRhs;
bool prevLConstant =
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
bool prevRConstant =
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
if (prevLConstant && prevRConstant)
return failure();
if ((prevLConstant || prevRConstant) &&
prevMulIntOp->hasOneUse() == 1) {
int64_t secondConstant = prevLConstant ? prevLhs : prevRhs;
if (secondConstant == firstConstant) {
rewriter.replaceAllUsesWith(
op.getResult(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0));
rewriter.eraseOp(op);
rewriter.eraseOp(prevMulIntOp);
return success();
}
}
}
}
return failure();
});
}

//===----------------------------------------------------------------------===//
// AtenRemainderIntOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3697,6 +3735,45 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

void AtenMulIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenMulIntOp op, PatternRewriter &rewriter) {
int64_t lhs, rhs;
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
if (lConstant && rConstant)
return failure();
if (lConstant || rConstant) {
int64_t firstConstant = lConstant ? lhs : rhs;
Value firstOperand = lConstant ? op.getB() : op.getA();
if (firstOperand.getDefiningOp() &&
firstOperand.getDefiningOp<AtenMulIntOp>()) {
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
int64_t prevLhs, prevRhs;
bool prevLConstant =
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
bool prevRConstant =
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
if (prevLConstant && prevRConstant)
return failure();
if ((prevLConstant || prevRConstant) &&
prevMulIntOp->hasOneUse() == 1) {
auto newConstant = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(
prevLConstant ? prevLhs * firstConstant
: prevRhs * firstConstant));
rewriter.replaceOpWithNewOp<AtenMulIntOp>(
op, op.getType(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0),
newConstant);
rewriter.eraseOp(prevMulIntOp);
return success();
}
}
}
return failure();
});
}

//===----------------------------------------------------------------------===//
// AtenMulFloatOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,7 @@
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"RsubIntModule_basic",
"ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic",
"ScalarTensorInt32Module_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1060,13 +1060,21 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::le.int : (int, int) -> (bool)", has_folder=True)
emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
emit(
"aten::floordiv.int : (int, int) -> (int)",
has_folder=True,
has_canonicalizer=True,
)
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
emit(
"aten::mul.int : (int, int) -> (int)",
has_folder=True,
has_canonicalizer=True,
)
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)")
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
"gt": torch.ops.aten.gt,
"mod": torch.ops.aten.fmod,
"eq": torch.ops.aten.eq,
"floordiv": torch.ops.aten.floordiv,
}

# torch with cuda has a __version__ that looks like "2.1.0+cu113",
Expand Down
24 changes: 23 additions & 1 deletion test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,19 @@ func.func @torch.aten.mul.int() -> !torch.int {
return %ret : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.mul.int$canonicalize(
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[CST30:.*]] = torch.constant.int 30
// CHECK: %[[RET:.*]] = torch.aten.mul.int %[[ARG]], %[[CST30]] : !torch.int, !torch.int -> !torch.int
// CHECK: return %[[RET]] : !torch.int
func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int {
%cst6 = torch.constant.int 6
%cst5 = torch.constant.int 5
%1 = torch.aten.mul.int %arg0, %cst5: !torch.int, !torch.int -> !torch.int
%ret = torch.aten.mul.int %1, %cst6: !torch.int, !torch.int -> !torch.int
return %ret : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
// CHECK: return %[[CST30]] : !torch.float
Expand Down Expand Up @@ -1207,6 +1220,16 @@ func.func @torch.aten.floordiv.int() -> !torch.int {
return %ret : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.floordiv.int$canonicalize(
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
// CHECK: return %[[ARG]] : !torch.int
func.func @torch.aten.floordiv.int$canonicalize(%arg0: !torch.int) -> !torch.int {
%cst6 = torch.constant.int 6
%1 = torch.aten.mul.int %arg0, %cst6: !torch.int, !torch.int -> !torch.int
%ret = torch.aten.floordiv.int %1, %cst6: !torch.int, !torch.int -> !torch.int
return %ret : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.remainder.int() -> !torch.int {
// CHECK: %[[CST3:.*]] = torch.constant.int 3
// CHECK: return %[[CST3]] : !torch.int
Expand Down Expand Up @@ -3122,7 +3145,6 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!
return %1 : !torch.tensor
}


// -----

// CHECK-LABEL: @torch.symbolic_int$canonicalize(
Expand Down

0 comments on commit b3942ff

Please sign in to comment.