Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A rewrite pattern to optimize constant scaling in self-attention layer #2640

Merged
merged 5 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,17 @@ bool hasShapeAndRank(Operation *op) {
return true;
}

/// Test if a value has only one use except ONNXDimOp.
bool hasOneUseExceptDimOp(Value val) {
int64_t numOfUsersExceptDim = 0;
for (auto user : val.getUsers()) {
if (isa<ONNXDimOp>(user))
continue;
numOfUsersExceptDim++;
}
return (numOfUsersExceptDim == 1);
}

//===----------------------------------------------------------------------===//
// Support for rewrite patterns.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ bool isScalarConstantTensor(mlir::Value v);
bool hasShapeAndRank(mlir::Value val);
bool hasShapeAndRank(mlir::Operation *op);

/// Test if a value has only one use except ONNXDimOp.
bool hasOneUseExceptDimOp(mlir::Value val);

//===----------------------------------------------------------------------===//
// Support for Rewrite.
//===----------------------------------------------------------------------===//
Expand Down
181 changes: 181 additions & 0 deletions src/Dialect/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,77 @@ bool haveSameStaticShape(Value lhs, Value rhs) {
return hasStaticShape(lhsT) && (getShape(lhsT) == getShape(rhsT));
}

// Match v = shape_transform(X*A + B).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we cover the case where instead of X*A+B we have a Gemm op?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I will add Gemm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the case for Gemm.

// shape_transform is a sequence of operations like Reshape, Transpose,
// Squeeze, Unsqueeze, etc. that do not change the numerical values by data
// shape.
// A and B are constants.
bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB,
Operation *&matmulOrGemmOp, Operation *&addOp, bool &isGemm) {
if (v.isa<BlockArgument>())
return false;
if (!hasOneUseExceptDimOp(v))
return false;
Value origV = v;
// Match a sequence of shape operations. Each shape operation has only one
// use.
while (auto defOp = origV.getDefiningOp()) {
if (!isa<ONNXReshapeOp, ONNXTransposeOp, ONNXSqueezeOp, ONNXUnsqueezeOp>(
defOp))
break;
origV = defOp->getOperands()[0];
if (!hasOneUseExceptDimOp(origV))
break;
}
if (origV.isa<BlockArgument>() || !hasOneUseExceptDimOp(origV))
return false;

// Match Gemm
auto onnxGemmOp = origV.getDefiningOp<ONNXGemmOp>();
if (onnxGemmOp) {
if (!isDenseONNXConstant(onnxGemmOp.getB()))
return false;
if (!isNoneValue(onnxGemmOp.getC()) &&
!isDenseONNXConstant(onnxGemmOp.getC()))
return false;
matmulOrGemmOp = onnxGemmOp.getOperation();
matA = onnxGemmOp.getB();
biasB = onnxGemmOp.getC();
isGemm = true;
return true;
}

// Not Gemm, match Add.
auto onnxAddOp = origV.getDefiningOp<ONNXAddOp>();
if (!onnxAddOp)
return false;
Value lhsAdd = onnxAddOp.getA();
Value rhsAdd = onnxAddOp.getB();

// LHS of Add is the only one use of MatMul's result.
if (!hasOneUseExceptDimOp(lhsAdd))
return false;
auto onnxMatMulOp = lhsAdd.getDefiningOp<ONNXMatMulOp>();
if (!onnxMatMulOp)
return false;
Value rhsMatMul = onnxMatMulOp.getB();
if (!isDenseONNXConstant(rhsMatMul))
return false;

// RHS of Add is a constant.
if (!isDenseONNXConstant(rhsAdd))
return false;

// Passed all tests.
matmulOrGemmOp = onnxMatMulOp.getOperation();
addOp = onnxAddOp.getOperation();
matA = rhsMatMul;
biasB = rhsAdd;
isGemm = false;

return true;
}

} // namespace onnx_mlir

// =============================================================================
Expand Down Expand Up @@ -395,6 +466,112 @@ class PropagateReshapeThroughBinaryOpPattern
};
};

