Skip to content

Commit

Permalink
Merge pull request #30 from robertknight/arm-gemm
Browse files Browse the repository at this point in the history
Add an initial non-optimized ARM NEON kernel
  • Loading branch information
robertknight authored Jan 2, 2024
2 parents 262b19f + ff1d89b commit 2d70ee8
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 1 deletion.
28 changes: 28 additions & 0 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ pub enum KernelHint {

/// Use the AVX 512 kernel. Intel x64 only.
Avx512,

/// Use the ARM NEON kernel. ARM 64 only.
ArmNeon,
}

impl GemmExecutor {
Expand All @@ -318,6 +321,10 @@ impl GemmExecutor {
if let Some(gemm) = Self::with_kernel(KernelHint::Fma) {
return gemm;
}
#[cfg(target_arch = "aarch64")]
if let Some(gemm) = Self::with_kernel(KernelHint::ArmNeon) {
return gemm;
}
Self::with_base_kernel()
}

Expand Down Expand Up @@ -364,6 +371,21 @@ impl GemmExecutor {
}
None
}
KernelHint::ArmNeon => {
#[cfg(target_arch = "aarch64")]
{
use kernels::aarch64::ArmNeonKernel;

if ArmNeonKernel::supported() {
return Some(GemmExecutor {
kernel: Box::new(ArmNeonKernel {}),
nr: ArmNeonKernel::NR,
mr: ArmNeonKernel::MR,
});
}
}
None
}
KernelHint::Base => Some(Self::with_base_kernel()),
}
}
Expand Down Expand Up @@ -1135,6 +1157,12 @@ mod tests {
test_gemm_with_kernel(KernelHint::Avx512)
}

#[cfg(target_arch = "aarch64")]
#[test]
fn test_gemm_with_arm_neon_kernel() -> Result<(), Box<dyn Error>> {
test_gemm_with_kernel(KernelHint::ArmNeon)
}

// This duplicates one of the other `test_gemm_with_XXX_kernel` tests
// depending on what the preferred kernel is. That's OK as long as this
// test is fast.
Expand Down
3 changes: 3 additions & 0 deletions src/gemm/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ use rten_tensor::Matrix;

use super::{GemmInputA, GemmInputB};

#[cfg(target_arch = "aarch64")]
pub mod aarch64;

#[cfg(target_arch = "x86_64")]
pub mod x64;

Expand Down
88 changes: 88 additions & 0 deletions src/gemm/kernels/aarch64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use super::Kernel;

use crate::iter_util::unroll_loop;

/// This is not a fully optimized ARM NEON kernel, just an initial version
/// which is a copy of the base kernel that has been tweaked to:
///
/// - Use a larger tile size
/// - Use FMA instructions via `f32::mul_add`
/// - Unroll the inner loop
#[derive(Default)]
pub struct ArmNeonKernel {}

impl Kernel for ArmNeonKernel {
// ARM NEON has 32 registers. Empirically 14x4 is the largest tile size
// this naive auto-vectorized implementation can use before LLVM spills
// registers and performance drops. Better kernels in eg. OpenBLAS have
// 64-element tiles (8x8 or 16x4).

const MR: usize = 14;
const NR: usize = 4;

fn name() -> &'static str {
"arm-neon"
}

fn supported() -> bool {
true
}

unsafe fn kernel(
tile_ptr: *mut f32,
tile_row_stride: usize,
a: &[f32],
b: &[f32],
depth: usize,
alpha: f32,
beta: f32,
) {
const MR: usize = ArmNeonKernel::MR;
const NR: usize = ArmNeonKernel::NR;

assert!(a.len() >= depth * MR);
assert!(b.len() >= depth * NR);

// Accumulate into a fixed-sized array to allow the compiler to generate
// more efficient code for the loop over `depth`.
let mut tmp = [[0.0; NR]; MR];

unroll_loop!(depth, k, 8, {
let a_off = k * MR;
let b_off = k * NR;

for i in 0..MR {
for j in 0..NR {
tmp[i][j] = a
.get_unchecked(a_off + i)
.mul_add(*b.get_unchecked(b_off + j), tmp[i][j]);
}
}
});

if beta == 0. && alpha == 1. {
for i in 0..MR {
for j in 0..NR {
let out_el = tile_ptr.add(tile_row_stride * i + j);
*out_el = tmp[i][j];
}
}
} else if beta == 1. && alpha == 1. {
for i in 0..MR {
for j in 0..NR {
let out_el = tile_ptr.add(tile_row_stride * i + j);
*out_el += tmp[i][j];
}
}
} else {
for i in 0..MR {
for j in 0..NR {
let out_el = tile_ptr.add(tile_row_stride * i + j);
*out_el = beta * *out_el + alpha * tmp[i][j];
}
}
}
}
}

super::impl_gemmops!(ArmNeonKernel);
35 changes: 34 additions & 1 deletion src/iter_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,35 @@ impl MaybeParIter for Range<usize> {
}
}

#[macro_export]
macro_rules! unroll_loop {
($count:expr, $loop_var:ident, $factor: literal, $block:tt) => {
let mut n = $count;
let mut $loop_var = 0;
while n >= $factor {
for _i in 0..$factor {
$block;
$loop_var += 1;
}
n -= $factor;
}
while n > 0 {
$block;

$loop_var += 1;
n -= 1;
}
};
}

#[allow(unused_imports)]
pub use unroll_loop;

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU32, Ordering};

use super::{range_chunks, range_chunks_exact, MaybeParIter};
use super::{range_chunks, range_chunks_exact, unroll_loop, MaybeParIter};

#[test]
fn test_range_chunks() {
Expand Down Expand Up @@ -197,4 +221,13 @@ mod tests {
});
assert_eq!(count.load(Ordering::SeqCst), 1000);
}

#[test]
fn test_unroll_loop() {
let mut items: Vec<i32> = Vec::new();
unroll_loop!(10, i, 4, {
items.push(i);
});
assert_eq!(items, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
}

0 comments on commit 2d70ee8

Please sign in to comment.