diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index 97e82e14..bd6aa354 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -229,7 +229,7 @@ mod tests { use rten_tensor::prelude::*; use rten_tensor::rng::XorShiftRng; use rten_tensor::test_util::expect_equal; - use rten_tensor::Tensor; + use rten_tensor::{Tensor, TensorView, TensorViewMut}; use crate::gemm::gemm; use crate::test_util::run_bench; @@ -249,6 +249,34 @@ mod tests { ) } + /// Multiply matrices in `a` by corresponding matrices in `b` and write to + /// `c`. The shapes of `a` and `b` are broadcast so that their first N-2 + /// dims match `c`. + fn reference_matmul(mut c: TensorViewMut, a: TensorView, b: TensorView) { + let a_batch_dims = a.ndim() - 2; + let b_batch_dims = b.ndim() - 2; + let out_prefix = &c.shape()[..c.ndim() - 2]; + + let a_bcast = [out_prefix, &a.shape()[a_batch_dims..]].concat(); + let b_bcast = [out_prefix, &b.shape()[b_batch_dims..]].concat(); + + a.broadcast(a_bcast.as_slice()) + .inner_iter::<2>() + .zip(b.broadcast(b_bcast.as_slice()).inner_iter::<2>()) + .zip(c.inner_iter_mut::<2>()) + .for_each(|((a, b), mut c)| { + let c_row_stride = c.stride(0); + gemm( + c.data_mut().unwrap(), + c_row_stride, + a, + b, + 1., /* alpha */ + 0., /* beta */ + ) + }); + } + #[test] fn test_gemm_op() -> Result<(), Box> { let mut rng = XorShiftRng::new(1234); @@ -321,15 +349,104 @@ mod tests { #[test] fn test_matmul() -> Result<(), Box> { - let mut rng = XorShiftRng::new(1234); - let a = Tensor::rand(&[3, 10], &mut rng); - let b = Tensor::rand(&[10, 8], &mut rng); + struct Case<'a> { + a_shape: &'a [usize], + b_shape: &'a [usize], + out_shape: &'a [usize], + } - let mut expected = Tensor::zeros(&[3, 8]); - gemm_tensors(&mut expected, &a, &b, 1., 1.); + let cases = [ + // Simple matmul + Case { + a_shape: &[3, 10], + b_shape: &[10, 8], + out_shape: &[3, 8], + }, + // LHS input is a batch + Case { + a_shape: &[2, 3, 10], + b_shape: &[10, 8], + out_shape: &[2, 3, 8], + }, + // RHS input is a batch + Case { + a_shape: &[3, 10], + b_shape: &[2, 10, 8], + out_shape: &[2, 3, 8], + }, + // Both inputs are batches + Case { + a_shape: &[2, 3, 10], + b_shape: &[2, 10, 8], + out_shape: &[2, 3, 8], + }, + ]; - let result = matmul(a.view(), b.view()).unwrap(); - expect_equal(&result, &expected)?; + for Case { + a_shape, + b_shape, + out_shape, + } in cases + { + let mut rng = XorShiftRng::new(1234); + let a = Tensor::rand(a_shape, &mut rng); + let b = Tensor::rand(b_shape, &mut rng); + let mut expected = Tensor::zeros(out_shape); + + reference_matmul(expected.view_mut(), a.view(), b.view()); + let result = matmul(a.view(), b.view()).unwrap(); + expect_equal(&result, &expected)?; + } + + Ok(()) + } + + #[test] + fn test_matmul_invalid() -> Result<(), Box> { + struct Case<'a> { + a_shape: &'a [usize], + b_shape: &'a [usize], + error: OpError, + } + + let cases = [ + Case { + a_shape: &[3], + b_shape: &[10, 8], + error: OpError::InvalidValue("Inputs must have >= 2 dimensions"), + }, + Case { + a_shape: &[3, 10], + b_shape: &[10], + error: OpError::InvalidValue("Inputs must have >= 2 dimensions"), + }, + Case { + a_shape: &[3, 10], + b_shape: &[11, 8], + error: OpError::IncompatibleInputShapes( + "Columns of first matrix does not match rows of second matrix", + ), + }, + Case { + a_shape: &[2, 3, 10], + b_shape: &[3, 10, 8], + error: OpError::IncompatibleInputShapes("Cannot broadcast shapes"), + }, + ]; + + for Case { + a_shape, + b_shape, + error, + } in cases + { + let mut rng = XorShiftRng::new(1234); + let a = Tensor::rand(a_shape, &mut rng); + let b = Tensor::rand(b_shape, &mut rng); + + let result = matmul(a.view(), b.view()); + assert_eq!(result, Err(error)); + } Ok(()) } @@ -361,46 +478,6 @@ mod tests { } } - #[test] - fn test_matmul_broadcast() -> Result<(), Box> { - let mut rng = XorShiftRng::new(1234); - let mut a = Tensor::rand(&[3, 10], &mut rng); - let mut b = Tensor::rand(&[10, 8], &mut rng); - - let mut expected = Tensor::zeros(&[3, 8]); - gemm_tensors(&mut expected, &a, &b, 1., 1.); - expected.reshape(&[1, 1, 3, 8]); - - // LHS input has excess 1 dims - a.reshape(&[1, 1, 3, 10]); - let result = matmul(a.view(), b.view()).unwrap(); - expect_equal(&result, &expected)?; - - // RHS input has excess 1 dims - a.reshape(&[3, 10]); - b.reshape(&[1, 1, 10, 8]); - let result = matmul(a.view(), b.view()).unwrap(); - expect_equal(&result, &expected)?; - - // RHS input requires broadcasting - let broadcast_a_shape = &[1, 4, 3, 10][..]; - let broadcast_expected_shape = &[1, 4, 3, 8][..]; - let broadcast_a = a.broadcast(broadcast_a_shape); - let broadcast_expected = expected.broadcast(broadcast_expected_shape); - let result = matmul(broadcast_a, b.view()).unwrap(); - expect_equal(&result.view(), &broadcast_expected)?; - - // LHS input requires broadcasting - let broadcast_b_shape = &[1, 3, 10, 8][..]; - let broadcast_expected_shape = &[1, 3, 3, 8][..]; - let broadcast_b = b.broadcast(broadcast_b_shape); - let expected = expected.broadcast(broadcast_expected_shape); - let result = matmul(a.view(), broadcast_b).unwrap(); - expect_equal(&result.view(), &expected)?; - - Ok(()) - } - #[test] #[ignore] fn bench_matmul() {