// This rewriting is to optimize the scalar Div/Mul in self-attention layers.
// In particular, it rewrites the following pattern:
// ```
// shape_transform(X1 * A1 + B1) * shape_transform(X2 * A2 + B2) / k
// ```
//
// into
// ```
// shape_transform(X1 * A1 + B1) * shape_transform(X2 * A2/k + B2/k)
// ```
// if A2, B2 and k are constants,
//
// or into
// ```
// shape_transform(X1 * A1/k + B1/k) * shape_transform(X2 * A2 + B2)
// ```
// if A1, B1 and k are constants,
//
// where
// - * is matrix multiplication; + and / are element-wise addition and division
// - A1, A2, B1, B2, and k are constants so that A1/k, B1/k, A2/k and B2/k can
// be folded. k is a scalar constant so that it's broadcastable to all A1, A2,
// B1, B2.
// - shape_transform includes a sequence of operations that change the data
// shape of the input but not numerical values, for example: Reshape,
// Transpose, etc.
//
// This pattern supports both division and multiplication by k.
template <typename ONNXOp>
struct PropagateConstantScalingInAttentionLayerPattern
: public OpRewritePattern<ONNXOp> {
using OpRewritePattern<ONNXOp>::OpRewritePattern;

LogicalResult matchAndRewrite(
ONNXOp omOp, PatternRewriter &rewriter) const final {
Operation *genericOp = omOp.getOperation();
Value lhsOMOp = omOp.getA();
Value K = omOp.getB();

// Match (lhs * rhs) / K.
// The first operand of Div/Mul is produced by MatMulOp.
auto onnxMatMulOp = lhsOMOp.getDefiningOp<ONNXMatMulOp>();
if (!onnxMatMulOp)
return rewriter.notifyMatchFailure(genericOp,
"The first operand of Div/Mul is not produced by MatMulOp");
Value lhs = onnxMatMulOp.getA();
Value rhs = onnxMatMulOp.getB();
// The second operand of Div/Mul is a scalar constant.
if (!isScalarConstantTensor(K))
return rewriter.notifyMatchFailure(
genericOp, "The second operand of Div/Mul is not a scalar constant");

// Match lhs = shape_transform(X1*A1 + B1)
Value A, B;
Operation *matmulOrGemmOp, *addOp;
bool isGemm;
bool matched =
matchShapeAddMatMul(lhs, A, B, matmulOrGemmOp, addOp, isGemm);

if (!matched) {
// Match rhs = shape_transform(X2*A2 + B2)
matched = matchShapeAddMatMul(rhs, A, B, matmulOrGemmOp, addOp, isGemm);
}

if (!matched)
return rewriter.notifyMatchFailure(genericOp,
"There is no constant tensor to replace the first operand "
"of Div/Mul");

// Rewrite.
// Move K up before MatMul/Gemm to make sure it is in the dominant region.
K.getDefiningOp()->moveBefore(matmulOrGemmOp);
if (isGemm) {
auto onnxGemmOp = cast<ONNXGemmOp>(matmulOrGemmOp);
// Update in place B and C of Gemm.
rewriter.updateRootInPlace(onnxGemmOp, [&] {
rewriter.setInsertionPoint(onnxGemmOp);
onnxGemmOp.getBMutable().assign(rewriter.create<ONNXOp>(
onnxGemmOp.getLoc(), onnxGemmOp.getB().getType(), A, K));
if (!isNoneValue(onnxGemmOp.getC()))
onnxGemmOp.getCMutable().assign(rewriter.create<ONNXOp>(
onnxGemmOp.getLoc(), onnxGemmOp.getC().getType(), B, K));
});
} else {
auto onnxSubMatOp = cast<ONNXMatMulOp>(matmulOrGemmOp);
auto onnxAddOp = cast<ONNXAddOp>(addOp);
// Update in place MatMul and Add.
rewriter.updateRootInPlace(onnxSubMatOp, [&] {
rewriter.setInsertionPoint(onnxSubMatOp);
onnxSubMatOp.getBMutable().assign(rewriter.create<ONNXOp>(
onnxSubMatOp.getLoc(), onnxSubMatOp.getB().getType(), A, K));
});
rewriter.updateRootInPlace(onnxAddOp, [&] {
OnnxBuilder createONNX(rewriter, onnxAddOp.getLoc());
rewriter.setInsertionPoint(onnxAddOp);
onnxAddOp.getBMutable().assign(rewriter.create<ONNXOp>(
onnxAddOp.getLoc(), onnxAddOp.getB().getType(), B, K));
});
}

// Bypass Div/Mul.
rewriter.replaceOp(genericOp, onnxMatMulOp.getY());
return success();
}
};

// =============================================================================
// Rewrite pattern for Resize (not handled in Rewrite.td).
// =============================================================================
Expand Down Expand Up @@ -1379,6 +1556,8 @@ void ONNXDivOp::getCanonicalizationPatterns(
result.insert<BinaryOpBroadcastAxisPattern<ONNXDivOp>>(context);
result.insert<PropagateScalarConstantExpandPattern<ONNXDivOp>>(context);
result.insert<PropagateReshapeThroughBinaryOpPattern<ONNXDivOp>>(context);
result.insert<PropagateConstantScalingInAttentionLayerPattern<ONNXDivOp>>(
context);
}

