Skip to content

Commit

Permalink
Initial work on metal backend for tract using metal flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Jun 10, 2024
1 parent c54c235 commit 6356f54
Show file tree
Hide file tree
Showing 19 changed files with 1,545 additions and 33 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ members = [
"test-rt/test-onnx-core",
"test-rt/test-nnef-cycle",
"test-rt/test-tflite"
]
, "metal"]

[workspace.dependencies]
accelerate-src = "0.3"
Expand Down Expand Up @@ -93,6 +93,7 @@ liquid-core = "0.26"
log = "0.4.14"
maplit = "1.0.2"
memmap2 = "0.9"
metal = { version = "0.27.0", features = ["mps"] }
ndarray = "0.15.3"
ndarray-npy = { version = "0.8.0", features = [ "compressed_npz" ] }
nom = "7.0.0"
Expand Down
18 changes: 5 additions & 13 deletions linalg/matmul-bench/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,14 @@ pub fn tract_blaslike(

unsafe {
let mmm = tract_linalg::ops().mmm(dt, dt, dt, Some(m), Some(k), Some(n)).unwrap();
let a_storage = mmm.a_packed(dt.size_of(), k);
let b_storage = mmm.b_packed(dt.size_of(), k);

let c_storage = mmm.c_view(0, 1);

let mut pa =
Tensor::zero_aligned_dt(dt, &[mmm.a_pack().len(k, m)], mmm.a_pack().alignment())
.unwrap();
let mut pb =
Tensor::zero_aligned_dt(dt, &[mmm.b_pack().len(k, n)], mmm.b_pack().alignment())
.unwrap();
let mut scratch = mmm.allocate_scratch_space();

crit.bench_function(&format!("tract_blaslike_{:?}", dt), |be| {
mmm.a_pack().pack(&mut pa.view_mut(), &a.view(), 1, 0);
mmm.b_pack().pack(&mut pb.view_mut(), &b.view(), 0, 1);
let packed_a = mmm.a_pack().pack_tensor(&a, 1, 0).unwrap();
let packed_b = mmm.b_pack().pack_tensor(&b, 0, 1).unwrap();

be.iter(|| {
mmm.run_with_scratch_space(
Expand All @@ -74,9 +67,8 @@ pub fn tract_blaslike(
&mut *scratch,
&[
FusedSpec::AddMatMul {
k,
a: a_storage.wrap(&pa.view()),
b: b_storage.wrap(&pb.view()),
a: packed_a.as_ref(),
b: packed_b.as_ref(),
},
FusedSpec::Store(c_storage.wrap(&mut c.view_mut())),
],
Expand Down
24 changes: 5 additions & 19 deletions linalg/matmul-bench/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,28 +409,15 @@ pub fn tract(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32])
let mmm = tract_linalg::ops()
.mmm(DatumType::F32, DatumType::F32, DatumType::F32, Some(m), Some(k), Some(n))
.unwrap();
let a_storage = mmm.a_packed(f32::datum_type().size_of(), k);
let b_storage = mmm.b_packed(f32::datum_type().size_of(), k);

let c_storage = mmm.c_view(0, 1);

let a = Tensor::from_shape(&[m, k], a).unwrap();
let b = Tensor::from_shape(&[k, n], b).unwrap();
let mut tc = Tensor::uninitialized_dt(f32::datum_type(), &[m, n]).unwrap();

let mut pa = Tensor::uninitialized_aligned_dt(
DatumType::F32,
&[mmm.a_pack().len(k, m)],
mmm.a_pack().alignment(),
)
.unwrap();
let mut pb = Tensor::uninitialized_aligned_dt(
DatumType::F32,
&[mmm.b_pack().len(k, n)],
mmm.b_pack().alignment(),
)
.unwrap();
mmm.a_pack().pack(&mut pa.view_mut(), &a.view(), 1, 0);
mmm.b_pack().pack(&mut pb.view_mut(), &b.view(), 0, 1);
let packed_a = mmm.a_pack().pack_tensor(&a, 1, 0).unwrap();
let packed_b = mmm.b_pack().pack_tensor(&b, 0, 1).unwrap();

let mut scratch = mmm.allocate_scratch_space();

Expand All @@ -440,9 +427,8 @@ pub fn tract(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32])
&mut *scratch,
&[
FusedSpec::AddMatMul {
k,
a: a_storage.wrap(&pa.view()),
b: b_storage.wrap(&pb.view()),
a: packed_a.as_ref(),
b: packed_b.as_ref(),
},
FusedSpec::Store(c_storage.wrap(&mut tc.view_mut())),
],
Expand Down
38 changes: 38 additions & 0 deletions metal/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
[package]
name = "tract-metal"
version = "0.21.6-pre"
license = "MIT OR Apache-2.0"
authors = [
"Hubert de La Jonquière <[email protected]>",
"Mathieu Poumeyrol <[email protected]>",
]
description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference"
repository = "https://github.com/snipsco/tract"
keywords = [ "TensorFlow", "NeuralNetworks", "Metal" ]
categories = [ "science" ]
autobenches = false
edition = "2021"
rust-version = "1.75"

[badges]
maintenance = { status = "actively-developed" }

[dependencies]
anyhow.workspace = true
metal.workspace = true
objc = { version = "0.2.7" }
num-traits.workspace = true
tract-core = { version = "=0.21.6-pre", path = "../core" }

[features]
default = [ ]

[dev-dependencies]
criterion = "*"
proptest.workspace = true
derive-new.workspace = true

[[bench]]
name = "metal_gemm"
harness = false

16 changes: 16 additions & 0 deletions metal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# tract-metal

## Updating Metal Flash Attention library

```
git clone https://github.com/philipturner/metal-flash-attention.git
cd metal-flash-attention
# for iOS
swift build.swift --platform iOS --xcode-path /Applications/Xcode.app
cp build/lib/libMetalFlashAttention.metallib path/to/tract/metal/src/kernels/libMetalFlashAttention-ios.metallib
# for MacOS
swift build.swift --platform macOS --xcode-path /Applications/Xcode.app
cp build/lib/libMetalFlashAttention.metallib path/to/tract/metal/src/kernels/libMetalFlashAttention-macos.metallib
```
178 changes: 178 additions & 0 deletions metal/benches/metal_gemm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use criterion::measurement::WallTime;
use criterion::*;
use ggml;
use tract_core::internal::*;
use tract_metal::*;

pub fn ggml_matmul(
crit: &mut BenchmarkGroup<WallTime>,
m: usize,
k: usize,
n: usize,
dt: DatumType,
) {
let ggml_dt = match dt {
DatumType::F32 => ggml::Type::F32,
DatumType::F16 => ggml::Type::F16,
_ => unimplemented!(),
};

let ctxt = ggml::Context::new(ggml::ContextStorage::Allocate { mem_size: 10_000_000 });

let mut a = ctxt.new_tensor_2d(ggml_dt, k, m);
a.zero_data();
println!("{:?}", a.get_ne());
let mut b = ctxt.new_tensor_2d(ggml_dt, k, n); // intern transposition
b.zero_data();
println!("{:?}", b.get_ne());

let c = ctxt.op_mul_mat(&a, &b);

crit.bench_function(&format!("ggml_{:?}", dt), |be| {
be.iter(|| {
let mut graph = ctxt.create_compute_graph();
graph.build_forward_expand(&c);

let mut execution_plan = ggml::GraphExecutionPlan::new(&mut graph, 1);
execution_plan.execute(&ctxt);
});
});
}

pub fn tract_with_packing(
crit: &mut BenchmarkGroup<WallTime>,
m: usize,
k: usize,
n: usize,
dt: DatumType,
) {
use tract_linalg::frame::mmm::FusedSpec;
let a = Tensor::zero_dt(dt, &[m, k]).unwrap();
let b = Tensor::zero_dt(dt, &[k, n]).unwrap();
let mut c = Tensor::zero_dt(dt, &[m, n]).unwrap();

// mk,kn -> mn
unsafe {
let mmm = tract_linalg::ops().mmm(dt, dt, dt, Some(m), Some(k), Some(n)).unwrap();

let c_storage = mmm.c_view(0, 1);

let mut scratch = mmm.allocate_scratch_space();

crit.bench_function(&format!("tract_with_packing_{:?}", dt), |be| {
let packed_a = mmm.a_pack().pack_tensor(&a, 1, 0).unwrap();
let packed_b = mmm.b_pack().pack_tensor(&b, 0, 1).unwrap();

be.iter(|| {
mmm.run_with_scratch_space(
m,
n,
&mut *scratch,
&[
FusedSpec::AddMatMul { a: packed_a.as_ref(), b: packed_b.as_ref() },
FusedSpec::Store(c_storage.wrap(&mut c.view_mut())),
],
)
.unwrap()
});
});
}
}

pub fn metal_gemm(
crit: &mut BenchmarkGroup<WallTime>,
m: usize,
k: usize,
n: usize,
dt: DatumType,
) {
let mut context = MetalContext::new();
context.shared_context().load_library(LibraryName::MfaLib).unwrap();

let a = Tensor::zero_dt(dt, &[1, m, k]).unwrap();
let b = Tensor::zero_dt(dt, &[1, k, n]).unwrap();
let metal_a = a.into_metal().unwrap();
let metal_b = b.into_metal().unwrap();
// Warmup
let _ = tract_metal::gemm::gemm(&mut context, &metal_a, &metal_b).unwrap();

crit.bench_function(&format!("tract_metal_gemm_{:?}", dt), |be| {
be.iter(|| {
let _ = tract_metal::gemm::gemm(&mut context, &metal_a, &metal_b).unwrap();
});
});
}

pub fn metal_tile_8x8(crit: &mut BenchmarkGroup<WallTime>, dim: usize, dt: DatumType) {
let mut context = MetalContext::new();
crit.bench_function(&format!("tract_metal_mmm_tile_8x8_{:?}", dt), |be| {
let a = Tensor::zero_dt(dt, &[dim, dim]).unwrap();
let b = Tensor::zero_dt(dt, &[dim, dim]).unwrap();
let metal_a = a.into_metal().unwrap();
let metal_b = b.into_metal().unwrap();

be.iter(|| {
let _ = tract_metal::kernels::mmm_tile_8x8(&mut context, &metal_a, &metal_b).unwrap();
});
});
}

fn matmul(c: &mut Criterion, m: usize, k: usize, n: usize) {
let mut c = c.benchmark_group(format!("{}x{}x{}", m, k, n));
c.throughput(Throughput::Elements((m * k * n) as _));
// ggml_matmul(&mut c, m, k, n, f32::datum_type());
tract_with_packing(&mut c, m, k, n, f32::datum_type());
metal_gemm(&mut c, m, k, n, f32::datum_type());
// ggml_matmul(&mut c, m, k, n, f16::datum_type());
tract_with_packing(&mut c, m, k, n, f16::datum_type());
metal_gemm(&mut c, m, k, n, f16::datum_type());
c.finish();
}

fn tinyllama(c: &mut Criterion) {
let shapes = vec![
(1, 64, 3),
(1, 64, 1),
(1, 5632, 2048),
(1, 3, 64),
(1, 64, 13),
(1, 12, 64),
(1, 2048, 5632),
(1, 2048, 32003),
(1, 2048, 2048),
(1, 2048, 256),
];
for (m, k, n) in shapes {
matmul(c, m, k, n);
}
}

fn big(c: &mut Criterion) {
matmul(c, 2048, 2048, 1);
matmul(c, 1, 2048, 2048);
matmul(c, 2048, 2048, 2048);
matmul(c, 4096, 4096, 4096);
}

fn wavenet(c: &mut Criterion) {
matmul(c, 32, 32, 8);
matmul(c, 16, 60, 8);
}

fn asr_15_m(c: &mut Criterion) {
matmul(c, 768, 200, 24);
matmul(c, 768, 2304, 24);
matmul(c, 768, 2304, 8);
matmul(c, 768, 384, 1);
}

fn inception(c: &mut Criterion) {
matmul(c, 64, 288, 21609);
}

fn whisper_base(c: &mut Criterion) {
matmul(c, 512, 512, 1500);
}

criterion_group!(benches, tinyllama, big, wavenet, asr_15_m, inception, whisper_base);
criterion_main!(benches);
Loading

0 comments on commit 6356f54

Please sign in to comment.