Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 25, 2024
1 parent 9c0719b commit 444ef0a
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 61 deletions.
22 changes: 11 additions & 11 deletions linalg/src/arm32/armv7neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@ use crate::DatumType;

const NEON: fn() -> bool = || crate::arm32::has_neon();

MMMExternKernel2!(armv7neon_mmm_f32_8x4_cortexa7 <f32>( 8, 4 )@(4, 4 ) where(NEON));
MMMExternKernel2!(armv7neon_mmm_f32_8x4_cortexa9 <f32>( 8, 4 )@(4, 4 ) where(NEON));
MMMExternKernel2!(armv7neon_mmm_f32_8x4_generic <f32>( 8, 4 )@(4, 4 ) where(NEON));
MMMExternKernel2!(armv7neon_mmm_f32_8x6_cortexa7 <f32>( 8, 6 )@(4, 4 ) where(NEON));
MMMExternKernel2!(armv7neon_mmm_f32_8x6_cortexa9 <f32>( 8, 6 )@(4, 4 ) where(NEON));
MMMExternKernel2!(armv7neon_mmm_f32_8x6_generic <f32>( 8, 6 )@(4, 4 ) where(NEON));
MMMExternKernel2!(armv7neon_mmm_f32_32x1_cortexa7<f32>( 32, 1)@( 4, 4) where(NEON));
MMMExternKernel2!(armv7neon_mmm_f32_32x1_cortexa9<f32>( 32, 1)@( 4, 4) where(NEON));
MMMExternKernel2!(armv7neon_mmm_f32_32x1_generic <f32>(32, 1 )@(4, 4 ) where(NEON));
MMMExternKernel!(armv7neon_mmm_f32_8x4_cortexa7 <f32>( 8, 4 )@(4, 4 ) where(NEON));
MMMExternKernel!(armv7neon_mmm_f32_8x4_cortexa9 <f32>( 8, 4 )@(4, 4 ) where(NEON));
MMMExternKernel!(armv7neon_mmm_f32_8x4_generic <f32>( 8, 4 )@(4, 4 ) where(NEON));
MMMExternKernel!(armv7neon_mmm_f32_8x6_cortexa7 <f32>( 8, 6 )@(4, 4 ) where(NEON));
MMMExternKernel!(armv7neon_mmm_f32_8x6_cortexa9 <f32>( 8, 6 )@(4, 4 ) where(NEON));
MMMExternKernel!(armv7neon_mmm_f32_8x6_generic <f32>( 8, 6 )@(4, 4 ) where(NEON));
MMMExternKernel!(armv7neon_mmm_f32_32x1_cortexa7<f32>( 32, 1)@( 4, 4) where(NEON));
MMMExternKernel!(armv7neon_mmm_f32_32x1_cortexa9<f32>( 32, 1)@( 4, 4) where(NEON));
MMMExternKernel!(armv7neon_mmm_f32_32x1_generic <f32>(32, 1 )@(4, 4 ) where(NEON));

MMMExternKernel2!(armv7neon_mmm_i32_8x4<i32>(8, 4)@(32, 4) where(NEON)
MMMExternKernel!(armv7neon_mmm_i32_8x4<i32>(8, 4)@(32, 4) where(NEON)
packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 8, 32), PackedFormat::new(DatumType::I8, 4, 32))
);

MMMExternKernel2!(armv7neon_mmm_i32_32x1<i32>(32, 1)@(32, 4) where(NEON)
MMMExternKernel!(armv7neon_mmm_i32_32x1<i32>(32, 1)@(32, 4) where(NEON)
packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 32, 32), PackedFormat::new(DatumType::I8, 1, 4))
);

Expand Down
2 changes: 1 addition & 1 deletion linalg/src/arm32/armvfpv2.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
use crate::frame::mmm::*;

MMMExternKernel2!(armvfpv2_mmm_f32_4x4<f32>(4, 4)@(4, 4));
MMMExternKernel!(armvfpv2_mmm_f32_4x4<f32>(4, 4)@(4, 4));
18 changes: 9 additions & 9 deletions linalg/src/arm64/arm64fp16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ use crate::frame::block_quant::PackedBlockQuantFormat;

const FP16: fn() -> bool = crate::arm64::has_fp16;

MMMExternKernel2!(arm64fp16_mmm_f16_16x8_gen<f16>(16, 8)@(16, 16) where(FP16));
MMMExternKernel2!(arm64fp16_mmm_f16_16x8_a55<f16>(16, 8)@(16, 16) where(FP16));
MMMExternKernel2!(arm64fp16_mmm_f16_32x4_gen<f16>(32, 4)@(16, 16) where(FP16));
MMMExternKernel2!(arm64fp16_mmm_f16_32x4_a55<f16>(32, 4)@(16, 16) where(FP16));
MMMExternKernel2!(arm64fp16_mmm_f16_128x1_gen<f16>(128,1)@(16, 16) where(FP16));
MMMExternKernel2!(arm64fp16_mmm_f16_128x1_a55<f16>(128,1)@(16, 16) where(FP16));
MMMExternKernel!(arm64fp16_mmm_f16_16x8_gen<f16>(16, 8)@(16, 16) where(FP16));
MMMExternKernel!(arm64fp16_mmm_f16_16x8_a55<f16>(16, 8)@(16, 16) where(FP16));
MMMExternKernel!(arm64fp16_mmm_f16_32x4_gen<f16>(32, 4)@(16, 16) where(FP16));
MMMExternKernel!(arm64fp16_mmm_f16_32x4_a55<f16>(32, 4)@(16, 16) where(FP16));
MMMExternKernel!(arm64fp16_mmm_f16_128x1_gen<f16>(128,1)@(16, 16) where(FP16));
MMMExternKernel!(arm64fp16_mmm_f16_128x1_a55<f16>(128,1)@(16, 16) where(FP16));

