From 3056b96c2054007ca7dcbe1623f03a24dee8c454 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 24 Nov 2023 16:11:59 +0900 Subject: [PATCH 1/4] A pattern to optimize a scalar Div in self-attention layer Signed-off-by: Tung D. Le --- src/Dialect/ONNX/Rewrite.cpp | 138 ++++++++++++++++++++++ test/mlir/onnx/onnx_canonicalization.mlir | 55 +++++++++ 2 files changed, 193 insertions(+) diff --git a/src/Dialect/ONNX/Rewrite.cpp b/src/Dialect/ONNX/Rewrite.cpp index 7523d7543d..1600abb62c 100644 --- a/src/Dialect/ONNX/Rewrite.cpp +++ b/src/Dialect/ONNX/Rewrite.cpp @@ -209,6 +209,46 @@ bool haveSameStaticShape(Value lhs, Value rhs) { return hasStaticShape(lhsT) && (getShape(lhsT) == getShape(rhsT)); } +// Match v = shape_transform(X*A + B). +// shape_transform is a sequence of operations like Reshape, Transpose, +// Squeeze, Unsqueeze, etc. that do not change the numerical values by data +// shape. +// A and B are constants. +bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB, + Operation *&matmulOp, Operation *&addOp) { + if (v.isa()) + return false; + Value origV = v; + // Match shape operations. + while (isa( + origV.getDefiningOp())) { + origV = origV.getDefiningOp()->getOperands()[0]; + if (origV.isa()) + return false; + } + + // Match Add. + addOp = origV.getDefiningOp(); + if (!addOp || !isa(addOp)) + return false; + + // LHS of Add is MatMul. + matmulOp = addOp->getOperands()[0].getDefiningOp(); + if (!matmulOp || !isa(matmulOp)) + return false; + matA = matmulOp->getOperands()[1]; + if (!isDenseONNXConstant(matA)) + return false; + + // RHS of Add is a constant. + biasB = addOp->getOperands()[1]; + if (!isDenseONNXConstant(biasB)) + return false; + + // Passed all tests. + return true; +} + } // namespace onnx_mlir // ============================================================================= @@ -395,6 +435,103 @@ class PropagateReshapeThroughBinaryOpPattern }; }; +// This rewriting is to optimize the scalar Div in self-attention layers. +// In particular, it rewrites the following pattern: +// ``` +// shape_transform(X1 * A1 + B1) * shape_transform(X2 * A2 + B2) / k +// ``` +// +// into +// ``` +// shape_transform(X1 * A1 + B1) * shape_transform(X2 * A2/k + B2/k) +// ``` +// if A2, B2 and k are constants, +// +// or into +// ``` +// shape_transform(X1 * A1/k + B1/k) * shape_transform(X2 * A2 + B2) +// ``` +// if A1, B1 and k are constants, +// +// where +// - * is matrix multiplication; + and / are element-wise addition and division +// - A1, A2, B1, B2, and k are constants so that A1/k, B1/k, A2/k and B2/k can +// be folded. k is a scalar constant so that it's broadcastable to all A1, A2, +// B1, B2. +// - shape_transform includes a sequence of operations that change the data +// shape of the input but not numerical values, for example: Reshape, +// Transpose, etc. +// +struct PropagateScalarDivInAttentionLayerPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateScalarDivInAttentionLayerPattern(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite( + ONNXDivOp omDivOp, PatternRewriter &rewriter) const final { + Operation *divOp = omDivOp.getOperation(); + Value divLHS = omDivOp.getA(); + Value divK = omDivOp.getB(); + + // Match (lhs * rhs) / divK. + // The first operand of Div is produced by MatMulOp. + auto onnxMatMulOp = divLHS.getDefiningOp(); + if (!onnxMatMulOp) + return rewriter.notifyMatchFailure( + divOp, "The first operand of Div is not produced by MatMulOp"); + Value lhs = onnxMatMulOp.getA(); + Value rhs = onnxMatMulOp.getB(); + // The second operand of Div is a scalar constant. + if (!isScalarConstantTensor(divK)) + return rewriter.notifyMatchFailure( + divOp, "The second operand of Div is not a scalar constant"); + + // Match lhs = shape_transform(X1*A1 + B1) + Value A1, B1; + Operation *lhsSubMatOp, *lhsAddOp; + bool matchLHS = matchShapeAddMatMul(lhs, A1, B1, lhsSubMatOp, lhsAddOp); + + // Match rhs = shape_transform(X2*A2 + B2) + Value A2, B2; + Operation *rhsSubMatOp, *rhsAddOp; + bool matchRHS = matchShapeAddMatMul(rhs, A2, B2, rhsSubMatOp, rhsAddOp); + + if (!matchLHS && !matchRHS) + return rewriter.notifyMatchFailure(divOp, + "There is no constant tensor to replace the first operand of Div"); + + // Rewrite. + // Only rewrite one side, so use LHS if both sides are matched. + if (matchLHS && matchRHS) + matchRHS = false; + ONNXMatMulOp onnxSubMatOp = + cast(matchLHS ? lhsSubMatOp : rhsSubMatOp); + ONNXAddOp onnxAddOp = cast(matchLHS ? lhsAddOp : rhsAddOp); + Value A = matchLHS ? A1 : A2; + Value B = matchLHS ? B1 : B2; + + // Move divK up before MatMul to make sure it is in the dominant region. + divK.getDefiningOp()->moveBefore(onnxSubMatOp); + // Update in place MatMul and Add. + rewriter.updateRootInPlace(onnxSubMatOp, [&] { + OnnxBuilder createONNX(rewriter, onnxSubMatOp.getLoc()); + rewriter.setInsertionPoint(onnxSubMatOp); + onnxSubMatOp.getBMutable().assign(createONNX.div(A, divK)); + }); + rewriter.updateRootInPlace(onnxAddOp, [&] { + OnnxBuilder createONNX(rewriter, onnxAddOp.getLoc()); + rewriter.setInsertionPoint(onnxAddOp); + onnxAddOp.getBMutable().assign(createONNX.div(B, divK)); + }); + + // Bypass Div. + rewriter.replaceOp(divOp, onnxMatMulOp.getY()); + return success(); + } +}; + // ============================================================================= // Rewrite pattern for Resize (not handled in Rewrite.td). // ============================================================================= @@ -1379,6 +1516,7 @@ void ONNXDivOp::getCanonicalizationPatterns( result.insert>(context); result.insert>(context); result.insert>(context); + result.insert(context); } /// on the ONNXDropoutOp. diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index b6887b7389..4573795efe 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -1620,3 +1620,58 @@ func.func @test_not_replace_sub_by_expand_two_expands(%arg0: tensor) -> t // CHECK: return [[VAR_7_]] : tensor<2x?xf32> // CHECK: } } + +// ----- + +// COM: Optimize the scalar div in self-attention layer. +func.func @test_div_in_attention(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = onnx.Constant dense<1.280000e+02> : tensor<768xf32> + %1 = onnx.Constant dense<64> : tensor<1xi64> + %2 = onnx.Constant dense<12> : tensor<1xi64> + %3 = onnx.Constant dense<1.280000e+02> : tensor<768x768xf32> + %4 = "onnx.MatMul"(%arg1, %3) {onnx_node_name = "/encoder/layer.0/attention/self/query/MatMul"} : (tensor, tensor<768x768xf32>) -> tensor + %5 = "onnx.Add"(%4, %0) {onnx_node_name = "/encoder/layer.0/attention/self/query/Add-Initializer_encoder.layer.0.attention.self.query.bias_193"} : (tensor, tensor<768xf32>) -> tensor + %6 = "onnx.Dim"(%5) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_4_51"} : (tensor) -> tensor<1xi64> + %7 = "onnx.Dim"(%5) {axis = 1 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_5_205"} : (tensor) -> tensor<1xi64> + %8 = "onnx.Concat"(%6, %7, %2, %1) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Concat_2"} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> + %9 = "onnx.Reshape"(%5, %8) {allowzero = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Reshape_2"} : (tensor, tensor<4xi64>) -> tensor + %10 = "onnx.Transpose"(%9) {onnx_node_name = "/encoder/layer.0/attention/self/Transpose_1", perm = [0, 2, 1, 3]} : (tensor) -> tensor + %11 = "onnx.MatMul"(%arg0, %3) {onnx_node_name = "/encoder/layer.0/attention/self/key/MatMul"} : (tensor, tensor<768x768xf32>) -> tensor + %12 = "onnx.Add"(%11, %0) {onnx_node_name = "/encoder/layer.0/attention/self/key/Add-Initializer_encoder.layer.0.attention.self.key.bias_127"} : (tensor, tensor<768xf32>) -> tensor + %13 = "onnx.Dim"(%12) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_124"} : (tensor) -> tensor<1xi64> + %14 = "onnx.Dim"(%12) {axis = 1 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_1_133"} : (tensor) -> tensor<1xi64> + %15 = "onnx.Concat"(%13, %14, %2, %1) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Concat"} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> + %16 = "onnx.Reshape"(%12, %15) {allowzero = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Reshape"} : (tensor, tensor<4xi64>) -> tensor + %17 = "onnx.Transpose"(%16) {onnx_node_name = "/encoder/layer.0/attention/self/Transpose_2", perm = [0, 2, 3, 1]} : (tensor) -> tensor + %18 = "onnx.MatMul"(%10, %17) {onnx_node_name = "/encoder/layer.0/attention/self/MatMul"} : (tensor, tensor) -> tensor + %19 = onnx.Constant dense<8.000000e+00> : tensor + %20 = "onnx.Div"(%18, %19) {onnx_node_name = "/encoder/layer.0/attention/self/Div"} : (tensor, tensor) -> tensor + onnx.Return %20 : tensor + +// CHECK-LABEL: func.func @test_div_in_attention +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.280000e+02> : tensor<768xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<64> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<12> : tensor<1xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1.280000e+02> : tensor<768x768xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor +// CHECK: [[VAR_5_:%.+]] = "onnx.Div"([[VAR_3_]], [[VAR_4_]]) : (tensor<768x768xf32>, tensor) -> tensor<768x768xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.MatMul"([[PARAM_1_]], [[VAR_5_]]) {onnx_node_name = "/encoder/layer.0/attention/self/query/MatMul"} : (tensor, tensor<768x768xf32>) -> tensor +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Div"([[VAR_0_]], [[VAR_4_]]) : (tensor<768xf32>, tensor) -> tensor<768xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_7_]]) {onnx_node_name = "/encoder/layer.0/attention/self/query/Add-Initializer_encoder.layer.0.attention.self.query.bias_193"} : (tensor, tensor<768xf32>) -> tensor +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_4_51"} : (tensor) -> tensor<1xi64> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 1 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_5_205"} : (tensor) -> tensor<1xi64> +// CHECK: [[VAR_11_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_10_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Concat_2"} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> +// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[VAR_8_]], [[VAR_11_]]) {allowzero = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Reshape_2"} : (tensor, tensor<4xi64>) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Transpose"([[VAR_12_]]) {onnx_node_name = "/encoder/layer.0/attention/self/Transpose_1", perm = [0, 2, 1, 3]} : (tensor) -> tensor +// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_3_]]) {onnx_node_name = "/encoder/layer.0/attention/self/key/MatMul"} : (tensor, tensor<768x768xf32>) -> tensor +// CHECK: [[VAR_15_:%.+]] = "onnx.Add"([[VAR_14_]], [[VAR_0_]]) {onnx_node_name = "/encoder/layer.0/attention/self/key/Add-Initializer_encoder.layer.0.attention.self.key.bias_127"} : (tensor, tensor<768xf32>) -> tensor +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_124"} : (tensor) -> tensor<1xi64> +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 1 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_1_133"} : (tensor) -> tensor<1xi64> +// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Concat"} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> +// CHECK: [[VAR_19_:%.+]] = "onnx.Reshape"([[VAR_15_]], [[VAR_18_]]) {allowzero = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Reshape"} : (tensor, tensor<4xi64>) -> tensor +// CHECK: [[VAR_20_:%.+]] = "onnx.Transpose"([[VAR_19_]]) {onnx_node_name = "/encoder/layer.0/attention/self/Transpose_2", perm = [0, 2, 3, 1]} : (tensor) -> tensor +// CHECK: [[VAR_21_:%.+]] = "onnx.MatMul"([[VAR_13_]], [[VAR_20_]]) {onnx_node_name = "/encoder/layer.0/attention/self/MatMul"} : (tensor, tensor) -> tensor +// CHECK: onnx.Return [[VAR_21_]] : tensor +// CHECK: } +} From f094daffb82e3cdc74dc5cf8bb94c560eb3c7a9f Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 24 Nov 2023 17:15:16 +0900 Subject: [PATCH 2/4] remove onnx_node_name in the lit test Signed-off-by: Tung D. Le --- test/mlir/onnx/onnx_canonicalization.mlir | 62 +++++++++++------------ 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 4573795efe..09cc8d15c9 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -1629,23 +1629,23 @@ func.func @test_div_in_attention(%arg0: tensor, %arg1: tensor : tensor<1xi64> %2 = onnx.Constant dense<12> : tensor<1xi64> %3 = onnx.Constant dense<1.280000e+02> : tensor<768x768xf32> - %4 = "onnx.MatMul"(%arg1, %3) {onnx_node_name = "/encoder/layer.0/attention/self/query/MatMul"} : (tensor, tensor<768x768xf32>) -> tensor - %5 = "onnx.Add"(%4, %0) {onnx_node_name = "/encoder/layer.0/attention/self/query/Add-Initializer_encoder.layer.0.attention.self.query.bias_193"} : (tensor, tensor<768xf32>) -> tensor - %6 = "onnx.Dim"(%5) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_4_51"} : (tensor) -> tensor<1xi64> - %7 = "onnx.Dim"(%5) {axis = 1 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_5_205"} : (tensor) -> tensor<1xi64> - %8 = "onnx.Concat"(%6, %7, %2, %1) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Concat_2"} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> - %9 = "onnx.Reshape"(%5, %8) {allowzero = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Reshape_2"} : (tensor, tensor<4xi64>) -> tensor - %10 = "onnx.Transpose"(%9) {onnx_node_name = "/encoder/layer.0/attention/self/Transpose_1", perm = [0, 2, 1, 3]} : (tensor) -> tensor - %11 = "onnx.MatMul"(%arg0, %3) {onnx_node_name = "/encoder/layer.0/attention/self/key/MatMul"} : (tensor, tensor<768x768xf32>) -> tensor - %12 = "onnx.Add"(%11, %0) {onnx_node_name = "/encoder/layer.0/attention/self/key/Add-Initializer_encoder.layer.0.attention.self.key.bias_127"} : (tensor, tensor<768xf32>) -> tensor - %13 = "onnx.Dim"(%12) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_124"} : (tensor) -> tensor<1xi64> - %14 = "onnx.Dim"(%12) {axis = 1 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_1_133"} : (tensor) -> tensor<1xi64> - %15 = "onnx.Concat"(%13, %14, %2, %1) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Concat"} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> - %16 = "onnx.Reshape"(%12, %15) {allowzero = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Reshape"} : (tensor, tensor<4xi64>) -> tensor - %17 = "onnx.Transpose"(%16) {onnx_node_name = "/encoder/layer.0/attention/self/Transpose_2", perm = [0, 2, 3, 1]} : (tensor) -> tensor - %18 = "onnx.MatMul"(%10, %17) {onnx_node_name = "/encoder/layer.0/attention/self/MatMul"} : (tensor, tensor) -> tensor + %4 = "onnx.MatMul"(%arg1, %3) : (tensor, tensor<768x768xf32>) -> tensor + %5 = "onnx.Add"(%4, %0) : (tensor, tensor<768xf32>) -> tensor + %6 = "onnx.Dim"(%5) {axis = 0 : si64} : (tensor) -> tensor<1xi64> + %7 = "onnx.Dim"(%5) {axis = 1 : si64} : (tensor) -> tensor<1xi64> + %8 = "onnx.Concat"(%6, %7, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> + %9 = "onnx.Reshape"(%5, %8) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor + %10 = "onnx.Transpose"(%9) {perm = [0, 2, 1, 3]} : (tensor) -> tensor + %11 = "onnx.MatMul"(%arg0, %3) : (tensor, tensor<768x768xf32>) -> tensor + %12 = "onnx.Add"(%11, %0) : (tensor, tensor<768xf32>) -> tensor + %13 = "onnx.Dim"(%12) {axis = 0 : si64} : (tensor) -> tensor<1xi64> + %14 = "onnx.Dim"(%12) {axis = 1 : si64} : (tensor) -> tensor<1xi64> + %15 = "onnx.Concat"(%13, %14, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> + %16 = "onnx.Reshape"(%12, %15) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor + %17 = "onnx.Transpose"(%16) {perm = [0, 2, 3, 1]} : (tensor) -> tensor + %18 = "onnx.MatMul"(%10, %17) : (tensor, tensor) -> tensor %19 = onnx.Constant dense<8.000000e+00> : tensor - %20 = "onnx.Div"(%18, %19) {onnx_node_name = "/encoder/layer.0/attention/self/Div"} : (tensor, tensor) -> tensor + %20 = "onnx.Div"(%18, %19) : (tensor, tensor) -> tensor onnx.Return %20 : tensor // CHECK-LABEL: func.func @test_div_in_attention @@ -1656,22 +1656,22 @@ func.func @test_div_in_attention(%arg0: tensor, %arg1: tensor : tensor<768x768xf32> // CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor // CHECK: [[VAR_5_:%.+]] = "onnx.Div"([[VAR_3_]], [[VAR_4_]]) : (tensor<768x768xf32>, tensor) -> tensor<768x768xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.MatMul"([[PARAM_1_]], [[VAR_5_]]) {onnx_node_name = "/encoder/layer.0/attention/self/query/MatMul"} : (tensor, tensor<768x768xf32>) -> tensor +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.MatMul"([[PARAM_1_]], [[VAR_5_]]) : (tensor, tensor<768x768xf32>) -> tensor // CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Div"([[VAR_0_]], [[VAR_4_]]) : (tensor<768xf32>, tensor) -> tensor<768xf32> -// CHECK: [[VAR_8_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_7_]]) {onnx_node_name = "/encoder/layer.0/attention/self/query/Add-Initializer_encoder.layer.0.attention.self.query.bias_193"} : (tensor, tensor<768xf32>) -> tensor -// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_4_51"} : (tensor) -> tensor<1xi64> -// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 1 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_5_205"} : (tensor) -> tensor<1xi64> -// CHECK: [[VAR_11_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_10_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Concat_2"} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> -// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[VAR_8_]], [[VAR_11_]]) {allowzero = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Reshape_2"} : (tensor, tensor<4xi64>) -> tensor -// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Transpose"([[VAR_12_]]) {onnx_node_name = "/encoder/layer.0/attention/self/Transpose_1", perm = [0, 2, 1, 3]} : (tensor) -> tensor -// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_3_]]) {onnx_node_name = "/encoder/layer.0/attention/self/key/MatMul"} : (tensor, tensor<768x768xf32>) -> tensor -// CHECK: [[VAR_15_:%.+]] = "onnx.Add"([[VAR_14_]], [[VAR_0_]]) {onnx_node_name = "/encoder/layer.0/attention/self/key/Add-Initializer_encoder.layer.0.attention.self.key.bias_127"} : (tensor, tensor<768xf32>) -> tensor -// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_124"} : (tensor) -> tensor<1xi64> -// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 1 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Shape_1_133"} : (tensor) -> tensor<1xi64> -// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Concat"} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> -// CHECK: [[VAR_19_:%.+]] = "onnx.Reshape"([[VAR_15_]], [[VAR_18_]]) {allowzero = 0 : si64, onnx_node_name = "/encoder/layer.0/attention/self/Reshape"} : (tensor, tensor<4xi64>) -> tensor -// CHECK: [[VAR_20_:%.+]] = "onnx.Transpose"([[VAR_19_]]) {onnx_node_name = "/encoder/layer.0/attention/self/Transpose_2", perm = [0, 2, 3, 1]} : (tensor) -> tensor -// CHECK: [[VAR_21_:%.+]] = "onnx.MatMul"([[VAR_13_]], [[VAR_20_]]) {onnx_node_name = "/encoder/layer.0/attention/self/MatMul"} : (tensor, tensor) -> tensor +// CHECK: [[VAR_8_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_7_]]) : (tensor, tensor<768xf32>) -> tensor +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> +// CHECK: [[VAR_11_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_10_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> +// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[VAR_8_]], [[VAR_11_]]) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Transpose"([[VAR_12_]]) {perm = [0, 2, 1, 3]} : (tensor) -> tensor +// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_3_]]) : (tensor, tensor<768x768xf32>) -> tensor +// CHECK: [[VAR_15_:%.+]] = "onnx.Add"([[VAR_14_]], [[VAR_0_]]) : (tensor, tensor<768xf32>) -> tensor +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> +// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> +// CHECK: [[VAR_19_:%.+]] = "onnx.Reshape"([[VAR_15_]], [[VAR_18_]]) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor +// CHECK: [[VAR_20_:%.+]] = "onnx.Transpose"([[VAR_19_]]) {perm = [0, 2, 3, 1]} : (tensor) -> tensor +// CHECK: [[VAR_21_:%.+]] = "onnx.MatMul"([[VAR_13_]], [[VAR_20_]]) : (tensor, tensor) -> tensor // CHECK: onnx.Return [[VAR_21_]] : tensor // CHECK: } } From f98422a5550c3404bf97db81c6ad73a8b43e1068 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Mon, 27 Nov 2023 12:57:28 +0900 Subject: [PATCH 3/4] Support Mul Signed-off-by: Tung D. Le --- src/Dialect/ONNX/ONNXOps/OpHelper.cpp | 11 +++ src/Dialect/ONNX/ONNXOps/OpHelper.hpp | 3 + src/Dialect/ONNX/Rewrite.cpp | 113 +++++++++++++--------- test/mlir/onnx/onnx_canonicalization.mlir | 55 +++++++++++ 4 files changed, 135 insertions(+), 47 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index b251b121a0..8f4292b39c 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -420,6 +420,17 @@ bool hasShapeAndRank(Operation *op) { return true; } +/// Test if a value has only one use except ONNXDimOp. +bool hasOneUseExceptDimOp(Value val) { + int64_t numOfUsersExceptDim = 0; + for (auto user : val.getUsers()) { + if (isa(user)) + continue; + numOfUsersExceptDim++; + } + return (numOfUsersExceptDim == 1); +} + //===----------------------------------------------------------------------===// // Support for rewrite patterns. //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index 557b9c7419..84722c2dbb 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -209,6 +209,9 @@ bool isScalarConstantTensor(mlir::Value v); bool hasShapeAndRank(mlir::Value val); bool hasShapeAndRank(mlir::Operation *op); +/// Test if a value has only one use except ONNXDimOp. +bool hasOneUseExceptDimOp(mlir::Value val); + //===----------------------------------------------------------------------===// // Support for Rewrite. //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/Rewrite.cpp b/src/Dialect/ONNX/Rewrite.cpp index 1600abb62c..653fdf3489 100644 --- a/src/Dialect/ONNX/Rewrite.cpp +++ b/src/Dialect/ONNX/Rewrite.cpp @@ -218,34 +218,49 @@ bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB, Operation *&matmulOp, Operation *&addOp) { if (v.isa()) return false; + if (!hasOneUseExceptDimOp(v)) + return false; Value origV = v; - // Match shape operations. - while (isa( - origV.getDefiningOp())) { - origV = origV.getDefiningOp()->getOperands()[0]; - if (origV.isa()) - return false; + // Match a sequence of shape operations. Each shape operation has only one + // use. + while (auto defOp = origV.getDefiningOp()) { + if (!isa( + defOp)) + break; + origV = defOp->getOperands()[0]; + if (!hasOneUseExceptDimOp(origV)) + break; } + if (origV.isa() || !hasOneUseExceptDimOp(origV)) + return false; // Match Add. - addOp = origV.getDefiningOp(); - if (!addOp || !isa(addOp)) + auto onnxAddOp = origV.getDefiningOp(); + if (!onnxAddOp) return false; + Value lhsAdd = onnxAddOp.getA(); + Value rhsAdd = onnxAddOp.getB(); - // LHS of Add is MatMul. - matmulOp = addOp->getOperands()[0].getDefiningOp(); - if (!matmulOp || !isa(matmulOp)) + // LHS of Add is the only one use of MatMul's result. + if (!hasOneUseExceptDimOp(lhsAdd)) + return false; + auto onnxMatMulOp = lhsAdd.getDefiningOp(); + if (!onnxMatMulOp) return false; - matA = matmulOp->getOperands()[1]; - if (!isDenseONNXConstant(matA)) + Value rhsMatMul = onnxMatMulOp.getB(); + if (!isDenseONNXConstant(rhsMatMul)) return false; // RHS of Add is a constant. - biasB = addOp->getOperands()[1]; - if (!isDenseONNXConstant(biasB)) + if (!isDenseONNXConstant(rhsAdd)) return false; // Passed all tests. + matmulOp = onnxMatMulOp.getOperation(); + addOp = onnxAddOp.getOperation(); + matA = rhsMatMul; + biasB = rhsAdd; + return true; } @@ -435,7 +450,7 @@ class PropagateReshapeThroughBinaryOpPattern }; }; -// This rewriting is to optimize the scalar Div in self-attention layers. +// This rewriting is to optimize the scalar Div/Mul in self-attention layers. // In particular, it rewrites the following pattern: // ``` // shape_transform(X1 * A1 + B1) * shape_transform(X2 * A2 + B2) / k @@ -462,31 +477,30 @@ class PropagateReshapeThroughBinaryOpPattern // shape of the input but not numerical values, for example: Reshape, // Transpose, etc. // -struct PropagateScalarDivInAttentionLayerPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PropagateScalarDivInAttentionLayerPattern(MLIRContext *context) - : OpRewritePattern(context) {} +// This pattern supports both division and multiplication by k. +template +struct PropagateConstantScalingInAttentionLayerPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( - ONNXDivOp omDivOp, PatternRewriter &rewriter) const final { - Operation *divOp = omDivOp.getOperation(); - Value divLHS = omDivOp.getA(); - Value divK = omDivOp.getB(); - - // Match (lhs * rhs) / divK. - // The first operand of Div is produced by MatMulOp. - auto onnxMatMulOp = divLHS.getDefiningOp(); + ONNXOp omOp, PatternRewriter &rewriter) const final { + Operation *genericOp = omOp.getOperation(); + Value lhsOMOp = omOp.getA(); + Value K = omOp.getB(); + + // Match (lhs * rhs) / K. + // The first operand of Div/Mul is produced by MatMulOp. + auto onnxMatMulOp = lhsOMOp.getDefiningOp(); if (!onnxMatMulOp) - return rewriter.notifyMatchFailure( - divOp, "The first operand of Div is not produced by MatMulOp"); + return rewriter.notifyMatchFailure(genericOp, + "The first operand of Div/Mul is not produced by MatMulOp"); Value lhs = onnxMatMulOp.getA(); Value rhs = onnxMatMulOp.getB(); - // The second operand of Div is a scalar constant. - if (!isScalarConstantTensor(divK)) + // The second operand of Div/Mul is a scalar constant. + if (!isScalarConstantTensor(K)) return rewriter.notifyMatchFailure( - divOp, "The second operand of Div is not a scalar constant"); + genericOp, "The second operand of Div/Mul is not a scalar constant"); // Match lhs = shape_transform(X1*A1 + B1) Value A1, B1; @@ -499,35 +513,37 @@ struct PropagateScalarDivInAttentionLayerPattern bool matchRHS = matchShapeAddMatMul(rhs, A2, B2, rhsSubMatOp, rhsAddOp); if (!matchLHS && !matchRHS) - return rewriter.notifyMatchFailure(divOp, - "There is no constant tensor to replace the first operand of Div"); + return rewriter.notifyMatchFailure(genericOp, + "There is no constant tensor to replace the first operand " + "of Div/Mul"); // Rewrite. // Only rewrite one side, so use LHS if both sides are matched. if (matchLHS && matchRHS) matchRHS = false; - ONNXMatMulOp onnxSubMatOp = + auto onnxSubMatOp = cast(matchLHS ? lhsSubMatOp : rhsSubMatOp); - ONNXAddOp onnxAddOp = cast(matchLHS ? lhsAddOp : rhsAddOp); + auto onnxAddOp = cast(matchLHS ? lhsAddOp : rhsAddOp); Value A = matchLHS ? A1 : A2; Value B = matchLHS ? B1 : B2; - // Move divK up before MatMul to make sure it is in the dominant region. - divK.getDefiningOp()->moveBefore(onnxSubMatOp); + // Move K up before MatMul to make sure it is in the dominant region. + K.getDefiningOp()->moveBefore(onnxSubMatOp); // Update in place MatMul and Add. rewriter.updateRootInPlace(onnxSubMatOp, [&] { - OnnxBuilder createONNX(rewriter, onnxSubMatOp.getLoc()); rewriter.setInsertionPoint(onnxSubMatOp); - onnxSubMatOp.getBMutable().assign(createONNX.div(A, divK)); + onnxSubMatOp.getBMutable().assign(rewriter.create( + onnxSubMatOp.getLoc(), onnxSubMatOp.getB().getType(), A, K)); }); rewriter.updateRootInPlace(onnxAddOp, [&] { OnnxBuilder createONNX(rewriter, onnxAddOp.getLoc()); rewriter.setInsertionPoint(onnxAddOp); - onnxAddOp.getBMutable().assign(createONNX.div(B, divK)); + onnxAddOp.getBMutable().assign(rewriter.create( + onnxAddOp.getLoc(), onnxAddOp.getB().getType(), B, K)); }); - // Bypass Div. - rewriter.replaceOp(divOp, onnxMatMulOp.getY()); + // Bypass Div/Mul. + rewriter.replaceOp(genericOp, onnxMatMulOp.getY()); return success(); } }; @@ -1516,7 +1532,8 @@ void ONNXDivOp::getCanonicalizationPatterns( result.insert>(context); result.insert>(context); result.insert>(context); - result.insert(context); + result.insert>( + context); } /// on the ONNXDropoutOp. @@ -1602,6 +1619,8 @@ void ONNXMulOp::getCanonicalizationPatterns( results.insert>(context); results.insert>(context); results.insert>(context); + results.insert>( + context); } /// on the ONNXOrOp. diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 09cc8d15c9..5643e1a8d9 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -1675,3 +1675,58 @@ func.func @test_div_in_attention(%arg0: tensor, %arg1: tensor // CHECK: } } + +// ----- + +// COM: Optimize the scalar multiplication in self-attention layer. +func.func @test_mul_in_attention(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = onnx.Constant dense<1.280000e+02> : tensor<768xf32> + %1 = onnx.Constant dense<64> : tensor<1xi64> + %2 = onnx.Constant dense<12> : tensor<1xi64> + %3 = onnx.Constant dense<1.280000e+02> : tensor<768x768xf32> + %4 = "onnx.MatMul"(%arg1, %3) : (tensor, tensor<768x768xf32>) -> tensor + %5 = "onnx.Add"(%4, %0) : (tensor, tensor<768xf32>) -> tensor + %6 = "onnx.Dim"(%5) {axis = 0 : si64} : (tensor) -> tensor<1xi64> + %7 = "onnx.Dim"(%5) {axis = 1 : si64} : (tensor) -> tensor<1xi64> + %8 = "onnx.Concat"(%6, %7, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> + %9 = "onnx.Reshape"(%5, %8) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor + %10 = "onnx.Transpose"(%9) {perm = [0, 2, 1, 3]} : (tensor) -> tensor + %11 = "onnx.MatMul"(%arg0, %3) : (tensor, tensor<768x768xf32>) -> tensor + %12 = "onnx.Add"(%11, %0) : (tensor, tensor<768xf32>) -> tensor + %13 = "onnx.Dim"(%12) {axis = 0 : si64} : (tensor) -> tensor<1xi64> + %14 = "onnx.Dim"(%12) {axis = 1 : si64} : (tensor) -> tensor<1xi64> + %15 = "onnx.Concat"(%13, %14, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> + %16 = "onnx.Reshape"(%12, %15) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor + %17 = "onnx.Transpose"(%16) {perm = [0, 2, 3, 1]} : (tensor) -> tensor + %18 = "onnx.MatMul"(%10, %17) : (tensor, tensor) -> tensor + %19 = onnx.Constant dense<0.125> : tensor + %20 = "onnx.Mul"(%18, %19) : (tensor, tensor) -> tensor + onnx.Return %20 : tensor + +// CHECK-LABEL: func.func @test_mul_in_attention +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.280000e+02> : tensor<768xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<64> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<12> : tensor<1xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1.280000e+02> : tensor<768x768xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor +// CHECK: [[VAR_5_:%.+]] = "onnx.Mul"([[VAR_3_]], [[VAR_4_]]) : (tensor<768x768xf32>, tensor) -> tensor<768x768xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.MatMul"([[PARAM_1_]], [[VAR_5_]]) : (tensor, tensor<768x768xf32>) -> tensor +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_4_]]) : (tensor<768xf32>, tensor) -> tensor<768xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_7_]]) : (tensor, tensor<768xf32>) -> tensor +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> +// CHECK: [[VAR_11_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_10_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> +// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[VAR_8_]], [[VAR_11_]]) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Transpose"([[VAR_12_]]) {perm = [0, 2, 1, 3]} : (tensor) -> tensor +// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_3_]]) : (tensor, tensor<768x768xf32>) -> tensor +// CHECK: [[VAR_15_:%.+]] = "onnx.Add"([[VAR_14_]], [[VAR_0_]]) : (tensor, tensor<768xf32>) -> tensor +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> +// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> +// CHECK: [[VAR_19_:%.+]] = "onnx.Reshape"([[VAR_15_]], [[VAR_18_]]) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor +// CHECK: [[VAR_20_:%.+]] = "onnx.Transpose"([[VAR_19_]]) {perm = [0, 2, 3, 1]} : (tensor) -> tensor +// CHECK: [[VAR_21_:%.+]] = "onnx.MatMul"([[VAR_13_]], [[VAR_20_]]) : (tensor, tensor) -> tensor +// CHECK: onnx.Return [[VAR_21_]] : tensor +// CHECK: } +} From 2bbf650e51b03181c383784e61fbf759de2bedc4 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Tue, 28 Nov 2023 13:47:19 +0900 Subject: [PATCH 4/4] Support GEMM Signed-off-by: Tung D. Le --- src/Dialect/ONNX/Rewrite.cpp | 94 ++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 35 deletions(-) diff --git a/src/Dialect/ONNX/Rewrite.cpp b/src/Dialect/ONNX/Rewrite.cpp index 653fdf3489..5a7a6c35d2 100644 --- a/src/Dialect/ONNX/Rewrite.cpp +++ b/src/Dialect/ONNX/Rewrite.cpp @@ -215,7 +215,7 @@ bool haveSameStaticShape(Value lhs, Value rhs) { // shape. // A and B are constants. bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB, - Operation *&matmulOp, Operation *&addOp) { + Operation *&matmulOrGemmOp, Operation *&addOp, bool &isGemm) { if (v.isa()) return false; if (!hasOneUseExceptDimOp(v)) @@ -234,7 +234,22 @@ bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB, if (origV.isa() || !hasOneUseExceptDimOp(origV)) return false; - // Match Add. + // Match Gemm + auto onnxGemmOp = origV.getDefiningOp(); + if (onnxGemmOp) { + if (!isDenseONNXConstant(onnxGemmOp.getB())) + return false; + if (!isNoneValue(onnxGemmOp.getC()) && + !isDenseONNXConstant(onnxGemmOp.getC())) + return false; + matmulOrGemmOp = onnxGemmOp.getOperation(); + matA = onnxGemmOp.getB(); + biasB = onnxGemmOp.getC(); + isGemm = true; + return true; + } + + // Not Gemm, match Add. auto onnxAddOp = origV.getDefiningOp(); if (!onnxAddOp) return false; @@ -256,10 +271,11 @@ bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB, return false; // Passed all tests. - matmulOp = onnxMatMulOp.getOperation(); + matmulOrGemmOp = onnxMatMulOp.getOperation(); addOp = onnxAddOp.getOperation(); matA = rhsMatMul; biasB = rhsAdd; + isGemm = false; return true; } @@ -503,44 +519,52 @@ struct PropagateConstantScalingInAttentionLayerPattern genericOp, "The second operand of Div/Mul is not a scalar constant"); // Match lhs = shape_transform(X1*A1 + B1) - Value A1, B1; - Operation *lhsSubMatOp, *lhsAddOp; - bool matchLHS = matchShapeAddMatMul(lhs, A1, B1, lhsSubMatOp, lhsAddOp); - - // Match rhs = shape_transform(X2*A2 + B2) - Value A2, B2; - Operation *rhsSubMatOp, *rhsAddOp; - bool matchRHS = matchShapeAddMatMul(rhs, A2, B2, rhsSubMatOp, rhsAddOp); + Value A, B; + Operation *matmulOrGemmOp, *addOp; + bool isGemm; + bool matched = + matchShapeAddMatMul(lhs, A, B, matmulOrGemmOp, addOp, isGemm); + + if (!matched) { + // Match rhs = shape_transform(X2*A2 + B2) + matched = matchShapeAddMatMul(rhs, A, B, matmulOrGemmOp, addOp, isGemm); + } - if (!matchLHS && !matchRHS) + if (!matched) return rewriter.notifyMatchFailure(genericOp, "There is no constant tensor to replace the first operand " "of Div/Mul"); // Rewrite. - // Only rewrite one side, so use LHS if both sides are matched. - if (matchLHS && matchRHS) - matchRHS = false; - auto onnxSubMatOp = - cast(matchLHS ? lhsSubMatOp : rhsSubMatOp); - auto onnxAddOp = cast(matchLHS ? lhsAddOp : rhsAddOp); - Value A = matchLHS ? A1 : A2; - Value B = matchLHS ? B1 : B2; - - // Move K up before MatMul to make sure it is in the dominant region. - K.getDefiningOp()->moveBefore(onnxSubMatOp); - // Update in place MatMul and Add. - rewriter.updateRootInPlace(onnxSubMatOp, [&] { - rewriter.setInsertionPoint(onnxSubMatOp); - onnxSubMatOp.getBMutable().assign(rewriter.create( - onnxSubMatOp.getLoc(), onnxSubMatOp.getB().getType(), A, K)); - }); - rewriter.updateRootInPlace(onnxAddOp, [&] { - OnnxBuilder createONNX(rewriter, onnxAddOp.getLoc()); - rewriter.setInsertionPoint(onnxAddOp); - onnxAddOp.getBMutable().assign(rewriter.create( - onnxAddOp.getLoc(), onnxAddOp.getB().getType(), B, K)); - }); + // Move K up before MatMul/Gemm to make sure it is in the dominant region. + K.getDefiningOp()->moveBefore(matmulOrGemmOp); + if (isGemm) { + auto onnxGemmOp = cast(matmulOrGemmOp); + // Update in place B and C of Gemm. + rewriter.updateRootInPlace(onnxGemmOp, [&] { + rewriter.setInsertionPoint(onnxGemmOp); + onnxGemmOp.getBMutable().assign(rewriter.create( + onnxGemmOp.getLoc(), onnxGemmOp.getB().getType(), A, K)); + if (!isNoneValue(onnxGemmOp.getC())) + onnxGemmOp.getCMutable().assign(rewriter.create( + onnxGemmOp.getLoc(), onnxGemmOp.getC().getType(), B, K)); + }); + } else { + auto onnxSubMatOp = cast(matmulOrGemmOp); + auto onnxAddOp = cast(addOp); + // Update in place MatMul and Add. + rewriter.updateRootInPlace(onnxSubMatOp, [&] { + rewriter.setInsertionPoint(onnxSubMatOp); + onnxSubMatOp.getBMutable().assign(rewriter.create( + onnxSubMatOp.getLoc(), onnxSubMatOp.getB().getType(), A, K)); + }); + rewriter.updateRootInPlace(onnxAddOp, [&] { + OnnxBuilder createONNX(rewriter, onnxAddOp.getLoc()); + rewriter.setInsertionPoint(onnxAddOp); + onnxAddOp.getBMutable().assign(rewriter.create( + onnxAddOp.getLoc(), onnxAddOp.getB().getType(), B, K)); + }); + } // Bypass Div/Mul. rewriter.replaceOp(genericOp, onnxMatMulOp.getY());