Skip to content

Commit

Permalink
Improve MatMul tests
Browse files Browse the repository at this point in the history
Refactor MatMul tests into a single table-driven test that has cases for when
neither, one or both of the inputs is a batch. Also add tests for various
invalid inputs.
  • Loading branch information
robertknight committed Feb 6, 2024
1 parent 74f13d9 commit abeda07
Showing 1 changed file with 125 additions and 48 deletions.
173 changes: 125 additions & 48 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ mod tests {
use rten_tensor::prelude::*;
use rten_tensor::rng::XorShiftRng;
use rten_tensor::test_util::expect_equal;
use rten_tensor::Tensor;
use rten_tensor::{Tensor, TensorView, TensorViewMut};

use crate::gemm::gemm;
use crate::test_util::run_bench;
Expand All @@ -249,6 +249,34 @@ mod tests {
)
}

/// Multiply matrices in `a` by corresponding matrices in `b` and write to
/// `c`. The shapes of `a` and `b` are broadcast so that their first N-2
/// dims match `c`.
fn reference_matmul(mut c: TensorViewMut, a: TensorView, b: TensorView) {
let a_batch_dims = a.ndim() - 2;
let b_batch_dims = b.ndim() - 2;
let out_prefix = &c.shape()[..c.ndim() - 2];

let a_bcast = [out_prefix, &a.shape()[a_batch_dims..]].concat();
let b_bcast = [out_prefix, &b.shape()[b_batch_dims..]].concat();

a.broadcast(a_bcast.as_slice())
.inner_iter::<2>()
.zip(b.broadcast(b_bcast.as_slice()).inner_iter::<2>())
.zip(c.inner_iter_mut::<2>())
.for_each(|((a, b), mut c)| {
let c_row_stride = c.stride(0);
gemm(
c.data_mut().unwrap(),
c_row_stride,
a,
b,
1., /* alpha */
0., /* beta */
)
});
}

#[test]
fn test_gemm_op() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
Expand Down Expand Up @@ -321,15 +349,104 @@ mod tests {

#[test]
fn test_matmul() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[3, 10], &mut rng);
let b = Tensor::rand(&[10, 8], &mut rng);
struct Case<'a> {
a_shape: &'a [usize],
b_shape: &'a [usize],
out_shape: &'a [usize],
}

let mut expected = Tensor::zeros(&[3, 8]);
gemm_tensors(&mut expected, &a, &b, 1., 1.);
let cases = [
// Simple matmul
Case {
a_shape: &[3, 10],
b_shape: &[10, 8],
out_shape: &[3, 8],
},
// LHS input is a batch
Case {
a_shape: &[2, 3, 10],
b_shape: &[10, 8],
out_shape: &[2, 3, 8],
},
// RHS input is a batch
Case {
a_shape: &[3, 10],
b_shape: &[2, 10, 8],
out_shape: &[2, 3, 8],
},
// Both inputs are batches
Case {
a_shape: &[2, 3, 10],
b_shape: &[2, 10, 8],
out_shape: &[2, 3, 8],
},
];

let result = matmul(a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
for Case {
a_shape,
b_shape,
out_shape,
} in cases
{
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(a_shape, &mut rng);
let b = Tensor::rand(b_shape, &mut rng);
let mut expected = Tensor::zeros(out_shape);

reference_matmul(expected.view_mut(), a.view(), b.view());
let result = matmul(a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
}

Ok(())
}

#[test]
fn test_matmul_invalid() -> Result<(), Box<dyn Error>> {
struct Case<'a> {
a_shape: &'a [usize],
b_shape: &'a [usize],
error: OpError,
}

let cases = [
Case {
a_shape: &[3],
b_shape: &[10, 8],
error: OpError::InvalidValue("Inputs must have >= 2 dimensions"),
},
Case {
a_shape: &[3, 10],
b_shape: &[10],
error: OpError::InvalidValue("Inputs must have >= 2 dimensions"),
},
Case {
a_shape: &[3, 10],
b_shape: &[11, 8],
error: OpError::IncompatibleInputShapes(
"Columns of first matrix does not match rows of second matrix",
),
},
Case {
a_shape: &[2, 3, 10],
b_shape: &[3, 10, 8],
error: OpError::IncompatibleInputShapes("Cannot broadcast shapes"),
},
];

for Case {
a_shape,
b_shape,
error,
} in cases
{
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(a_shape, &mut rng);
let b = Tensor::rand(b_shape, &mut rng);

let result = matmul(a.view(), b.view());
assert_eq!(result, Err(error));
}

Ok(())
}
Expand Down Expand Up @@ -361,46 +478,6 @@ mod tests {
}
}

#[test]
fn test_matmul_broadcast() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
let mut a = Tensor::rand(&[3, 10], &mut rng);
let mut b = Tensor::rand(&[10, 8], &mut rng);

let mut expected = Tensor::zeros(&[3, 8]);
gemm_tensors(&mut expected, &a, &b, 1., 1.);
expected.reshape(&[1, 1, 3, 8]);

// LHS input has excess 1 dims
a.reshape(&[1, 1, 3, 10]);
let result = matmul(a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;

// RHS input has excess 1 dims
a.reshape(&[3, 10]);
b.reshape(&[1, 1, 10, 8]);
let result = matmul(a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;

// RHS input requires broadcasting
let broadcast_a_shape = &[1, 4, 3, 10][..];
let broadcast_expected_shape = &[1, 4, 3, 8][..];
let broadcast_a = a.broadcast(broadcast_a_shape);
let broadcast_expected = expected.broadcast(broadcast_expected_shape);
let result = matmul(broadcast_a, b.view()).unwrap();
expect_equal(&result.view(), &broadcast_expected)?;

// LHS input requires broadcasting
let broadcast_b_shape = &[1, 3, 10, 8][..];
let broadcast_expected_shape = &[1, 3, 3, 8][..];
let broadcast_b = b.broadcast(broadcast_b_shape);
let expected = expected.broadcast(broadcast_expected_shape);
let result = matmul(a.view(), broadcast_b).unwrap();
expect_equal(&result.view(), &expected)?;

Ok(())
}

#[test]
#[ignore]
fn bench_matmul() {
Expand Down

0 comments on commit abeda07

Please sign in to comment.