Skip to content

Commit

Permalink
amx, can_fuse
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieupoumeyrolsonos committed Sep 25, 2024
1 parent 752fd0f commit 468418d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
20 changes: 6 additions & 14 deletions linalg/src/arm64/apple_amx.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
use crate::mmm::*;
use tract_data::prelude::*;

MMMExternKernel!(f32, apple_amx_mmm_f32_32x32; 32, 32; 128, 128; 0, 0; no_prefetch, crate::arm64::has_amx(),
can_fuse: |f| !matches!(f, &FusedSpec::LeakyRelu(_))
);
const AMX: fn() -> bool = || crate::arm64::has_amx();
const CAN_FUSE: fn(&FusedSpec) -> bool = |f| !matches!(f, &FusedSpec::LeakyRelu(_));

MMMExternKernel!(f32, apple_amx_mmm_f32_32x1; 32, 1; 128, 128; 0, 0; no_prefetch, crate::arm64::has_amx(),
can_fuse: |f| !matches!(f, &FusedSpec::LeakyRelu(_))
);

MMMExternKernel!(f16, apple_amx_mmm_f16_64x32; 64, 32; 128, 128; 0, 0; no_prefetch, crate::arm64::has_amx(),
can_fuse: |f| !matches!(f, &FusedSpec::LeakyRelu(_))
);

MMMExternKernel!(f16, apple_amx_mmm_f16_64x1; 64, 1; 128, 128; 0, 0; no_prefetch, crate::arm64::has_amx(),
can_fuse: |f| !matches!(f, &FusedSpec::LeakyRelu(_))
);
MMMExternKernel!(apple_amx_mmm_f32_32x32<f32>(32, 32)@(128, 128) where(AMX) can_fuse(CAN_FUSE));
MMMExternKernel!(apple_amx_mmm_f32_32x1<f32>(32, 1)@(128, 128) where(AMX) can_fuse(CAN_FUSE));
MMMExternKernel!(apple_amx_mmm_f16_64x32<f16>(64, 32)@(128, 128) where(AMX) can_fuse(CAN_FUSE));
MMMExternKernel!(apple_amx_mmm_f16_64x1<f16>(64, 1)@(128, 128) where(AMX) can_fuse(CAN_FUSE));
10 changes: 10 additions & 0 deletions linalg/src/frame/mmm/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct DynKernel<const MR: usize, const NR: usize, Acc: LADatum> {
pub default_packing_alignments: (usize, usize),
pub packings: Vec<(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)>,
pub supported_predicate: fn() -> bool,
pub can_fuse: fn(&FusedSpec) -> bool,
}

impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
Expand All @@ -50,6 +51,7 @@ impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
packings: vec![],
supported_predicate: || true,
default_packing_alignments,
can_fuse: |_| true,
};
let a = kernel.regular_pack_a();
let b = kernel.regular_pack_b();
Expand Down Expand Up @@ -79,6 +81,10 @@ impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
PackedFormat::new(Acc::datum_type(), NR, self.default_packing_alignments.1)
}

pub fn with_can_fuse(self, can_fuse: fn(&FusedSpec) -> bool) -> Self {
Self { can_fuse, ..self }
}

pub fn mmm(&self) -> Box<dyn MatMatMul> {
Box::new(self.clone())
}
Expand Down Expand Up @@ -108,6 +114,10 @@ impl<const MR: usize, const NR: usize, Acc: LADatum> MatMatMulKer for DynKernel<
NR
}

fn can_fuse(&self, spec: &FusedSpec) -> bool {
(self.can_fuse)(spec)
}

fn kernel(&self, op: &[FusedKerSpec<Self::Acc>]) -> isize {
unsafe { (self.kernel)(op) }
}
Expand Down
6 changes: 6 additions & 0 deletions linalg/src/frame/mmm/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ macro_rules! MMMExternKernel {
(
$func:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr)
$(where($where:expr))?
$(can_fuse($can_fuse:expr))?
$(packing[$pnum:literal] = $pid:ident => $packing:expr)*
) => {
paste! {
Expand All @@ -20,6 +21,7 @@ macro_rules! MMMExternKernel {

MMMKernel!([<sys_$func>]::rusty as $func<$ti>($mr, $nr)@($align_a, $align_b)
$(where($where))?
$(can_fuse($can_fuse))?
$(packing[$pnum] = $pid => $packing)*
);
}
Expand All @@ -29,6 +31,7 @@ macro_rules! MMMRustKernel {
( $func: path =>
$id:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr)
$(where($where:expr))?
$(can_fuse($can_fuse:expr))?
$(packing[$pnum:literal] = $pid:ident => $packing:expr)*
) => {
paste! {
Expand All @@ -43,6 +46,7 @@ macro_rules! MMMRustKernel {
}
MMMKernel!([<sys_$id>]::rusty as $id<$ti>($mr, $nr)@($align_a, $align_b)
$(where($where))?
$(can_fuse($can_fuse))?
$(packing[$pnum] = $pid => $packing)*
);
}
Expand All @@ -54,6 +58,7 @@ macro_rules! MMMKernel {
$func: path as
$id:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr)
$(where($where:expr))?
$(can_fuse($can_fuse:expr))?
$(packing[$pnum:literal] = $pid:ident => $packing:expr)*
) => {
paste! {
Expand All @@ -68,6 +73,7 @@ macro_rules! MMMKernel {
let f: fn(DynKernel<$mr, $nr, $ti>) -> DynKernel<$mr, $nr, $ti> = $packing;
k = f(k);
)*
$(k.can_fuse = $can_fuse;)?
k
};
}
Expand Down

0 comments on commit 468418d

Please sign in to comment.