MMMExternKernel2!(arm64fp16_mmm_f16_64x3_gen<f16>(64, 3)@(16, 16) where(FP16));
MMMExternKernel2!(arm64fp16_mmm_f16_32x6_gen<f16>(32, 6)@(16, 16) where(FP16));
MMMExternKernel!(arm64fp16_mmm_f16_64x3_gen<f16>(64, 3)@(16, 16) where(FP16));
MMMExternKernel!(arm64fp16_mmm_f16_32x6_gen<f16>(32, 6)@(16, 16) where(FP16));

MMMExternKernel2! { arm64fp16_mmm_f16_64x1_gen<f16>(64, 1)@(16, 16) where(FP16)
MMMExternKernel! { arm64fp16_mmm_f16_64x1_gen<f16>(64, 1)@(16, 16) where(FP16)
packing[1] = q40f16z16se => |k| k.with_packing_a(PackedBlockQuantFormat::new(&Q4_0, 64, 16, true))
packing[2] = q40f16z16 => |k| k.with_packing_a(PackedBlockQuantFormat::new(&Q4_0, 64, 16, false))
}
Expand Down
40 changes: 20 additions & 20 deletions linalg/src/arm64/arm64simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,29 @@ pub use mul::arm64simd_unicast_mul_f32_16n;
pub use softmax::arm64simd_softmax2_fastcompact_f32_16n;
pub use sum::arm64simd_sum_f32_16n;

MMMExternKernel2!(arm64simd_mmm_f32_8x8_a55 <f32>(8, 8)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_12x8_a55<f32>(12, 8)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_16x4_a55<f32>(16, 4)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_24x4_a55<f32>(24, 4)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_64x1_a55<f32>(64, 1)@(16, 16));

MMMExternKernel2!(arm64simd_mmm_f32_16x4_a53<f32>(16, 4)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_24x4_a53<f32>(24, 4)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_8x8_a53 <f32>(8, 8)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_12x8_a53<f32>(12, 8)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_64x1_a53<f32>(64, 1)@(16, 16));

MMMExternKernel2!(arm64simd_mmm_f32_16x4_gen<f32>(16, 4)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_24x4_gen<f32>(24, 4)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_8x8_gen <f32>(8, 8)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_12x8_gen<f32>(12, 8)@(16, 16));
MMMExternKernel2!(arm64simd_mmm_f32_64x1_gen<f32>(64, 1)@(16, 16));

MMMExternKernel2!(arm64simd_mmm_i32_8x8<i32>(8, 8)@(16, 16)
MMMExternKernel!(arm64simd_mmm_f32_8x8_a55 <f32>(8, 8)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_12x8_a55<f32>(12, 8)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_16x4_a55<f32>(16, 4)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_24x4_a55<f32>(24, 4)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_64x1_a55<f32>(64, 1)@(16, 16));

MMMExternKernel!(arm64simd_mmm_f32_16x4_a53<f32>(16, 4)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_24x4_a53<f32>(24, 4)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_8x8_a53 <f32>(8, 8)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_12x8_a53<f32>(12, 8)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_64x1_a53<f32>(64, 1)@(16, 16));

MMMExternKernel!(arm64simd_mmm_f32_16x4_gen<f32>(16, 4)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_24x4_gen<f32>(24, 4)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_8x8_gen <f32>(8, 8)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_12x8_gen<f32>(12, 8)@(16, 16));
MMMExternKernel!(arm64simd_mmm_f32_64x1_gen<f32>(64, 1)@(16, 16));

MMMExternKernel!(arm64simd_mmm_i32_8x8<i32>(8, 8)@(16, 16)
packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 8, 16), PackedFormat::new(DatumType::I8, 8, 16))
);

MMMExternKernel2!(arm64simd_mmm_i32_64x1<i32>(64, 1)@(16, 1)
MMMExternKernel!(arm64simd_mmm_i32_64x1<i32>(64, 1)@(16, 1)
packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 64,16), PackedFormat::new(DatumType::I8, 1, 1))
);

