-
Notifications
You must be signed in to change notification settings - Fork 214
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Silu op rewrite and metal implementation
- Loading branch information
1 parent
0d31bbd
commit 18d3763
Showing
9 changed files
with
250 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
use crate::{LibraryName, MetalContext, MetalTensor}; | ||
use anyhow::Result; | ||
use metal::MTLSize; | ||
use tract_core::internal::*; | ||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] | ||
pub struct Silu; | ||
|
||
impl Silu { | ||
pub fn is_supported_dt(dt: DatumType) -> bool { | ||
Self::tname(dt).is_ok() | ||
} | ||
|
||
pub fn tname(dt: DatumType) -> Result<&'static str> { | ||
let tname = match dt { | ||
DatumType::F32 => "f32", | ||
DatumType::F16 => "f16", | ||
_ => bail!("Unsupport dt {:?} for reducer op", dt), | ||
}; | ||
Ok(tname) | ||
} | ||
|
||
pub fn kernel_name(&self, dt: DatumType) -> Result<String> { | ||
let tname = Self::tname(dt)?; | ||
Ok(format!("nn_ops::silu_{tname}")) | ||
} | ||
|
||
pub fn eval(&self, context: &MetalContext, input: &MetalTensor) -> Result<MetalTensor> { | ||
let o = self.dispatch_eval(context, input)?; | ||
context.wait_until_completed()?; | ||
Ok(o) | ||
} | ||
|
||
pub fn dispatch_eval(&self, context: &MetalContext, a: &MetalTensor) -> Result<MetalTensor> { | ||
a.retain_until_completion(); | ||
|
||
let output = unsafe { MetalTensor::uninitialized_dt(a.datum_type(), a.shape())? }; | ||
output.retained_until_completion(); | ||
|
||
let kernel_name = self.kernel_name(a.datum_type())?; | ||
|
||
let a_buffer = a.metal(); | ||
let output_buffer = output.metal(); | ||
let pipeline = context.shared_context().load_pipeline(LibraryName::NNOps, &kernel_name)?; | ||
let command_buffer = context.command_buffer(); | ||
let encoder = command_buffer.new_compute_command_encoder(); | ||
encoder.set_compute_pipeline_state(&pipeline); | ||
encoder.set_buffer(0, Some(a_buffer), 0); | ||
encoder.set_buffer(1, Some(output_buffer), 0); | ||
|
||
let grid_size = MTLSize { width: output.len() as _, height: 1, depth: 1 }; | ||
let group_size = MTLSize { width: 1, height: 1, depth: 1 }; | ||
encoder.use_resource(a_buffer, metal::MTLResourceUsage::Read); | ||
encoder.use_resource(output_buffer, metal::MTLResourceUsage::Write); | ||
encoder.dispatch_thread_groups(grid_size, group_size); | ||
encoder.end_encoding(); | ||
Ok(output) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
use crate::kernels::nn::Silu; | ||
use crate::tensor::MetalTensorExt; | ||
use derive_new::new; | ||
use tract_core::internal::*; | ||
|
||
#[derive(Clone, Debug, new, Hash)] | ||
pub struct MetalSilu; | ||
|
||
impl Op for MetalSilu { | ||
fn name(&self) -> Cow<str> { | ||
"MetalSilu".into() | ||
} | ||
|
||
op_as_typed_op!(); | ||
} | ||
|
||
impl EvalOp for MetalSilu { | ||
fn is_stateless(&self) -> bool { | ||
true | ||
} | ||
|
||
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> { | ||
objc::rc::autoreleasepool(|| { | ||
crate::METAL_CONTEXT.with_borrow(|context| { | ||
let input = args_1!(inputs); | ||
let input_metal = input.to_metal_tensor()?; | ||
Ok(tvec!(Silu | ||
.dispatch_eval(context, input_metal)? | ||
.into_opaque_tensor() | ||
.into_tvalue())) | ||
}) | ||
}) | ||
} | ||
} | ||
|
||
impl TypedOp for MetalSilu { | ||
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> { | ||
crate::utils::metal_output_facts(inputs, |facts| { | ||
let dt = facts[0].datum_type; | ||
let fact = dt.fact(facts[0].shape.clone()); | ||
Ok(tvec!(fact)) | ||
}) | ||
.with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name())) | ||
} | ||
|
||
as_op!(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
mod rms_norm; | ||
mod silu; | ||
|
||
use tract_core::internal::*; | ||
use tract_core::ops::konst::Const; | ||
|
||
pub use rms_norm::{as_rms_norm_rule, BasicRmsNorm}; | ||
pub use silu::{as_silu_rule, BasicSilu}; | ||
|
||
#[macro_export] | ||
macro_rules! rule_ensure { | ||
($cond:expr) => { | ||
if !$cond { | ||
return Ok(None); | ||
} | ||
}; | ||
} | ||
|
||
fn next_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> { | ||
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 { | ||
return None; | ||
} | ||
let succ = node.outputs[0].successors[0]; | ||
Some(&model.nodes()[succ.node]) | ||
} | ||
|
||
fn collect_node_const_inputs<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a Const> { | ||
node.inputs | ||
.iter() | ||
.filter_map(|i| { | ||
let prec = &model.nodes()[i.node]; | ||
prec.op_as::<Const>() | ||
}) | ||
.collect::<TVec<_>>() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
use crate::rewrite_rules::next_node; | ||
use crate::rule_ensure; | ||
use tract_core::internal::*; | ||
use tract_core::ops::binary::TypedBinOp; | ||
use tract_core::ops::element_wise::ElementWiseOp; | ||
use tract_core::ops::math::Mul; | ||
use tract_core::ops::nn::Sigmoid; | ||
|
||
#[derive(Clone, Debug, Hash)] | ||
pub struct BasicSilu; | ||
|
||
impl Op for BasicSilu { | ||
fn name(&self) -> Cow<str> { | ||
"BasicSilu".to_string().into() | ||
} | ||
op_as_typed_op!(); | ||
} | ||
|
||
impl EvalOp for BasicSilu { | ||
fn is_stateless(&self) -> bool { | ||
true | ||
} | ||
} | ||
|
||
impl TypedOp for BasicSilu { | ||
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> { | ||
let dt = inputs[0].datum_type; | ||
let fact = dt.fact(inputs[0].shape.clone()); | ||
Ok(tvec!(fact)) | ||
} | ||
|
||
as_op!(); | ||
} | ||
|
||
pub fn as_silu_rule( | ||
_ctx: &(), | ||
model: &TypedModel, | ||
node: &TypedNode, | ||
node_name: &str, | ||
op: &ElementWiseOp, | ||
) -> TractResult<Option<TypedModelPatch>> { | ||
// Search pattern => A = A * SIGMOID(A); | ||
|
||
rule_ensure!(op.0.is::<Sigmoid>()); | ||
|
||
let in_fact = model.node_input_facts(node.id)?[0]; | ||
let dt = in_fact.datum_type; | ||
|
||
// Only F16 and F32 is supported. | ||
rule_ensure!(matches!(dt, DatumType::F32 | DatumType::F16)); | ||
|
||
let mut patch = TypedModelPatch::default(); | ||
let silu_input = patch.taps(model, &node.inputs)?; | ||
// Identify Mul | ||
let Some(mul_succ) = next_node(model, node) else { return Ok(None) }; | ||
let Some(mul_succ_op) = mul_succ.op_as::<TypedBinOp>() else { return Ok(None) }; | ||
rule_ensure!(mul_succ_op.0.is::<Mul>()); | ||
rule_ensure!(mul_succ.inputs.contains(&node.inputs[0])); | ||
|
||
let out = patch.wire_node(format!("{node_name}.silu"), BasicSilu, &silu_input)?; | ||
|
||
patch.shunt_outside(model, mul_succ.id.into(), out[0])?; | ||
|
||
Ok(Some(patch)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters