Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alternative strategy for batched matrix multiplication #51

Merged
merged 3 commits into from
Feb 6, 2024

Commits on Feb 6, 2024

  1. Configuration menu
    Copy the full SHA
    13772c3 View commit details
    Browse the repository at this point in the history
  2. Add alternative strategy for batched matrix multiplication

    Batched matrix multiplication was handled by prepacking one or neither of
    the inputs, depending on how often each is re-used, and then performing
    one `gemm` call per matrix in the output shape.
    
    This can be inefficient the LHS input has a small number of rows. For example in
    [1], the LHS / "A" input is a row vector. In the case where the "A" input is a
    batch and the "B" input is a single matrix, the "A" input can be reshaped so a
    single gemm call can be used, with the output reshaped afterwards to restore the
    batch dimensions.
    
    Implement this alternate approach and add a simple benchmark for batched matmul.
    
    [1] #50
    robertknight committed Feb 6, 2024
    Configuration menu
    Copy the full SHA
    5e4e5fd View commit details
    Browse the repository at this point in the history
  3. Improve MatMul tests

    Refactor MatMul tests into a single table-driven test that has cases for when
    neither, one or both of the inputs is a batch. Also add tests for various
    invalid inputs.
    robertknight committed Feb 6, 2024
    Configuration menu
    Copy the full SHA
    e33f6a8 View commit details
    Browse the repository at this point in the history