Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A rewrite pattern to optimize constant scaling in self-attention layer #2640

Merged
merged 5 commits into from
Nov 28, 2023

Conversation

tungld
Copy link
Collaborator

@tungld tungld commented Nov 24, 2023

In the self-attention layer, the output of MatMul is scaled by a constant factor via a division/multiplication operation. This patch rewrites the division/multiplication operation so that the constant input of MatMul will be scaled instead of its output. Thus, the scaling of the constant inputs can be folded at compile time.

For example, this patch rewrites the following pattern:

shape_transform(X1 * A1 + B1) * shape_transform(X2 * A2 + B2) / k

into

shape_transform(X1 * A1 + B1) * shape_transform(X2 * A2/k + B2/k)

if A2, B2 and k are constants,

or into

shape_transform(X1 * A1/k + B1/k) * shape_transform(X2 * A2 + B2)

if A1, B1 and k are constants,

where

  • * is matrix multiplication; + and / are element-wise addition and division
  • A1, A2, B1, B2, and k are constants so that A1/k, B1/k, A2/k and B2/k can
    be folded. k is a scalar constant so that it's broadcastable to all A1, A2,
    B1, B2.
  • shape_transform includes a sequence of operations that change the data
    shape of the input but not numerical values, for example: Reshape,
    Transpose, etc.

@tehbone
Copy link
Contributor

tehbone commented Nov 24, 2023

Can this be abstracted to any pair wise operation? There could be some attention models that have a multiplication by an inverse sqrt.

@tungld
Copy link
Collaborator Author

tungld commented Nov 24, 2023

Can this be abstracted to any pair wise operation? There could be some attention models that have a multiplication by an inverse sqrt.

Good suggestion! Yes, it is quite straightforward to support multiplication also (I don’t think it is applicable to addition and subtraction). Will add that soon. Thanks!

Signed-off-by: Tung D. Le <[email protected]>
@tungld tungld changed the title A rewrite pattern to optimize a scalar Div in self-attention layer A rewrite pattern to optimize constant scaling in self-attention layer Nov 27, 2023
Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Operation *lhsSubMatOp, *lhsAddOp;
bool matchLHS = matchShapeAddMatMul(lhs, A1, B1, lhsSubMatOp, lhsAddOp);

// Match rhs = shape_transform(X2*A2 + B2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, you could test match RHS only when matchLHS fails. Less testing, same results

bool matchRHS = !matchRHS && matchShapeAddMatMul(rhs, A2, B2, rhsSubMatOp, rhsAddOp);

Then you can get rid of the case where both matches, as this will never be the case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Updated the code to check matchRHS only when matchLHS failed. Thanks!

@@ -209,6 +209,61 @@ bool haveSameStaticShape(Value lhs, Value rhs) {
return hasStaticShape(lhsT) && (getShape(lhsT) == getShape(rhsT));
}

// Match v = shape_transform(X*A + B).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we cover the case where instead of X*A+B we have a Gemm op?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I will add Gemm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the case for Gemm.

@gongsu832
Copy link
Collaborator

@jenkins-droid test this please

@tungld tungld merged commit 135cac8 into onnx:main Nov 28, 2023
8 checks passed
@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #13552 [push] A rewrite pattern to opt... started at 02:01

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #12549 [push] A rewrite pattern to opt... started at 02:07

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #13525 [push] A rewrite pattern to opt... started at 01:01

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #13552 [push] A rewrite pattern to opt... passed after 1 hr 52 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #12549 [push] A rewrite pattern to opt... passed after 2 hr 5 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #13525 [push] A rewrite pattern to opt... passed after 2 hr 15 min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants