diff --git a/Cargo.toml b/Cargo.toml index 127f0593aa..051d40ab61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ members = [ "test-rt/test-onnx-core", "test-rt/test-nnef-cycle", "test-rt/test-tflite" -] +, "metal"] [workspace.dependencies] accelerate-src = "0.3" @@ -93,6 +93,7 @@ liquid-core = "0.26" log = "0.4.14" maplit = "1.0.2" memmap2 = "0.9" +metal = { version = "0.27.0", features = ["mps"] } ndarray = "0.15.3" ndarray-npy = { version = "0.8.0", features = [ "compressed_npz" ] } nom = "7.0.0" diff --git a/linalg/matmul-bench/benches/matmul.rs b/linalg/matmul-bench/benches/matmul.rs index 7673d46846..3359b6f658 100644 --- a/linalg/matmul-bench/benches/matmul.rs +++ b/linalg/matmul-bench/benches/matmul.rs @@ -51,21 +51,14 @@ pub fn tract_blaslike( unsafe { let mmm = tract_linalg::ops().mmm(dt, dt, dt, Some(m), Some(k), Some(n)).unwrap(); - let a_storage = mmm.a_packed(dt.size_of(), k); - let b_storage = mmm.b_packed(dt.size_of(), k); + let c_storage = mmm.c_view(0, 1); - let mut pa = - Tensor::zero_aligned_dt(dt, &[mmm.a_pack().len(k, m)], mmm.a_pack().alignment()) - .unwrap(); - let mut pb = - Tensor::zero_aligned_dt(dt, &[mmm.b_pack().len(k, n)], mmm.b_pack().alignment()) - .unwrap(); let mut scratch = mmm.allocate_scratch_space(); crit.bench_function(&format!("tract_blaslike_{:?}", dt), |be| { - mmm.a_pack().pack(&mut pa.view_mut(), &a.view(), 1, 0); - mmm.b_pack().pack(&mut pb.view_mut(), &b.view(), 0, 1); + let packed_a = mmm.a_pack().pack_tensor(&a, 1, 0).unwrap(); + let packed_b = mmm.b_pack().pack_tensor(&b, 0, 1).unwrap(); be.iter(|| { mmm.run_with_scratch_space( @@ -74,9 +67,8 @@ pub fn tract_blaslike( &mut *scratch, &[ FusedSpec::AddMatMul { - k, - a: a_storage.wrap(&pa.view()), - b: b_storage.wrap(&pb.view()), + a: packed_a.as_ref(), + b: packed_b.as_ref(), }, FusedSpec::Store(c_storage.wrap(&mut c.view_mut())), ], diff --git a/linalg/matmul-bench/src/lib.rs b/linalg/matmul-bench/src/lib.rs index a4a07207a0..e389222102 100644 --- a/linalg/matmul-bench/src/lib.rs +++ b/linalg/matmul-bench/src/lib.rs @@ -409,28 +409,15 @@ pub fn tract(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) let mmm = tract_linalg::ops() .mmm(DatumType::F32, DatumType::F32, DatumType::F32, Some(m), Some(k), Some(n)) .unwrap(); - let a_storage = mmm.a_packed(f32::datum_type().size_of(), k); - let b_storage = mmm.b_packed(f32::datum_type().size_of(), k); + let c_storage = mmm.c_view(0, 1); let a = Tensor::from_shape(&[m, k], a).unwrap(); let b = Tensor::from_shape(&[k, n], b).unwrap(); let mut tc = Tensor::uninitialized_dt(f32::datum_type(), &[m, n]).unwrap(); - let mut pa = Tensor::uninitialized_aligned_dt( - DatumType::F32, - &[mmm.a_pack().len(k, m)], - mmm.a_pack().alignment(), - ) - .unwrap(); - let mut pb = Tensor::uninitialized_aligned_dt( - DatumType::F32, - &[mmm.b_pack().len(k, n)], - mmm.b_pack().alignment(), - ) - .unwrap(); - mmm.a_pack().pack(&mut pa.view_mut(), &a.view(), 1, 0); - mmm.b_pack().pack(&mut pb.view_mut(), &b.view(), 0, 1); + let packed_a = mmm.a_pack().pack_tensor(&a, 1, 0).unwrap(); + let packed_b = mmm.b_pack().pack_tensor(&b, 0, 1).unwrap(); let mut scratch = mmm.allocate_scratch_space(); @@ -440,9 +427,8 @@ pub fn tract(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) &mut *scratch, &[ FusedSpec::AddMatMul { - k, - a: a_storage.wrap(&pa.view()), - b: b_storage.wrap(&pb.view()), + a: packed_a.as_ref(), + b: packed_b.as_ref(), }, FusedSpec::Store(c_storage.wrap(&mut tc.view_mut())), ], diff --git a/metal/Cargo.toml b/metal/Cargo.toml new file mode 100644 index 0000000000..99c240ceba --- /dev/null +++ b/metal/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "tract-metal" +version = "0.21.6-pre" +license = "MIT OR Apache-2.0" +authors = [ + "Hubert de La Jonquière ", + "Mathieu Poumeyrol ", +] +description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" +repository = "https://github.com/snipsco/tract" +keywords = [ "TensorFlow", "NeuralNetworks", "Metal" ] +categories = [ "science" ] +autobenches = false +edition = "2021" +rust-version = "1.75" + +[badges] +maintenance = { status = "actively-developed" } + +[dependencies] +anyhow.workspace = true +metal.workspace = true +objc = { version = "0.2.7" } +num-traits.workspace = true +tract-core = { version = "=0.21.6-pre", path = "../core" } + +[features] +default = [ ] + +[dev-dependencies] +criterion = "*" +proptest.workspace = true +derive-new.workspace = true + +[[bench]] +name = "metal_gemm" +harness = false + diff --git a/metal/README.md b/metal/README.md new file mode 100644 index 0000000000..d53e10fa57 --- /dev/null +++ b/metal/README.md @@ -0,0 +1,16 @@ +# tract-metal + +## Updating Metal Flash Attention library + +``` +git clone https://github.com/philipturner/metal-flash-attention.git +cd metal-flash-attention + +# for iOS +swift build.swift --platform iOS --xcode-path /Applications/Xcode.app +cp build/lib/libMetalFlashAttention.metallib path/to/tract/metal/src/kernels/libMetalFlashAttention-ios.metallib + +# for MacOS +swift build.swift --platform macOS --xcode-path /Applications/Xcode.app +cp build/lib/libMetalFlashAttention.metallib path/to/tract/metal/src/kernels/libMetalFlashAttention-macos.metallib +``` \ No newline at end of file diff --git a/metal/benches/metal_gemm.rs b/metal/benches/metal_gemm.rs new file mode 100644 index 0000000000..1225ce8f15 --- /dev/null +++ b/metal/benches/metal_gemm.rs @@ -0,0 +1,178 @@ +use criterion::measurement::WallTime; +use criterion::*; +use ggml; +use tract_core::internal::*; +use tract_metal::*; + +pub fn ggml_matmul( + crit: &mut BenchmarkGroup, + m: usize, + k: usize, + n: usize, + dt: DatumType, +) { + let ggml_dt = match dt { + DatumType::F32 => ggml::Type::F32, + DatumType::F16 => ggml::Type::F16, + _ => unimplemented!(), + }; + + let ctxt = ggml::Context::new(ggml::ContextStorage::Allocate { mem_size: 10_000_000 }); + + let mut a = ctxt.new_tensor_2d(ggml_dt, k, m); + a.zero_data(); + println!("{:?}", a.get_ne()); + let mut b = ctxt.new_tensor_2d(ggml_dt, k, n); // intern transposition + b.zero_data(); + println!("{:?}", b.get_ne()); + + let c = ctxt.op_mul_mat(&a, &b); + + crit.bench_function(&format!("ggml_{:?}", dt), |be| { + be.iter(|| { + let mut graph = ctxt.create_compute_graph(); + graph.build_forward_expand(&c); + + let mut execution_plan = ggml::GraphExecutionPlan::new(&mut graph, 1); + execution_plan.execute(&ctxt); + }); + }); +} + +pub fn tract_with_packing( + crit: &mut BenchmarkGroup, + m: usize, + k: usize, + n: usize, + dt: DatumType, +) { + use tract_linalg::frame::mmm::FusedSpec; + let a = Tensor::zero_dt(dt, &[m, k]).unwrap(); + let b = Tensor::zero_dt(dt, &[k, n]).unwrap(); + let mut c = Tensor::zero_dt(dt, &[m, n]).unwrap(); + + // mk,kn -> mn + unsafe { + let mmm = tract_linalg::ops().mmm(dt, dt, dt, Some(m), Some(k), Some(n)).unwrap(); + + let c_storage = mmm.c_view(0, 1); + + let mut scratch = mmm.allocate_scratch_space(); + + crit.bench_function(&format!("tract_with_packing_{:?}", dt), |be| { + let packed_a = mmm.a_pack().pack_tensor(&a, 1, 0).unwrap(); + let packed_b = mmm.b_pack().pack_tensor(&b, 0, 1).unwrap(); + + be.iter(|| { + mmm.run_with_scratch_space( + m, + n, + &mut *scratch, + &[ + FusedSpec::AddMatMul { a: packed_a.as_ref(), b: packed_b.as_ref() }, + FusedSpec::Store(c_storage.wrap(&mut c.view_mut())), + ], + ) + .unwrap() + }); + }); + } +} + +pub fn metal_gemm( + crit: &mut BenchmarkGroup, + m: usize, + k: usize, + n: usize, + dt: DatumType, +) { + let mut context = MetalContext::new(); + context.shared_context().load_library(LibraryName::MfaLib).unwrap(); + + let a = Tensor::zero_dt(dt, &[1, m, k]).unwrap(); + let b = Tensor::zero_dt(dt, &[1, k, n]).unwrap(); + let metal_a = a.into_metal().unwrap(); + let metal_b = b.into_metal().unwrap(); + // Warmup + let _ = tract_metal::gemm::gemm(&mut context, &metal_a, &metal_b).unwrap(); + + crit.bench_function(&format!("tract_metal_gemm_{:?}", dt), |be| { + be.iter(|| { + let _ = tract_metal::gemm::gemm(&mut context, &metal_a, &metal_b).unwrap(); + }); + }); +} + +pub fn metal_tile_8x8(crit: &mut BenchmarkGroup, dim: usize, dt: DatumType) { + let mut context = MetalContext::new(); + crit.bench_function(&format!("tract_metal_mmm_tile_8x8_{:?}", dt), |be| { + let a = Tensor::zero_dt(dt, &[dim, dim]).unwrap(); + let b = Tensor::zero_dt(dt, &[dim, dim]).unwrap(); + let metal_a = a.into_metal().unwrap(); + let metal_b = b.into_metal().unwrap(); + + be.iter(|| { + let _ = tract_metal::kernels::mmm_tile_8x8(&mut context, &metal_a, &metal_b).unwrap(); + }); + }); +} + +fn matmul(c: &mut Criterion, m: usize, k: usize, n: usize) { + let mut c = c.benchmark_group(format!("{}x{}x{}", m, k, n)); + c.throughput(Throughput::Elements((m * k * n) as _)); + // ggml_matmul(&mut c, m, k, n, f32::datum_type()); + tract_with_packing(&mut c, m, k, n, f32::datum_type()); + metal_gemm(&mut c, m, k, n, f32::datum_type()); + // ggml_matmul(&mut c, m, k, n, f16::datum_type()); + tract_with_packing(&mut c, m, k, n, f16::datum_type()); + metal_gemm(&mut c, m, k, n, f16::datum_type()); + c.finish(); +} + +fn tinyllama(c: &mut Criterion) { + let shapes = vec![ + (1, 64, 3), + (1, 64, 1), + (1, 5632, 2048), + (1, 3, 64), + (1, 64, 13), + (1, 12, 64), + (1, 2048, 5632), + (1, 2048, 32003), + (1, 2048, 2048), + (1, 2048, 256), + ]; + for (m, k, n) in shapes { + matmul(c, m, k, n); + } +} + +fn big(c: &mut Criterion) { + matmul(c, 2048, 2048, 1); + matmul(c, 1, 2048, 2048); + matmul(c, 2048, 2048, 2048); + matmul(c, 4096, 4096, 4096); +} + +fn wavenet(c: &mut Criterion) { + matmul(c, 32, 32, 8); + matmul(c, 16, 60, 8); +} + +fn asr_15_m(c: &mut Criterion) { + matmul(c, 768, 200, 24); + matmul(c, 768, 2304, 24); + matmul(c, 768, 2304, 8); + matmul(c, 768, 384, 1); +} + +fn inception(c: &mut Criterion) { + matmul(c, 64, 288, 21609); +} + +fn whisper_base(c: &mut Criterion) { + matmul(c, 512, 512, 1500); +} + +criterion_group!(benches, tinyllama, big, wavenet, asr_15_m, inception, whisper_base); +criterion_main!(benches); diff --git a/metal/src/context.rs b/metal/src/context.rs new file mode 100644 index 0000000000..89065d357c --- /dev/null +++ b/metal/src/context.rs @@ -0,0 +1,263 @@ +use crate::func_constants::ConstantValues; +pub use crate::kernels::{LibraryContent, LibraryName}; +pub use crate::tensor::MetalTensor; +use metal::Buffer; +use metal::MTLResourceOptions; +use metal::NSUInteger; +use std::cell::RefCell; +use std::path::Path; +use std::sync::Arc; +use std::sync::{OnceLock, RwLock}; + +use anyhow::{anyhow, Context, Result}; +use metal::{ + CommandBuffer, CommandQueue, CompileOptions, ComputePipelineState, Device, Function, + FunctionConstantValues, Library, +}; +use std::collections::HashMap; + +thread_local! { + pub static METAL_CONTEXT: RefCell = RefCell::new(MetalContext::new()); +} + +fn shared_metal_context() -> SharedMetalContext { + static INSTANCE: OnceLock = OnceLock::new(); + INSTANCE + .get_or_init(|| SharedMetalContext::new().expect("Could not create shared metal context")) + .clone() +} + +#[derive(Debug, Clone)] +pub struct SharedMetalContext { + device: Device, + cache_libraries: Arc>>, + cache_pipelines: Arc< + RwLock), ComputePipelineState>>, + >, +} + +impl SharedMetalContext { + pub fn new() -> Result { + let device = Device::system_default() + .with_context(|| "Could not find system default Metal device")?; + Ok(Self { + device, + cache_libraries: Arc::new(RwLock::new(HashMap::new())), + cache_pipelines: Arc::new(RwLock::new(HashMap::new())), + }) + } + + pub fn load_library(&self, name: LibraryName) -> Result { + { + let cache_libraries = self.cache_libraries.read().map_err(|e| anyhow!("{:?}", e))?; + if let Some(library) = cache_libraries.get(&name) { + return Ok(library.clone()); + } + } + let mut cache_libraries = self.cache_libraries.write().map_err(|e| anyhow!("{:?}", e))?; + let library = match name.content() { + LibraryContent::Data(lib_data) => self + .device + .new_library_with_data(lib_data) + .map_err(|e| anyhow!("{}", e)) + .with_context(|| { + anyhow!("Error while loading Metal library from data: {:?}", name) + })?, + LibraryContent::Source(lib_source) => self + .device + .new_library_with_source(lib_source, &CompileOptions::new()) + .map_err(|e| anyhow!("{}", e)) + .with_context(|| { + format!("Error while loading Metal library from source: {:?}", name) + })?, + }; + cache_libraries.insert(name, library.clone()); + Ok(library) + } + + pub fn load_function( + &self, + library_name: LibraryName, + func_name: &'static str, + constants: Option, + ) -> Result { + let func = self + .load_library(library_name)? + .get_function(func_name, constants) + .map_err(|e| anyhow!("{}", e)) + .with_context(|| { + format!( + "Error while loading function {func_name} from library: {:?} with constants", + library_name + ) + })?; + Ok(func) + } + + pub(crate) fn load_pipeline_with_constants( + &self, + library_name: LibraryName, + func_name: &'static str, + constants: Option, + ) -> Result { + let key = (library_name, func_name, constants); + { + let cache_pipelines = self.cache_pipelines.read().map_err(|e| anyhow!("{:?}", e))?; + if let Some(pipeline) = cache_pipelines.get(&key) { + return Ok(pipeline.clone()); + } + } + let mut cache_pipelines = self.cache_pipelines.write().map_err(|e| anyhow!("{:?}", e))?; + + let (library_name, func_name, constants) = key; + let func = self.load_function( + library_name, + func_name, + constants.as_ref().map(|c| c.function_constant_values()), + )?; + let pipeline = self.device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| anyhow!("{}", e)) + .with_context(|| format!("Error while creating compute pipeline for function {func_name} from source: {:?}", library_name))?; + cache_pipelines.insert((library_name, func_name, constants), pipeline.clone()); + Ok(pipeline) + } + + pub fn load_pipeline( + &mut self, + library_name: LibraryName, + func_name: &'static str, + ) -> Result { + self.load_pipeline_with_constants(library_name, func_name, None) + } +} + +#[derive(Debug)] +pub struct MetalContext { + shared: SharedMetalContext, + command_queue: CommandQueue, + command_buffer: RwLock, + command_buffer_idx: RwLock, + command_buffer_capacity: usize, +} + +impl MetalContext { + pub fn new() -> Self { + let shared = shared_metal_context(); + let command_queue = shared.device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + Self { + shared, + command_queue, + command_buffer: RwLock::new(command_buffer), + command_buffer_idx: RwLock::new(0), + command_buffer_capacity: 10, + } + } + + pub fn device(&self) -> &Device { + &self.shared.device + } + + pub fn shared_context(&self) -> &SharedMetalContext { + &self.shared + } + + pub fn buffer_from_slice_with_copy(&self, data: &[T]) -> Buffer { + let size = core::mem::size_of_val(data) as NSUInteger; + self.device().new_buffer_with_bytes_no_copy( + data.as_ptr() as *const core::ffi::c_void, + size, + MTLResourceOptions::StorageModeShared, + None, + ) + } + + pub fn buffer_from_slice_with_copy_mut(&self, data: &mut [T]) -> Buffer { + let size = core::mem::size_of_val(data) as NSUInteger; + self.device().new_buffer_with_bytes_no_copy( + data.as_ptr() as *const core::ffi::c_void, + size, + MTLResourceOptions::StorageModeShared, + None, + ) + } + + pub fn command_buffer(&self) -> Result { + let mut self_command_buffer = + self.command_buffer.try_write().map_err(|e| anyhow!("{:?}", e))?; + let mut command_buffer = self_command_buffer.to_owned(); + + let mut command_buffer_idx = + self.command_buffer_idx.try_write().map_err(|e| anyhow!("{:?}", e))?; + + if *command_buffer_idx > self.command_buffer_capacity { + *command_buffer_idx = 0; + command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + *self_command_buffer = command_buffer.clone(); + Ok(command_buffer.to_owned()) + } else { + *command_buffer_idx += 1; + Ok(command_buffer) + } + } + + pub fn wait_until_completed(&self) -> Result<()> { + let mut command_buffer = self.command_buffer.try_write().map_err(|e| anyhow!("{:?}", e))?; + + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + anyhow::bail!("Current Metal command buffer is already committed.") + } + _ => {} + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + *command_buffer = self.command_queue.new_command_buffer().to_owned(); + Ok(()) + } + + pub fn capture_trace(&self, path: P, compute: F) -> Result<()> + where + P: AsRef, + F: Fn(&Self) -> Result<()>, + { + self.wait_until_completed()?; + + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + descriptor.set_capture_device(self.device()); + descriptor.set_output_url(path); + + capture.start_capture(&descriptor).map_err(|e| anyhow!("{:?}", e))?; + + (compute)(&self)?; + + self.wait_until_completed()?; + capture.stop_capture(); + Ok(()) + } +} + +impl Drop for MetalContext { + fn drop(&mut self) { + let command_buffer = + self.command_buffer.try_write().map_err(|e| anyhow!("{:?}", e)).unwrap(); + + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Current Metal command buffer is already committed.") + } + _ => {} + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + } +} diff --git a/metal/src/func_constants.rs b/metal/src/func_constants.rs new file mode 100644 index 0000000000..ded07b3014 --- /dev/null +++ b/metal/src/func_constants.rs @@ -0,0 +1,84 @@ +use metal::{FunctionConstantValues, MTLDataType}; +use std::ffi::c_void; + +/// From candle-metal-kernels +#[derive(Debug, PartialEq)] +pub enum Value { + USize(usize), + Bool(bool), + F32(f32), + U16(u16), +} + +impl std::hash::Hash for Value { + fn hash(&self, state: &mut H) { + match self { + Value::F32(v) => v.to_bits().hash(state), + Value::USize(v) => v.hash(state), + Value::U16(v) => v.hash(state), + Value::Bool(v) => v.hash(state), + } + } +} + +impl Value { + fn data_type(&self) -> MTLDataType { + match self { + Value::USize(_) => MTLDataType::UInt, + Value::F32(_) => MTLDataType::Float, + Value::U16(_) => MTLDataType::UShort, + Value::Bool(_) => MTLDataType::Bool, + } + } +} + +// Not true, good enough for our purposes. +impl Eq for Value {} + +/// From candle-metal-kernels +#[derive(Debug, Eq, PartialEq, Hash)] +pub(crate) struct ConstantValues(Vec<(usize, Value)>); + +impl ConstantValues { + pub fn new(values: Vec<(usize, Value)>) -> Self { + Self(values) + } + + pub fn function_constant_values(&self) -> FunctionConstantValues { + let f = FunctionConstantValues::new(); + for (index, value) in &self.0 { + let ty = value.data_type(); + match value { + Value::USize(v) => { + f.set_constant_value_at_index( + v as *const usize as *const c_void, + ty, + *index as u64, + ); + } + Value::F32(v) => { + f.set_constant_value_at_index( + v as *const f32 as *const c_void, + ty, + *index as u64, + ); + } + Value::U16(v) => { + f.set_constant_value_at_index( + v as *const u16 as *const c_void, + ty, + *index as u64, + ); + } + Value::Bool(v) => { + f.set_constant_value_at_index( + v as *const bool as *const c_void, + ty, + *index as u64, + ); + } + } + } + f + } +} diff --git a/metal/src/gemm.rs b/metal/src/gemm.rs new file mode 100644 index 0000000000..69bf5c1f1d --- /dev/null +++ b/metal/src/gemm.rs @@ -0,0 +1,238 @@ +use crate::kernels::GemmPrecision; +use crate::{MetalContext, MetalTensor}; +use anyhow::{bail, ensure, Result}; +use num_traits::Float; +use tract_core::internal::{Datum, DatumType}; + +pub fn gemm_precision_from_dt(dt: DatumType) -> Result { + match dt { + DatumType::F32 => Ok(GemmPrecision::Single), + DatumType::F16 => Ok(GemmPrecision::Half), + _ => bail!("Metal GEMM only support F32 or F16 tensors"), + } +} + +pub fn gemm_with_slice( + context: &MetalContext, + (b, m, n, k): (usize, usize, usize, usize), + lhs: &[T], + lhs_strides: &[isize], + rhs: &[T], + rhs_strides: &[isize], + output: &mut [T], +) -> Result<()> { + ensure!( + lhs_strides.len() == rhs_strides.len() && lhs_strides.len() == 3, + "Only 3D tensors are supported in Metal GEMM" + ); + + let precision = gemm_precision_from_dt(T::datum_type())?; + + let lhs_strides = lhs_strides.iter().map(|it| *it as usize).collect::>(); + let rhs_strides = rhs_strides.iter().map(|it| *it as usize).collect::>(); + + let lhs_buff = context.buffer_from_slice_with_copy(lhs); + let rhs_buff = context.buffer_from_slice_with_copy(rhs); + let out_buff = context.buffer_from_slice_with_copy_mut(output); + crate::kernels::metal_gemm( + context, + precision, + (b, m, n, k), + &lhs_strides, + 0, + &lhs_buff, + &rhs_strides, + 0, + &rhs_buff, + &out_buff, + )?; + context.wait_until_completed()?; + Ok(()) +} + +pub fn gemm(context: &MetalContext, lhs: &MetalTensor, rhs: &MetalTensor) -> Result { + ensure!(lhs.rank() == 3 && rhs.rank() == 3); + ensure!(lhs.datum_type() == rhs.datum_type()); + + let precision = gemm_precision_from_dt(lhs.datum_type())?; + + let b = lhs.shape()[0]; + let m = lhs.shape()[1]; + let n = rhs.shape()[2]; + let k = lhs.shape()[2]; + + let lhs_strides = lhs.strides().iter().map(|it| *it as usize).collect::>(); + let rhs_strides = rhs.strides().iter().map(|it| *it as usize).collect::>(); + + let o_dt = lhs.datum_type(); + let o_shape = &[b, m, n]; + + let output = unsafe { MetalTensor::uninitialized_dt(o_dt, o_shape)? }; + + crate::kernels::metal_gemm( + context, + precision, + (b, m, n, k), + &lhs_strides, + 0, + &lhs.metal(), + &rhs_strides, + 0, + &rhs.metal(), + output.metal(), + )?; + context.wait_until_completed()?; + Ok(output) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::IntoMetal; + use derive_new::new; + use num_traits::AsPrimitive; + use proptest::collection::vec; + use proptest::prelude::*; + use tract_core::internal::*; + + #[test] + fn test_gemm() -> Result<()> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let (b, m, n, k) = (1, 2, 4, 3); + let a = Tensor::from_shape( + &[b, m, k], + &(0..b * m * k).map(|f| f as f32).collect::>(), + )? + .into_metal()?; + let b = Tensor::from_shape( + &[b, k, n], + &(0..b * n * k).map(|f| f as f32).collect::>(), + )? + .into_metal()?; + + let c = gemm(&context, &a, &b)?; + + let expected_c = Tensor::from_shape( + &[1, 2, 4], + &[20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0], + )?; + + let c = c.into_tensor(); + assert!(c.close_enough(&expected_c, Approximation::Close).is_ok()); + + let (b, m, n, k) = (2, 2, 4, 3); + let a = MetalTensor::from_shape( + &[b, m, k], + &(0..b * m * k).map(|f| f as f32).collect::>(), + )?; + let b = MetalTensor::from_shape( + &[b, k, n], + &(0..b * n * k).map(|f| f as f32).collect::>(), + )?; + + let c = gemm(&context, &a, &b)?; + + let expected_c = Tensor::from_shape( + &[2, 2, 4], + &[ + 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, + 488.0, 518.0, 548.0, 578.0, + ], + )?; + + assert!(c.into_tensor().close_enough(&expected_c, Approximation::Close).is_ok()); + Ok(()) + }) + }) + } + + proptest::proptest! { + #[test] + fn mmm_prop_f32(pb in any::>()) { + prop_assert_eq!(pb.run().unwrap(), pb.reference()) + } + + #[test] + fn mmm_prop_f16(pb in any::>()) { + prop_assert_eq!(pb.run().unwrap(), pb.reference()) + } + } + + #[derive(Debug, new)] + pub struct MmmProblem + where + F: Datum + Float, + usize: AsPrimitive, + { + pub b: usize, + pub m: usize, + pub k: usize, + pub n: usize, + pub lhs: Vec, + pub rhs: Vec, + } + + impl Arbitrary for MmmProblem + where + F: Datum + Float, + usize: AsPrimitive, + { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_: ()) -> Self::Strategy { + (1usize..2, 1usize..20, 1usize..20, 1usize..20) + .prop_flat_map(|(b, m, k, n)| { + let lhs_len = b * m * k; + let rhs_len = b * k * n; + let lhs = (0usize..10).prop_map(|x| x.as_()); + let rhs = (0usize..10).prop_map(|x| x.as_()); + ( + Just(b), + Just(m), + Just(k), + Just(n), + vec(lhs, lhs_len..=lhs_len), + vec(rhs, rhs_len..=rhs_len), + ) + }) + .prop_map(|(b, m, k, n, lhs, rhs)| Self { b, m, k, n, lhs, rhs }) + .boxed() + } + } + + impl MmmProblem + where + F: Datum + Float + std::ops::AddAssign, + usize: AsPrimitive, + { + pub fn reference(&self) -> Vec { + let mut vi = vec![F::zero(); self.b * self.m * self.n]; + for m in 0..self.m { + for n in 0..self.n { + for k in 0..self.k { + // m, k * k, n + let lhs: F = self.lhs[k + self.k * m]; + let rhs: F = self.rhs[n + self.n * k]; + let offset = n + m * self.n; + vi[offset] += lhs * rhs; + } + } + } + vi + } + + pub fn run(&self) -> Result> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let (b, m, n, k) = dbg!((self.b, self.m, self.n, self.k)); + let lhs = Tensor::from_shape(&[b, m, k], &self.lhs)?.into_metal()?; + let rhs = Tensor::from_shape(&[b, k, n], &self.rhs)?.into_metal()?; + let c = gemm(context, &lhs, &rhs)?; + Ok(c.into_tensor().as_slice::()?.to_vec()) + }) + }) + } + } +} diff --git a/metal/src/kernels/libMetalFlashAttention-ios.metallib b/metal/src/kernels/libMetalFlashAttention-ios.metallib new file mode 100644 index 0000000000..74c82f12cc Binary files /dev/null and b/metal/src/kernels/libMetalFlashAttention-ios.metallib differ diff --git a/metal/src/kernels/libMetalFlashAttention-macos.metallib b/metal/src/kernels/libMetalFlashAttention-macos.metallib new file mode 100644 index 0000000000..1e2d1acf3d Binary files /dev/null and b/metal/src/kernels/libMetalFlashAttention-macos.metallib differ diff --git a/metal/src/kernels/mfa_gemm.rs b/metal/src/kernels/mfa_gemm.rs new file mode 100644 index 0000000000..cd0a84e30a --- /dev/null +++ b/metal/src/kernels/mfa_gemm.rs @@ -0,0 +1,172 @@ +use crate::{ConstantValues, LibraryName, MetalContext, Value}; +use anyhow::{bail, Result}; +use metal::NSUInteger; +use metal::{Buffer, MTLSize}; +use std::ffi::c_void; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum GemmPrecision { + Single, + Half, +} + +#[allow(clippy::too_many_arguments)] +pub fn metal_gemm( + context: &MetalContext, + precision: GemmPrecision, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<()> { + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + let a_trans = if lhs_m1 == 1 && lhs_m2 == k { + false + } else if lhs_m1 == m && lhs_m2 == 1 { + true + } else { + bail!(format!( + "Invalid left matmul argument {:?} {:?} ({m}, {n}, {k})", + lhs_stride, rhs_stride + )) + }; + let b_trans = if rhs_m1 == 1 && rhs_m2 == n { + false + } else if rhs_m1 == k && rhs_m2 == 1 { + true + } else { + bail!(format!( + "Invalid right matmul arguments {:?} {:?} ({m}, {n}, {k})", + lhs_stride, rhs_stride + )) + }; + let d_trans = false; + let alpha = 1.0f32; + let beta = 0.0f32; + let batched = b > 1; + let fused_activation = false; + let fused_bias = false; + let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { + let m_simd = 8; + let n_simd = 8; + let k_simd = 64; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + } else { + let m_simd = 40; + let n_simd = 40; + let k_simd = 32; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + }; + let constants = Some(ConstantValues::new(vec![ + (0, Value::USize(m)), + (1, Value::USize(n)), + (2, Value::USize(k)), + (10, Value::Bool(a_trans)), + (11, Value::Bool(b_trans)), + (13, Value::Bool(d_trans)), + (20, Value::F32(alpha)), + (21, Value::F32(beta)), + (100, Value::Bool(batched)), + (101, Value::Bool(fused_activation)), + // Garbage + (102, Value::Bool(false)), + (103, Value::Bool(false)), + (113, Value::Bool(false)), + (50_000, Value::Bool(false)), + // End garbage + (200, Value::U16(m_simd)), + (201, Value::U16(n_simd)), + (202, Value::U16(k_simd)), + (210, Value::U16(m_splits)), + (211, Value::U16(n_splits)), + (50_001, Value::Bool(fused_bias)), + ])); + + let name = match precision { + GemmPrecision::Single => "sgemm", + GemmPrecision::Half => "hgemm", + }; + + let pipeline = context.shared_context().load_pipeline_with_constants( + LibraryName::MfaLib, + name, + constants, + )?; + let m_group = m_simd * m_splits; + let n_group = n_simd * n_splits; + + let a_block_length = m_group * k_simd; + let b_block_length = k_simd * n_group; + + let mut block_elements = a_block_length + b_block_length; + if (m % 8 != 0) && (n % 8 != 0) { + let c_block_length = m_group * n_group; + block_elements = std::cmp::max(c_block_length, block_elements) + } + if fused_bias { + if d_trans { + block_elements = std::cmp::max(block_elements, m_group); + } else { + block_elements = std::cmp::max(block_elements, n_group); + } + } + let bytes = match precision { + GemmPrecision::Single => 4, + GemmPrecision::Half => 2, + }; + let block_bytes = block_elements * bytes; + + let command_buffer = context.command_buffer()?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, block_bytes.into()); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(2, Some(output), 0); + // TODO Tensor D + + let grid_z = b; + if batched { + let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; + let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; + let byte_stride_c = m * n * bytes as usize; + // TODO byte_stride_d + let byte_stride_d = 0; + + let buffer: Vec = + vec![byte_stride_a as _, byte_stride_b as _, byte_stride_c as _, byte_stride_d as _]; + encoder.set_bytes( + 10, + (buffer.len() * core::mem::size_of::()) as NSUInteger, + buffer.as_ptr() as *const NSUInteger as *const c_void, + ); + } + + let grid_size = MTLSize { + width: crate::utils::div_ceil(n, n_group.into()), + height: crate::utils::div_ceil(m, m_group.into()), + depth: grid_z as NSUInteger, + }; + let group_size = + MTLSize { width: 32 * (m_splits as u64) * (n_splits as u64), height: 1, depth: 1 }; + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_size, group_size); + encoder.end_encoding(); + + Ok(()) +} diff --git a/metal/src/kernels/mmm_tile_8x8.metal b/metal/src/kernels/mmm_tile_8x8.metal new file mode 100644 index 0000000000..78b0e07d34 --- /dev/null +++ b/metal/src/kernels/mmm_tile_8x8.metal @@ -0,0 +1,73 @@ +// From https://github.com/cyrusmsk/gemm_apple/blob/main/gemm_metal.py + +#include +#include // Available from Metal version 2.3 released with OS X 11.0+ + +using namespace metal; + +constant uint LID [[function_constant(0)]]; +constant uint dim [[function_constant(1)]]; + +kernel void mmm_tile_8x8(device float *a [[buffer(0)]], // output + device const float *data1 [[buffer(1)]], + device const float *data2 [[buffer(2)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) {{ + a += gid.x * 32 * dim + (gid.y * LID + lid.y) * 32; + data1 += gid.x * 32 * dim; + data2 += (gid.y * LID + lid.y) * 32; + + simdgroup_float8x8 acc[4][4]; + for (uint i = 0; i < 4; i++) {{ + for (uint j = 0; j < 4; j++) {{ + acc[i][j] = simdgroup_float8x8(0); + }} + }} + + simdgroup_float8x8 A[4]; + simdgroup_float8x8 B[4]; + for (uint k = 0; k < dim; k+=8) {{ + threadgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(A[0], data1+k+(0*dim), dim, ulong2(0, 0)); + simdgroup_load(A[1], data1+k+(8*dim), dim, ulong2(0, 0)); + simdgroup_load(A[2], data1+k+(16*dim), dim, ulong2(0, 0)); + simdgroup_load(A[3], data1+k+(24*dim), dim, ulong2(0, 0)); + simdgroup_load(B[0], data2+0+k*dim, dim, ulong2(0, 0)); + simdgroup_load(B[1], data2+8+k*dim, dim, ulong2(0, 0)); + simdgroup_load(B[2], data2+16+k*dim, dim, ulong2(0, 0)); + simdgroup_load(B[3], data2+24+k*dim, dim, ulong2(0, 0)); + + simdgroup_multiply_accumulate(acc[0][0], A[0], B[0], acc[0][0]); + simdgroup_multiply_accumulate(acc[0][1], A[1], B[0], acc[0][1]); + simdgroup_multiply_accumulate(acc[0][2], A[2], B[0], acc[0][2]); + simdgroup_multiply_accumulate(acc[0][3], A[3], B[0], acc[0][3]); + simdgroup_multiply_accumulate(acc[1][0], A[0], B[1], acc[1][0]); + simdgroup_multiply_accumulate(acc[1][1], A[1], B[1], acc[1][1]); + simdgroup_multiply_accumulate(acc[1][2], A[2], B[1], acc[1][2]); + simdgroup_multiply_accumulate(acc[1][3], A[3], B[1], acc[1][3]); + simdgroup_multiply_accumulate(acc[2][0], A[0], B[2], acc[2][0]); + simdgroup_multiply_accumulate(acc[2][1], A[1], B[2], acc[2][1]); + simdgroup_multiply_accumulate(acc[2][2], A[2], B[2], acc[2][2]); + simdgroup_multiply_accumulate(acc[2][3], A[3], B[2], acc[2][3]); + simdgroup_multiply_accumulate(acc[3][0], A[0], B[3], acc[3][0]); + simdgroup_multiply_accumulate(acc[3][1], A[1], B[3], acc[3][1]); + simdgroup_multiply_accumulate(acc[3][2], A[2], B[3], acc[3][2]); + simdgroup_multiply_accumulate(acc[3][3], A[3], B[3], acc[3][3]); + }} + simdgroup_store(acc[0][0], a+(0+0*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[1][0], a+(8+0*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[2][0], a+(16+0*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[3][0], a+(24+0*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[0][1], a+(0+8*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[1][1], a+(8+8*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[2][1], a+(16+8*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[3][1], a+(24+8*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[0][2], a+(0+16*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[1][2], a+(8+16*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[2][2], a+(16+16*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[3][2], a+(24+16*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[0][3], a+(0+24*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[1][3], a+(8+24*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[2][3], a+(16+24*dim), dim, ulong2(0, 0)); + simdgroup_store(acc[3][3], a+(24+24*dim), dim, ulong2(0, 0)); +}} diff --git a/metal/src/kernels/mmm_tile_8x8.rs b/metal/src/kernels/mmm_tile_8x8.rs new file mode 100644 index 0000000000..c4d6686d70 --- /dev/null +++ b/metal/src/kernels/mmm_tile_8x8.rs @@ -0,0 +1,128 @@ +use crate::func_constants::{ConstantValues, Value}; +use crate::MetalTensor; +use crate::{LibraryName, MetalContext}; +use anyhow::{ensure, Result}; +use metal::{Buffer, MTLSize, NSUInteger}; +use tract_core::internal::DatumType; + +pub fn mmm_tile_8x8( + context: &MetalContext, + lhs: &MetalTensor, + rhs: &MetalTensor, +) -> Result { + ensure!(lhs.rank() == 2 && rhs.rank() == 2); + ensure!(lhs.datum_type() == rhs.datum_type()); + ensure!(lhs.datum_type() == DatumType::F32); + + let m = lhs.shape()[0]; + let n = rhs.shape()[1]; + let k = lhs.shape()[1]; + + ensure!(m == n && m == k); + + let o_dt = lhs.datum_type(); + let o_shape = &[m, m]; + + let output = MetalTensor::zero_dt(context, o_dt, o_shape)?; + + crate::kernels::metal_mmm_tile_8x8(context, m, &lhs.metal(), &rhs.metal(), output.metal())?; + context.wait_until_completed()?; + + Ok(output) +} + +#[allow(clippy::too_many_arguments)] +pub fn metal_mmm_tile_8x8( + context: &MetalContext, + dim: usize, + lhs_buffer: &Buffer, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<()> { + ensure!(dim % 8 == 0, "Dim must be a multiple of 8"); + + let constants = Some(ConstantValues::new(vec![(0, Value::USize(2)), (1, Value::USize(dim))])); + let pipeline = context.shared_context().load_pipeline_with_constants( + LibraryName::MmmTile8x8, + "mmm_tile_8x8", + constants, + )?; + + let command_buffer = context.command_buffer()?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_buffer(0, Some(output), 0); + encoder.set_buffer(1, Some(lhs_buffer), 0); + encoder.set_buffer(2, Some(rhs_buffer), 0); + + let grid_size = MTLSize { + width: crate::utils::div_ceil(dim, 8 * 4), + height: crate::utils::div_ceil(dim, 8 * 4 * 2), + depth: 1 as NSUInteger, + }; + let group_size = MTLSize { width: 32, height: 2, depth: 1 }; + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_size, group_size); + encoder.end_encoding(); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::IntoMetal; + use tract_core::internal::Tensor; + + #[test] + fn test_mmm_tile_8x8() -> Result<()> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let n = 512; + + let constants = + Some(ConstantValues::new(vec![(0, Value::USize(2)), (1, Value::USize(n))])); + + context.shared_context().load_pipeline_with_constants( + LibraryName::MmmTile8x8, + "mmm_tile_8x8", + constants, + )?; + context.wait_until_completed()?; + + let mut cpu_start = 0; + let mut gpu_start = 0; + context.device().sample_timestamps(&mut cpu_start, &mut gpu_start); + + let a = Tensor::from_shape( + &[n, n], + &(0..n * n).map(|_f| 1 as f32).collect::>(), + )? + .into_metal()?; + let b = Tensor::from_shape( + &[n, n], + &(0..n * n).map(|_f| 1 as f32).collect::>(), + )? + .into_metal()?; + let start = std::time::Instant::now(); + let num_iter = 100; + for _ in 0..num_iter { + let _c = mmm_tile_8x8(&context, &a, &b)?; + } + + let mut cpu_end = 0; + let mut gpu_end = 0; + context.device().sample_timestamps(&mut cpu_end, &mut gpu_end); + + dbg!(start.elapsed().as_secs_f32() / num_iter as f32); + println!( + "{:3?} GOP/s", + (n * n * n * 2 * num_iter) as f32 / start.elapsed().as_secs_f32() / 10.0e9 + ); + Ok(()) + }) + }) + } +} diff --git a/metal/src/kernels/mod.rs b/metal/src/kernels/mod.rs new file mode 100644 index 0000000000..5b81817004 --- /dev/null +++ b/metal/src/kernels/mod.rs @@ -0,0 +1,33 @@ +mod mfa_gemm; +mod mmm_tile_8x8; + +pub use mfa_gemm::{metal_gemm, GemmPrecision}; +pub use mmm_tile_8x8::{metal_mmm_tile_8x8, mmm_tile_8x8}; + +#[cfg(target_os = "ios")] +pub const METAL_FLASH_ATTENTION_LIB: &[u8] = include_bytes!("libMetalFlashAttention-ios.metallib"); +#[cfg(target_os = "macos")] +pub const METAL_FLASH_ATTENTION_LIB: &[u8] = + include_bytes!("libMetalFlashAttention-macos.metallib"); +pub const MMM_TILE_8X8_METAL_SOURCE: &str = include_str!("mmm_tile_8x8.metal"); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum LibraryContent<'a> { + Data(&'a [u8]), + Source(&'a str), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum LibraryName { + MfaLib, + MmmTile8x8, +} + +impl LibraryName { + pub fn content(&self) -> LibraryContent<'static> { + match self { + Self::MfaLib => LibraryContent::Data(METAL_FLASH_ATTENTION_LIB), + Self::MmmTile8x8 => LibraryContent::Source(MMM_TILE_8X8_METAL_SOURCE), + } + } +} diff --git a/metal/src/lib.rs b/metal/src/lib.rs new file mode 100644 index 0000000000..5432152ffd --- /dev/null +++ b/metal/src/lib.rs @@ -0,0 +1,18 @@ +pub mod context; +pub mod func_constants; +pub mod gemm; +pub mod kernels; +pub mod tensor; +pub mod transform; +pub mod utils; + +pub use crate::context::{MetalContext, METAL_CONTEXT}; +use crate::func_constants::{ConstantValues, Value}; +pub use crate::kernels::{LibraryContent, LibraryName}; +pub use crate::tensor::MetalTensor; +pub use crate::transform::MetalGemmTransform; +use anyhow::Result; + +pub trait IntoMetal { + fn into_metal(self) -> Result; +} diff --git a/metal/src/tensor.rs b/metal/src/tensor.rs new file mode 100644 index 0000000000..5c8e2b21f2 --- /dev/null +++ b/metal/src/tensor.rs @@ -0,0 +1,120 @@ +use crate::{IntoMetal, MetalContext}; +use anyhow::Result; +use metal::{Buffer, MTLResourceOptions, NSUInteger}; +use tract_core::internal::*; + +impl IntoMetal for Tensor { + fn into_metal(self) -> Result { + crate::METAL_CONTEXT.with_borrow(|ctxt| MetalTensor::from_tensor(ctxt, self)) + } +} + +/// This struct represents a metal tensor that can be accessed from the +/// GPU and the CPU. Metal's MTLResourceStorageModeShared is used. +#[derive(Debug)] +pub struct MetalTensor { + pub inner: Tensor, + metal: Buffer, +} + +impl MetalTensor { + // Create a metal tensor with a given shape and a slice of elements. The data is copied and aligned to size of T. + pub fn from_shape(shape: &[usize], data: &[T]) -> Result { + crate::METAL_CONTEXT.with_borrow(|ctxt| { + let tensor = Tensor::from_shape(shape, data)?; + Self::from_tensor(ctxt, tensor) + }) + } + + /// Create an uninitialized MetalTensor + pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> Result { + crate::METAL_CONTEXT.with_borrow(|ctxt| { + let tensor = Tensor::uninitialized_dt(dt, shape)?; + Self::from_tensor(ctxt, tensor) + }) + } + + /// Create a MetalTensor filled of zeros + pub fn zero_dt( + context: &MetalContext, + datum_type: DatumType, + shape: &[usize], + ) -> Result { + let t = Tensor::zero_dt(datum_type, shape)?; + Self::from_tensor(context, t) + } + + /// Create a metal tensor from a cpu tensor. + pub fn from_tensor(context: &MetalContext, tensor: Tensor) -> Result { + ensure!( + tensor.datum_type().is_copy(), + "Tensor is not copied. No Metal buffer can be allocated for it." + ); + let size = (tensor.datum_type().size_of() * tensor.len()) as NSUInteger; + let buffer = context.device().new_buffer_with_bytes_no_copy( + tensor.as_bytes().as_ptr() as *const core::ffi::c_void, + size, + MTLResourceOptions::StorageModeShared, + None, + ); + Ok(Self { inner: tensor, metal: buffer }) + } + + /// Get the datum type of the tensor. + #[inline] + pub fn datum_type(&self) -> DatumType { + self.inner.datum_type() + } + + /// Get the number of dimensions (or axes) of the tensor. + #[inline] + pub fn rank(&self) -> usize { + self.inner.rank() + } + + /// Get the shape of the tensor. + #[inline] + pub fn shape(&self) -> &[usize] { + self.inner.shape() + } + + /// Get the number of values in the tensor. + #[inline] + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Get the strides of the tensor. + #[inline] + pub fn strides(&self) -> &[isize] { + &self.inner.strides() + } + + /// Get underlying inner metal buffer. + pub fn metal<'a>(&'a self) -> &'a Buffer { + &self.metal + } + + /// Get mutable underlying inner metal buffer. + pub fn metal_mut<'a>(&'a mut self) -> &'a mut Buffer { + &mut self.metal + } +} + +impl IntoTensor for MetalTensor { + fn into_tensor(self) -> Tensor { + self.inner + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metal_tensor() -> Result<()> { + let a = MetalTensor::from_shape(&[1], &[0f32])?; + assert_eq!(a.into_tensor().as_slice::()?, &[0.0]); + Ok(()) + } +} diff --git a/metal/src/transform.rs b/metal/src/transform.rs new file mode 100644 index 0000000000..81cfba089b --- /dev/null +++ b/metal/src/transform.rs @@ -0,0 +1,169 @@ +use crate::gemm; +use anyhow::Result; +use num_traits::Float; +use std::borrow::Cow; +use std::fmt::Debug; +use tract_core::broadcast; +use tract_core::internal::*; +use tract_core::ndarray; +use tract_core::transform::ModelTransform; +use tract_ndarray::Dimension; + +use tract_core::ops::einsum::{rewrite_einsums_as_matmul, BasicMatMul}; + +#[derive(Debug, Default)] +pub struct MetalGemmTransform; + +impl ModelTransform for MetalGemmTransform { + fn name(&self) -> Cow { + "metal-gemm-transform".into() + } + + fn transform(&self, model: &mut TypedModel) -> TractResult<()> { + rewrite_einsums_as_matmul(model)?; + Rewriter::default() + .with_rule_for("matmul-to-metal-gemm", matmul_to_gemm) + .rewrite(&(), model)?; + Ok(()) + } +} + +fn matmul_to_gemm( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + _node_name: &str, + op: &BasicMatMul, +) -> Result> { + if !op.transpose_a + && !op.transpose_b + && !op.transpose_c + && op.quantize_output.is_none() + && (model.node_input_facts(node.id)?.iter().all(|f| f.datum_type == f32::datum_type()) + || model.node_input_facts(node.id)?.iter().all(|f| f.datum_type == f16::datum_type())) + { + TypedModelPatch::replace_single_op(model, node, &node.inputs, MetalGemm::default()) + .map(Some) + } else { + Ok(None) + } +} + +#[derive(Debug, Default, Clone)] +pub struct MetalGemm {} + +impl Op for MetalGemm { + fn name(&self) -> Cow { + "MetalGemm".into() + } + + op_as_typed_op!(); +} + +impl MetalGemm { + fn output_shape(&self, a: &[D], b: &[D]) -> TractResult> { + ensure!(a.len() == b.len()); + let a_rank = a.len(); + let b_rank = b.len(); + let m = a[a_rank - 2].clone(); + let n = b[b_rank - 1].clone(); + let mut c_shape = broadcast::multi_broadcast(&[&a[..a_rank - 2], &b[..b_rank - 2]]) + .context("Unable to broadcast")?; + c_shape.push(m); + c_shape.push(n); + Ok(c_shape) + } + + fn _eval(&self, a: TValue, b: TValue) -> TractResult> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let a_ptr = a.as_ptr::()?; + let b_ptr = b.as_ptr::()?; + let c_shape = self.output_shape(a.shape(), b.shape())?; + let rank = c_shape.len(); + let m = c_shape[rank - 2]; + let n = c_shape[rank - 1]; + let k = a.shape()[rank - 1]; + + let a_mk_strides = natural_strides(&[1, m, k]); + let b_kn_strides = natural_strides(&[1, k, n]); + unsafe { + let mut c = Tensor::uninitialized::(&c_shape)?; + let c_ptr = c.as_ptr_mut::()?; + let silent_a_axis = c.rank() - a.rank(); + let silent_b_axis = c.rank() - b.rank(); + for prefix in ndarray::indices(&c_shape[0..rank - 2]) { + let mut a_ptr = a_ptr; + let mut b_ptr = b_ptr; + let mut c_ptr = c_ptr; + for (axis, x) in prefix.as_array_view().iter().enumerate() { + if axis >= silent_a_axis && a.shape()[axis - silent_a_axis] != 1 { + a_ptr = + a_ptr.offset(*x as isize * a.strides()[axis - silent_a_axis]); + } + if axis >= silent_b_axis && b.shape()[axis - silent_b_axis] != 1 { + b_ptr = + b_ptr.offset(*x as isize * b.strides()[axis - silent_b_axis]); + } + c_ptr = c_ptr.offset(*x as isize * c.strides()[axis]); + } + + gemm::gemm_with_slice( + context, + (1, m, n, k), + std::slice::from_raw_parts(a_ptr, m * k), + &a_mk_strides, + std::slice::from_raw_parts(b_ptr, k * n), + &b_kn_strides, + std::slice::from_raw_parts_mut(c_ptr, m * n), + )?; + } + + Ok(tvec!(c.into_tvalue())) + } + }) + }) + } +} + +impl EvalOp for MetalGemm { + fn is_stateless(&self) -> bool { + true + } + + fn eval(&self, inputs: TVec) -> TractResult> { + let (a, b) = args_2!(inputs); + if a.datum_type() == DatumType::F32 { + self._eval::(a, b) + } else if a.datum_type() == DatumType::F16 { + self._eval::(a, b) + } else { + bail!("MetalGemm doesn't support this datum type: {:?}", a.datum_type()) + } + } +} + +impl TypedOp for MetalGemm { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + if inputs[0].datum_type == f16::datum_type() { + ensure!(inputs[1].datum_type == f16::datum_type()); + Ok(tvec!(f16::fact(&self.output_shape(&inputs[0].shape, &inputs[1].shape)?))) + } else { + ensure!(inputs[0].datum_type == f32::datum_type()); + ensure!(inputs[1].datum_type == f32::datum_type()); + Ok(tvec!(f32::fact(&self.output_shape(&inputs[0].shape, &inputs[1].shape)?))) + } + } + + fn cost(&self, inputs: &[&TypedFact]) -> TractResult> { + let fma = self.output_shape(&inputs[0].shape, &inputs[1].shape)?.iter().product::() + * inputs[0].shape.last().unwrap(); + if inputs[0].datum_type == f16::datum_type() { + Ok(tvec!((Cost::FMA(f16::datum_type()), fma))) + } else { + Ok(tvec!((Cost::FMA(f32::datum_type()), fma))) + } + } + + as_op!(); +} diff --git a/metal/src/utils.rs b/metal/src/utils.rs new file mode 100644 index 0000000000..ca1abfc9b2 --- /dev/null +++ b/metal/src/utils.rs @@ -0,0 +1,3 @@ +pub fn div_ceil(m: usize, b: usize) -> metal::NSUInteger { + ((m + b - 1) / b) as metal::NSUInteger +}