Skip to content

Commit

Permalink
Add an initial non-optimized ARM NEON kernel
Browse files Browse the repository at this point in the history
This is just a copy of the base kernel that has been tweaked to take advantage
of the extra available SIMD registers (32 instead of 16 as in SSE), use FMA
instructions and unroll the inner loop.

On an AWS c6g.xlarge instance (Graviton 2, 4 vCPU) this achieves ~114 GFLOPS vs
~65 with the base kernel. For comparison OpenBLAS achieves ~150 GFLOPS.
  • Loading branch information
robertknight committed Jan 2, 2024
1 parent 262b19f commit ff1d89b
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 1 deletion.
28 changes: 28 additions & 0 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ pub enum KernelHint {

/// Use the AVX 512 kernel. Intel x64 only.
Avx512,

/// Use the ARM NEON kernel. ARM 64 only.
ArmNeon,
}

impl GemmExecutor {
Expand All @@ -318,6 +321,10 @@ impl GemmExecutor {
if let Some(gemm) = Self::with_kernel(KernelHint::Fma) {
return gemm;
}
#[cfg(target_arch = "aarch64")]
if let Some(gemm) = Self::with_kernel(KernelHint::ArmNeon) {
return gemm;
}
Self::with_base_kernel()
}

Expand Down Expand Up @@ -364,6 +371,21 @@ impl GemmExecutor {
}
None
}
KernelHint::ArmNeon => {
#[cfg(target_arch = "aarch64")]
{
use kernels::aarch64::ArmNeonKernel;

if ArmNeonKernel::supported() {
return Some(GemmExecutor {
kernel: Box::new(ArmNeonKernel {}),
nr: ArmNeonKernel::NR,
mr: ArmNeonKernel::MR,
});
}
}
None
}
KernelHint::Base => Some(Self::with_base_kernel()),
}
}
Expand Down Expand Up @@ -1135,6 +1157,12 @@ mod tests {
test_gemm_with_kernel(KernelHint::Avx512)
}

#[cfg(target_arch = "aarch64")]
#[test]
fn test_gemm_with_arm_neon_kernel() -> Result<(), Box<dyn Error>> {
test_gemm_with_kernel(KernelHint::ArmNeon)
}

// This duplicates one of the other `test_gemm_with_XXX_kernel` tests
// depending on what the preferred kernel is. That's OK as long as this
// test is fast.
Expand Down
3 changes: 3 additions & 0 deletions src/gemm/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ use rten_tensor::Matrix;

use super::{GemmInputA, GemmInputB};

#[cfg(target_arch = "aarch64")]
pub mod aarch64;

#[cfg(target_arch = "x86_64")]
pub mod x64;

Expand Down
88 changes: 88 additions & 0 deletions src/gemm/kernels/aarch64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use super::Kernel;

use crate::iter_util::unroll_loop;

/// This is not a fully optimized ARM NEON kernel, just an initial version
/// which is a copy of the base kernel that has been tweaked to:
///
/// - Use a larger tile size
/// - Use FMA instructions via `f32::mul_add`
/// - Unroll the inner loop
#[derive(Default)]
pub struct ArmNeonKernel {}

impl Kernel for ArmNeonKernel {
// ARM NEON has 32 registers. Empirically 14x4 is the largest tile size
// this naive auto-vectorized implementation can use before LLVM spills
// registers and performance drops. Better kernels in eg. OpenBLAS have
// 64-element tiles (8x8 or 16x4).

const MR: usize = 14;
const NR: usize = 4;

fn name() -> &'static str {
"arm-neon"
}

fn supported() -> bool {
true
}

unsafe fn kernel(
tile_ptr: *mut f32,
tile_row_stride: usize,
a: &[f32],
b: &[f32],
depth: usize,
alpha: f32,
beta: f32,
) {
const MR: usize = ArmNeonKernel::MR;
const NR: usize = ArmNeonKernel::NR;

assert!(a.len() >= depth * MR);
assert!(b.len() >= depth * NR);

// Accumulate into a fixed-sized array to allow the compiler to generate
// more efficient code for the loop over `depth`.
let mut tmp = [[0.0; NR]; MR];

unroll_loop!(depth, k, 8, {
let a_off = k * MR;
let b_off = k * NR;

for i in 0..MR {
for j in 0..NR {
tmp[i][j] = a
.get_unchecked(a_off + i)
.mul_add(*b.get_unchecked(b_off + j), tmp[i][j]);
}
}
});

if beta == 0. && alpha == 1. {
for i in 0..MR {
for j in 0..NR {
let out_el = tile_ptr.add(tile_row_stride * i + j);
*out_el = tmp[i][j];
}
}
} else if beta == 1. && alpha == 1. {
for i in 0..MR {
for j in 0..NR {
let out_el = tile_ptr.add(tile_row_stride * i + j);
*out_el += tmp[i][j];
}
}
} else {
for i in 0..MR {
for j in 0..NR {
let out_el = tile_ptr.add(tile_row_stride * i + j);
*out_el = beta * *out_el + alpha * tmp[i][j];
}
}
}
}
}

super::impl_gemmops!(ArmNeonKernel);
35 changes: 34 additions & 1 deletion src/iter_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,35 @@ impl MaybeParIter for Range<usize> {
}
}

#[macro_export]
macro_rules! unroll_loop {
($count:expr, $loop_var:ident, $factor: literal, $block:tt) => {
let mut n = $count;
let mut $loop_var = 0;
while n >= $factor {
for _i in 0..$factor {
$block;
$loop_var += 1;
}
n -= $factor;
}
while n > 0 {
$block;

$loop_var += 1;
n -= 1;
}
};
}

#[allow(unused_imports)]
pub use unroll_loop;

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU32, Ordering};

use super::{range_chunks, range_chunks_exact, MaybeParIter};
use super::{range_chunks, range_chunks_exact, unroll_loop, MaybeParIter};

#[test]
fn test_range_chunks() {
Expand Down Expand Up @@ -197,4 +221,13 @@ mod tests {
});
assert_eq!(count.load(Ordering::SeqCst), 1000);
}

#[test]
fn test_unroll_loop() {
let mut items: Vec<i32> = Vec::new();
unroll_loop!(10, i, 4, {
items.push(i);
});
assert_eq!(items, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
}

0 comments on commit ff1d89b

Please sign in to comment.