Skip to content

Commit

Permalink
Add alternative strategy for batched matrix multiplication
Browse files Browse the repository at this point in the history
Batched matrix multiplication was handled by prepacking one or neither of
the inputs, depending on how often each is re-used, and then performing
one `gemm` call per matrix in the output shape.

This can be inefficient if one of the matrices passed to a gemm call ends up
being small in one or both dimensions. For example in [1], the LHS / "A" input
is a vector. In the case where the "A" input is a batch and the "B" input is a
single matrix, the "A" input can be reshaped so a single gemm call can be used,
with the output reshaped afterwards to restore the batch dimensions.

In addition to the strategy, add a simple benchmark for different input shapes.

[1] #50
  • Loading branch information
robertknight committed Feb 6, 2024
1 parent 13772c3 commit 74f13d9
Showing 1 changed file with 93 additions and 6 deletions.
99 changes: 93 additions & 6 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,24 @@ impl Operator for Gemm {
}
}

/// Hints for how a batched MatMul should be performed. This exists to enable
/// comparisons in tests and benchmarks.
#[derive(Copy, Clone, Debug, PartialEq)]
enum MatmulStrategy {
/// Use the best strategy for the input shapes.
Auto,

/// Perform separate GEMM calls for each pair of matrices to multiply in
/// the batch.
#[cfg(test)]
Batch,
}

pub fn matmul(a: TensorView, b: TensorView) -> Result<Tensor, OpError> {
matmul_impl(a, b, MatmulStrategy::Auto)
}

fn matmul_impl(a: TensorView, b: TensorView, strategy: MatmulStrategy) -> Result<Tensor, OpError> {
if a.ndim() < 2 || b.ndim() < 2 {
return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions"));
}
Expand All @@ -106,12 +123,31 @@ pub fn matmul(a: TensorView, b: TensorView) -> Result<Tensor, OpError> {

let a_prefix = &a.shape()[..a.ndim() - 2];
let b_prefix = &b.shape()[..b.ndim() - 2];

let num_a_matrices: usize = a_prefix.iter().product();
let num_b_matrices: usize = b_prefix.iter().product();

let out_prefix = broadcast_shapes(a_prefix, b_prefix)
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast shapes"))?;

let out_shape = &[out_prefix.as_slice(), &[a_rows, b_cols]].concat();
let mut output = Tensor::zeros(out_shape);

// A batched matrix multiplication with `[A, M, K] x [K, N]`, where `A` and
// can consist of multiple dimensions, can be converted to a non-batched
// matmul by reshaping the inputs as `[A * M, K]` * `[K, N]`, and then
// reshaping the `[A * M, N]` output to `[A, M, N]`.
//
// The upside is that one larger matmul is likely to be more efficient than
// `A` smaller matmuls. This is especially true if `M` is small (eg. 1).
if strategy == MatmulStrategy::Auto && a.ndim() > 2 && b.ndim() == 2 {
// nb. We assume `a` is likely already contiguous, so this will be cheap.
let a_contig = a.to_contiguous();
let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice());
let mut output = matmul(a_matrix, b.clone())?;
output.reshape(out_shape);
return Ok(output);
}

let mut output = Tensor::zeros(out_shape);
if output.is_empty() {
return Ok(output);
}
Expand All @@ -128,9 +164,6 @@ pub fn matmul(a: TensorView, b: TensorView) -> Result<Tensor, OpError> {
.unwrap()
.chunks_mut(out_row_stride * a_rows);

let num_a_matrices: usize = a_prefix.iter().product();
let num_b_matrices: usize = b_prefix.iter().product();

let gemm = GemmExecutor::new();

// Prepack re-used inputs to amortize packing cost.
Expand Down Expand Up @@ -199,7 +232,9 @@ mod tests {
use rten_tensor::Tensor;

use crate::gemm::gemm;
use crate::ops::matmul::{gemm_op, matmul, OpError};
use crate::test_util::run_bench;

use super::{gemm_op, matmul, matmul_impl, MatmulStrategy, OpError};

fn gemm_tensors(c: &mut Tensor, a: &Tensor, b: &Tensor, alpha: f32, beta: f32) {
c.make_contiguous();
Expand Down Expand Up @@ -365,4 +400,56 @@ mod tests {

Ok(())
}

#[test]
#[ignore]
fn bench_matmul() {
struct Case {
a_batch: usize,
a_rows: usize,
a_cols: usize,
b_cols: usize,
}

let mut cases = Vec::new();
let a_cols = 512;
let b_cols = 1536;

for a_batch in [1, 10, 128, 256, 512, 1024] {
for a_rows in [1, 16, 32, 64] {
cases.push(Case {
a_batch,
a_rows,
a_cols,
b_cols,
});
}
}

for Case {
a_batch,
a_rows,
a_cols,
b_cols,
} in cases
{
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[a_batch, a_rows, a_cols], &mut rng);
let b = Tensor::rand(&[a_cols, b_cols], &mut rng);

let run_trial = |strategy| {
let trials = 10;
let desc = format!(
"matmul [{a_batch},{a_rows},{a_cols}] x [{a_cols},{b_cols}], strategy={strategy:?}",
);
run_bench(trials, &desc, || {
matmul_impl(a.view(), b.view(), strategy).unwrap();
});
};

run_trial(MatmulStrategy::Batch);
run_trial(MatmulStrategy::Auto);
println!();
}
}
}

0 comments on commit 74f13d9

Please sign in to comment.