Skip to content

Commit

Permalink
Merge pull request #31 from robertknight/vecmath-arm
Browse files Browse the repository at this point in the history
Add aarch64 implementations for rten-vecmath
  • Loading branch information
robertknight authored Jan 5, 2024
2 parents 441d8c0 + 3701560 commit 999be27
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 4 deletions.
38 changes: 38 additions & 0 deletions rten-vecmath/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,21 @@ macro_rules! dispatch_unary_op {
return;
}

#[cfg(target_arch = "aarch64")]
{
use std::arch::aarch64::float32x4_t;

unsafe {
vec_unary_op(
$in.into(),
$out.into(),
|x: float32x4_t| $op_func(x),
0., /* pad */
);
}
return;
}

// Generic fallback.
for (x, y) in $in.iter().zip($out.iter_mut()) {
*y = $fallback_func(*x);
Expand Down Expand Up @@ -287,6 +302,21 @@ macro_rules! dispatch_unary_op {
return;
}

#[cfg(target_arch = "aarch64")]
{
use std::arch::aarch64::float32x4_t;

unsafe {
vec_unary_op(
$out.into(),
$out.into(),
|x: float32x4_t| $op_func(x),
0., /* pad */
);
}
return;
}

// Generic fallback.
for x in $out.iter_mut() {
*x = $fallback_func(*x);
Expand Down Expand Up @@ -335,6 +365,14 @@ macro_rules! dispatch_simd {
return;
}

#[cfg(target_arch = "aarch64")]
{
use std::arch::aarch64::float32x4_t;

unsafe { $func::<float32x4_t>($in, $out) };
return;
}

// Generic fallback.
unsafe { $func::<f32>($in, $out) };
}
Expand Down
12 changes: 8 additions & 4 deletions rten-vecmath/src/simd_vec.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(target_arch = "aarch64")]
pub(crate) mod aarch64;
#[cfg(target_arch = "x86_64")]
pub(crate) mod avx;
#[cfg(target_arch = "wasm32")]
Expand Down Expand Up @@ -31,6 +33,9 @@ pub trait SimdInt: Copy + Sized {
unsafe fn gt(self, other: Self) -> Self;

/// Select elements from this vector or `other` according to a mask.
///
/// For each lane, if the mask value is zero, return the element from
/// `self`, otherwise return the value from `other`.
unsafe fn blend(self, other: Self, mask: Self) -> Self;

/// Compute `self + rhs`.
Expand Down Expand Up @@ -136,10 +141,9 @@ pub trait SimdFloat: Copy + Sized {

/// Combine elements of `self` and `rhs` according to a mask.
///
/// If the mask bits for an element are off, the corresponding element from
/// `self` is returned, otherwise the corresponding element from `rhs`
/// is returned.
unsafe fn blend(self, rhs: Self, mask: Self) -> Self;
/// For each lane, if the mask value is zero, return the element from
/// `self`, otherwise return the value from `other`.
unsafe fn blend(self, other: Self, mask: Self) -> Self;

/// Evaluate a polynomial using Horner's method.
///
Expand Down
121 changes: 121 additions & 0 deletions rten-vecmath/src/simd_vec/aarch64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use std::arch::aarch64::{
float32x4_t, int32x4_t, vabsq_f32, vaddq_f32, vaddq_s32, vbslq_f32, vbslq_s32, vcgeq_f32,
vcgtq_s32, vcleq_f32, vcltq_f32, vcvtq_s32_f32, vdivq_f32, vdupq_n_f32, vdupq_n_s32, vfmaq_f32,
vld1q_f32, vld1q_s32, vmaxq_f32, vmulq_f32, vreinterpretq_f32_s32, vreinterpretq_f32_u32,
vreinterpretq_s32_u32, vreinterpretq_u32_f32, vreinterpretq_u32_s32, vshlq_n_s32, vst1q_f32,
vst1q_s32, vsubq_f32, vsubq_s32,
};

use crate::simd_vec::{SimdFloat, SimdInt};

impl SimdInt for int32x4_t {
type Float = float32x4_t;

const LEN: usize = 4;

unsafe fn zero() -> Self {
vdupq_n_s32(0)
}

unsafe fn splat(val: i32) -> Self {
vdupq_n_s32(val)
}

unsafe fn gt(self, other: Self) -> Self {
vreinterpretq_s32_u32(vcgtq_s32(self, other))
}

unsafe fn blend(self, other: Self, mask: Self) -> Self {
vbslq_s32(vreinterpretq_u32_s32(mask), other, self)
}

unsafe fn add(self, rhs: Self) -> Self {
vaddq_s32(self, rhs)
}

unsafe fn sub(self, rhs: Self) -> Self {
vsubq_s32(self, rhs)
}

unsafe fn shl<const COUNT: i32>(self) -> Self {
vshlq_n_s32(self, COUNT)
}

unsafe fn reinterpret_as_float(self) -> Self::Float {
vreinterpretq_f32_s32(self)
}

unsafe fn load(ptr: *const i32) -> Self {
vld1q_s32(ptr)
}

unsafe fn store(self, ptr: *mut i32) {
vst1q_s32(ptr, self)
}
}

impl SimdFloat for float32x4_t {
type Int = int32x4_t;

const LEN: usize = 4;

unsafe fn splat(val: f32) -> Self {
vdupq_n_f32(val)
}

unsafe fn abs(self) -> Self {
vabsq_f32(self)
}

unsafe fn mul_add(self, a: Self, b: Self) -> Self {
vfmaq_f32(b, self, a)
}

unsafe fn sub(self, rhs: Self) -> Self {
vsubq_f32(self, rhs)
}

unsafe fn add(self, rhs: Self) -> Self {
vaddq_f32(self, rhs)
}

unsafe fn to_int_trunc(self) -> Self::Int {
vcvtq_s32_f32(self)
}

unsafe fn mul(self, rhs: Self) -> Self {
vmulq_f32(self, rhs)
}

unsafe fn div(self, rhs: Self) -> Self {
vdivq_f32(self, rhs)
}

unsafe fn ge(self, rhs: Self) -> Self {
vreinterpretq_f32_u32(vcgeq_f32(self, rhs))
}

unsafe fn le(self, rhs: Self) -> Self {
vreinterpretq_f32_u32(vcleq_f32(self, rhs))
}

unsafe fn lt(self, rhs: Self) -> Self {
vreinterpretq_f32_u32(vcltq_f32(self, rhs))
}

unsafe fn max(self, rhs: Self) -> Self {
vmaxq_f32(self, rhs)
}

unsafe fn blend(self, other: Self, mask: Self) -> Self {
vbslq_f32(vreinterpretq_u32_f32(mask), other, self)
}

unsafe fn load(ptr: *const f32) -> Self {
vld1q_f32(ptr)
}

unsafe fn store(self, ptr: *mut f32) {
vst1q_f32(ptr, self)
}
}

0 comments on commit 999be27

Please sign in to comment.