Skip to content

Commit

Permalink
scales_at_end
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieupoumeyrolsonos authored and kali committed Aug 30, 2024
1 parent 8ec6c01 commit 3c128be
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 35 deletions.
61 changes: 61 additions & 0 deletions linalg/arm64/arm64fp16/arm64fp16_mmm_f16_64x1.core.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
beq .non_linear_loop

cmp x4, #1
beq .q4f16se

cmp x4, #2
beq .q4f16

.p2align 4
Expand All @@ -60,6 +63,64 @@
.q40f16_const:
.byte 0xc8, 0xc7, 0xc6, 0xc5, 0xc4, 0xc2, 0xc0, 0xbc
.byte 0x00, 0x3c, 0x40, 0x42, 0x44, 0x45, 0x46, 0x47

.q4f16se:
adr x4, .q40f16_const
movi v15.16b, 15
ld1 {v13.16b}, [ x4 ]
eor v12.16b, v12.16b, v12.16b

.q4f16se_outerloop:
{% for i in (0..7) %}
eor v{{i|plus:16}}.16b, v{{i|plus:16}}.16b, v{{i|plus:16}}.16b
{% endfor %}
mov x4, #32

.p2align 4
.q4f16se_innerloop:
ld1 { v9.16b-v10.16b }, [x1], #32
ld1 { v8.h }[0], [ x2 ], #2

and v0.16b, v9.16b, v15.16b
ushr v2.16b, v9.16b, 4

and v4.16b, v10.16b, v15.16b
ushr v6.16b, v10.16b, 4

tbl v0.16b, { v13.16b }, v0.16b
tbl v2.16b, { v13.16b }, v2.16b
tbl v4.16b, { v13.16b }, v4.16b
tbl v6.16b, { v13.16b }, v6.16b

zip2 v1.16b, v12.16b, v0.16b
zip2 v3.16b, v12.16b, v2.16b
zip2 v5.16b, v12.16b, v4.16b
zip2 v7.16b, v12.16b, v6.16b

zip1 v0.16b, v12.16b, v0.16b
zip1 v2.16b, v12.16b, v2.16b
zip1 v4.16b, v12.16b, v4.16b
zip1 v6.16b, v12.16b, v6.16b

{% for i in (0..7) %}
fmla v{{ i|plus: 16 }}.8h, v{{i}}.8h, v8.h[0]
{% endfor %}

subs x4, x4, #1
bne .q4f16se_innerloop

// scales
ld1 { v0.8h-v3.8h }, [ x1 ], #64
ld1 { v4.8h-v7.8h }, [ x1 ], #64

{% for i in (0..7) %}
fmla v{{i|plus:24}}.8h, v{{i}}.8h, v{{i|plus:16}}.8h
{% endfor %}

subs x3, x3, #32
bne .q4f16se_outerloop

b .non_linear_loop

.q4f16:
adr x4, .q40f16_const
Expand Down
9 changes: 6 additions & 3 deletions linalg/src/arm64/arm64fp16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ MMMExternKernel!(f16, arm64fp16_mmm_f16_128x1_a55; 128, 1; 16, 16; 1, 1; no_pref
MMMExternKernel!(f16, arm64fp16_mmm_f16_64x1_gen; 64, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16(),
packing_defs: {
use crate::frame::block_quant::*;
const PQ40_R64_Z16: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 64, 16);
const PQ40_R64_Z16_SE: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 64, 16, true);
const PQ40_R64_Z16: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 64, 16, false);
const F16_B: PackedFormat = PackedFormat::new(DatumType::F16, 1, 2);
const PQ40_F16_Z16_SE: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&PQ40_R64_Z16_SE, &F16_B);
const PQ40_F16_Z16: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&PQ40_R64_Z16, &F16_B);
},
packings: PQ40_F16_Z16,
test: mmm_packed_packed_tests!{ crate::arm64::has_fp16(), arm64fp16_mmm_f16_64x1_gen, q40f16z16:1 }
packings: PQ40_F16_Z16_SE PQ40_F16_Z16,
test: mmm_packed_packed_tests!{ crate::arm64::has_fp16(), arm64fp16_mmm_f16_64x1_gen, q40f16z16se:1 },
test: mmm_packed_packed_tests!{ crate::arm64::has_fp16(), arm64fp16_mmm_f16_64x1_gen, q40f16z16:2 }
);

