diff --git a/rten-vecmath/src/simd_vec.rs b/rten-vecmath/src/simd_vec.rs index 864f185f..f2660021 100644 --- a/rten-vecmath/src/simd_vec.rs +++ b/rten-vecmath/src/simd_vec.rs @@ -34,6 +34,9 @@ pub trait SimdInt: Copy + Sized { /// vector to a float. type Float: SimdFloat; + /// The type used by operations that use or return masks. + type Mask: Copy; + /// Return a new vector with all elements set to zero. unsafe fn zero() -> Self { Self::splat(0) @@ -43,13 +46,13 @@ pub trait SimdInt: Copy + Sized { unsafe fn splat(val: i32) -> Self; /// Return a mask indicating whether `self > other`. - unsafe fn gt(self, other: Self) -> Self; + unsafe fn gt(self, other: Self) -> Self::Mask; /// Select elements from this vector or `other` according to a mask. /// /// For each lane, if the mask value is zero, return the element from /// `self`, otherwise return the value from `other`. - unsafe fn blend(self, other: Self, mask: Self) -> Self; + unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self; /// Compute `self + rhs`. unsafe fn add(self, rhs: Self) -> Self; @@ -96,6 +99,9 @@ pub trait SimdFloat: Copy + Sized { /// to a vector of ints. type Int: SimdInt; + /// The type used by operations that use or return masks. + type Mask: Copy; + /// Shorthand for `Self::splat(1.0)`. unsafe fn one() -> Self { Self::splat(1.0) @@ -141,13 +147,13 @@ pub trait SimdFloat: Copy + Sized { unsafe fn div(self, rhs: Self) -> Self; /// Compute a mask containing `self >= rhs`. - unsafe fn ge(self, rhs: Self) -> Self; + unsafe fn ge(self, rhs: Self) -> Self::Mask; /// Compute a mask containing `self <= rhs`. - unsafe fn le(self, rhs: Self) -> Self; + unsafe fn le(self, rhs: Self) -> Self::Mask; /// Compute a mask containing `self < rhs`. - unsafe fn lt(self, rhs: Self) -> Self; + unsafe fn lt(self, rhs: Self) -> Self::Mask; /// Compute the maximum of `self` and `rhs`. unsafe fn max(self, rhs: Self) -> Self; @@ -156,7 +162,7 @@ pub trait SimdFloat: Copy + Sized { /// /// For each lane, if the mask value is zero, return the element from /// `self`, otherwise return the value from `other`. - unsafe fn blend(self, other: Self, mask: Self) -> Self; + unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self; /// Evaluate a polynomial using Horner's method. /// @@ -200,6 +206,7 @@ impl SimdInt for i32 { const LEN: usize = 1; type Float = f32; + type Mask = bool; unsafe fn zero() -> Self { 0 @@ -209,12 +216,12 @@ impl SimdInt for i32 { val } - unsafe fn gt(self, other: Self) -> Self { - (self > other) as i32 + unsafe fn gt(self, other: Self) -> Self::Mask { + self > other } - unsafe fn blend(self, other: Self, mask: Self) -> Self { - if mask == 0 { + unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self { + if !mask { self } else { other @@ -251,6 +258,7 @@ impl SimdFloat for f32 { const LEN: usize = 1; type Int = i32; + type Mask = bool; unsafe fn one() -> Self { 1. @@ -292,36 +300,24 @@ impl SimdFloat for f32 { self / rhs } - unsafe fn ge(self, rhs: Self) -> Self { - if self >= rhs { - 1. - } else { - 0. - } + unsafe fn ge(self, rhs: Self) -> Self::Mask { + self >= rhs } - unsafe fn le(self, rhs: Self) -> Self { - if self <= rhs { - 1. - } else { - 0. - } + unsafe fn le(self, rhs: Self) -> Self::Mask { + self <= rhs } - unsafe fn lt(self, rhs: Self) -> Self { - if self < rhs { - 1. - } else { - 0. - } + unsafe fn lt(self, rhs: Self) -> Self::Mask { + self < rhs } unsafe fn max(self, rhs: Self) -> Self { f32::max(self, rhs) } - unsafe fn blend(self, rhs: Self, mask: Self) -> Self { - if mask == 0. { + unsafe fn blend(self, rhs: Self, mask: Self::Mask) -> Self { + if !mask { self } else { rhs diff --git a/rten-vecmath/src/simd_vec/aarch64.rs b/rten-vecmath/src/simd_vec/aarch64.rs index e2e5b3c8..197dc36b 100644 --- a/rten-vecmath/src/simd_vec/aarch64.rs +++ b/rten-vecmath/src/simd_vec/aarch64.rs @@ -1,15 +1,15 @@ use std::arch::aarch64::{ - float32x4_t, int32x4_t, vabsq_f32, vaddq_f32, vaddq_s32, vbslq_f32, vbslq_s32, vcgeq_f32, - vcgtq_s32, vcleq_f32, vcltq_f32, vcvtq_s32_f32, vdivq_f32, vdupq_n_f32, vdupq_n_s32, vfmaq_f32, - vld1q_f32, vld1q_s32, vmaxq_f32, vmulq_f32, vreinterpretq_f32_s32, vreinterpretq_f32_u32, - vreinterpretq_s32_u32, vreinterpretq_u32_f32, vreinterpretq_u32_s32, vshlq_n_s32, vst1q_f32, - vst1q_s32, vsubq_f32, vsubq_s32, + float32x4_t, int32x4_t, uint32x4_t, vabsq_f32, vaddq_f32, vaddq_s32, vbslq_f32, vbslq_s32, + vcgeq_f32, vcgtq_s32, vcleq_f32, vcltq_f32, vcvtq_s32_f32, vdivq_f32, vdupq_n_f32, vdupq_n_s32, + vfmaq_f32, vld1q_f32, vld1q_s32, vmaxq_f32, vmulq_f32, vreinterpretq_f32_s32, vshlq_n_s32, + vst1q_f32, vst1q_s32, vsubq_f32, vsubq_s32, }; use crate::simd_vec::{SimdFloat, SimdInt}; impl SimdInt for int32x4_t { type Float = float32x4_t; + type Mask = uint32x4_t; const LEN: usize = 4; @@ -21,12 +21,12 @@ impl SimdInt for int32x4_t { vdupq_n_s32(val) } - unsafe fn gt(self, other: Self) -> Self { - vreinterpretq_s32_u32(vcgtq_s32(self, other)) + unsafe fn gt(self, other: Self) -> Self::Mask { + vcgtq_s32(self, other) } - unsafe fn blend(self, other: Self, mask: Self) -> Self { - vbslq_s32(vreinterpretq_u32_s32(mask), other, self) + unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self { + vbslq_s32(mask, other, self) } unsafe fn add(self, rhs: Self) -> Self { @@ -56,6 +56,7 @@ impl SimdInt for int32x4_t { impl SimdFloat for float32x4_t { type Int = int32x4_t; + type Mask = uint32x4_t; const LEN: usize = 4; @@ -91,24 +92,24 @@ impl SimdFloat for float32x4_t { vdivq_f32(self, rhs) } - unsafe fn ge(self, rhs: Self) -> Self { - vreinterpretq_f32_u32(vcgeq_f32(self, rhs)) + unsafe fn ge(self, rhs: Self) -> Self::Mask { + vcgeq_f32(self, rhs) } - unsafe fn le(self, rhs: Self) -> Self { - vreinterpretq_f32_u32(vcleq_f32(self, rhs)) + unsafe fn le(self, rhs: Self) -> Self::Mask { + vcleq_f32(self, rhs) } - unsafe fn lt(self, rhs: Self) -> Self { - vreinterpretq_f32_u32(vcltq_f32(self, rhs)) + unsafe fn lt(self, rhs: Self) -> Self::Mask { + vcltq_f32(self, rhs) } unsafe fn max(self, rhs: Self) -> Self { vmaxq_f32(self, rhs) } - unsafe fn blend(self, other: Self, mask: Self) -> Self { - vbslq_f32(vreinterpretq_u32_f32(mask), other, self) + unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self { + vbslq_f32(mask, other, self) } unsafe fn load(ptr: *const f32) -> Self { diff --git a/rten-vecmath/src/simd_vec/wasm.rs b/rten-vecmath/src/simd_vec/wasm.rs index a8ec5f45..9d4b9e9f 100644 --- a/rten-vecmath/src/simd_vec/wasm.rs +++ b/rten-vecmath/src/simd_vec/wasm.rs @@ -18,6 +18,7 @@ pub struct v128f(v128); impl SimdInt for v128i { type Float = v128f; + type Mask = v128i; const LEN: usize = 4; @@ -25,11 +26,11 @@ impl SimdInt for v128i { Self(i32x4_splat(val)) } - unsafe fn gt(self, other: Self) -> Self { + unsafe fn gt(self, other: Self) -> Self::Mask { Self(i32x4_gt(self.0, other.0)) } - unsafe fn blend(self, other: Self, mask: Self) -> Self { + unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self { Self(v128_bitselect(other.0, self.0, mask.0)) } @@ -60,6 +61,7 @@ impl SimdInt for v128i { impl SimdFloat for v128f { type Int = v128i; + type Mask = v128i; const LEN: usize = 4; @@ -95,23 +97,23 @@ impl SimdFloat for v128f { Self(f32x4_div(self.0, rhs.0)) } - unsafe fn ge(self, rhs: Self) -> Self { - Self(f32x4_ge(self.0, rhs.0)) + unsafe fn ge(self, rhs: Self) -> Self::Mask { + v128i(f32x4_ge(self.0, rhs.0)) } - unsafe fn le(self, rhs: Self) -> Self { - Self(f32x4_le(self.0, rhs.0)) + unsafe fn le(self, rhs: Self) -> Self::Mask { + v128i(f32x4_le(self.0, rhs.0)) } - unsafe fn lt(self, rhs: Self) -> Self { - Self(f32x4_lt(self.0, rhs.0)) + unsafe fn lt(self, rhs: Self) -> Self::Mask { + v128i(f32x4_lt(self.0, rhs.0)) } unsafe fn max(self, rhs: Self) -> Self { Self(f32x4_max(self.0, rhs.0)) } - unsafe fn blend(self, rhs: Self, mask: Self) -> Self { + unsafe fn blend(self, rhs: Self, mask: Self::Mask) -> Self { Self(v128_bitselect(rhs.0, self.0, mask.0)) } diff --git a/rten-vecmath/src/simd_vec/x86_64.rs b/rten-vecmath/src/simd_vec/x86_64.rs index a6cf201b..073f198c 100644 --- a/rten-vecmath/src/simd_vec/x86_64.rs +++ b/rten-vecmath/src/simd_vec/x86_64.rs @@ -11,6 +11,7 @@ use crate::simd_vec::{SimdFloat, SimdInt}; impl SimdInt for __m256i { type Float = __m256; + type Mask = __m256i; const LEN: usize = 8; @@ -28,13 +29,13 @@ impl SimdInt for __m256i { #[target_feature(enable = "avx2")] #[target_feature(enable = "fma")] - unsafe fn gt(self, other: Self) -> Self { + unsafe fn gt(self, other: Self) -> Self::Mask { _mm256_cmpgt_epi32(self, other) } #[target_feature(enable = "avx2")] #[target_feature(enable = "fma")] - unsafe fn blend(self, other: Self, mask: Self) -> Self { + unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self { _mm256_blendv_epi8(self, other, mask) } @@ -79,6 +80,7 @@ impl SimdInt for __m256i { impl SimdFloat for __m256 { type Int = __m256i; + type Mask = __m256; const LEN: usize = 8; @@ -134,19 +136,19 @@ impl SimdFloat for __m256 { #[target_feature(enable = "avx2")] #[target_feature(enable = "fma")] - unsafe fn ge(self, rhs: Self) -> Self { + unsafe fn ge(self, rhs: Self::Mask) -> Self { _mm256_cmp_ps(self, rhs, _CMP_GE_OQ) } #[target_feature(enable = "avx2")] #[target_feature(enable = "fma")] - unsafe fn le(self, rhs: Self) -> Self { + unsafe fn le(self, rhs: Self::Mask) -> Self { _mm256_cmp_ps(self, rhs, _CMP_LE_OQ) } #[target_feature(enable = "avx2")] #[target_feature(enable = "fma")] - unsafe fn lt(self, rhs: Self) -> Self { + unsafe fn lt(self, rhs: Self::Mask) -> Self { _mm256_cmp_ps(self, rhs, _CMP_LT_OQ) } @@ -158,7 +160,7 @@ impl SimdFloat for __m256 { #[target_feature(enable = "avx2")] #[target_feature(enable = "fma")] - unsafe fn blend(self, rhs: Self, mask: Self) -> Self { + unsafe fn blend(self, rhs: Self, mask: Self::Mask) -> Self { _mm256_blendv_ps(self, rhs, mask) }