Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alternative strategy for batched matrix multiplication #51

Merged
merged 3 commits into from
Feb 6, 2024

Conversation

robertknight
Copy link
Owner

@robertknight robertknight commented Feb 4, 2024

Previously batched matrix multiplication was handled by prepacking one or neither of the inputs, depending on how often each is re-used, and then performing one gemm call per matrix in the output shape. This is inefficient if the A input has only a small number of rows (as in #50). This PR implements a new strategy for the MatMul operator when A is a batch and B is a single matrix, by reshaping the inputs so that instead of many low-arithmetic intensity gemm calls, a single higher-arithmetic intensity call is performed. The output is then reshaped to restore the batch dimensions.

Testing with the benchmark added here and slight variations, the new method is a big improvement when M <= 8, a modest win for M ~ 8-24 and is roughly even, or a very slight win after that. The AVX kernel has MR=6, so this seems as-expected.

See #50

@robertknight robertknight marked this pull request as ready for review February 6, 2024 20:22
Batched matrix multiplication was handled by prepacking one or neither of
the inputs, depending on how often each is re-used, and then performing
one `gemm` call per matrix in the output shape.

This can be inefficient the LHS input has a small number of rows. For example in
[1], the LHS / "A" input is a row vector. In the case where the "A" input is a
batch and the "B" input is a single matrix, the "A" input can be reshaped so a
single gemm call can be used, with the output reshaped afterwards to restore the
batch dimensions.

Implement this alternate approach and add a simple benchmark for batched matmul.

[1] #50
Refactor MatMul tests into a single table-driven test that has cases for when
neither, one or both of the inputs is a batch. Also add tests for various
invalid inputs.
@robertknight robertknight merged commit 9047942 into main Feb 6, 2024
2 checks passed
@robertknight robertknight deleted the optimize-batched-matmul branch February 6, 2024 20:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant