diff --git a/linalg/build.rs b/linalg/build.rs index f13450cc50..d46d2b610b 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -283,6 +283,7 @@ fn preprocess_file( "long": long, "jump_table": jump_table(), "align": align, + "offset": if msvc { "offset" } else { "rip + "}, }); for (k, v) in variants { globals.insert(k.to_string().into(), liquid::model::Value::scalar(*v)); diff --git a/linalg/src/frame/block_quant/q4_0.rs b/linalg/src/frame/block_quant/q4_0.rs index 877e81d9c8..f6a8d2cac7 100644 --- a/linalg/src/frame/block_quant/q4_0.rs +++ b/linalg/src/frame/block_quant/q4_0.rs @@ -159,8 +159,10 @@ impl BlockQuant for BaseQ4_0 { zip: usize, scales_at_end: bool, ) -> TractResult { - assert!(input.len() % self.block_bytes() == 0); - assert!(k % self.block_len() == 0); + ensure!(input.len() % self.block_bytes() == 0); + ensure!(k % self.block_len() == 0); + // ensure!(input.len() == k * r / self.block_len() * self.block_bytes()); + ensure!(zip < r); let m = if input.len() == 0 { 0 } else { diff --git a/linalg/src/frame/mmm/tests/packed_packed.rs b/linalg/src/frame/mmm/tests/packed_packed.rs index 937e70e17e..35ea87cacc 100644 --- a/linalg/src/frame/mmm/tests/packed_packed.rs +++ b/linalg/src/frame/mmm/tests/packed_packed.rs @@ -352,8 +352,11 @@ impl PackedPackedProblem { } pub fn check(&self) -> TractResult<()> { + dbg!(self); let expected = self.reference()?; + dbg!(&expected); let found = self.run()?; + dbg!(&found); let app = if K::Acc::datum_type() == f16::datum_type() { Approximation::SuperApproximate } else { diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 71870650e0..384157bc85 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -1,8 +1,9 @@ //use crate::frame::block_quant::{PackedBlockQuantFormat, Q4_0}; use crate::mmm::no_prefetch; use tract_data::prelude::f16; +use crate::frame::block_quant::*; -// const PQ40_R32: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 32); +const PQ40_R32: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 32, 16, false); MMMExternKernel!(f16, fma_mmm_f16_8x8; 8, 8; 32, 2; 0, 0; no_prefetch, is_x86_feature_detected!("fma") && is_x86_feature_detected!("f16c")); @@ -14,8 +15,6 @@ MMMExternKernel!(f32, fma_mmm_f32_32x3; 32, 3; 32, 4; 0, 0; no_prefetch, is_x86_ MMMExternKernel!(f32, fma_mmm_f32_40x2; 40, 2; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma")); MMMExternKernel!(f32, fma_mmm_f32_64x1; 64, 1; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma")); -MMMExternKernel!(f32, fma_mmm_f32_32x1; 32, 1; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma")); -/* MMMExternKernel!(f32, fma_mmm_f32_32x1; 32, 1; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma"), packing_defs: { const F32_B: PackedFormat = PackedFormat::new(DatumType::F32, 1, 4); @@ -24,7 +23,6 @@ MMMExternKernel!(f32, fma_mmm_f32_32x1; 32, 1; 32, 4; 0, 0; no_prefetch, is_x86_ packings: PQ40_F32, test: mmm_packed_packed_tests!{ is_x86_feature_detected!("fma"), fma_mmm_f32_32x1, q40f32:1 } ); -*/ MMMExternKernel!(f32, avx512_mmm_f32_128x1; 128, 1; 64, 4; 0, 0; no_prefetch, is_x86_feature_detected!("avx512f")); MMMExternKernel!(f32, avx512_mmm_f32_16x1; 16, 1; 64, 4; 0, 0; no_prefetch, is_x86_feature_detected!("avx512f")); diff --git a/linalg/x86_64/fma/fma_mmm_f32_32x1.tmpl b/linalg/x86_64/fma/fma_mmm_f32_32x1.tmpl index a9eafa9b79..0ec801adf1 100644 --- a/linalg/x86_64/fma/fma_mmm_f32_32x1.tmpl +++ b/linalg/x86_64/fma/fma_mmm_f32_32x1.tmpl @@ -61,10 +61,81 @@ Windows ABI: jmp {{L}}non_linear_loop +{{L}}q40f32_mask: +{{long}} 0x0F0F0F0F +{{L}}q40f32_eight: +{{long}} 8 + {{L}}q40f32: - // read scales:64 f16 scales + // ymm0-3: acc + // ymm4-7: scales + // ymm13: 8 + // ymm14: mask + // ymm15: b value + vbroadcastss ymm14, dword ptr [{{offset}} {{L}}q40f32_mask] + vbroadcastss ymm13, dword ptr [{{offset}} {{L}}q40f32_eight] + +{{L}}q40f32_outerloop: + // scales + vmovaps xmm4, [rax] + vmovaps xmm5, [rax + 16] + vmovaps xmm6, [rax + 32] + vmovaps xmm7, [rax + 48] + vcvtph2ps ymm4, xmm4 + vcvtph2ps ymm5, xmm5 + vcvtph2ps ymm6, xmm6 + vcvtph2ps ymm7, xmm7 + add rax, 64 + + mov rdx, 32 + +{{L}}q40f32_innerloop: + vbroadcastss ymm15, dword ptr [rcx] + vmovaps xmm8, [rax] // 32 nibbles + + vpand xmm10, xmm8, xmm14 // 16 bytes + + vpmovzxbd ymm9, xmm10 // 8 u32 + + vpermilpd xmm10, xmm10, 1 // swap 64bit halves + vpmovzxbd ymm10, xmm10 // 8 u32 + + vpsrlw xmm8, xmm8, 4 + vpand xmm12, xmm8, xmm14 // 16 bytes + vpmovzxbd ymm11, xmm12 // 8 u32 + vpermilpd xmm12, xmm12, 1 // swap 64bit halves + vpmovzxbd ymm12, xmm12 // 8 u32 + + vpsubd ymm9, ymm9, ymm13 + vpsubd ymm10, ymm10, ymm13 + vpsubd ymm11, ymm11, ymm13 + vpsubd ymm12, ymm12, ymm13 + + vcvtdq2ps ymm9, ymm9 + vcvtdq2ps ymm10, ymm10 + vcvtdq2ps ymm11, ymm11 + vcvtdq2ps ymm12, ymm12 + + vmulps ymm9, ymm9, ymm4 + vmulps ymm10, ymm10, ymm5 + vmulps ymm11, ymm11, ymm6 + vmulps ymm12, ymm12, ymm7 + + vfmadd231ps ymm0, ymm15, ymm9 + vfmadd231ps ymm1, ymm15, ymm10 + vfmadd231ps ymm2, ymm15, ymm11 + vfmadd231ps ymm3, ymm15, ymm12 + + add rax, 16 + add rcx, 4 + sub rdx, 1 + jnz {{L}}q40f32_innerloop + + sub rbx, 32 + jnz {{L}}q40f32_outerloop + + jmp {{L}}non_linear_loop - jmp {{L}}unsupported {% include "fma_mmm_f32_scalars.tmpliq" from:0, to:3, type:"f32" %} {% include "fma_mmm_f32_per_rows.tmpliq" mr:32, from:0, to:3, type:"f32" %} diff --git a/linalg/x86_64/fma/fma_sigmoid_f32.tmpl b/linalg/x86_64/fma/fma_sigmoid_f32.tmpl index e8b0aaf4d3..4f650dc102 100644 --- a/linalg/x86_64/fma/fma_sigmoid_f32.tmpl +++ b/linalg/x86_64/fma/fma_sigmoid_f32.tmpl @@ -81,8 +81,6 @@ fma_sigmoid_f32_{{suffix}} proc ldmxcsr [rsp] // ---------------------------------------------------------------------- -{%capture offset%}{% if msvc %} offset {%else%} rip + {%endif%} {%endcapture%} - cmp rsi, 0 je {{L}}done