-
Notifications
You must be signed in to change notification settings - Fork 315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A rewrite pattern to optimize constant scaling in self-attention layer #2640
Conversation
Signed-off-by: Tung D. Le <[email protected]>
Signed-off-by: Tung D. Le <[email protected]>
Can this be abstracted to any pair wise operation? There could be some attention models that have a multiplication by an inverse sqrt. |
Good suggestion! Yes, it is quite straightforward to support multiplication also (I don’t think it is applicable to addition and subtraction). Will add that soon. Thanks! |
Signed-off-by: Tung D. Le <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
src/Dialect/ONNX/Rewrite.cpp
Outdated
Operation *lhsSubMatOp, *lhsAddOp; | ||
bool matchLHS = matchShapeAddMatMul(lhs, A1, B1, lhsSubMatOp, lhsAddOp); | ||
|
||
// Match rhs = shape_transform(X2*A2 + B2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, you could test match RHS only when matchLHS fails. Less testing, same results
bool matchRHS = !matchRHS && matchShapeAddMatMul(rhs, A2, B2, rhsSubMatOp, rhsAddOp);
Then you can get rid of the case where both matches, as this will never be the case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Updated the code to check matchRHS only when matchLHS failed. Thanks!
@@ -209,6 +209,61 @@ bool haveSameStaticShape(Value lhs, Value rhs) { | |||
return hasStaticShape(lhsT) && (getShape(lhsT) == getShape(rhsT)); | |||
} | |||
|
|||
// Match v = shape_transform(X*A + B). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we cover the case where instead of X*A+B
we have a Gemm
op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I will add Gemm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the case for Gemm.
@jenkins-droid test this please |
Signed-off-by: Tung D. Le <[email protected]>
Jenkins Linux s390x Build #13552 [push] A rewrite pattern to opt... started at 02:01 |
Jenkins Linux ppc64le Build #12549 [push] A rewrite pattern to opt... started at 02:07 |
Jenkins Linux amd64 Build #13525 [push] A rewrite pattern to opt... started at 01:01 |
Jenkins Linux s390x Build #13552 [push] A rewrite pattern to opt... passed after 1 hr 52 min |
Jenkins Linux ppc64le Build #12549 [push] A rewrite pattern to opt... passed after 2 hr 5 min |
Jenkins Linux amd64 Build #13525 [push] A rewrite pattern to opt... passed after 2 hr 15 min |
In the self-attention layer, the output of MatMul is scaled by a constant factor via a division/multiplication operation. This patch rewrites the division/multiplication operation so that the constant input of MatMul will be scaled instead of its output. Thus, the scaling of the constant inputs can be folded at compile time.
For example, this patch rewrites the following pattern:
into
if A2, B2 and k are constants,
or into
if A1, B1 and k are constants,
where
*
is matrix multiplication;+
and/
are element-wise addition and divisionbe folded. k is a scalar constant so that it's broadcastable to all A1, A2,
B1, B2.
shape of the input but not numerical values, for example: Reshape,
Transpose, etc.