Skip to content

Commit

Permalink
q40f32 kernel for avx2
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Aug 30, 2024
1 parent 3c128be commit f5e1087
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 10 deletions.
1 change: 1 addition & 0 deletions linalg/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ fn preprocess_file(
"long": long,
"jump_table": jump_table(),
"align": align,
"offset": if msvc { "offset" } else { "rip + "},
});
for (k, v) in variants {
globals.insert(k.to_string().into(), liquid::model::Value::scalar(*v));
Expand Down
6 changes: 4 additions & 2 deletions linalg/src/frame/block_quant/q4_0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,10 @@ impl<const QK: usize> BlockQuant for BaseQ4_0<QK> {
zip: usize,
scales_at_end: bool,
) -> TractResult<EagerPackedInput> {
assert!(input.len() % self.block_bytes() == 0);
assert!(k % self.block_len() == 0);
ensure!(input.len() % self.block_bytes() == 0);
ensure!(k % self.block_len() == 0);
// ensure!(input.len() == k * r / self.block_len() * self.block_bytes());
ensure!(zip < r);
let m = if input.len() == 0 {
0
} else {
Expand Down
3 changes: 3 additions & 0 deletions linalg/src/frame/mmm/tests/packed_packed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,11 @@ impl<K: MatMatMulKer + Default> PackedPackedProblem<K> {
}

pub fn check(&self) -> TractResult<()> {
dbg!(self);
let expected = self.reference()?;
dbg!(&expected);
let found = self.run()?;
dbg!(&found);
let app = if K::Acc::datum_type() == f16::datum_type() {
Approximation::SuperApproximate
} else {
Expand Down
6 changes: 2 additions & 4 deletions linalg/src/x86_64_fma/mmm.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//use crate::frame::block_quant::{PackedBlockQuantFormat, Q4_0};
use crate::mmm::no_prefetch;
use tract_data::prelude::f16;
use crate::frame::block_quant::*;

// const PQ40_R32: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 32);
const PQ40_R32: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 32, 16, false);

MMMExternKernel!(f16, fma_mmm_f16_8x8; 8, 8; 32, 2; 0, 0; no_prefetch, is_x86_feature_detected!("fma") && is_x86_feature_detected!("f16c"));

Expand All @@ -14,8 +15,6 @@ MMMExternKernel!(f32, fma_mmm_f32_32x3; 32, 3; 32, 4; 0, 0; no_prefetch, is_x86_
MMMExternKernel!(f32, fma_mmm_f32_40x2; 40, 2; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma"));
MMMExternKernel!(f32, fma_mmm_f32_64x1; 64, 1; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma"));

MMMExternKernel!(f32, fma_mmm_f32_32x1; 32, 1; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma"));
/*
MMMExternKernel!(f32, fma_mmm_f32_32x1; 32, 1; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma"),
packing_defs: {
const F32_B: PackedFormat = PackedFormat::new(DatumType::F32, 1, 4);
Expand All @@ -24,7 +23,6 @@ MMMExternKernel!(f32, fma_mmm_f32_32x1; 32, 1; 32, 4; 0, 0; no_prefetch, is_x86_
packings: PQ40_F32,
test: mmm_packed_packed_tests!{ is_x86_feature_detected!("fma"), fma_mmm_f32_32x1, q40f32:1 }
);
*/

MMMExternKernel!(f32, avx512_mmm_f32_128x1; 128, 1; 64, 4; 0, 0; no_prefetch, is_x86_feature_detected!("avx512f"));
MMMExternKernel!(f32, avx512_mmm_f32_16x1; 16, 1; 64, 4; 0, 0; no_prefetch, is_x86_feature_detected!("avx512f"));
Expand Down
75 changes: 73 additions & 2 deletions linalg/x86_64/fma/fma_mmm_f32_32x1.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,81 @@ Windows ABI:

jmp {{L}}non_linear_loop

{{L}}q40f32_mask:
{{long}} 0x0F0F0F0F
{{L}}q40f32_eight:
{{long}} 8

{{L}}q40f32:
// read scales:64 f16 scales
// ymm0-3: acc
// ymm4-7: scales
// ymm13: 8
// ymm14: mask
// ymm15: b value
vbroadcastss ymm14, dword ptr [{{offset}} {{L}}q40f32_mask]
vbroadcastss ymm13, dword ptr [{{offset}} {{L}}q40f32_eight]

{{L}}q40f32_outerloop:
// scales
vmovaps xmm4, [rax]
vmovaps xmm5, [rax + 16]
vmovaps xmm6, [rax + 32]
vmovaps xmm7, [rax + 48]
vcvtph2ps ymm4, xmm4
vcvtph2ps ymm5, xmm5
vcvtph2ps ymm6, xmm6
vcvtph2ps ymm7, xmm7
add rax, 64

mov rdx, 32

{{L}}q40f32_innerloop:
vbroadcastss ymm15, dword ptr [rcx]
vmovaps xmm8, [rax] // 32 nibbles

vpand xmm10, xmm8, xmm14 // 16 bytes

vpmovzxbd ymm9, xmm10 // 8 u32

vpermilpd xmm10, xmm10, 1 // swap 64bit halves
vpmovzxbd ymm10, xmm10 // 8 u32

vpsrlw xmm8, xmm8, 4
vpand xmm12, xmm8, xmm14 // 16 bytes
vpmovzxbd ymm11, xmm12 // 8 u32
vpermilpd xmm12, xmm12, 1 // swap 64bit halves
vpmovzxbd ymm12, xmm12 // 8 u32

vpsubd ymm9, ymm9, ymm13
vpsubd ymm10, ymm10, ymm13
vpsubd ymm11, ymm11, ymm13
vpsubd ymm12, ymm12, ymm13

vcvtdq2ps ymm9, ymm9
vcvtdq2ps ymm10, ymm10
vcvtdq2ps ymm11, ymm11
vcvtdq2ps ymm12, ymm12

vmulps ymm9, ymm9, ymm4
vmulps ymm10, ymm10, ymm5
vmulps ymm11, ymm11, ymm6
vmulps ymm12, ymm12, ymm7

vfmadd231ps ymm0, ymm15, ymm9
vfmadd231ps ymm1, ymm15, ymm10
vfmadd231ps ymm2, ymm15, ymm11
vfmadd231ps ymm3, ymm15, ymm12

add rax, 16
add rcx, 4
sub rdx, 1
jnz {{L}}q40f32_innerloop

sub rbx, 32
jnz {{L}}q40f32_outerloop

jmp {{L}}non_linear_loop

jmp {{L}}unsupported

{% include "fma_mmm_f32_scalars.tmpliq" from:0, to:3, type:"f32" %}
{% include "fma_mmm_f32_per_rows.tmpliq" mr:32, from:0, to:3, type:"f32" %}
Expand Down
2 changes: 0 additions & 2 deletions linalg/x86_64/fma/fma_sigmoid_f32.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ fma_sigmoid_f32_{{suffix}} proc
ldmxcsr [rsp]
// ----------------------------------------------------------------------

{%capture offset%}{% if msvc %} offset {%else%} rip + {%endif%} {%endcapture%}

cmp rsi, 0
je {{L}}done

Expand Down

0 comments on commit f5e1087

Please sign in to comment.