/// on the ONNXDropoutOp.
Expand Down Expand Up @@ -1464,6 +1643,8 @@ void ONNXMulOp::getCanonicalizationPatterns(
results.insert<BinaryOpBroadcastAxisPattern<ONNXMulOp>>(context);
results.insert<PropagateScalarConstantExpandPattern<ONNXMulOp>>(context);
results.insert<PropagateReshapeThroughBinaryOpPattern<ONNXMulOp>>(context);
results.insert<PropagateConstantScalingInAttentionLayerPattern<ONNXMulOp>>(
context);
}

/// on the ONNXOrOp.
Expand Down
110 changes: 110 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1620,3 +1620,113 @@ func.func @test_not_replace_sub_by_expand_two_expands(%arg0: tensor<?xi64>) -> t
// CHECK: return [[VAR_7_]] : tensor<2x?xf32>
// CHECK: }
}

// -----

// COM: Optimize the scalar div in self-attention layer.
func.func @test_div_in_attention(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<?x12x?x?xf32> {
%0 = onnx.Constant dense<1.280000e+02> : tensor<768xf32>
%1 = onnx.Constant dense<64> : tensor<1xi64>
%2 = onnx.Constant dense<12> : tensor<1xi64>
%3 = onnx.Constant dense<1.280000e+02> : tensor<768x768xf32>
%4 = "onnx.MatMul"(%arg1, %3) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
%5 = "onnx.Add"(%4, %0) : (tensor<?x?x768xf32>, tensor<768xf32>) -> tensor<?x?x768xf32>
%6 = "onnx.Dim"(%5) {axis = 0 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
%7 = "onnx.Dim"(%5) {axis = 1 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
%8 = "onnx.Concat"(%6, %7, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
%9 = "onnx.Reshape"(%5, %8) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
%10 = "onnx.Transpose"(%9) {perm = [0, 2, 1, 3]} : (tensor<?x?x12x64xf32>) -> tensor<?x12x?x64xf32>
%11 = "onnx.MatMul"(%arg0, %3) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
%12 = "onnx.Add"(%11, %0) : (tensor<?x?x768xf32>, tensor<768xf32>) -> tensor<?x?x768xf32>
%13 = "onnx.Dim"(%12) {axis = 0 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
%14 = "onnx.Dim"(%12) {axis = 1 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
%15 = "onnx.Concat"(%13, %14, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
%16 = "onnx.Reshape"(%12, %15) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
%17 = "onnx.Transpose"(%16) {perm = [0, 2, 3, 1]} : (tensor<?x?x12x64xf32>) -> tensor<?x12x64x?xf32>
%18 = "onnx.MatMul"(%10, %17) : (tensor<?x12x?x64xf32>, tensor<?x12x64x?xf32>) -> tensor<?x12x?x?xf32>
%19 = onnx.Constant dense<8.000000e+00> : tensor<f32>
%20 = "onnx.Div"(%18, %19) : (tensor<?x12x?x?xf32>, tensor<f32>) -> tensor<?x12x?x?xf32>
onnx.Return %20 : tensor<?x12x?x?xf32>

// CHECK-LABEL: func.func @test_div_in_attention
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<?x?x768xf32>) -> tensor<?x12x?x?xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.280000e+02> : tensor<768xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<12> : tensor<1xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1.280000e+02> : tensor<768x768xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor<f32>
// CHECK: [[VAR_5_:%.+]] = "onnx.Div"([[VAR_3_]], [[VAR_4_]]) : (tensor<768x768xf32>, tensor<f32>) -> tensor<768x768xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.MatMul"([[PARAM_1_]], [[VAR_5_]]) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Div"([[VAR_0_]], [[VAR_4_]]) : (tensor<768xf32>, tensor<f32>) -> tensor<768xf32>
// CHECK: [[VAR_8_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_7_]]) : (tensor<?x?x768xf32>, tensor<768xf32>) -> tensor<?x?x768xf32>
// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 0 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 1 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
// CHECK: [[VAR_11_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_10_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[VAR_8_]], [[VAR_11_]]) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Transpose"([[VAR_12_]]) {perm = [0, 2, 1, 3]} : (tensor<?x?x12x64xf32>) -> tensor<?x12x?x64xf32>
// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_3_]]) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
// CHECK: [[VAR_15_:%.+]] = "onnx.Add"([[VAR_14_]], [[VAR_0_]]) : (tensor<?x?x768xf32>, tensor<768xf32>) -> tensor<?x?x768xf32>
// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 0 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 1 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
// CHECK: [[VAR_19_:%.+]] = "onnx.Reshape"([[VAR_15_]], [[VAR_18_]]) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
// CHECK: [[VAR_20_:%.+]] = "onnx.Transpose"([[VAR_19_]]) {perm = [0, 2, 3, 1]} : (tensor<?x?x12x64xf32>) -> tensor<?x12x64x?xf32>
// CHECK: [[VAR_21_:%.+]] = "onnx.MatMul"([[VAR_13_]], [[VAR_20_]]) : (tensor<?x12x?x64xf32>, tensor<?x12x64x?xf32>) -> tensor<?x12x?x?xf32>
// CHECK: onnx.Return [[VAR_21_]] : tensor<?x12x?x?xf32>
// CHECK: }
}

// -----

// COM: Optimize the scalar multiplication in self-attention layer.
func.func @test_mul_in_attention(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<?x12x?x?xf32> {
%0 = onnx.Constant dense<1.280000e+02> : tensor<768xf32>
%1 = onnx.Constant dense<64> : tensor<1xi64>
%2 = onnx.Constant dense<12> : tensor<1xi64>
%3 = onnx.Constant dense<1.280000e+02> : tensor<768x768xf32>
%4 = "onnx.MatMul"(%arg1, %3) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
%5 = "onnx.Add"(%4, %0) : (tensor<?x?x768xf32>, tensor<768xf32>) -> tensor<?x?x768xf32>
%6 = "onnx.Dim"(%5) {axis = 0 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
%7 = "onnx.Dim"(%5) {axis = 1 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
%8 = "onnx.Concat"(%6, %7, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
%9 = "onnx.Reshape"(%5, %8) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
%10 = "onnx.Transpose"(%9) {perm = [0, 2, 1, 3]} : (tensor<?x?x12x64xf32>) -> tensor<?x12x?x64xf32>
%11 = "onnx.MatMul"(%arg0, %3) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
%12 = "onnx.Add"(%11, %0) : (tensor<?x?x768xf32>, tensor<768xf32>) -> tensor<?x?x768xf32>
%13 = "onnx.Dim"(%12) {axis = 0 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
%14 = "onnx.Dim"(%12) {axis = 1 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
%15 = "onnx.Concat"(%13, %14, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
%16 = "onnx.Reshape"(%12, %15) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
%17 = "onnx.Transpose"(%16) {perm = [0, 2, 3, 1]} : (tensor<?x?x12x64xf32>) -> tensor<?x12x64x?xf32>
%18 = "onnx.MatMul"(%10, %17) : (tensor<?x12x?x64xf32>, tensor<?x12x64x?xf32>) -> tensor<?x12x?x?xf32>
%19 = onnx.Constant dense<0.125> : tensor<f32>
%20 = "onnx.Mul"(%18, %19) : (tensor<?x12x?x?xf32>, tensor<f32>) -> tensor<?x12x?x?xf32>
onnx.Return %20 : tensor<?x12x?x?xf32>

// CHECK-LABEL: func.func @test_mul_in_attention
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<?x?x768xf32>) -> tensor<?x12x?x?xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.280000e+02> : tensor<768xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<12> : tensor<1xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1.280000e+02> : tensor<768x768xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor<f32>
// CHECK: [[VAR_5_:%.+]] = "onnx.Mul"([[VAR_3_]], [[VAR_4_]]) : (tensor<768x768xf32>, tensor<f32>) -> tensor<768x768xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.MatMul"([[PARAM_1_]], [[VAR_5_]]) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_4_]]) : (tensor<768xf32>, tensor<f32>) -> tensor<768xf32>
// CHECK: [[VAR_8_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_7_]]) : (tensor<?x?x768xf32>, tensor<768xf32>) -> tensor<?x?x768xf32>
// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 0 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Dim"([[VAR_8_]]) {axis = 1 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
// CHECK: [[VAR_11_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_10_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[VAR_8_]], [[VAR_11_]]) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Transpose"([[VAR_12_]]) {perm = [0, 2, 1, 3]} : (tensor<?x?x12x64xf32>) -> tensor<?x12x?x64xf32>
// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_3_]]) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
// CHECK: [[VAR_15_:%.+]] = "onnx.Add"([[VAR_14_]], [[VAR_0_]]) : (tensor<?x?x768xf32>, tensor<768xf32>) -> tensor<?x?x768xf32>
// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 0 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Dim"([[VAR_15_]]) {axis = 1 : si64} : (tensor<?x?x768xf32>) -> tensor<1xi64>
// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
// CHECK: [[VAR_19_:%.+]] = "onnx.Reshape"([[VAR_15_]], [[VAR_18_]]) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
// CHECK: [[VAR_20_:%.+]] = "onnx.Transpose"([[VAR_19_]]) {perm = [0, 2, 3, 1]} : (tensor<?x?x12x64xf32>) -> tensor<?x12x64x?xf32>
// CHECK: [[VAR_21_:%.+]] = "onnx.MatMul"([[VAR_13_]], [[VAR_20_]]) : (tensor<?x12x?x64xf32>, tensor<?x12x64x?xf32>) -> tensor<?x12x?x?xf32>
// CHECK: onnx.Return [[VAR_21_]] : tensor<?x12x?x?xf32>
// CHECK: }
}
Loading