Expand Down
2 changes: 1 addition & 1 deletion linalg/src/frame/mmm/macros.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
macro_rules! MMMExternKernel2 {
macro_rules! MMMExternKernel {
(
$func:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr)
$(where($where:expr))?
Expand Down
38 changes: 19 additions & 19 deletions linalg/src/x86_64_fma/mmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,31 @@ const FMA_F16C: fn() -> bool =
|| is_x86_feature_detected!("fma") && is_x86_feature_detected!("f16c");
const AVX512F: fn() -> bool = || is_x86_feature_detected!("avx512f");

MMMExternKernel2!(fma_mmm_f16_8x8<f16>(8,8)@(32,2) where(FMA_F16C));
MMMExternKernel!(fma_mmm_f16_8x8<f16>(8,8)@(32,2) where(FMA_F16C));

MMMExternKernel2!(fma_mmm_f32_8x8 <f32>(8, 8)@(32,4) where(FMA));
MMMExternKernel2!(fma_mmm_f32_16x6<f32>(16,6)@(32,4) where(FMA));
MMMExternKernel2!(fma_mmm_f32_16x5<f32>(16,5)@(32,4) where(FMA));
MMMExternKernel2!(fma_mmm_f32_24x4<f32>(24,4)@(32,4) where(FMA));
MMMExternKernel2!(fma_mmm_f32_32x3<f32>(32,3)@(32,4) where(FMA));
MMMExternKernel2!(fma_mmm_f32_40x2<f32>(40,2)@(32,4) where(FMA));
MMMExternKernel2!(fma_mmm_f32_64x1<f32>(64,1)@(32,4) where(FMA));
MMMExternKernel!(fma_mmm_f32_8x8 <f32>(8, 8)@(32,4) where(FMA));
MMMExternKernel!(fma_mmm_f32_16x6<f32>(16,6)@(32,4) where(FMA));
MMMExternKernel!(fma_mmm_f32_16x5<f32>(16,5)@(32,4) where(FMA));
MMMExternKernel!(fma_mmm_f32_24x4<f32>(24,4)@(32,4) where(FMA));
MMMExternKernel!(fma_mmm_f32_32x3<f32>(32,3)@(32,4) where(FMA));
MMMExternKernel!(fma_mmm_f32_40x2<f32>(40,2)@(32,4) where(FMA));
MMMExternKernel!(fma_mmm_f32_64x1<f32>(64,1)@(32,4) where(FMA));

const PQ40_R32: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 32, 16, false);
MMMExternKernel2! {fma_mmm_f32_32x1<f32>(32,1)@(32,4) where(FMA)
MMMExternKernel! {fma_mmm_f32_32x1<f32>(32,1)@(32,4) where(FMA)
packing[1] = q40f32 => |k| k.with_packing_a(PQ40_R32)
}

MMMExternKernel2!(avx512_mmm_f32_128x1<f32>(128, 1)@(64,4) where (AVX512F));
MMMExternKernel2!(avx512_mmm_f32_16x1 <f32>( 16, 1)@(64,4) where (AVX512F));
MMMExternKernel2!(avx512_mmm_f32_16x12<f32>( 16,12)@(64,4) where (AVX512F));
MMMExternKernel2!(avx512_mmm_f32_16x8 <f32>( 16, 8)@(64,4) where (AVX512F));
MMMExternKernel2!(avx512_mmm_f32_32x6 <f32>( 32, 6)@(64,4) where (AVX512F));
MMMExternKernel2!(avx512_mmm_f32_32x5 <f32>( 32, 5)@(64,4) where (AVX512F));
MMMExternKernel2!(avx512_mmm_f32_48x4 <f32>( 48, 4)@(64,4) where (AVX512F));
MMMExternKernel2!(avx512_mmm_f32_64x3 <f32>( 64, 3)@(64,4) where (AVX512F));
MMMExternKernel2!(avx512_mmm_f32_80x2 <f32>( 80, 2)@(64,4) where (AVX512F));
MMMExternKernel!(avx512_mmm_f32_128x1<f32>(128, 1)@(64,4) where (AVX512F));
MMMExternKernel!(avx512_mmm_f32_16x1 <f32>( 16, 1)@(64,4) where (AVX512F));
MMMExternKernel!(avx512_mmm_f32_16x12<f32>( 16,12)@(64,4) where (AVX512F));
MMMExternKernel!(avx512_mmm_f32_16x8 <f32>( 16, 8)@(64,4) where (AVX512F));
MMMExternKernel!(avx512_mmm_f32_32x6 <f32>( 32, 6)@(64,4) where (AVX512F));
MMMExternKernel!(avx512_mmm_f32_32x5 <f32>( 32, 5)@(64,4) where (AVX512F));
MMMExternKernel!(avx512_mmm_f32_48x4 <f32>( 48, 4)@(64,4) where (AVX512F));
MMMExternKernel!(avx512_mmm_f32_64x3 <f32>( 64, 3)@(64,4) where (AVX512F));
MMMExternKernel!(avx512_mmm_f32_80x2 <f32>( 80, 2)@(64,4) where (AVX512F));

MMMExternKernel2! { avx2_mmm_i32_8x8<i32>(8,8)@(32,4) where(AVX2)
MMMExternKernel! { avx2_mmm_i32_8x8<i32>(8,8)@(32,4) where(AVX2)
packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 8,32), PackedFormat::new(DatumType::I8, 8, 4))
}

0 comments on commit 444ef0a

Please sign in to comment.