Skip to content

Commit

Permalink
Silu op rewrite and metal implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Jul 19, 2024
1 parent 0d31bbd commit 18d3763
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 36 deletions.
2 changes: 2 additions & 0 deletions metal/src/kernels/nn/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
pub mod reduce;
pub mod rms_norm;
pub mod silu;
pub mod softmax;

pub use reduce::Reducer;
pub use rms_norm::RmsNorm;
pub use silu::Silu;
pub use softmax::Softmax;

pub fn all_functions() -> Vec<String> {
Expand Down
32 changes: 29 additions & 3 deletions metal/src/kernels/nn/nn_ops.metal
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ template<typename F>
}
}

struct Sigmoid {
template <typename T>
T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};

template<typename T>
[[kernel]] void silu(device const T *input[ [buffer(0)]],
device T *output [[buffer(1)]],
uint tpig[[thread_position_in_grid]]) {
output[tpig] = Sigmoid()(input[tpig]) * input[tpig];
}


template<typename F>
[[kernel]] void softmax_nd3(
device const F *input,
Expand Down Expand Up @@ -153,16 +169,15 @@ template<typename F>
float el = static_cast<float>(input[idx]);
float exp_el = fast::exp(el - axis_max);
partial_norm += exp_el;
output[idx] = static_cast<F>(exp_el);
}

float axis_norm = simd_sum(partial_norm);
float inv_axis_norm = 1.0 / axis_norm;

