Skip to content

Commit

Permalink
Use a separate associated type for SIMD masks
Browse files Browse the repository at this point in the history
For WebAssembly and AVX2, the `Self` type can also be used as the SIMD mask
type. However for i32/f32 and ARM the code had to convert to/from the actual
mask type (bool and uint32_t respectively) to the `Self` type.  Remove this
unnecessary conversion and make the APIs clearer by introducing a separate
`Mask` associated type for SIMD vectors.
  • Loading branch information
robertknight committed Jan 6, 2024
1 parent d083acd commit 3c952b6
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 62 deletions.
56 changes: 26 additions & 30 deletions rten-vecmath/src/simd_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ pub trait SimdInt: Copy + Sized {
/// vector to a float.
type Float: SimdFloat<Int = Self>;

/// The type used by operations that use or return masks.
type Mask: Copy;

/// Return a new vector with all elements set to zero.
unsafe fn zero() -> Self {
Self::splat(0)
Expand All @@ -43,13 +46,13 @@ pub trait SimdInt: Copy + Sized {
unsafe fn splat(val: i32) -> Self;

/// Return a mask indicating whether `self > other`.
unsafe fn gt(self, other: Self) -> Self;
unsafe fn gt(self, other: Self) -> Self::Mask;

/// Select elements from this vector or `other` according to a mask.
///
/// For each lane, if the mask value is zero, return the element from
/// `self`, otherwise return the value from `other`.
unsafe fn blend(self, other: Self, mask: Self) -> Self;
unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self;

/// Compute `self + rhs`.
unsafe fn add(self, rhs: Self) -> Self;
Expand Down Expand Up @@ -96,6 +99,9 @@ pub trait SimdFloat: Copy + Sized {
/// to a vector of ints.
type Int: SimdInt<Float = Self>;

/// The type used by operations that use or return masks.
type Mask: Copy;

/// Shorthand for `Self::splat(1.0)`.
unsafe fn one() -> Self {
Self::splat(1.0)
Expand Down Expand Up @@ -141,13 +147,13 @@ pub trait SimdFloat: Copy + Sized {
unsafe fn div(self, rhs: Self) -> Self;

/// Compute a mask containing `self >= rhs`.
unsafe fn ge(self, rhs: Self) -> Self;
unsafe fn ge(self, rhs: Self) -> Self::Mask;

/// Compute a mask containing `self <= rhs`.
unsafe fn le(self, rhs: Self) -> Self;
unsafe fn le(self, rhs: Self) -> Self::Mask;

/// Compute a mask containing `self < rhs`.
unsafe fn lt(self, rhs: Self) -> Self;
unsafe fn lt(self, rhs: Self) -> Self::Mask;

/// Compute the maximum of `self` and `rhs`.
unsafe fn max(self, rhs: Self) -> Self;
Expand All @@ -156,7 +162,7 @@ pub trait SimdFloat: Copy + Sized {
///
/// For each lane, if the mask value is zero, return the element from
/// `self`, otherwise return the value from `other`.
unsafe fn blend(self, other: Self, mask: Self) -> Self;
unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self;

/// Evaluate a polynomial using Horner's method.
///
Expand Down Expand Up @@ -200,6 +206,7 @@ impl SimdInt for i32 {
const LEN: usize = 1;

type Float = f32;
type Mask = bool;

unsafe fn zero() -> Self {
0
Expand All @@ -209,12 +216,12 @@ impl SimdInt for i32 {
val
}

unsafe fn gt(self, other: Self) -> Self {
(self > other) as i32
unsafe fn gt(self, other: Self) -> Self::Mask {
self > other
}

unsafe fn blend(self, other: Self, mask: Self) -> Self {
if mask == 0 {
unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self {
if !mask {
self
} else {
other
Expand Down Expand Up @@ -251,6 +258,7 @@ impl SimdFloat for f32 {
const LEN: usize = 1;

type Int = i32;
type Mask = bool;

unsafe fn one() -> Self {
1.
Expand Down Expand Up @@ -292,36 +300,24 @@ impl SimdFloat for f32 {
self / rhs
}

unsafe fn ge(self, rhs: Self) -> Self {
if self >= rhs {
1.
} else {
0.
}
unsafe fn ge(self, rhs: Self) -> Self::Mask {
self >= rhs
}

unsafe fn le(self, rhs: Self) -> Self {
if self <= rhs {
1.
} else {
0.
}
unsafe fn le(self, rhs: Self) -> Self::Mask {
self <= rhs
}

unsafe fn lt(self, rhs: Self) -> Self {
if self < rhs {
1.
} else {
0.
}
unsafe fn lt(self, rhs: Self) -> Self::Mask {
self < rhs
}

unsafe fn max(self, rhs: Self) -> Self {
f32::max(self, rhs)
}

unsafe fn blend(self, rhs: Self, mask: Self) -> Self {
if mask == 0. {
unsafe fn blend(self, rhs: Self, mask: Self::Mask) -> Self {
if !mask {
self
} else {
rhs
Expand Down
35 changes: 18 additions & 17 deletions rten-vecmath/src/simd_vec/aarch64.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use std::arch::aarch64::{
float32x4_t, int32x4_t, vabsq_f32, vaddq_f32, vaddq_s32, vbslq_f32, vbslq_s32, vcgeq_f32,
vcgtq_s32, vcleq_f32, vcltq_f32, vcvtq_s32_f32, vdivq_f32, vdupq_n_f32, vdupq_n_s32, vfmaq_f32,
vld1q_f32, vld1q_s32, vmaxq_f32, vmulq_f32, vreinterpretq_f32_s32, vreinterpretq_f32_u32,
vreinterpretq_s32_u32, vreinterpretq_u32_f32, vreinterpretq_u32_s32, vshlq_n_s32, vst1q_f32,
vst1q_s32, vsubq_f32, vsubq_s32,
float32x4_t, int32x4_t, uint32x4_t, vabsq_f32, vaddq_f32, vaddq_s32, vbslq_f32, vbslq_s32,
vcgeq_f32, vcgtq_s32, vcleq_f32, vcltq_f32, vcvtq_s32_f32, vdivq_f32, vdupq_n_f32, vdupq_n_s32,
vfmaq_f32, vld1q_f32, vld1q_s32, vmaxq_f32, vmulq_f32, vreinterpretq_f32_s32, vshlq_n_s32,
vst1q_f32, vst1q_s32, vsubq_f32, vsubq_s32,
};

use crate::simd_vec::{SimdFloat, SimdInt};

impl SimdInt for int32x4_t {
type Float = float32x4_t;
type Mask = uint32x4_t;

const LEN: usize = 4;

Expand All @@ -21,12 +21,12 @@ impl SimdInt for int32x4_t {
vdupq_n_s32(val)
}

unsafe fn gt(self, other: Self) -> Self {
vreinterpretq_s32_u32(vcgtq_s32(self, other))
unsafe fn gt(self, other: Self) -> Self::Mask {
vcgtq_s32(self, other)
}

unsafe fn blend(self, other: Self, mask: Self) -> Self {
vbslq_s32(vreinterpretq_u32_s32(mask), other, self)
unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self {
vbslq_s32(mask, other, self)
}

unsafe fn add(self, rhs: Self) -> Self {
Expand Down Expand Up @@ -56,6 +56,7 @@ impl SimdInt for int32x4_t {

impl SimdFloat for float32x4_t {
type Int = int32x4_t;
type Mask = uint32x4_t;

const LEN: usize = 4;

Expand Down Expand Up @@ -91,24 +92,24 @@ impl SimdFloat for float32x4_t {
vdivq_f32(self, rhs)
}

unsafe fn ge(self, rhs: Self) -> Self {
vreinterpretq_f32_u32(vcgeq_f32(self, rhs))
unsafe fn ge(self, rhs: Self) -> Self::Mask {
vcgeq_f32(self, rhs)
}

unsafe fn le(self, rhs: Self) -> Self {
vreinterpretq_f32_u32(vcleq_f32(self, rhs))
unsafe fn le(self, rhs: Self) -> Self::Mask {
vcleq_f32(self, rhs)
}

unsafe fn lt(self, rhs: Self) -> Self {
vreinterpretq_f32_u32(vcltq_f32(self, rhs))
unsafe fn lt(self, rhs: Self) -> Self::Mask {
vcltq_f32(self, rhs)
}

unsafe fn max(self, rhs: Self) -> Self {
vmaxq_f32(self, rhs)
}

unsafe fn blend(self, other: Self, mask: Self) -> Self {
vbslq_f32(vreinterpretq_u32_f32(mask), other, self)
unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self {
vbslq_f32(mask, other, self)
}

unsafe fn load(ptr: *const f32) -> Self {
Expand Down
20 changes: 11 additions & 9 deletions rten-vecmath/src/simd_vec/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@ pub struct v128f(v128);

impl SimdInt for v128i {
type Float = v128f;
type Mask = v128i;

const LEN: usize = 4;

unsafe fn splat(val: i32) -> Self {
Self(i32x4_splat(val))
}

unsafe fn gt(self, other: Self) -> Self {
unsafe fn gt(self, other: Self) -> Self::Mask {
Self(i32x4_gt(self.0, other.0))
}

unsafe fn blend(self, other: Self, mask: Self) -> Self {
unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self {
Self(v128_bitselect(other.0, self.0, mask.0))
}

Expand Down Expand Up @@ -60,6 +61,7 @@ impl SimdInt for v128i {

impl SimdFloat for v128f {
type Int = v128i;
type Mask = v128i;

const LEN: usize = 4;

Expand Down Expand Up @@ -95,23 +97,23 @@ impl SimdFloat for v128f {
Self(f32x4_div(self.0, rhs.0))
}

unsafe fn ge(self, rhs: Self) -> Self {
Self(f32x4_ge(self.0, rhs.0))
unsafe fn ge(self, rhs: Self) -> Self::Mask {
v128i(f32x4_ge(self.0, rhs.0))
}

unsafe fn le(self, rhs: Self) -> Self {
Self(f32x4_le(self.0, rhs.0))
unsafe fn le(self, rhs: Self) -> Self::Mask {
v128i(f32x4_le(self.0, rhs.0))
}

unsafe fn lt(self, rhs: Self) -> Self {
Self(f32x4_lt(self.0, rhs.0))
unsafe fn lt(self, rhs: Self) -> Self::Mask {
v128i(f32x4_lt(self.0, rhs.0))
}

unsafe fn max(self, rhs: Self) -> Self {
Self(f32x4_max(self.0, rhs.0))
}

unsafe fn blend(self, rhs: Self, mask: Self) -> Self {
unsafe fn blend(self, rhs: Self, mask: Self::Mask) -> Self {
Self(v128_bitselect(rhs.0, self.0, mask.0))
}

Expand Down
14 changes: 8 additions & 6 deletions rten-vecmath/src/simd_vec/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::simd_vec::{SimdFloat, SimdInt};

impl SimdInt for __m256i {
type Float = __m256;
type Mask = __m256i;

const LEN: usize = 8;

Expand All @@ -28,13 +29,13 @@ impl SimdInt for __m256i {

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
unsafe fn gt(self, other: Self) -> Self {
unsafe fn gt(self, other: Self) -> Self::Mask {
_mm256_cmpgt_epi32(self, other)
}

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
unsafe fn blend(self, other: Self, mask: Self) -> Self {
unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self {
_mm256_blendv_epi8(self, other, mask)
}

Expand Down Expand Up @@ -79,6 +80,7 @@ impl SimdInt for __m256i {

impl SimdFloat for __m256 {
type Int = __m256i;
type Mask = __m256;

const LEN: usize = 8;

Expand Down Expand Up @@ -134,19 +136,19 @@ impl SimdFloat for __m256 {

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
unsafe fn ge(self, rhs: Self) -> Self {
unsafe fn ge(self, rhs: Self::Mask) -> Self {
_mm256_cmp_ps(self, rhs, _CMP_GE_OQ)
}

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
unsafe fn le(self, rhs: Self) -> Self {
unsafe fn le(self, rhs: Self::Mask) -> Self {
_mm256_cmp_ps(self, rhs, _CMP_LE_OQ)
}

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
unsafe fn lt(self, rhs: Self) -> Self {
unsafe fn lt(self, rhs: Self::Mask) -> Self {
_mm256_cmp_ps(self, rhs, _CMP_LT_OQ)
}

Expand All @@ -158,7 +160,7 @@ impl SimdFloat for __m256 {

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
unsafe fn blend(self, rhs: Self, mask: Self) -> Self {
unsafe fn blend(self, rhs: Self, mask: Self::Mask) -> Self {
_mm256_blendv_ps(self, rhs, mask)
}

Expand Down

0 comments on commit 3c952b6

Please sign in to comment.