tanh_impl!(f16, arm64fp16_tanh_f16_8n, 8, 8, crate::arm64::has_fp16());
Expand Down
23 changes: 18 additions & 5 deletions linalg/src/frame/block_quant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,14 @@ pub trait BlockQuant: Debug + Display + Send + Sync + DynClone + DynHash + Downc
}
}

fn pack(&self, input: &[u8], k: usize, r: usize, zip: usize) -> TractResult<EagerPackedInput>;
fn pack(
&self,
input: &[u8],
k: usize,
r: usize,
zip: usize,
scales_at_end: bool,
) -> TractResult<EagerPackedInput>;

unsafe fn extract_panel(
&self,
Expand Down Expand Up @@ -126,6 +133,7 @@ pub struct PackedBlockQuantFormat {
pub bq: StaticBlockQuant,
pub r: usize,
pub zip: usize,
pub scales_at_end: bool,
}

impl Display for PackedBlockQuantFormat {
Expand All @@ -141,8 +149,13 @@ impl Debug for PackedBlockQuantFormat {
}

impl PackedBlockQuantFormat {
pub const fn new(bq: &'static dyn BlockQuant, r: usize, zip: usize) -> Self {
PackedBlockQuantFormat { bq: StaticBlockQuant::Borrow(bq), r, zip }
pub const fn new(
bq: &'static dyn BlockQuant,
r: usize,
zip: usize,
scales_at_end: bool,
) -> Self {
PackedBlockQuantFormat { bq: StaticBlockQuant::Borrow(bq), r, zip, scales_at_end }
}

#[cfg(test)]
Expand Down Expand Up @@ -172,7 +185,7 @@ impl PackedBlockQuantFormat {
}

pub fn pack(&self, input: &[u8], k: usize) -> TractResult<EagerPackedInput> {
self.bq.pack(input, k, self.r, self.zip)
self.bq.pack(input, k, self.r, self.zip, self.scales_at_end)
}
}

Expand Down Expand Up @@ -201,7 +214,7 @@ impl MMMInputFormat for PackedBlockQuantFormat {
} else {
todo!()
};
Ok(Box::new(self.bq.pack(&quant, k, self.r, self.zip)?))
Ok(Box::new(self.pack(&quant, k)?))
}

fn k_alignment(&self) -> usize {
Expand Down
72 changes: 49 additions & 23 deletions linalg/src/frame/block_quant/q4_0.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use num_traits::{AsPrimitive, Float};
use num_traits::{AsPrimitive, Float, Zero};
use std::alloc::Layout;

#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
Expand Down Expand Up @@ -68,24 +68,27 @@ impl<const QK: usize> BaseQ4_0<QK> {
let scratch = std::slice::from_raw_parts_mut(scratch as *mut T, value.k * target.r);
let blocks_for_k = value.k / self.block_len();
let row_bytes = blocks_for_k * self.block_bytes();
dbg!(&value);
dbg!(&value.packed);
dbg!(&value.packed[panel * target.r * row_bytes..]);
let mut input = NibbleReader::for_slice(&value.packed[panel * target.r * row_bytes..]);
let input = &value.packed[panel * target.r * row_bytes..];
let mut scales = vec![T::zero(); target.r];
let mut scratch = scratch.iter_mut();
let zipped_order = zipped_order(pbqf.r, pbqf.zip);
let mut weights = vec!(0i8; pbqf.r);
dbg!(pbqf);
dbg!(panel, blocks_for_k);
for _ in 0..blocks_for_k {
let mut weights = vec![0i8; pbqf.r];
let panel_block_bytes = target.r * self.block_bytes();
let (scale_offset, weights_offset) = if pbqf.scales_at_end {
(panel_block_bytes - target.r * f16::datum_type().size_of(), 0)
} else {
(0, target.r * f16::datum_type().size_of())
};
for block in 0..blocks_for_k {
let block = &input[block * panel_block_bytes..][..panel_block_bytes];
let mut s_reader = NibbleReader::for_slice(&block[scale_offset..]);
let mut w_reader = NibbleReader::for_slice(&block[weights_offset..]);
for s in &mut scales {
*s = input.read_f16().as_();
*s = s_reader.read_f16().as_();
}
dbg!(&scales);
for _ in 0..self.block_len() {
for &o in &zipped_order {
weights[o] = input.read_i4();
weights[o] = w_reader.read_i4();
}
for (w, s) in weights.iter().zip(scales.iter()) {
*scratch.next().unwrap() = *s * (*w - 8).as_();
Expand Down Expand Up @@ -148,7 +151,14 @@ impl<const QK: usize> BlockQuant for BaseQ4_0<QK> {
// s0_0 S1_0 S2_0 s3_0 n0_0 n1_0 n2_0 n3_0 n0_1 n1_1 n2_1 n3_1 ... n0_33 n1_33 n2_33 n3_33
// s0_32 S1_32 S2_32 s3_32 n0_0 n1_0 n2_0 n3_0 n0_1 n1_1 n2_1 n3_1 ... n0_33 n1_33 n2_33 n3_33
// ...
fn pack(&self, input: &[u8], k: usize, r: usize, zip: usize) -> TractResult<EagerPackedInput> {
fn pack(
&self,
input: &[u8],
k: usize,
r: usize,
zip: usize,
scales_at_end: bool,
) -> TractResult<EagerPackedInput> {
assert!(input.len() % self.block_bytes() == 0);
assert!(k % self.block_len() == 0);
let m = if input.len() == 0 {
Expand All @@ -164,6 +174,7 @@ impl<const QK: usize> BlockQuant for BaseQ4_0<QK> {
unsafe { Blob::for_layout(Layout::from_size_align(panel_bytes * panels, 128)?) };
let mut writer = NibbleWriter::for_slice(&mut blob);
let order = zipped_order(r, zip);
let mut scales = vec![f16::zero(); r];
for p in 0..panels {
let input = &input[(r * p) * row_bytes..];
let mut readers = (0..r)
Expand All @@ -174,23 +185,29 @@ impl<const QK: usize> BlockQuant for BaseQ4_0<QK> {
})
.collect_vec();
for _ in 0..blocks_for_k {
for reader in &mut readers {
let scale = reader.read_f16();
writer.write_f16(scale);
for (ix, reader) in readers.iter_mut().enumerate() {
scales[ix] = reader.read_f16();
}
if !scales_at_end {
scales.iter().for_each(|s| writer.write_f16(*s))
}
for _ in 0..self.block_len() {
for &ix in &order {
let nib = readers[ix].read_i4();
writer.write_i4(nib);
}
}
if scales_at_end {
scales.iter().for_each(|s| writer.write_f16(*s))
}
}
}
Ok(EagerPackedInput {
format: Box::new(PackedBlockQuantFormat {
bq: StaticBlockQuant::Owned(Box::new(*self)),
r,
zip,
scales_at_end,
}),
packed: blob,
mn: m,
Expand Down Expand Up @@ -279,19 +296,29 @@ mod tests {
cycle_f16(Q4_0, &[-1234.0]);
}



#[test]
fn packing() -> TractResult<()> {
test_packing(BaseQ4_0::<2>, 4, 4, 2, 0)
test_packing(BaseQ4_0::<2>, 4, 4, 2, 0, false)
}

#[test]
fn packing_with_zip() -> TractResult<()> {
test_packing(BaseQ4_0::<2>, 2, 8, 8, 4)
test_packing(BaseQ4_0::<2>, 2, 8, 8, 4, false)
}

#[test]
fn packing_with_scales_at_end() -> TractResult<()> {
test_packing(BaseQ4_0::<2>, 2, 4, 4, 0, true)
}

fn test_packing(q: impl BlockQuant, k: usize, m: usize, r:usize, zip: usize) -> TractResult<()> {
fn test_packing(
q: impl BlockQuant,
k: usize,
m: usize,
r: usize,
zip: usize,
scales_at_end: bool,
) -> TractResult<()> {
let weights_orig =
Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
.into_tensor();
Expand All @@ -301,7 +328,7 @@ mod tests {
let packed_f32 = packer.pack_tensor(&weights_f32, 1, 0)?;

let q4 = q.quant_f32(&weights_f32.as_slice::<f32>()?)?;
let packed_q4 = q.pack(&q4, k, r, zip)?;
let packed_q4 = q.pack(&q4, k, r, zip, scales_at_end)?;

for panel in 0..packed_f32.panels_count() {
unsafe {
Expand All @@ -314,5 +341,4 @@ mod tests {
}
Ok(())
}

}
2 changes: 1 addition & 1 deletion linalg/src/frame/mmm/tests/packed_packed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ macro_rules! mmm_packed_packed_tests {

#[test]
fn packed_packed_bug_4() -> TractResult<()> {
if $ker.mr() >= 16 {
if $ker.mr() > 16 {
let mut a = vec![0f32; $ker.mr()];
let mut b = vec![0f32; $ker.nr()];
a[16] = 1.;
Expand Down
46 changes: 43 additions & 3 deletions linalg/src/generic/mmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,41 @@ unsafe fn add_mat_mul_pq40<const MR: usize, const NR: usize, TI>(
}
}

unsafe fn add_mat_mul_pq40_scales_at_end<const MR: usize, const NR: usize, TI>(
pa: *const u8,
pb: *const u8,
k: usize,
ab: &mut [[TI; NR]; MR],
) where
TI: LADatum,
f16: AsPrimitive<TI>,
i8: AsPrimitive<TI>,
{
assert!(k % Q4_0.block_len() == 0);
let len = (k * MR) / Q4_0.block_len() * Q4_0.block_bytes();
let mut pa = NibbleReader::for_slice(std::slice::from_raw_parts(pa, len));
let b = pb as *const TI;
for bk in 0..k / 32 {
let mut temp = [[TI::zero(); NR]; MR];
for ik in 0..32 {
let mut a: [TI; MR] = [TI::zero(); MR];
a.iter_mut().for_each(|x| *x = (pa.read_i4() - 8).as_());
let b = std::slice::from_raw_parts(b.add(NR * (ik + 32 * bk)), NR);
for i in 0..MR {
for j in 0..NR {
temp[i][j] += a[i] * b[j];
}
}
}
for i in 0..MR {
let scale = pa.read_f16().as_();
for j in 0..NR {
ab[i][j] += temp[i][j] * scale;
}
}
}
}

unsafe fn store_t<const MR: usize, const NR: usize, TC, TI>(
tile: &OutputStoreKer,
ab: &[[TI; NR]; MR],
Expand Down Expand Up @@ -208,6 +243,8 @@ where
add_mat_mul::<MR, NR, TI, TI, TI>(pa, pb, k, &mut ab);
} else if packing == 1 {
add_mat_mul_pq40(pa, pb, k, &mut ab);
} else if packing == 2 {
add_mat_mul_pq40_scales_at_end(pa, pb, k, &mut ab)
}
} else if TI::datum_type() == i32::datum_type() {
// transmute to allow using explicitly i3 in add_mat_mul generic params
Expand Down Expand Up @@ -237,16 +274,19 @@ where
0
}

const PQ40_R4: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 4, 0);
const PQ40_R4: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 4, 0, false);
const PQ40_R4_SE: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 4, 0, true);

MMMKernelWrapper!(f16, generic_f16_4x4; kernel::<f16, 4, 4>; 4, 4; 4, 4; 0, 0; no_prefetch, true);
MMMKernelWrapper!(f16, generic_f16_4x1; kernel::<f16, 4, 1>; 4, 1; 4, 4; 0, 0; no_prefetch, true,
packing_defs: {
const F16_B: PackedFormat = PackedFormat::new(DatumType::F16, 1, 4);
const PQ40_F16: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&super::PQ40_R4, &F16_B);
const PQ40_F16_SE: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&super::PQ40_R4_SE, &F16_B);
},
packings: PQ40_F16,
test: mmm_packed_packed_tests!{ true, generic_f16_4x1, q40f16:1 }
packings: PQ40_F16 PQ40_F16_SE,
test: mmm_packed_packed_tests!{ true, generic_f16_4x1, q40f16:1 },
test: mmm_packed_packed_tests!{ true, generic_f16_4x1, q40f16se:2 }
);

MMMKernelWrapper!(f32, generic_f32_4x4; kernel::<f32, 4, 4>; 4, 4; 4, 4; 0, 0; no_prefetch, true,
Expand Down

0 comments on commit 3c128be

Please sign in to comment.