for (size_t i = tiisg; i < dim; i += tpsg) {
auto idx = base_idx + i * strides[1];
float el = static_cast<float>(input[idx]);
// TODO: avoid computing exp_el twice
float exp_el = fast::exp(el - axis_max);
float exp_el = static_cast<float>(output[idx]);
output[idx] = static_cast<F>(exp_el * inv_axis_norm);
}
}
Expand Down Expand Up @@ -219,6 +234,17 @@ template [[host_name("nn_ops::rms_norm_nd3_" #tname)]] \
INSTANTIATE_RMS_NORM(f32, float)
INSTANTIATE_RMS_NORM(f16, half)

#define INSTANTIATE_SILU(tname, type) \
template [[host_name("nn_ops::silu_" #tname)]] \
[[kernel]] void silu<type>( \
device const type *input [[buffer(0)]], \
device type *output [[buffer(1)]], \
uint tpig[[thread_position_in_grid]] \
);

INSTANTIATE_SILU(f32, float)
INSTANTIATE_SILU(f16, half)




59 changes: 59 additions & 0 deletions metal/src/kernels/nn/silu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use crate::{LibraryName, MetalContext, MetalTensor};
use anyhow::Result;
use metal::MTLSize;
use tract_core::internal::*;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Silu;

impl Silu {
pub fn is_supported_dt(dt: DatumType) -> bool {
Self::tname(dt).is_ok()
}

pub fn tname(dt: DatumType) -> Result<&'static str> {
let tname = match dt {
DatumType::F32 => "f32",
DatumType::F16 => "f16",
_ => bail!("Unsupport dt {:?} for reducer op", dt),
};
Ok(tname)
}

pub fn kernel_name(&self, dt: DatumType) -> Result<String> {
let tname = Self::tname(dt)?;
Ok(format!("nn_ops::silu_{tname}"))
}

pub fn eval(&self, context: &MetalContext, input: &MetalTensor) -> Result<MetalTensor> {
let o = self.dispatch_eval(context, input)?;
context.wait_until_completed()?;
Ok(o)
}

pub fn dispatch_eval(&self, context: &MetalContext, a: &MetalTensor) -> Result<MetalTensor> {
a.retain_until_completion();

let output = unsafe { MetalTensor::uninitialized_dt(a.datum_type(), a.shape())? };
output.retained_until_completion();

let kernel_name = self.kernel_name(a.datum_type())?;

let a_buffer = a.metal();
let output_buffer = output.metal();
let pipeline = context.shared_context().load_pipeline(LibraryName::NNOps, &kernel_name)?;
let command_buffer = context.command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(a_buffer), 0);
encoder.set_buffer(1, Some(output_buffer), 0);

let grid_size = MTLSize { width: output.len() as _, height: 1, depth: 1 };
let group_size = MTLSize { width: 1, height: 1, depth: 1 };
encoder.use_resource(a_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output_buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding();
Ok(output)
}
}
4 changes: 3 additions & 1 deletion metal/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod gemm;
pub mod konst;
pub mod reduce;
pub mod rms_norm;
pub mod silu;
pub mod slice;
pub mod softmax;
pub mod sync;
Expand All @@ -22,6 +23,7 @@ pub use gemm::MetalGemm;
pub use konst::MetalConst;
pub use reduce::MetalReduce;
pub use rms_norm::MetalRmsNorm;
pub use silu::MetalSilu;
pub use slice::MetalSlice;
pub use softmax::MetalSoftmax;
pub use sync::MetalSync;
pub use sync::{MetalSync, MetalSyncKind};
47 changes: 47 additions & 0 deletions metal/src/ops/silu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use crate::kernels::nn::Silu;
use crate::tensor::MetalTensorExt;
use derive_new::new;
use tract_core::internal::*;

#[derive(Clone, Debug, new, Hash)]
pub struct MetalSilu;

impl Op for MetalSilu {
fn name(&self) -> Cow<str> {
"MetalSilu".into()
}

op_as_typed_op!();
}

impl EvalOp for MetalSilu {
fn is_stateless(&self) -> bool {
true
}

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
objc::rc::autoreleasepool(|| {
crate::METAL_CONTEXT.with_borrow(|context| {
let input = args_1!(inputs);
let input_metal = input.to_metal_tensor()?;
Ok(tvec!(Silu
.dispatch_eval(context, input_metal)?
.into_opaque_tensor()
.into_tvalue()))
})
})
}
}

impl TypedOp for MetalSilu {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
crate::utils::metal_output_facts(inputs, |facts| {
let dt = facts[0].datum_type;
let fact = dt.fact(facts[0].shape.clone());
Ok(tvec!(fact))
})
.with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name()))
}

as_op!();
}
35 changes: 35 additions & 0 deletions metal/src/rewrite_rules/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
mod rms_norm;
mod silu;

use tract_core::internal::*;
use tract_core::ops::konst::Const;

pub use rms_norm::{as_rms_norm_rule, BasicRmsNorm};
pub use silu::{as_silu_rule, BasicSilu};

#[macro_export]
macro_rules! rule_ensure {
($cond:expr) => {
if !$cond {
return Ok(None);
}
};
}

fn next_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return None;
}
let succ = node.outputs[0].successors[0];
Some(&model.nodes()[succ.node])
}

fn collect_node_const_inputs<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a Const> {
node.inputs
.iter()
.filter_map(|i| {
let prec = &model.nodes()[i.node];
prec.op_as::<Const>()
})
.collect::<TVec<_>>()
}
29 changes: 2 additions & 27 deletions metal/src/rewrite_rules.rs → metal/src/rewrite_rules/rms_norm.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::rewrite_rules::{collect_node_const_inputs, next_node};
use crate::rule_ensure;
use std::sync::Arc;
use tract_core::internal::*;
use tract_core::ops::binary::TypedBinOp;
use tract_core::ops::element_wise::ElementWiseOp;
use tract_core::ops::konst::Const;
use tract_core::ops::math::{Add, Mul, Rsqrt};
use tract_core::ops::nn::{Reduce, Reducer};

Expand Down Expand Up @@ -38,32 +39,6 @@ impl TypedOp for BasicRmsNorm {
as_op!();
}

macro_rules! rule_ensure {
($cond:expr) => {
if !$cond {
return Ok(None);
}
};
}

fn next_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return None;
}
let succ = node.outputs[0].successors[0];
Some(&model.nodes()[succ.node])
}

fn collect_node_const_inputs<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a Const> {
node.inputs
.iter()
.filter_map(|i| {
let prec = &model.nodes()[i.node];
prec.op_as::<Const>()
})
.collect::<TVec<_>>()
}

pub fn as_rms_norm_rule(
_ctx: &(),
model: &TypedModel,
Expand Down
65 changes: 65 additions & 0 deletions metal/src/rewrite_rules/silu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use crate::rewrite_rules::next_node;
use crate::rule_ensure;
use tract_core::internal::*;
use tract_core::ops::binary::TypedBinOp;
use tract_core::ops::element_wise::ElementWiseOp;
use tract_core::ops::math::Mul;
use tract_core::ops::nn::Sigmoid;

#[derive(Clone, Debug, Hash)]
pub struct BasicSilu;

impl Op for BasicSilu {
fn name(&self) -> Cow<str> {
"BasicSilu".to_string().into()
}
op_as_typed_op!();
}

impl EvalOp for BasicSilu {
fn is_stateless(&self) -> bool {
true
}
}

impl TypedOp for BasicSilu {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let dt = inputs[0].datum_type;
let fact = dt.fact(inputs[0].shape.clone());
Ok(tvec!(fact))
}

as_op!();
}

pub fn as_silu_rule(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
node_name: &str,
op: &ElementWiseOp,
) -> TractResult<Option<TypedModelPatch>> {
// Search pattern => A = A * SIGMOID(A);

rule_ensure!(op.0.is::<Sigmoid>());

let in_fact = model.node_input_facts(node.id)?[0];
let dt = in_fact.datum_type;

// Only F16 and F32 is supported.
rule_ensure!(matches!(dt, DatumType::F32 | DatumType::F16));

let mut patch = TypedModelPatch::default();
let silu_input = patch.taps(model, &node.inputs)?;
// Identify Mul
let Some(mul_succ) = next_node(model, node) else { return Ok(None) };
let Some(mul_succ_op) = mul_succ.op_as::<TypedBinOp>() else { return Ok(None) };
rule_ensure!(mul_succ_op.0.is::<Mul>());
rule_ensure!(mul_succ.inputs.contains(&node.inputs[0]));

let out = patch.wire_node(format!("{node_name}.silu"), BasicSilu, &silu_input)?;

patch.shunt_outside(model, mul_succ.id.into(), out[0])?;

Ok(Some(patch))
}
13 changes: 8 additions & 5 deletions metal/src/transform.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::fact::MetalTypedFactExt;
use crate::kernels::nn::{Reducer, RmsNorm, Softmax};
use crate::kernels::nn::{Reducer, RmsNorm, Silu, Softmax};
use crate::ops;
use crate::ops::sync::MetalSyncKind;
use crate::ops::MetalSync;
use crate::rewrite_rules::as_rms_norm_rule;
use crate::rewrite_rules::BasicRmsNorm;
use crate::ops::{MetalSync, MetalSyncKind};
use crate::rewrite_rules::{as_rms_norm_rule, as_silu_rule, BasicRmsNorm, BasicSilu};
use crate::tensor::MetalTensorExt;
use crate::{IntoMetal, MetalFact, MetalTensor};
use anyhow::Result;
Expand Down Expand Up @@ -32,6 +30,7 @@ impl ModelTransform for MetalTransform {
rewrite_einsums_as_matmul(model)?;
Rewriter::default()
.with_rule_for::<Reduce>("as-rms-norm", as_rms_norm_rule)
.with_rule_for::<ElementWiseOp>("as-silu", as_silu_rule)
.rewrite(&(), model)?;

model.optimize()?;
Expand Down Expand Up @@ -195,6 +194,10 @@ impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for Met
check_in_dts_are_supported(source, node.id, RmsNorm::is_supported_dt)?
.then(|| ops::MetalRmsNorm::new(op.axis, op.eps.clone()))
.map(|o| -> Box<dyn TypedOp> { Box::new(o) })
} else if let Some(_op) = node.op_as::<BasicSilu>() {
check_in_dts_are_supported(source, node.id, Silu::is_supported_dt)?
.then(|| ops::MetalSilu)
.map(|o| -> Box<dyn TypedOp> { Box::new(o) })
} else {
None
};
Expand Down

0 comments on commit 18d3763

Please sign in to comment.