From 15c1418aa4c8c6977528fc17973fdce7198e2989 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Fri, 12 Jul 2024 17:32:30 +0200 Subject: [PATCH 01/32] shapes indicies iterator + collapse_axis in tensorview --- data/src/lib.rs | 1 + data/src/tensor.rs | 1 + data/src/tensor/indices.rs | 73 ++++++++++++++++++++++++++++++++++++++ data/src/tensor/view.rs | 54 +++++++++++++++++++++++++--- 4 files changed, 124 insertions(+), 5 deletions(-) create mode 100644 data/src/tensor/indices.rs diff --git a/data/src/lib.rs b/data/src/lib.rs index 077e7ea24f..c811d584f4 100644 --- a/data/src/lib.rs +++ b/data/src/lib.rs @@ -43,6 +43,7 @@ pub mod internal { pub use crate::opaque::{ OpaquePayload, OpaqueFact }; pub use crate::prelude::*; pub use crate::tensor::view::TensorView; + pub use crate::tensor::indices::iter_indices; pub use crate::tensor::Approximation; pub use crate::tensor::vector_size; pub use anyhow::{anyhow, bail, ensure, format_err, Context as TractErrorContext}; diff --git a/data/src/tensor.rs b/data/src/tensor.rs index 9d1ee4a5ee..d54d02adc0 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -17,6 +17,7 @@ use std::hash::Hash; use std::ops::Range; use std::sync::Arc; +pub mod indices; pub mod litteral; pub mod view; diff --git a/data/src/tensor/indices.rs b/data/src/tensor/indices.rs new file mode 100644 index 0000000000..ad5bec1068 --- /dev/null +++ b/data/src/tensor/indices.rs @@ -0,0 +1,73 @@ +pub struct IndexIterator { + shape: Vec, + current_index: Vec, + done: bool, +} + +impl IndexIterator { + pub fn new(shape: &[usize]) -> Self { + let current_index = vec![0; shape.len()]; + Self { shape: shape.to_vec(), current_index, done: false } + } +} + +impl Iterator for IndexIterator { + type Item = Vec; + + fn next(&mut self) -> Option { + if self.done { + return None; + } + + let result = self.current_index.clone(); + + for i in (0..self.shape.len()).rev() { + if self.current_index[i] + 1 < self.shape[i] { + self.current_index[i] += 1; + // Reset all indices to the right of i to 0 + for j in i + 1..self.shape.len() { + self.current_index[j] = 0; + } + return Some(result); + } + } + + self.done = true; + Some(result) + } +} + +pub fn iter_indices(shape: &[usize]) -> IndexIterator { + IndexIterator::new(shape) +} + +#[cfg(test)] +mod test { + use super::iter_indices; + #[test] + fn test_single_element() { + let shape = vec![1, 1, 1]; + let expected_indices = vec![vec![0, 0, 0]]; + let iter = iter_indices(&shape); + let result: Vec> = iter.collect(); + assert_eq!(result, expected_indices); + } + + #[test] + fn test_3x1x1() { + let shape = vec![3, 1, 1]; + let expected_indices = vec![vec![0, 0, 0], vec![1, 0, 0], vec![2, 0, 0]]; + let iter = iter_indices(&shape); + let result: Vec> = iter.collect(); + assert_eq!(result, expected_indices); + } + + #[test] + fn test_2x2x1() { + let shape = vec![2, 2, 1]; + let expected_indices = vec![vec![0, 0, 0], vec![0, 1, 0], vec![1, 0, 0], vec![1, 1, 0]]; + let iter = iter_indices(&shape); + let result: Vec> = iter.collect(); + assert_eq!(result, expected_indices); + } +} diff --git a/data/src/tensor/view.rs b/data/src/tensor/view.rs index 29ae462fe7..f6fa718e0d 100644 --- a/data/src/tensor/view.rs +++ b/data/src/tensor/view.rs @@ -2,16 +2,16 @@ use super::*; use crate::internal::*; #[derive(Clone, Debug)] -enum Indexing<'a> { +enum Indexing { Prefix(usize), - Custom { shape: &'a [usize], strides: &'a [isize] }, + Custom { shape: Vec, strides: Vec }, } #[derive(Clone, Debug)] pub struct TensorView<'a> { pub tensor: &'a Tensor, offset_bytes: isize, - indexing: Indexing<'a>, + indexing: Indexing, } impl<'a> TensorView<'a> { @@ -21,7 +21,11 @@ impl<'a> TensorView<'a> { shape: &'a [usize], strides: &'a [isize], ) -> TensorView<'a> { - TensorView { tensor, offset_bytes, indexing: Indexing::Custom { shape, strides } } + TensorView { + tensor, + offset_bytes, + indexing: Indexing::Custom { shape: shape.to_vec(), strides: strides.to_vec() }, + } } pub fn offsetting(tensor: &'a Tensor, coords: &[usize]) -> TractResult> { @@ -41,7 +45,10 @@ impl<'a> TensorView<'a> { TensorView { tensor, offset_bytes, - indexing: Indexing::Custom { shape: &tensor.shape, strides: &tensor.strides }, + indexing: Indexing::Custom { + shape: tensor.shape.to_vec(), + strides: tensor.strides.to_vec(), + }, } } @@ -229,6 +236,29 @@ impl<'a> TensorView<'a> { unsafe { Ok(self.at_unchecked(coords)) } } + #[inline] + pub fn collapse_axis(&mut self, axis: usize, index: isize) { + let stride = self.strides()[axis] * self.datum_type().size_of() as isize; + unsafe { self.offset_bytes(stride * index) }; + match &mut self.indexing { + Indexing::Prefix(x) => { + if *x == 0 { + let mut new_shape = self.tensor.shape().to_owned(); + new_shape[axis] = 1; + self.indexing = Indexing::Custom { + shape: new_shape, + strides: self.tensor.strides().to_owned(), + } + } else { + unimplemented!("TODO: understand how it is used") + } + } + Indexing::Custom { shape, .. } => { + shape[axis] = 1; + } + } + } + #[inline] pub fn at_mut(&mut self, coords: impl AsRef<[usize]>) -> TractResult<&mut T> { self.check_dt::()?; @@ -254,3 +284,17 @@ impl<'a> TensorView<'a> { } */ } + +#[cfg(test)] +mod test { + use crate::prelude::Tensor; + + #[test] + fn test_collapse_axis() { + let a = Tensor::from_shape(&[2, 2], &[1, 2, 3, 4]).unwrap(); + let mut a_view = a.view(); + a_view.collapse_axis(0, 1); + assert_eq!(a_view.shape(), &[1, 2]); + assert_eq!(a_view.as_slice::().unwrap(), &[3, 4]); + } +} From 2c06fe756f07234c2e773e91737b0e52bef0bac8 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Fri, 12 Jul 2024 17:53:53 +0200 Subject: [PATCH 02/32] Introduce BinOpByScalar & BinOpUnicast --- core/src/ops/binary.rs | 268 +++++++++++++++++++++++++++---- core/src/ops/matmul/optimized.rs | 2 +- core/src/ops/mod.rs | 1 + core/src/ops/quant.rs | 6 + 4 files changed, 249 insertions(+), 28 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 0e3ebb6c24..45346f065d 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -18,6 +18,10 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>; fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>; + // Temporary introduced to test TensorView approach + fn eval_by_scalar(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()>; + fn eval_unicast(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()>; + #[allow(unused_variables)] fn maybe_eval_qbinary_as_float_op( &self, @@ -202,39 +206,201 @@ impl TypedOp for TypedBinOp { model: &TypedModel, node: &TypedNode, ) -> TractResult> { + // Need to add new methods to BinMiniOp firt + //if let Some(neutral_patch) = declutter_neutral(model, node, &self.0)? { + // return Ok(Some(neutral_patch)) + //} + + let (by_scalar_should_be_efficient, unicast_should_be_efficient) = find_most_efficient_config(model, node)?; + let can_eval_in_a = if let &[a, b] = &*model.node_input_facts(node.id)? { + let c_dt = self.output_datum_type(a.datum_type, b.datum_type)?; + let c_shape = crate::broadcast::multi_broadcast(&[a.shape.clone(), b.shape.clone()])?; + (c_shape == a.shape.to_tvec()) && (c_dt == a.datum_type) + } else { + false + }; + + // Don't declutter yet (missing AST for ByScalar + ByUnicast) + //if by_scalar_should_be_efficient & can_eval_in_a { + // return Ok(Some( + // TypedModelPatch::replace_single_op( + // model, + // node, + // &node.inputs, + // BinOpByScalar(self.0.clone()), + // )? + // .with_context("ByScalar"), + // )) + //} + + //if unicast_should_be_efficient & can_eval_in_a { + // return Ok(Some( + // TypedModelPatch::replace_single_op( + // model, + // node, + // &node.inputs, + // BinOpUnicast(self.0.clone()), + // )? + // .with_context("Unicast"), + // )) + //} self.0.declutter(model, node) } - fn codegen( + as_op!(); +} + +fn find_most_efficient_config( + model: &TypedModel, + node: &TypedNode, +) -> TractResult<(bool, bool)> { + if let &[a, b] = &*model.node_input_facts(node.id)? { + let a_shape = a.shape.clone(); + let b_shape = b.shape.clone(); + + let by_scalar_is_possible = BinOpByScalar::check_input_shapes(&a_shape, &b_shape); + let num_trailing_elements = if by_scalar_is_possible { + a_shape.iter().zip(b_shape.iter()).rev().take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1)).map(|(rev_a_dim, _)| rev_a_dim).product::() + } else { + TDim::Val(0) + }; + + let unicast_is_possible = BinOpUnicast::check_input_shapes(&a_shape, &b_shape); + let num_leading_elements = if unicast_is_possible { + a_shape.iter().zip(b_shape.iter()).take_while(|(_, b_dim)| **b_dim == TDim::Val(1)).map(|(a_dim, _)| a_dim).product::() + } else { + TDim::Val(0) + }; + + let min_num_elements = 32; + let by_scalar_should_be_efficient = gt_tdim(num_trailing_elements, min_num_elements); + let unicast_should_be_efficient = gt_tdim(num_leading_elements, min_num_elements); + return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient)) + } + Ok((false, false)) +} + +pub fn gt_tdim(x: TDim, min_val: i64) -> bool { + TDim::Val(min_val).mini(x).to_i64().map_or(false, |v| v == min_val) +} + +#[derive(Debug, Clone)] +pub struct BinOpByScalar(pub Box); + +impl BinOpByScalar { + fn check_input_shapes(a_shape: &[TDim], b_shape:&[TDim]) -> bool{ + if a_shape.len() != b_shape.len() {return false}; + + let mut must_be_unary = false; + a_shape.iter().zip(b_shape.iter()).all(|(a_dim, b_dim)| { + // As soon as a and b dimensions differ, b dimensions must be 1 until the end. + if (a_dim != b_dim) && !must_be_unary { + must_be_unary = true + } + + // Leading dimensions: a_dim==b_dim condition + // Trailing dimensison: b_dim == 1 + ((a_dim == b_dim) & !must_be_unary) || ((*b_dim == 1.into()) & must_be_unary) + }) + } +} + +impl Op for BinOpByScalar { + fn name(&self) -> Cow { + format!("{}ByScalar", self.0.name()).into() + } + + op_as_typed_op!(); +} + +impl EvalOp for BinOpByScalar { + fn is_stateless(&self) -> bool { + true + } + + fn eval(&self, inputs: TVec) -> TractResult> { + let (a, b) = args_2!(inputs); + let mut a = a.into_tensor(); + let b_shape = b.shape(); + let view = a.view_mut(); + let b_view = b.view(); + + let trailing_unary_dims: Vec = b_shape.iter() + .enumerate() + .rev() + .take_while(|&(_, &dim)| dim == 1) + .map(|(i, _)| i) + .collect(); + for it_coords in tract_data::internal::iter_indices(b_shape) { + // Prepare array view to perform computation + // - view should be a slice + // - b should be a scalar + let mut view = view.clone(); + let mut tmp_b_view = b_view.clone(); + for idx in 0..b_shape.len() { + if !trailing_unary_dims.contains(&idx) { + view.collapse_axis(idx, it_coords[idx] as isize); + tmp_b_view.collapse_axis(idx, it_coords[idx] as isize); + } + } + + self.0.eval_by_scalar(&mut view, &tmp_b_view)?; + } + Ok(tvec!(a.into_tvalue())) + } +} + +impl TypedOp for BinOpByScalar { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape)); + let out_dt = self.0.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?; + let out_shape = inputs[0].shape.clone(); + Ok(tvec!(out_dt.fact(out_shape))) + } + + fn cost(&self, inputs: &[&TypedFact]) -> TractResult> { + let count: TDim = self.output_facts(inputs)?[0].shape.iter().product(); + Ok(self + .0 + .cost_per_element(inputs[0].datum_type) + .into_iter() + .map(|(c, n)| (c, count.clone() * n)) + .collect()) + } + + fn declutter( &self, model: &TypedModel, node: &TypedNode, ) -> TractResult> { - let facts = model.node_input_facts(node.id)?; - if self.output_datum_type(facts[0].datum_type, facts[1].datum_type)? == facts[0].datum_type - && facts[0].without_value() == facts[1].without_value() - { - Ok(Some( - TypedModelPatch::replace_single_op( - model, - node, - &node.inputs, - MergeOpUnicast(self.0.clone()), - )? - .with_context("Unicast"), - )) - } else { - Ok(None) - } + self.0.declutter(model, node) } as_op!(); } - + #[derive(Debug, Clone)] -pub struct MergeOpUnicast(pub Box); +pub struct BinOpUnicast(pub Box); + +impl BinOpUnicast { + fn check_input_shapes(a_shape: &[TDim], b_shape:&[TDim]) -> bool{ + if a_shape.len() != b_shape.len() {return false}; -impl Op for MergeOpUnicast { + let mut must_be_equal = false; + a_shape.iter().zip(b_shape.iter()).all(|(a_dim, b_dim)| { + // As soon as b dimension not equal to one, a and b dimensions must be equal. + if (*b_dim != 1.into()) && !must_be_equal { + must_be_equal = true + } + + // Leading dimensions: b_dim==1 condition + // Trailing dimensison: a_dim == b_dim + ((*b_dim == 1.into()) & !must_be_equal) || ((a_dim == b_dim) & must_be_equal) + }) + } +} + +impl Op for BinOpUnicast { fn name(&self) -> Cow { format!("{}Unicast", self.0.name()).into() } @@ -242,23 +408,42 @@ impl Op for MergeOpUnicast { op_as_typed_op!(); } -impl EvalOp for MergeOpUnicast { +impl EvalOp for BinOpUnicast { fn is_stateless(&self) -> bool { true } fn eval(&self, inputs: TVec) -> TractResult> { let (a, b) = args_2!(inputs); - let mut b = b.into_tensor(); - self.0.eval_unicast_in_place(&a, &mut b)?; - Ok(tvec!(b.into_tvalue())) + let mut a = a.into_tensor(); + let b_shape = b.shape(); + let view = a.view_mut(); + let b_view = b.view(); + + let leading_unary_dims: Vec = b_shape.iter() + .enumerate() + .take_while(|&(_, &dim)| dim == 1) + .map(|(i, _)| i) + .collect(); + for it_coords in tract_data::internal::iter_indices(b_shape) { + let mut view = view.clone(); + for idx in 0..view.shape().len() { + if leading_unary_dims.contains(&idx) { + view.collapse_axis(idx, it_coords[idx] as isize); + } + } + self.0.eval_unicast(&mut view, &b_view)?; + } + Ok(tvec!(a.into_tvalue())) } } -impl TypedOp for MergeOpUnicast { +impl TypedOp for BinOpUnicast { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - debug_assert_eq!(inputs[0].shape, inputs[1].shape); - Ok(tvec!(inputs[0].without_value())) + ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape)); + let out_dt = self.0.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?; + let out_shape = inputs[0].shape.clone(); + Ok(tvec!(out_dt.fact(out_shape))) } fn cost(&self, inputs: &[&TypedFact]) -> TractResult> { @@ -282,6 +467,7 @@ impl TypedOp for MergeOpUnicast { as_op!(); } + #[macro_export] macro_rules! bin_to_super_type { ($func:ident, $Op:ident, @@ -294,6 +480,8 @@ macro_rules! bin_to_super_type { $(operating_datum_type: $operating_datum_type:expr,)? $(uniform_in_place: $uniform_in_place:expr,)? $(unicast_in_place: $unicast_in_place:expr,)? + $(eval_by_scalar: $eval_by_scalar:expr,)? + $(eval_unicast: $eval_unicast:expr,)? $(out_of_place: $out_of_place:expr,)? $(validation: $validation:expr,)? $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)? @@ -420,6 +608,32 @@ macro_rules! bin_to_super_type { bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type()); } + fn eval_by_scalar(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> { + $(if $eval_by_scalar(a, b)? { return Ok(()) } )? + $( + $(if b.datum_type() == $typ::datum_type() { + let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab; + let b = &b.as_slice::<$typ>()?[0]; + let a_slice = a.as_slice_mut::<$typ>()?; + a_slice.iter_mut().for_each(|a| cab(a, &a.clone(), b)); + return Ok(()) + })* + )* + bail!("{} does not support {:?} (eval by scalar)", self.name(), a.datum_type()); + } + fn eval_unicast(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> { + $(if $eval_unicast(a, b)? { return Ok(()) } )? + $( + $(if b.datum_type() == $typ::datum_type() { + let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab; + let b = &b.as_slice::<$typ>()?[0]; + let a_slice = a.as_slice_mut::<$typ>()?; + a_slice.iter_mut().for_each(|a| cab(a, &a.clone(), b)); + return Ok(()) + })* + )* + bail!("{} does not support {:?} (eval unicast)", self.name(), a.datum_type()); + } fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> { // c and a are same type $(if $eval_in_a(a, b)? { return Ok(()) } )? diff --git a/core/src/ops/matmul/optimized.rs b/core/src/ops/matmul/optimized.rs index f059b52ff8..6b1fb65a71 100644 --- a/core/src/ops/matmul/optimized.rs +++ b/core/src/ops/matmul/optimized.rs @@ -552,7 +552,7 @@ impl TypedOp for OptMatMul { } } } - if let Some(op) = succ.op_as::() { + if let Some(op) = succ.op_as::() { if op.0.is::() && self.mmm.len() == 1 { let other_slot = 1 - node.outputs[0].successors[0].slot; let other_input = succ.inputs[other_slot]; diff --git a/core/src/ops/mod.rs b/core/src/ops/mod.rs index 4760d1207b..353f5b36af 100644 --- a/core/src/ops/mod.rs +++ b/core/src/ops/mod.rs @@ -11,6 +11,7 @@ pub mod macros; pub mod element_wise; #[macro_use] pub mod binary; +//pub mod binary_new; pub mod array; pub mod cast; diff --git a/core/src/ops/quant.rs b/core/src/ops/quant.rs index a968184234..b04f37f132 100644 --- a/core/src/ops/quant.rs +++ b/core/src/ops/quant.rs @@ -277,6 +277,12 @@ impl crate::ops::binary::BinMiniOp for Scale { fn name(&self) -> &'static str { "Scale" } + fn eval_by_scalar(&self, _a: &mut TensorView, _b: &TensorView) -> TractResult<()> { + unimplemented!("Eval by scalar not implemented") + } + fn eval_unicast(&self, _a: &mut TensorView, _b: &TensorView) -> TractResult<()> { + unimplemented!("Eval unicast not implemented") + } fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult { if !a.is_float() { From c7958c37c08f74e51405a2c3ad2e4c563a5dbde8 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Fri, 12 Jul 2024 18:18:18 +0200 Subject: [PATCH 03/32] Add serialization of BinOpByScalar + BinOpUncast --- core/src/ops/binary.rs | 45 +++++++++++++++++++++--------------------- nnef/src/registry.rs | 18 +++++++++++++++++ 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 45346f065d..fbb8f85b1c 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -220,30 +220,29 @@ impl TypedOp for TypedBinOp { false }; - // Don't declutter yet (missing AST for ByScalar + ByUnicast) - //if by_scalar_should_be_efficient & can_eval_in_a { - // return Ok(Some( - // TypedModelPatch::replace_single_op( - // model, - // node, - // &node.inputs, - // BinOpByScalar(self.0.clone()), - // )? - // .with_context("ByScalar"), - // )) - //} + if by_scalar_should_be_efficient & can_eval_in_a { + return Ok(Some( + TypedModelPatch::replace_single_op( + model, + node, + &node.inputs, + BinOpByScalar(self.0.clone()), + )? + .with_context("ByScalar"), + )) + } - //if unicast_should_be_efficient & can_eval_in_a { - // return Ok(Some( - // TypedModelPatch::replace_single_op( - // model, - // node, - // &node.inputs, - // BinOpUnicast(self.0.clone()), - // )? - // .with_context("Unicast"), - // )) - //} + if unicast_should_be_efficient & can_eval_in_a { + return Ok(Some( + TypedModelPatch::replace_single_op( + model, + node, + &node.inputs, + BinOpUnicast(self.0.clone()), + )? + .with_context("Unicast"), + )) + } self.0.declutter(model, node) } diff --git a/nnef/src/registry.rs b/nnef/src/registry.rs index 88cecd55c2..323ccb825e 100644 --- a/nnef/src/registry.rs +++ b/nnef/src/registry.rs @@ -158,6 +158,24 @@ impl Registry { let b = ast.mapping[&node.inputs[1]].clone(); return Ok(Some(invocation(&op.0, &[a, b], &[]))); } + // Temporary allow for by scalar serialization + } else if let Some(op) = node.op().downcast_ref::() { + if let Some(op) = + self.binary_ops.iter().find(|ew| ew.1.as_ref().type_id() == op.0.type_id()) + { + let a = ast.mapping[&node.inputs[0]].clone(); + let b = ast.mapping[&node.inputs[1]].clone(); + return Ok(Some(invocation(&op.0, &[a, b], &[]))); + } + // Temporary allow for unicast serialization + } else if let Some(op) = node.op().downcast_ref::() { + if let Some(op) = + self.binary_ops.iter().find(|ew| ew.1.as_ref().type_id() == op.0.type_id()) + { + let a = ast.mapping[&node.inputs[0]].clone(); + let b = ast.mapping[&node.inputs[1]].clone(); + return Ok(Some(invocation(&op.0, &[a, b], &[]))); + } } else if let Some(op) = self.from_tract.get(&node.op().type_id()) { if let Some(result) = op(ast, node)? { return Ok(Some(result)); From 73cf50d59ba4f3c82584092ff48da036ee956f5a Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Tue, 16 Jul 2024 11:26:03 +0200 Subject: [PATCH 04/32] Fix unicast & avoid quant bin op declutter --- core/src/ops/binary.rs | 52 +++++++++++++++++++++++--------- test-rt/suite-unit/src/conv_q.rs | 16 ++++++++++ 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index fbb8f85b1c..2725593dcf 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -212,6 +212,12 @@ impl TypedOp for TypedBinOp { //} let (by_scalar_should_be_efficient, unicast_should_be_efficient) = find_most_efficient_config(model, node)?; + let op_is_quant = if let &[a, b] = &*model.node_input_facts(node.id)? { + let c_dt = self.output_datum_type(a.datum_type, b.datum_type)?; + c_dt.is_quantized() || a.datum_type.is_quantized() || b.datum_type.is_quantized() + } else { + false + }; let can_eval_in_a = if let &[a, b] = &*model.node_input_facts(node.id)? { let c_dt = self.output_datum_type(a.datum_type, b.datum_type)?; let c_shape = crate::broadcast::multi_broadcast(&[a.shape.clone(), b.shape.clone()])?; @@ -220,7 +226,7 @@ impl TypedOp for TypedBinOp { false }; - if by_scalar_should_be_efficient & can_eval_in_a { + if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant{ return Ok(Some( TypedModelPatch::replace_single_op( model, @@ -232,7 +238,7 @@ impl TypedOp for TypedBinOp { )) } - if unicast_should_be_efficient & can_eval_in_a { + if unicast_should_be_efficient & can_eval_in_a & !op_is_quant { return Ok(Some( TypedModelPatch::replace_single_op( model, @@ -258,22 +264,22 @@ fn find_most_efficient_config( let b_shape = b.shape.clone(); let by_scalar_is_possible = BinOpByScalar::check_input_shapes(&a_shape, &b_shape); - let num_trailing_elements = if by_scalar_is_possible { + let num_by_scalar_elements = if by_scalar_is_possible { a_shape.iter().zip(b_shape.iter()).rev().take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1)).map(|(rev_a_dim, _)| rev_a_dim).product::() } else { TDim::Val(0) }; let unicast_is_possible = BinOpUnicast::check_input_shapes(&a_shape, &b_shape); - let num_leading_elements = if unicast_is_possible { - a_shape.iter().zip(b_shape.iter()).take_while(|(_, b_dim)| **b_dim == TDim::Val(1)).map(|(a_dim, _)| a_dim).product::() + let num_unicast_elements = if unicast_is_possible { + a_shape.iter().zip(b_shape.iter()).rev().take_while(|(a_dim, b_dim)| a_dim == b_dim).map(|(a_dim, _)| a_dim).product::() } else { TDim::Val(0) }; let min_num_elements = 32; - let by_scalar_should_be_efficient = gt_tdim(num_trailing_elements, min_num_elements); - let unicast_should_be_efficient = gt_tdim(num_leading_elements, min_num_elements); + let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements); + let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements); return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient)) } Ok((false, false)) @@ -424,13 +430,23 @@ impl EvalOp for BinOpUnicast { .take_while(|&(_, &dim)| dim == 1) .map(|(i, _)| i) .collect(); - for it_coords in tract_data::internal::iter_indices(b_shape) { + + // We only iterate on a dims that correspond to b leading_unary dims. + // To to so, we set all remaining a dims to 1. + // ex: A: [2, 16, 16, 32] B: [1, 1, 16, 32] -> [2, 16, 1, 1] + let mut iterating_shape = view.shape().to_vec(); + iterating_shape.iter_mut().enumerate().for_each(|(idx, dim)| { + if !leading_unary_dims.contains(&idx) { + *dim = 1 + } + }); + + // Iterate on outter dimensions and evaluate with unicast subviews + for it_coords in tract_data::internal::iter_indices(&iterating_shape) { let mut view = view.clone(); - for idx in 0..view.shape().len() { - if leading_unary_dims.contains(&idx) { - view.collapse_axis(idx, it_coords[idx] as isize); - } - } + leading_unary_dims.iter().for_each(|idx| { + view.collapse_axis(*idx, it_coords[*idx] as isize); + }); self.0.eval_unicast(&mut view, &b_view)?; } Ok(tvec!(a.into_tvalue())) @@ -625,9 +641,15 @@ macro_rules! bin_to_super_type { $( $(if b.datum_type() == $typ::datum_type() { let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab; - let b = &b.as_slice::<$typ>()?[0]; + let b = &b.as_slice::<$typ>()?; let a_slice = a.as_slice_mut::<$typ>()?; - a_slice.iter_mut().for_each(|a| cab(a, &a.clone(), b)); + unsafe { + for i in 0..b.len() { + let mut c = $typ::default(); + cab(&mut c, &a_slice[i], b.get_unchecked(i)); + *a_slice.get_unchecked_mut(i) = c; + } + } return Ok(()) })* )* diff --git a/test-rt/suite-unit/src/conv_q.rs b/test-rt/suite-unit/src/conv_q.rs index 8c83acebfa..6ccac71576 100644 --- a/test-rt/suite-unit/src/conv_q.rs +++ b/test-rt/suite-unit/src/conv_q.rs @@ -1221,5 +1221,21 @@ pub fn suite() -> TractResult { raw_output_dt: DatumType::I32, }, ); + let mut qp = qp_noop_i8(); + qp[0] = tensor0(-3); + suite.add( + "bin_by_scalar_and_bin_unicast_selection_0", + QConvProblem { + shape_in: NHWC.from_n_c_hw(2, 2, [4, 4]).unwrap(), + co: 2, + kernel_format: OIHW, + group: 2, + kernel: tensor4(&[[[[1i8]]],[[[0i8]]]]), + bias: None, + data: Tensor::zero::(&[2, 4, 4, 2]).unwrap(), + qp, + raw_output_dt: DatumType::I8, + }, + ); Ok(suite) } From e40c04d6ca3d0800f2c371a78fde1c316775ca5d Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Tue, 16 Jul 2024 15:12:07 +0200 Subject: [PATCH 05/32] conversion in optimize instead of declutter --- core/src/ops/binary.rs | 140 +++++++++++++++++++++++------------------ nnef/src/registry.rs | 18 ------ 2 files changed, 80 insertions(+), 78 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 2725593dcf..34e49112e1 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -210,8 +210,17 @@ impl TypedOp for TypedBinOp { //if let Some(neutral_patch) = declutter_neutral(model, node, &self.0)? { // return Ok(Some(neutral_patch)) //} - - let (by_scalar_should_be_efficient, unicast_should_be_efficient) = find_most_efficient_config(model, node)?; + + self.0.declutter(model, node) + } + + fn codegen( + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { + let (by_scalar_should_be_efficient, unicast_should_be_efficient) = + find_most_efficient_config(model, node)?; let op_is_quant = if let &[a, b] = &*model.node_input_facts(node.id)? { let c_dt = self.output_datum_type(a.datum_type, b.datum_type)?; c_dt.is_quantized() || a.datum_type.is_quantized() || b.datum_type.is_quantized() @@ -226,7 +235,7 @@ impl TypedOp for TypedBinOp { false }; - if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant{ + if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant { return Ok(Some( TypedModelPatch::replace_single_op( model, @@ -235,7 +244,7 @@ impl TypedOp for TypedBinOp { BinOpByScalar(self.0.clone()), )? .with_context("ByScalar"), - )) + )); } if unicast_should_be_efficient & can_eval_in_a & !op_is_quant { @@ -247,32 +256,41 @@ impl TypedOp for TypedBinOp { BinOpUnicast(self.0.clone()), )? .with_context("Unicast"), - )) + )); } - self.0.declutter(model, node) - } + Ok(None) + } as_op!(); } -fn find_most_efficient_config( - model: &TypedModel, - node: &TypedNode, -) -> TractResult<(bool, bool)> { +fn find_most_efficient_config(model: &TypedModel, node: &TypedNode) -> TractResult<(bool, bool)> { if let &[a, b] = &*model.node_input_facts(node.id)? { let a_shape = a.shape.clone(); let b_shape = b.shape.clone(); let by_scalar_is_possible = BinOpByScalar::check_input_shapes(&a_shape, &b_shape); - let num_by_scalar_elements = if by_scalar_is_possible { - a_shape.iter().zip(b_shape.iter()).rev().take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1)).map(|(rev_a_dim, _)| rev_a_dim).product::() + let num_by_scalar_elements = if by_scalar_is_possible { + a_shape + .iter() + .zip(b_shape.iter()) + .rev() + .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1)) + .map(|(rev_a_dim, _)| rev_a_dim) + .product::() } else { TDim::Val(0) }; - + let unicast_is_possible = BinOpUnicast::check_input_shapes(&a_shape, &b_shape); let num_unicast_elements = if unicast_is_possible { - a_shape.iter().zip(b_shape.iter()).rev().take_while(|(a_dim, b_dim)| a_dim == b_dim).map(|(a_dim, _)| a_dim).product::() + a_shape + .iter() + .zip(b_shape.iter()) + .rev() + .take_while(|(a_dim, b_dim)| a_dim == b_dim) + .map(|(a_dim, _)| a_dim) + .product::() } else { TDim::Val(0) }; @@ -280,7 +298,7 @@ fn find_most_efficient_config( let min_num_elements = 32; let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements); let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements); - return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient)) + return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient)); } Ok((false, false)) } @@ -293,8 +311,10 @@ pub fn gt_tdim(x: TDim, min_val: i64) -> bool { pub struct BinOpByScalar(pub Box); impl BinOpByScalar { - fn check_input_shapes(a_shape: &[TDim], b_shape:&[TDim]) -> bool{ - if a_shape.len() != b_shape.len() {return false}; + fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool { + if a_shape.len() != b_shape.len() { + return false; + }; let mut must_be_unary = false; a_shape.iter().zip(b_shape.iter()).all(|(a_dim, b_dim)| { @@ -327,29 +347,34 @@ impl EvalOp for BinOpByScalar { let (a, b) = args_2!(inputs); let mut a = a.into_tensor(); let b_shape = b.shape(); - let view = a.view_mut(); + let mut view = a.view_mut(); let b_view = b.view(); - let trailing_unary_dims: Vec = b_shape.iter() + let first_unary_axis = b_shape + .iter() .enumerate() .rev() .take_while(|&(_, &dim)| dim == 1) .map(|(i, _)| i) - .collect(); - for it_coords in tract_data::internal::iter_indices(b_shape) { - // Prepare array view to perform computation - // - view should be a slice - // - b should be a scalar - let mut view = view.clone(); - let mut tmp_b_view = b_view.clone(); - for idx in 0..b_shape.len() { - if !trailing_unary_dims.contains(&idx) { - view.collapse_axis(idx, it_coords[idx] as isize); - tmp_b_view.collapse_axis(idx, it_coords[idx] as isize); + .last() + .context("Cannot use by_scalar when no trailing dimensions are unary")?; + + let iterating_shape = view.shape()[..first_unary_axis].to_vec(); + if !iterating_shape.is_empty() { + for it_coords in tract_data::internal::iter_indices(&iterating_shape) { + let mut view = view.clone(); + let mut tmp_b_view = b_view.clone(); + + // Prepare array view to perform computation + for (axis, idx) in it_coords.iter().enumerate() { + view.collapse_axis(axis, *idx as isize); + tmp_b_view.collapse_axis(axis, *idx as isize); } - } - self.0.eval_by_scalar(&mut view, &tmp_b_view)?; + self.0.eval_by_scalar(&mut view, &tmp_b_view)?; + } + } else { + self.0.eval_by_scalar(&mut view, &b_view)?; } Ok(tvec!(a.into_tvalue())) } @@ -383,18 +408,20 @@ impl TypedOp for BinOpByScalar { as_op!(); } - + #[derive(Debug, Clone)] pub struct BinOpUnicast(pub Box); impl BinOpUnicast { - fn check_input_shapes(a_shape: &[TDim], b_shape:&[TDim]) -> bool{ - if a_shape.len() != b_shape.len() {return false}; + fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool { + if a_shape.len() != b_shape.len() { + return false; + }; let mut must_be_equal = false; a_shape.iter().zip(b_shape.iter()).all(|(a_dim, b_dim)| { // As soon as b dimension not equal to one, a and b dimensions must be equal. - if (*b_dim != 1.into()) && !must_be_equal { + if (*b_dim != 1.into()) && !must_be_equal { must_be_equal = true } @@ -422,33 +449,27 @@ impl EvalOp for BinOpUnicast { let (a, b) = args_2!(inputs); let mut a = a.into_tensor(); let b_shape = b.shape(); - let view = a.view_mut(); + let mut view = a.view_mut(); let b_view = b.view(); - let leading_unary_dims: Vec = b_shape.iter() - .enumerate() - .take_while(|&(_, &dim)| dim == 1) - .map(|(i, _)| i) - .collect(); - - // We only iterate on a dims that correspond to b leading_unary dims. - // To to so, we set all remaining a dims to 1. - // ex: A: [2, 16, 16, 32] B: [1, 1, 16, 32] -> [2, 16, 1, 1] - let mut iterating_shape = view.shape().to_vec(); - iterating_shape.iter_mut().enumerate().for_each(|(idx, dim)| { - if !leading_unary_dims.contains(&idx) { - *dim = 1 + let first_non_unary_axis = + b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last(); + + if let Some(first_non_unary_axis) = first_non_unary_axis { + // Iterate on outter dimensions and evaluate with unicast subviews + let iterating_shape = view.shape()[..first_non_unary_axis].to_vec(); + for it_coords in tract_data::internal::iter_indices(&iterating_shape) { + let mut view = view.clone(); + it_coords.iter().enumerate().for_each(|(axis, idx)| { + view.collapse_axis(axis, *idx as isize); + }); + self.0.eval_unicast(&mut view, &b_view)?; } - }); - - // Iterate on outter dimensions and evaluate with unicast subviews - for it_coords in tract_data::internal::iter_indices(&iterating_shape) { - let mut view = view.clone(); - leading_unary_dims.iter().for_each(|idx| { - view.collapse_axis(*idx, it_coords[*idx] as isize); - }); + } else { + debug_assert_eq!(view.shape(), b_view.shape()); self.0.eval_unicast(&mut view, &b_view)?; } + Ok(tvec!(a.into_tvalue())) } } @@ -482,7 +503,6 @@ impl TypedOp for BinOpUnicast { as_op!(); } - #[macro_export] macro_rules! bin_to_super_type { ($func:ident, $Op:ident, diff --git a/nnef/src/registry.rs b/nnef/src/registry.rs index 323ccb825e..88cecd55c2 100644 --- a/nnef/src/registry.rs +++ b/nnef/src/registry.rs @@ -158,24 +158,6 @@ impl Registry { let b = ast.mapping[&node.inputs[1]].clone(); return Ok(Some(invocation(&op.0, &[a, b], &[]))); } - // Temporary allow for by scalar serialization - } else if let Some(op) = node.op().downcast_ref::() { - if let Some(op) = - self.binary_ops.iter().find(|ew| ew.1.as_ref().type_id() == op.0.type_id()) - { - let a = ast.mapping[&node.inputs[0]].clone(); - let b = ast.mapping[&node.inputs[1]].clone(); - return Ok(Some(invocation(&op.0, &[a, b], &[]))); - } - // Temporary allow for unicast serialization - } else if let Some(op) = node.op().downcast_ref::() { - if let Some(op) = - self.binary_ops.iter().find(|ew| ew.1.as_ref().type_id() == op.0.type_id()) - { - let a = ast.mapping[&node.inputs[0]].clone(); - let b = ast.mapping[&node.inputs[1]].clone(); - return Ok(Some(invocation(&op.0, &[a, b], &[]))); - } } else if let Some(op) = self.from_tract.get(&node.op().type_id()) { if let Some(result) = op(ast, node)? { return Ok(Some(result)); From 6d6b13cb2153ba02aea90351298e071b5d0f2e1d Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Tue, 16 Jul 2024 17:23:14 +0200 Subject: [PATCH 06/32] Add declutter neutral to typed op --- core/src/ops/binary.rs | 74 +++++++++++++++++++++++++++++++++++++--- core/src/ops/math/mod.rs | 68 ++++++------------------------------ 2 files changed, 81 insertions(+), 61 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 34e49112e1..484fbfd4f4 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -4,6 +4,8 @@ use tract_itertools::Itertools; use std::fmt; use tract_data::itertools::izip; +use super::cast::cast; + pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast { fn name(&self) -> &'static str; fn validation(&self) -> Validation { @@ -22,6 +24,13 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + fn eval_by_scalar(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()>; fn eval_unicast(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()>; + fn is_commutative(&self) -> bool { + true + } + fn neutral_element(&self) -> Option { + None + } + #[allow(unused_variables)] fn maybe_eval_qbinary_as_float_op( &self, @@ -206,10 +215,10 @@ impl TypedOp for TypedBinOp { model: &TypedModel, node: &TypedNode, ) -> TractResult> { - // Need to add new methods to BinMiniOp firt - //if let Some(neutral_patch) = declutter_neutral(model, node, &self.0)? { - // return Ok(Some(neutral_patch)) - //} + let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {(a.datum_type().unwrap(), b.datum_type().unwrap())} else {unreachable!("")}; + if let Some(neutral_patch) = declutter_neutral(model, node, &self.0, self.output_datum_type(a_dt, b_dt)?)? { + return Ok(Some(neutral_patch)); + } self.0.declutter(model, node) } @@ -264,6 +273,55 @@ impl TypedOp for TypedBinOp { as_op!(); } +fn declutter_neutral( + model: &TypedModel, + node: &TypedNode, + mini_op: &Box, + out_dt: DatumType, +) -> TractResult> { + if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? { + // Not sure to understand why this check was needed + //let integer = uniform.uni.cast_to_scalar::()?; + //let is_scalar = tensor0(integer) + // .cast_to_dt(uniform.uni.datum_type())? + // .close_enough(&uniform.uni, false) + // .is_ok(); + + let is_neutral = mini_op + .neutral_element() + .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok()) + .unwrap_or(false); + + // For some operand neural element can be the left one while for other + // it is not the case (neutral - 1 -> not ok, 1 - neutal -> ok) + let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform; + + if is_neutral && pos_checked { + // Neutral decluttering for quant values is special. + // - if (fa) (a-az)*as + (fb = 0) (b-bz)*bs = (fc) (c-cz)*cs + // - then even if fa = fc, quant params needs to be updated (a != c). + // So it's not a no_op. + if uniform.uni.datum_type().is_quantized() { + return Ok(Some(TypedModelPatch::replace_single_op( + model, + node, + &[node.inputs[0]], + cast(out_dt), + )?)); + // In the non quantized case, it's a no_op. + } else { + return Ok(Some(TypedModelPatch::rewire( + model, + &[uniform.var], + &[node.id.into()], + &|_, inputs| Ok(inputs.into()), + )?)); + } + } + } + Ok(None) +} + fn find_most_efficient_config(model: &TypedModel, node: &TypedNode) -> TractResult<(bool, bool)> { if let &[a, b] = &*model.node_input_facts(node.id)? { let a_shape = a.shape.clone(); @@ -517,6 +575,8 @@ macro_rules! bin_to_super_type { $(unicast_in_place: $unicast_in_place:expr,)? $(eval_by_scalar: $eval_by_scalar:expr,)? $(eval_unicast: $eval_unicast:expr,)? + $(is_commutative: $is_commutative:expr,)? + $(neutral_element: $neutral_element:expr,)? $(out_of_place: $out_of_place:expr,)? $(validation: $validation:expr,)? $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)? @@ -675,6 +735,12 @@ macro_rules! bin_to_super_type { )* bail!("{} does not support {:?} (eval unicast)", self.name(), a.datum_type()); } + $(fn is_commutative(&self) -> bool { + $is_commutative + })? + $(fn neutral_element(&self) -> Option { + Some($neutral_element) + })? fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> { // c and a are same type $(if $eval_in_a(a, b)? { return Ok(()) } )? diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index 6856767040..ffbaed098d 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -23,8 +23,8 @@ mod complex; pub use complex::{ComplexToInnerDim, InnerDimToComplex}; bin_to_super_type!(add, Add, - declutter: declutter_add, linalg: Add, + neutral_element: 0, validation: Validation::Rounding, q: [i8, u8, i32, i32] => add_quant; q_op_on_f32: |a: f32, b: f32| -> f32 {a+b}, @@ -39,8 +39,9 @@ where } bin_to_super_type!(sub, Sub, - declutter: declutter_sub, linalg:Sub, + is_commutative: false, + neutral_element: 0, q: [i8, u8, i32, i32] => sub_quant; q_op_on_f32: |a: f32, b: f32| -> f32 {a-b}, [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() - b); @@ -110,6 +111,7 @@ bin_to_super_type!(mul, Mul, Ok(false) } }, + neutral_element: 1, out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult { if c.datum_type() == TDim::datum_type() && a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() { @@ -338,6 +340,8 @@ eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult { Div.generic_eval(a, b, c_dt) } }, +is_commutative: false, +neutral_element: 1, out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult { if c.datum_type() == TDim::datum_type() && a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() { @@ -452,61 +456,19 @@ bin_to_super_type!(max, Max, bin_to_super_type!(pow, Pow, declutter: declutter_pow, + is_commutative: false, + neutral_element: 1, q_op_on_f32: |a: f32, b: f32| -> f32 {a.powf(b)}, [f16, f32, f64] => |c,a,b| *c = a.powf(*b), [i32, i64] => |c,a,b| *c = a.pow(*b as u32)); bin_to_super_type!(shift_left, ShiftLeft, + is_commutative: false, [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a << *b); bin_to_super_type!(shift_right, ShiftRight, + is_commutative: false, [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a >> *b); -fn declutter_neutral( - model: &TypedModel, - node: &TypedNode, - value: i64, - also_left: bool, -) -> TractResult> { - if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? { - // casting to i64 uni quantized type need to be avoided - if uniform.uni.datum_type().is_quantized() { - return Ok(None); - } - let Ok(integer) = uniform.uni.cast_to_scalar::() else { return Ok(None) }; - if tensor0(integer) - .cast_to_dt(uniform.uni.datum_type())? - .close_enough(&uniform.uni, false) - .is_ok() - && integer == value - && (also_left || !uniform.left_is_uniform) - { - return Ok(Some(TypedModelPatch::rewire( - model, - &[uniform.var], - &[node.id.into()], - &|_, inputs| Ok(inputs.into()), - )?)); - } - } - Ok(None) -} - -fn declutter_add( - _op: &Add, - model: &TypedModel, - node: &TypedNode, -) -> TractResult> { - declutter_neutral(model, node, 0, true) -} - -fn declutter_sub( - _op: &Sub, - model: &TypedModel, - node: &TypedNode, -) -> TractResult> { - declutter_neutral(model, node, 0, false) -} - fn declutter_mul( _op: &Mul, model: &TypedModel, @@ -520,9 +482,7 @@ fn declutter_mul( square(), )?)); } - if let Some(p) = declutter_neutral(model, node, 1, true).context("decluttering neutral")? { - return Ok(Some(p)); - } + if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? { let var_fact = model.outlet_fact(uniform.var)?; if uniform.uni.cast_to_scalar::()? == 0.0 { @@ -597,9 +557,6 @@ fn declutter_div( model: &TypedModel, node: &TypedNode, ) -> TractResult> { - if let Some(p) = declutter_neutral(model, node, 1, false)? { - return Ok(Some(p)); - } if let &[p, q] = &*model.node_input_facts(node.id)? { let dt = q.datum_type; if let Some(q) = &q.uniform { @@ -648,9 +605,6 @@ fn declutter_pow( model: &TypedModel, node: &TypedNode, ) -> TractResult> { - if let Some(p) = declutter_neutral(model, node, 1, false)? { - return Ok(Some(p)); - } let b = model.outlet_fact(node.inputs[1])?; if let Some(b) = &b.uniform { let b = b.cast_to_scalar::()?; From b8d4c1aa5483c3be32561340fb09759e4fd99fa2 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Tue, 16 Jul 2024 17:45:36 +0200 Subject: [PATCH 07/32] Fix clippy --- core/src/ops/binary.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 484fbfd4f4..9976a31ccf 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -215,8 +215,14 @@ impl TypedOp for TypedBinOp { model: &TypedModel, node: &TypedNode, ) -> TractResult> { - let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {(a.datum_type().unwrap(), b.datum_type().unwrap())} else {unreachable!("")}; - if let Some(neutral_patch) = declutter_neutral(model, node, &self.0, self.output_datum_type(a_dt, b_dt)?)? { + let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? { + (a.datum_type().unwrap(), b.datum_type().unwrap()) + } else { + unreachable!("TypedBinOp has two inputs.") + }; + if let Some(neutral_patch) = + declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)? + { return Ok(Some(neutral_patch)); } @@ -276,7 +282,7 @@ impl TypedOp for TypedBinOp { fn declutter_neutral( model: &TypedModel, node: &TypedNode, - mini_op: &Box, + mini_op: &dyn BinMiniOp, out_dt: DatumType, ) -> TractResult> { if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? { From f7eaefbd0fb969e4fe7251f6aa87aacb72c0b77d Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Tue, 16 Jul 2024 18:54:33 +0200 Subject: [PATCH 08/32] Dirty plug in linalg --- core/src/ops/binary.rs | 2 +- core/src/ops/math/mod.rs | 8 +++++ linalg/src/lib.rs | 68 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 9976a31ccf..c63c0c24c1 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -710,7 +710,7 @@ macro_rules! bin_to_super_type { } fn eval_by_scalar(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> { - $(if $eval_by_scalar(a, b)? { return Ok(()) } )? + $(if $eval_by_scalar(a, b)? { return Ok(())} )? $( $(if b.datum_type() == $typ::datum_type() { let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab; diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index ffbaed098d..8dcc657ab6 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -111,6 +111,14 @@ bin_to_super_type!(mul, Mul, Ok(false) } }, + eval_by_scalar: |a: &mut TensorView, b: &TensorView | -> TractResult { + let res = tract_linalg::bin_by_scalar(tract_linalg::BinOp::Mul)(a, b).is_ok(); + Ok(res) + }, + eval_unicast: |a: &mut TensorView, b: &TensorView | -> TractResult { + let res = tract_linalg::bin_unicast(tract_linalg::BinOp::Mul)(a, b).is_ok(); + Ok(res) + }, neutral_element: 1, out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult { if c.datum_type() == TDim::datum_type() && diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 46a0e3e74d..47c44e6982 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -26,6 +26,7 @@ use frame::reduce::{MapReduceKer, ReduceKer}; use frame::unicast::UnicastKer; use frame::{reduce, unicast, MatMatMul}; pub use generic::{ScaleShiftAndRound, Scaler}; +use tract_data::internal::TensorView; #[cfg(target_arch = "x86_64")] pub mod x86_64_fma; @@ -175,6 +176,73 @@ lazy_static::lazy_static! { }; } +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum BinOp { + Min, + Max, + Add, + Mul, + Sub, + SubF, +} + +impl BinOp { + pub fn flip(&self) -> BinOp { + use BinOp::*; + match self { + Sub => SubF, + SubF => Sub, + sym => *sym, + } + } +} + +pub fn bin_by_scalar(bin: BinOp) -> Box TractResult<()>> { + match bin { + BinOp::Mul => { + return Box::new(|a: &mut TensorView, b: &TensorView| -> TractResult<()> { + match b.datum_type() { + DatumType::F32 =>{ + let a_slice = a.as_slice_mut()?; + let b_slice = b.as_slice()?[0]; + (ops().mul_by_scalar_f32)().run_with_params(a_slice, b_slice) + }, + DatumType::F16 => { + let a_slice = a.as_slice_mut()?; + let b_slice = b.as_slice()?[0]; + (ops().mul_by_scalar_f16)().run_with_params(a_slice, b_slice) + }, + _ => unimplemented!(""), + } + }) + }, + _ => unimplemented!() + } +} + +pub fn bin_unicast(bin: BinOp) -> Box TractResult<()>> { + match bin { + BinOp::Mul => { + return Box::new(|a: &mut TensorView, b: &TensorView| -> TractResult<()> { + match b.datum_type() { + DatumType::F32 => { + let a_slice = a.as_slice_mut()?; + let b_slice = b.as_slice()?; + (ops().unicast_mul_f32)().run(a_slice, b_slice) + }, + DatumType::F16 => { + let a_slice = a.as_slice_mut()?; + let b_slice = b.as_slice()?; + (ops().unicast_mul_f32)().run(a_slice, b_slice) + }, + _ => unimplemented!(""), + } + }) + }, + _ => unimplemented!() + } +} + pub fn ops() -> &'static Ops { &OPS } From 9c7f3e1e452746e1a6d2edb73e25f0bf37ffda21 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 18 Jul 2024 17:14:32 +0200 Subject: [PATCH 09/32] Create by_scalar & unicast registries in linalg --- core/src/ops/math/mod.rs | 8 ++- linalg/src/arm64.rs | 13 +++- linalg/src/arm64/arm64fp16/by_scalar.rs | 2 +- linalg/src/arm64/arm64simd/by_scalar.rs | 2 +- linalg/src/frame/by_scalar.rs | 58 +++++++++++++++++ linalg/src/frame/unicast/by_scalar.rs | 70 --------------------- linalg/src/frame/unicast/mod.rs | 15 +++++ linalg/src/frame/unicast/unicast.rs | 0 linalg/src/generic.rs | 14 +++++ linalg/src/generic/by_scalar.rs | 50 +++++---------- linalg/src/lib.rs | 84 ++++++++++++------------- 11 files changed, 163 insertions(+), 153 deletions(-) delete mode 100644 linalg/src/frame/unicast/by_scalar.rs delete mode 100644 linalg/src/frame/unicast/unicast.rs diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index 8dcc657ab6..6ab6cc37ef 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -112,11 +112,15 @@ bin_to_super_type!(mul, Mul, } }, eval_by_scalar: |a: &mut TensorView, b: &TensorView | -> TractResult { - let res = tract_linalg::bin_by_scalar(tract_linalg::BinOp::Mul)(a, b).is_ok(); + let res = tract_linalg::bin_by_scalar(a.datum_type(), tract_linalg::BinOp::Mul) + .context("unimplemented mul by scalar")?(a, b) + .is_ok(); Ok(res) }, eval_unicast: |a: &mut TensorView, b: &TensorView | -> TractResult { - let res = tract_linalg::bin_unicast(tract_linalg::BinOp::Mul)(a, b).is_ok(); + let res = tract_linalg::bin_unicast(a.datum_type(), tract_linalg::BinOp::Mul) + .context("unimplemented mul unicast")?(a, b) + .is_ok(); Ok(res) }, neutral_element: 1, diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index a0fe99445c..4396d518a5 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -14,8 +14,9 @@ mod arm64fp16; pub use arm64fp16::*; use crate::f16; -use crate::Ops; +use crate::{Ops, LinalgRegistry, DatumType, BinOp}; +use crate::frame::by_scalar::ByScalarKer; use crate::frame::element_wise::ElementWiseKer; use crate::frame::reduce::{MapReduceKer, ReduceKer}; use crate::frame::unicast::UnicastKer; @@ -213,6 +214,16 @@ impl Kind { } } +pub(crate)fn register_all_unicast(registry: &mut LinalgRegistry) { + registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_unicast_mul_f32_16n::bin_1())); + registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_unicast_mul_f16_32n::bin_1())); +} + +pub(crate)fn register_all_by_scalar(registry: &mut LinalgRegistry) { + registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_mul_by_scalar_f32_16n::bin_1())); + registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_mul_by_scalar_f16_32n::bin_1())); +} + pub fn plug(ops: &mut Ops) { ops.mmm_impls.extend([ arm64simd_mmm_f32_12x8_gen.mmm(), diff --git a/linalg/src/arm64/arm64fp16/by_scalar.rs b/linalg/src/arm64/arm64fp16/by_scalar.rs index 8694e01215..f9c78ac2eb 100644 --- a/linalg/src/arm64/arm64fp16/by_scalar.rs +++ b/linalg/src/arm64/arm64fp16/by_scalar.rs @@ -1,6 +1,6 @@ use crate::f16; -ew_impl_wrap!( +by_scalar_impl_wrap!( f16, arm64fp16_mul_by_scalar_f16_32n, 32, diff --git a/linalg/src/arm64/arm64simd/by_scalar.rs b/linalg/src/arm64/arm64simd/by_scalar.rs index 1bd91491cc..ac6f0a06fe 100644 --- a/linalg/src/arm64/arm64simd/by_scalar.rs +++ b/linalg/src/arm64/arm64simd/by_scalar.rs @@ -1,4 +1,4 @@ -ew_impl_wrap!( +by_scalar_impl_wrap!( f32, arm64simd_mul_by_scalar_f32_16n, 16, diff --git a/linalg/src/frame/by_scalar.rs b/linalg/src/frame/by_scalar.rs index 528c3c3585..cf6c948451 100644 --- a/linalg/src/frame/by_scalar.rs +++ b/linalg/src/frame/by_scalar.rs @@ -1,3 +1,61 @@ +use std::{fmt::Debug, marker::PhantomData}; + +use tract_data::{TractResult, internal::TensorView}; + +use crate::{LADatum, element_wise::ElementWiseKer}; + +use super::{ElementWise, element_wise_helper::map_slice_with_alignment}; + + +/// Generic implementation struct that unify all by scalar kernels. +/// A by scalar operation is an ElementWise operation with a scalar paramerer. +#[derive(Debug, Clone, new)] +pub struct ByScalarImpl +where + T: LADatum, + K: ByScalarKer + Clone, +{ + phantom: PhantomData<(K, T)>, +} + +impl ElementWise for ByScalarImpl +where + T: LADatum, + K: ByScalarKer + Clone, +{ + fn name(&self) -> &'static str { + K::name() + } + fn run_with_params(&self, vec: &mut [T], params: T) -> TractResult<()> { + map_slice_with_alignment(vec, |data| K::run(data, params), K::nr(), K::alignment_bytes()) + } +} + + +pub trait ByScalarKer: ElementWiseKer +where + T: LADatum +{ + fn bin_1() -> Box TractResult<()>> { + Box::new(|a: &mut TensorView, b: &TensorView| { + let a_slice = a.as_slice_mut()?; + let b = b.as_slice()?[0]; + (Self::ew()).run_with_params(a_slice, b) + }) + } +} + +macro_rules! by_scalar_impl_wrap { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $run: item) => { + paste! { + ew_impl_wrap!($ti, $func, $nr, $alignment_items, $ti, $run); + + impl crate::frame::by_scalar::ByScalarKer<$ti> for $func {} + } + }; +} + + #[cfg(test)] #[macro_use] pub mod test { diff --git a/linalg/src/frame/unicast/by_scalar.rs b/linalg/src/frame/unicast/by_scalar.rs deleted file mode 100644 index 710e8f6a10..0000000000 --- a/linalg/src/frame/unicast/by_scalar.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::marker::PhantomData; - -use tract_data::TractResult; - -use crate::{element_wise::ElementWiseKer, LADatum, frame::{ElementWise, element_wise_helper::map_slice_with_alignment}}; - -/// Generic implementation struct that unify all by scalar kernels. -/// A by scalar operation is an ElementWise operation with a scalar paramerer. -#[derive(Debug, Clone, new)] -pub struct ByScalarImpl -where - T: LADatum, - K: ByScalarKer + Clone, -{ - phantom: PhantomData<(K, T)>, -} - -impl ElementWise for ByScalarImpl -where - T: LADatum, - K: ByScalarKer + Clone, -{ - fn name(&self) -> &'static str { - K::name() - } - fn run_with_params(&self, vec: &mut [T], params: T) -> TractResult<()> { - map_slice_with_alignment(vec, |data| K::run(data, params), K::nr(), K::alignment_bytes()) - } -} - -pub trait ByScalarKer: ElementWiseKer -where - T: LADatum, -{ - fn name() -> &'static str; - fn alignment_bytes() -> usize { - ElementWiseKer::::alignment_bytes() - } - fn alignment_items() -> usize; - fn nr() -> usize; - fn run(vec: &mut [T], scalar: T) { - ElementWiseKer::::run(vec, scalar) - } - fn by_scalar() -> Box> { - Box::new(ByScalarImpl::::new()) - } -} - -#[cfg(test)] -#[macro_use] -pub mod test { - use crate::frame::element_wise::ElementWiseKer; - use crate::LADatum; - use num_traits::{AsPrimitive, Float}; - use proptest::test_runner::TestCaseResult; - - #[macro_export] - macro_rules! by_scalar_frame_tests { - ($cond:expr, $t: ty, $ker:ty, $reference: expr) => { - proptest::proptest! { - #[test] - fn prop(xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) { - if $cond { - $crate::frame::element_wise::test::test_element_wise::<$ker, $t>(&*xs, $reference, scalar).unwrap() - } - } - } - }; - } -} diff --git a/linalg/src/frame/unicast/mod.rs b/linalg/src/frame/unicast/mod.rs index bb58ef561c..e6b1e0cea9 100644 --- a/linalg/src/frame/unicast/mod.rs +++ b/linalg/src/frame/unicast/mod.rs @@ -4,6 +4,7 @@ use std::fmt::Debug; use std::marker::PhantomData; use tract_data::TractResult; +use tract_data::internal::TensorView; use crate::frame::element_wise_helper::TempBuffer; use crate::LADatum; @@ -53,6 +54,13 @@ where phantom: PhantomData<(K, T)>, } + +impl UnicastImpl +where + T: LADatum, + K: UnicastKer + Clone, +{ +} impl Unicast for UnicastImpl where T: LADatum, @@ -80,6 +88,13 @@ where fn bin() -> Box> { Box::new(UnicastImpl::::new()) } + fn bin_1() -> Box TractResult<()>> { + Box::new(|a: &mut TensorView, b: &TensorView| { + let a_slice = a.as_slice_mut()?; + let b_slice = b.as_slice()?; + (Self::bin()).run(a_slice, b_slice) + }) + } } std::thread_local! { diff --git a/linalg/src/frame/unicast/unicast.rs b/linalg/src/frame/unicast/unicast.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index 4764ee4001..65f474cf14 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -9,6 +9,10 @@ pub mod sigmoid; pub mod tanh; pub mod unicast; +use tract_data::prelude::DatumType; + +use crate::{LinalgRegistry, BinOp, UnicastKer, ByScalarKer}; + pub use self::by_scalar::{HMulByScalar8, SMulByScalar4}; pub use self::erf::SErf4; pub use self::leaky_relu::{HLeakyRelu8, SLeakyRelu4}; @@ -17,3 +21,13 @@ pub use self::rounding::{ScaleShiftAndRound, Scaler}; pub use self::sigmoid::{HSigmoid8, SSigmoid4}; pub use self::reduce::softmax_l2::SSoftMaxL2; pub use self::tanh::{HTanh8, STanh4}; + +pub(crate)fn register_all_unicast(registry: &mut LinalgRegistry) { + registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| unicast::SUnicastMul4::bin_1())); + registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| unicast::HUnicastMul8::bin_1())); +} + +pub(crate)fn register_all_by_scalar(registry: &mut LinalgRegistry) { + registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| by_scalar::SMulByScalar4::bin_1())); + registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| by_scalar::HMulByScalar8::bin_1())); +} diff --git a/linalg/src/generic/by_scalar.rs b/linalg/src/generic/by_scalar.rs index 646c1b65ee..ef67b31dc5 100644 --- a/linalg/src/generic/by_scalar.rs +++ b/linalg/src/generic/by_scalar.rs @@ -1,29 +1,17 @@ use tract_data::internal::f16; -use crate::element_wise::ElementWiseKer; - -#[derive(Clone, Debug)] -pub struct SMulByScalar4; - -impl ElementWiseKer for SMulByScalar4 { - fn name() -> &'static str { - "generic" - } - - fn alignment_items() -> usize { - 4 - } - - fn nr() -> usize { - 4 - } - +by_scalar_impl_wrap!( + f32, + SMulByScalar4, + 4, + 4, + f32, fn run(x: &mut [f32], s: f32) { debug_assert!(x.len() % Self::nr() == 0); debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); x.iter_mut().for_each(|px| *px *= s) } -} +); #[cfg(test)] #[macro_use] @@ -31,28 +19,18 @@ pub mod mul_by_scalar_f32 { mul_by_scalar_frame_tests!(true, f32, crate::generic::by_scalar::SMulByScalar4); } -#[derive(Clone, Debug)] -pub struct HMulByScalar8; - -impl ElementWiseKer for HMulByScalar8 { - fn name() -> &'static str { - "generic" - } - - fn alignment_items() -> usize { - 8 - } - - fn nr() -> usize { - 8 - } - +by_scalar_impl_wrap!( + f16, + HMulByScalar8, + 8, + 8, + f16, fn run(x: &mut [f16], s: f16) { debug_assert!(x.len() % Self::nr() == 0); debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); x.iter_mut().for_each(|px| *px *= s) } -} +); #[cfg(test)] #[macro_use] diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 47c44e6982..1c66eb0228 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -21,11 +21,13 @@ include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs")); pub mod frame; pub mod generic; pub mod multithread; +use frame::by_scalar::ByScalarKer; use frame::element_wise::ElementWiseKer; use frame::reduce::{MapReduceKer, ReduceKer}; use frame::unicast::UnicastKer; use frame::{reduce, unicast, MatMatMul}; pub use generic::{ScaleShiftAndRound, Scaler}; +use lazy_static::lazy_static; use tract_data::internal::TensorView; #[cfg(target_arch = "x86_64")] pub mod x86_64_fma; @@ -197,59 +199,57 @@ impl BinOp { } } -pub fn bin_by_scalar(bin: BinOp) -> Box TractResult<()>> { - match bin { - BinOp::Mul => { - return Box::new(|a: &mut TensorView, b: &TensorView| -> TractResult<()> { - match b.datum_type() { - DatumType::F32 =>{ - let a_slice = a.as_slice_mut()?; - let b_slice = b.as_slice()?[0]; - (ops().mul_by_scalar_f32)().run_with_params(a_slice, b_slice) - }, - DatumType::F16 => { - let a_slice = a.as_slice_mut()?; - let b_slice = b.as_slice()?[0]; - (ops().mul_by_scalar_f16)().run_with_params(a_slice, b_slice) - }, - _ => unimplemented!(""), - } - }) - }, - _ => unimplemented!() - } + +fn register_all_unicast(registry: &mut LinalgRegistry) { + generic::register_all_unicast(registry); + #[cfg(target_arch = "aarch64")] + arm64::register_all_unicast(registry); + } -pub fn bin_unicast(bin: BinOp) -> Box TractResult<()>> { - match bin { - BinOp::Mul => { - return Box::new(|a: &mut TensorView, b: &TensorView| -> TractResult<()> { - match b.datum_type() { - DatumType::F32 => { - let a_slice = a.as_slice_mut()?; - let b_slice = b.as_slice()?; - (ops().unicast_mul_f32)().run(a_slice, b_slice) - }, - DatumType::F16 => { - let a_slice = a.as_slice_mut()?; - let b_slice = b.as_slice()?; - (ops().unicast_mul_f32)().run(a_slice, b_slice) - }, - _ => unimplemented!(""), - } - }) - }, - _ => unimplemented!() - } +fn register_all_by_scalar(registry: &mut LinalgRegistry) { + generic::register_all_by_scalar(registry); + #[cfg(target_arch = "aarch64")] + arm64::register_all_by_scalar(registry); + } + +type LinalgFn = Box TractResult<()>>; +type LinalgRegistry = HashMap<(BinOp, DatumType), Box LinalgFn + Send + Sync>>; +lazy_static! { + static ref BIN_UNICAST_OPS: Mutex = { + let mut registry = HashMap::default(); + register_all_unicast(&mut registry); + Mutex::new(registry) + }; + static ref BIN_BY_SCALAR_OPS: Mutex = { + let mut registry = HashMap::default(); + register_all_by_scalar(&mut registry); + Mutex::new(registry) + }; +} + +pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option { + let map = BIN_BY_SCALAR_OPS.lock().unwrap(); + map.get(&(bin, dt)).map(|it| (it)()) +} + +pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option { + let map = BIN_UNICAST_OPS.lock().unwrap(); + map.get(&(bin, dt)).map(|it| (it)()) +} + + pub fn ops() -> &'static Ops { &OPS } use num_traits::*; +use std::collections::HashMap; use std::fmt::Debug; use std::ops::*; +use std::sync::Mutex; pub trait LADatum: Sized From 5eff92672b89c23462de4f98a4b545e70da067c6 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 18 Jul 2024 17:53:04 +0200 Subject: [PATCH 10/32] Fix import --- linalg/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 1c66eb0228..df178e1404 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -24,7 +24,6 @@ pub mod multithread; use frame::by_scalar::ByScalarKer; use frame::element_wise::ElementWiseKer; use frame::reduce::{MapReduceKer, ReduceKer}; -use frame::unicast::UnicastKer; use frame::{reduce, unicast, MatMatMul}; pub use generic::{ScaleShiftAndRound, Scaler}; use lazy_static::lazy_static; From a4a566fcc1c18fb1a084636a268ae2234d58cac9 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Fri, 19 Jul 2024 09:27:57 +0200 Subject: [PATCH 11/32] BinOpX are slower .. --- core/src/ops/binary.rs | 11 +++++++++- core/src/ops/math/mod.rs | 46 ++++++++++++++-------------------------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index c63c0c24c1..b43bf270df 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -1,8 +1,8 @@ use crate::internal::*; use downcast_rs::Downcast; -use tract_itertools::Itertools; use std::fmt; use tract_data::itertools::izip; +use tract_itertools::Itertools; use super::cast::cast; @@ -399,6 +399,11 @@ impl Op for BinOpByScalar { format!("{}ByScalar", self.0.name()).into() } + fn same_as(&self, other: &dyn Op) -> bool { + let Some(other) = other.downcast_ref::() else { return false }; + self.0.same_as(&*other.0) + } + op_as_typed_op!(); } @@ -501,6 +506,10 @@ impl Op for BinOpUnicast { format!("{}Unicast", self.0.name()).into() } + fn same_as(&self, other: &dyn Op) -> bool { + let Some(other) = other.downcast_ref::() else { return false }; + self.0.same_as(&*other.0) + } op_as_typed_op!(); } diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index 6ab6cc37ef..f7786f1b56 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -82,45 +82,31 @@ bin_to_super_type!(mul, Mul, }, linalg: Mul, uniform_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult { - if b.datum_type() == f32::datum_type() { - let a = a.to_scalar::()?; - let slice = b.as_slice_mut::()?; - (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, *a)?; - Ok(true) - } else if b.datum_type() == f16::datum_type() { - let a = a.to_scalar::()?; - let slice = b.as_slice_mut::()?; - (tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, *a)?; - Ok(true) - } else { - Ok(false) - } + let mut slice = b.view_mut(); + let scalar = a.view(); + let res = tract_linalg::bin_by_scalar(a.datum_type(), tract_linalg::BinOp::Mul) + .and_then(move |func| (func)(&mut slice, &scalar).ok()) + .is_some(); + Ok(res) }, unicast_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult { - if b.datum_type() == f32::datum_type() { - let a = a.as_slice::()?; - let slice = b.as_slice_mut::()?; - (tract_linalg::ops().unicast_mul_f32)().run(slice, a)?; - Ok(true) - } else if b.datum_type() == f16::datum_type() { - let a = a.as_slice::()?; - let slice = b.as_slice_mut::()?; - (tract_linalg::ops().unicast_mul_f16)().run(slice, a)?; - Ok(true) - } else { - Ok(false) - } + let mut slice = b.view_mut(); + let other = a.view(); + let res = tract_linalg::bin_unicast(a.datum_type(), tract_linalg::BinOp::Mul) + .and_then(move |func| (func)(&mut slice, &other).ok()) + .is_some(); + Ok(res) }, eval_by_scalar: |a: &mut TensorView, b: &TensorView | -> TractResult { let res = tract_linalg::bin_by_scalar(a.datum_type(), tract_linalg::BinOp::Mul) - .context("unimplemented mul by scalar")?(a, b) - .is_ok(); + .and_then(move |func| (func)(a, b).ok()) + .is_some(); Ok(res) }, eval_unicast: |a: &mut TensorView, b: &TensorView | -> TractResult { let res = tract_linalg::bin_unicast(a.datum_type(), tract_linalg::BinOp::Mul) - .context("unimplemented mul unicast")?(a, b) - .is_ok(); + .and_then(move |func| (func)(a, b).ok()) + .is_some(); Ok(res) }, neutral_element: 1, From 5fac248ee960c210993c0b9470a33b453cd5a03a Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Fri, 19 Jul 2024 17:54:45 +0200 Subject: [PATCH 12/32] Replace collapse_axis with prefix_with --- core/src/ops/binary.rs | 40 +++++++++++++--------------- data/src/tensor/view.rs | 45 +++++++++---------------------- linalg/src/frame/unicast/mod.rs | 47 ++++++++++++++++++++++++++------- 3 files changed, 68 insertions(+), 64 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index b43bf270df..11f132f610 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -414,10 +414,10 @@ impl EvalOp for BinOpByScalar { fn eval(&self, inputs: TVec) -> TractResult> { let (a, b) = args_2!(inputs); - let mut a = a.into_tensor(); + // Not a requirement as TensorView doesn't require a owned tensor but in reality + // "a "should be mutable (it's omitted here as Rust compiler advise to remove it) + let a = a.into_tensor(); let b_shape = b.shape(); - let mut view = a.view_mut(); - let b_view = b.view(); let first_unary_axis = b_shape .iter() @@ -428,21 +428,18 @@ impl EvalOp for BinOpByScalar { .last() .context("Cannot use by_scalar when no trailing dimensions are unary")?; - let iterating_shape = view.shape()[..first_unary_axis].to_vec(); + let iterating_shape = a.shape()[..first_unary_axis].to_vec(); if !iterating_shape.is_empty() { for it_coords in tract_data::internal::iter_indices(&iterating_shape) { - let mut view = view.clone(); - let mut tmp_b_view = b_view.clone(); - - // Prepare array view to perform computation - for (axis, idx) in it_coords.iter().enumerate() { - view.collapse_axis(axis, *idx as isize); - tmp_b_view.collapse_axis(axis, *idx as isize); - } - - self.0.eval_by_scalar(&mut view, &tmp_b_view)?; + let mut view = TensorView::at_prefix(&a, &it_coords)?; + let b_view = TensorView::at_prefix(&b, &it_coords)?; + debug_assert_eq!(b_view.shape().iter().product::(), 1); + self.0.eval_by_scalar(&mut view, &b_view)?; } } else { + let mut view = a.view(); + let b_view = b.view(); + debug_assert_eq!(b_view.shape().iter().product::(), 1); self.0.eval_by_scalar(&mut view, &b_view)?; } Ok(tvec!(a.into_tvalue())) @@ -520,25 +517,24 @@ impl EvalOp for BinOpUnicast { fn eval(&self, inputs: TVec) -> TractResult> { let (a, b) = args_2!(inputs); - let mut a = a.into_tensor(); + // Not a requirement as TensorView doesn't require a owned tensor but in reality + // "a "should be mutable (it's omitted here as Rust compiler advise to remove it) + let a = a.into_tensor(); let b_shape = b.shape(); - let mut view = a.view_mut(); let b_view = b.view(); - let first_non_unary_axis = b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last(); if let Some(first_non_unary_axis) = first_non_unary_axis { // Iterate on outter dimensions and evaluate with unicast subviews - let iterating_shape = view.shape()[..first_non_unary_axis].to_vec(); + let iterating_shape = a.shape()[..first_non_unary_axis].to_vec(); for it_coords in tract_data::internal::iter_indices(&iterating_shape) { - let mut view = view.clone(); - it_coords.iter().enumerate().for_each(|(axis, idx)| { - view.collapse_axis(axis, *idx as isize); - }); + let mut view = TensorView::at_prefix(&a, &it_coords)?; + debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.len()..]); self.0.eval_unicast(&mut view, &b_view)?; } } else { + let mut view = a.view(); debug_assert_eq!(view.shape(), b_view.shape()); self.0.eval_unicast(&mut view, &b_view)?; } diff --git a/data/src/tensor/view.rs b/data/src/tensor/view.rs index f6fa718e0d..2076732406 100644 --- a/data/src/tensor/view.rs +++ b/data/src/tensor/view.rs @@ -2,16 +2,16 @@ use super::*; use crate::internal::*; #[derive(Clone, Debug)] -enum Indexing { +enum Indexing<'a> { Prefix(usize), - Custom { shape: Vec, strides: Vec }, + Custom { shape: &'a [usize], strides: &'a [isize] }, } #[derive(Clone, Debug)] pub struct TensorView<'a> { pub tensor: &'a Tensor, offset_bytes: isize, - indexing: Indexing, + indexing: Indexing<'a>, } impl<'a> TensorView<'a> { @@ -24,7 +24,7 @@ impl<'a> TensorView<'a> { TensorView { tensor, offset_bytes, - indexing: Indexing::Custom { shape: shape.to_vec(), strides: strides.to_vec() }, + indexing: Indexing::Custom { shape, strides }, } } @@ -46,8 +46,8 @@ impl<'a> TensorView<'a> { tensor, offset_bytes, indexing: Indexing::Custom { - shape: tensor.shape.to_vec(), - strides: tensor.strides.to_vec(), + shape: &tensor.shape, + strides: &tensor.strides, }, } } @@ -236,29 +236,6 @@ impl<'a> TensorView<'a> { unsafe { Ok(self.at_unchecked(coords)) } } - #[inline] - pub fn collapse_axis(&mut self, axis: usize, index: isize) { - let stride = self.strides()[axis] * self.datum_type().size_of() as isize; - unsafe { self.offset_bytes(stride * index) }; - match &mut self.indexing { - Indexing::Prefix(x) => { - if *x == 0 { - let mut new_shape = self.tensor.shape().to_owned(); - new_shape[axis] = 1; - self.indexing = Indexing::Custom { - shape: new_shape, - strides: self.tensor.strides().to_owned(), - } - } else { - unimplemented!("TODO: understand how it is used") - } - } - Indexing::Custom { shape, .. } => { - shape[axis] = 1; - } - } - } - #[inline] pub fn at_mut(&mut self, coords: impl AsRef<[usize]>) -> TractResult<&mut T> { self.check_dt::()?; @@ -288,13 +265,15 @@ impl<'a> TensorView<'a> { #[cfg(test)] mod test { use crate::prelude::Tensor; + use super::TensorView; #[test] - fn test_collapse_axis() { + fn test_at_prefix() { let a = Tensor::from_shape(&[2, 2], &[1, 2, 3, 4]).unwrap(); - let mut a_view = a.view(); - a_view.collapse_axis(0, 1); - assert_eq!(a_view.shape(), &[1, 2]); + let a_view = TensorView::at_prefix(&a, &[1]).unwrap(); + assert_eq!(a_view.shape(), &[2]); assert_eq!(a_view.as_slice::().unwrap(), &[3, 4]); + + } } diff --git a/linalg/src/frame/unicast/mod.rs b/linalg/src/frame/unicast/mod.rs index e6b1e0cea9..05689d7e37 100644 --- a/linalg/src/frame/unicast/mod.rs +++ b/linalg/src/frame/unicast/mod.rs @@ -101,6 +101,20 @@ std::thread_local! { static TMP: std::cell::RefCell<(TempBuffer, TempBuffer)> = std::cell::RefCell::new((TempBuffer::default(), TempBuffer::default())); } +fn create_incomplete_tile<'a, T: LADatum>(a: &'a mut [T], b: &'a [T], a_prefix_len: usize, b_prefix_len: usize) -> (&'a mut [T], &'a [T], usize) { + let effective_prefix = if (a_prefix_len == 0) || (b_prefix_len == 0) { + // One of the two slice is aligned, the target size is the number of unaligned elements of + // the other slice, the max value between the two. + a_prefix_len.max(b_prefix_len) + } else { + // Both are unaligned, the minimal common subset is the one including elements from a and b + // so it's the min value between the two. + a_prefix_len.min(b_prefix_len) + }; + (&mut a[..effective_prefix], &b[..effective_prefix], effective_prefix) +} + + pub(crate) fn unicast_with_alignment( a: &mut [T], b: &[T], @@ -127,18 +141,33 @@ where f(tmp_a, tmp_b); a.copy_from_slice(&tmp_a[..a.len()]) }; - let prefix_len = a.as_ptr().align_offset(alignment_bytes).min(a.len()); - if prefix_len > 0 { - compute_via_temp_buffer(&mut a[..prefix_len], &b[..prefix_len]); + + let mut num_element_processed = 0; + let a_prefix_len = a.as_ptr().align_offset(alignment_bytes).min(a.len()); + let b_prefix_len = b.as_ptr().align_offset(alignment_bytes).min(b.len()); + let mut applied_prefix_len = 0; + if (a_prefix_len > 0) || (b_prefix_len > 0) { + // Incomplete tile needs to be created to process unaligned data. + let (mut sub_a, sub_b, applied_prefix) = create_incomplete_tile(a, b, a_prefix_len, b_prefix_len); + applied_prefix_len = applied_prefix; + compute_via_temp_buffer(&mut sub_a, &sub_b); + num_element_processed += applied_prefix_len; } - let aligned_len = (a.len() - prefix_len) / nr * nr; - if aligned_len > 0 { - f(&mut a[prefix_len..][..aligned_len], &b[prefix_len..][..aligned_len]); + + let num_complete_tiles = (a.len() - applied_prefix_len) / nr; + if num_complete_tiles > 0 { + // Process all tiles that are complete. + let mut sub_a = &mut a[applied_prefix_len..][..(num_complete_tiles * nr)]; + let sub_b = &b[applied_prefix_len..][..(num_complete_tiles * nr)]; + f(&mut sub_a, &sub_b); + num_element_processed += num_complete_tiles * nr; } - if prefix_len + aligned_len < a.len() { + + if num_element_processed < a.len() { + // Incomplete tile needs to be created to process remaining elements. compute_via_temp_buffer( - &mut a[prefix_len + aligned_len..], - &b[prefix_len + aligned_len..], + &mut a[num_element_processed..], + &b[num_element_processed..], ); } }) From 9955a1256cb133189d6466f4e26ead64b5938ee4 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Fri, 19 Jul 2024 19:29:04 +0200 Subject: [PATCH 13/32] Introduce LirMul with predefined linalg method --- core/src/ops/binary.rs | 82 ++++++++++++++++++++++++++++++--- linalg/src/frame/by_scalar.rs | 4 +- linalg/src/frame/unicast/mod.rs | 6 +-- linalg/src/lib.rs | 2 +- 4 files changed, 82 insertions(+), 12 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 11f132f610..42c22d30ed 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -1,10 +1,10 @@ use crate::internal::*; use downcast_rs::Downcast; -use std::fmt; +use std::fmt::{self, Debug}; use tract_data::itertools::izip; use tract_itertools::Itertools; -use super::cast::cast; +use super::{cast::cast, math::Mul}; pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast { fn name(&self) -> &'static str; @@ -416,7 +416,7 @@ impl EvalOp for BinOpByScalar { let (a, b) = args_2!(inputs); // Not a requirement as TensorView doesn't require a owned tensor but in reality // "a "should be mutable (it's omitted here as Rust compiler advise to remove it) - let a = a.into_tensor(); + let a = a.into_tensor(); let b_shape = b.shape(); let first_unary_axis = b_shape @@ -469,7 +469,19 @@ impl TypedOp for BinOpByScalar { model: &TypedModel, node: &TypedNode, ) -> TractResult> { - self.0.declutter(model, node) + if self.0.downcast_ref::().is_some() { + let dt = model.node_input_facts(node.id)?[0].datum_type().unwrap(); + let func = tract_linalg::bin_by_scalar(dt, tract_linalg::BinOp::Mul).unwrap(); + let eval = Arc::from(func); + return Ok(Some(TypedModelPatch::replace_single_op( + model, + node, + &node.inputs, + BinOpByScalar(Box::new(LirMul { eval, return_dt: dt })), + )?)); + } + Ok(None) + //self.0.declutter(model, node) } as_op!(); @@ -519,7 +531,7 @@ impl EvalOp for BinOpUnicast { let (a, b) = args_2!(inputs); // Not a requirement as TensorView doesn't require a owned tensor but in reality // "a "should be mutable (it's omitted here as Rust compiler advise to remove it) - let a = a.into_tensor(); + let a = a.into_tensor(); let b_shape = b.shape(); let b_view = b.view(); let first_non_unary_axis = @@ -566,12 +578,70 @@ impl TypedOp for BinOpUnicast { model: &TypedModel, node: &TypedNode, ) -> TractResult> { - self.0.declutter(model, node) + if self.0.downcast_ref::().is_some() { + let dt = model.node_input_facts(node.id)?[0].datum_type().unwrap(); + let func = tract_linalg::bin_unicast(dt, tract_linalg::BinOp::Mul).unwrap(); + let eval = Arc::from(func); + return Ok(Some(TypedModelPatch::replace_single_op( + model, + node, + &node.inputs, + BinOpUnicast(Box::new(LirMul { eval, return_dt: dt })), + )?)); + } + Ok(None) + //self.0.declutter(model, node) } as_op!(); } +#[derive(Clone)] +pub struct LirMul { + return_dt: DatumType, + eval: Arc TractResult<()> + Send + Sync>, +} + +impl Debug for LirMul { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + unimplemented!() + } +} + +impl BinMiniOp for LirMul { + fn name(&self) -> &'static str { + "LirMul" + } + + fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult { + Ok(self.return_dt) + } + + fn eval_in_a(&self, _a: &mut Tensor, _b: &Tensor) -> TractResult<()> { + unimplemented!() + } + + fn eval_unicast(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> { + (self.eval)(a, b) + } + + fn eval_by_scalar(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> { + (self.eval)(a, b) + } + + fn eval_out_of_place(&self, _c: &mut Tensor, _a: &Tensor, _b: &Tensor) -> TractResult<()> { + unimplemented!() + } + + fn eval_unicast_in_place(&self, _a: &Tensor, _b: &mut Tensor) -> TractResult<()> { + unimplemented!() + } + + fn eval_uniform_in_place(&self, _a: &Tensor, _b: &mut Tensor) -> TractResult<()> { + unimplemented!() + } +} + #[macro_export] macro_rules! bin_to_super_type { ($func:ident, $Op:ident, diff --git a/linalg/src/frame/by_scalar.rs b/linalg/src/frame/by_scalar.rs index cf6c948451..3f8c0b3bea 100644 --- a/linalg/src/frame/by_scalar.rs +++ b/linalg/src/frame/by_scalar.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, marker::PhantomData}; use tract_data::{TractResult, internal::TensorView}; -use crate::{LADatum, element_wise::ElementWiseKer}; +use crate::{LADatum, element_wise::ElementWiseKer, LinalgFn}; use super::{ElementWise, element_wise_helper::map_slice_with_alignment}; @@ -36,7 +36,7 @@ pub trait ByScalarKer: ElementWiseKer where T: LADatum { - fn bin_1() -> Box TractResult<()>> { + fn bin_1() -> LinalgFn { Box::new(|a: &mut TensorView, b: &TensorView| { let a_slice = a.as_slice_mut()?; let b = b.as_slice()?[0]; diff --git a/linalg/src/frame/unicast/mod.rs b/linalg/src/frame/unicast/mod.rs index 05689d7e37..1c9a7ade7d 100644 --- a/linalg/src/frame/unicast/mod.rs +++ b/linalg/src/frame/unicast/mod.rs @@ -7,7 +7,7 @@ use tract_data::TractResult; use tract_data::internal::TensorView; use crate::frame::element_wise_helper::TempBuffer; -use crate::LADatum; +use crate::{LADatum, LinalgFn}; macro_rules! unicast_impl_wrap { ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $run: item) => { @@ -88,11 +88,11 @@ where fn bin() -> Box> { Box::new(UnicastImpl::::new()) } - fn bin_1() -> Box TractResult<()>> { + fn bin_1() -> LinalgFn { Box::new(|a: &mut TensorView, b: &TensorView| { let a_slice = a.as_slice_mut()?; let b_slice = b.as_slice()?; - (Self::bin()).run(a_slice, b_slice) + Self::bin().run(a_slice, b_slice) }) } } diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index df178e1404..931aefe738 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -214,7 +214,7 @@ fn register_all_by_scalar(registry: &mut LinalgRegistry) { } -type LinalgFn = Box TractResult<()>>; +pub type LinalgFn = Box TractResult<()> + Send + Sync>; type LinalgRegistry = HashMap<(BinOp, DatumType), Box LinalgFn + Send + Sync>>; lazy_static! { static ref BIN_UNICAST_OPS: Mutex = { From 1a546f9b5efe31209d39e8ab71c7a5475762a58d Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Fri, 4 Oct 2024 12:17:06 -0400 Subject: [PATCH 14/32] Change naming --- core/src/ops/binary.rs | 12 ++++++------ linalg/src/lib.rs | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 42c22d30ed..160fe13f57 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -477,7 +477,7 @@ impl TypedOp for BinOpByScalar { model, node, &node.inputs, - BinOpByScalar(Box::new(LirMul { eval, return_dt: dt })), + BinOpByScalar(Box::new(OptMul { eval, return_dt: dt })), )?)); } Ok(None) @@ -586,7 +586,7 @@ impl TypedOp for BinOpUnicast { model, node, &node.inputs, - BinOpUnicast(Box::new(LirMul { eval, return_dt: dt })), + BinOpUnicast(Box::new(OptMul { eval, return_dt: dt })), )?)); } Ok(None) @@ -597,20 +597,20 @@ impl TypedOp for BinOpUnicast { } #[derive(Clone)] -pub struct LirMul { +pub struct OptMul { return_dt: DatumType, eval: Arc TractResult<()> + Send + Sync>, } -impl Debug for LirMul { +impl Debug for OptMul { fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { unimplemented!() } } -impl BinMiniOp for LirMul { +impl BinMiniOp for OptMul { fn name(&self) -> &'static str { - "LirMul" + "OptMul" } fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult { diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 931aefe738..26e8dc3dba 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -24,6 +24,7 @@ pub mod multithread; use frame::by_scalar::ByScalarKer; use frame::element_wise::ElementWiseKer; use frame::reduce::{MapReduceKer, ReduceKer}; +use frame::unicast::UnicastKer; use frame::{reduce, unicast, MatMatMul}; pub use generic::{ScaleShiftAndRound, Scaler}; use lazy_static::lazy_static; From 6834a92b7a39754539897b0a7e52fe3077f28cae Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Wed, 9 Oct 2024 11:27:20 -0400 Subject: [PATCH 15/32] Reorganize code & remove methods from BinMiniOp trait --- core/src/ops/binary.rs | 411 ++++++++----------------------- core/src/ops/cnn/conv/conv.rs | 2 +- core/src/ops/math/mod.rs | 170 ------------- core/src/ops/matmul/optimized.rs | 8 +- core/src/ops/quant.rs | 35 --- linalg/src/frame/mmm/fuse.rs | 23 +- linalg/src/frame/mmm/scratch.rs | 4 +- linalg/src/generic.rs | 16 +- linalg/src/lib.rs | 13 +- tflite/src/ops/math.rs | 2 +- 10 files changed, 126 insertions(+), 558 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 160fe13f57..78251f22b8 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -4,7 +4,7 @@ use std::fmt::{self, Debug}; use tract_data::itertools::izip; use tract_itertools::Itertools; -use super::{cast::cast, math::Mul}; +use super::cast::cast; pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast { fn name(&self) -> &'static str; @@ -15,15 +15,9 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b)) } fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult; - fn eval_unicast_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()>; - fn eval_uniform_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()>; fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>; fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>; - // Temporary introduced to test TensorView approach - fn eval_by_scalar(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()>; - fn eval_unicast(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()>; - fn is_commutative(&self) -> bool { true } @@ -44,14 +38,6 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult { if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? { Ok(tensor) - } else if c_dt == b.datum_type() && a.len() == 1 { - let mut b = b.into_tensor(); - self.eval_uniform_in_place(&a, &mut b)?; - Ok(b) - } else if a.shape() == b.shape() && c_dt == b.datum_type() { - let mut b = b.into_tensor(); - self.eval_unicast_in_place(&a, &mut b)?; - Ok(b) } else { let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?; if &*c_shape == a.shape() && c_dt == a.datum_type() { @@ -88,7 +74,7 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> { tvec!() } - fn as_linalg_binop(&self) -> Option { + fn as_linalg_binop(&self) -> Option { None } @@ -234,44 +220,52 @@ impl TypedOp for TypedBinOp { model: &TypedModel, node: &TypedNode, ) -> TractResult> { - let (by_scalar_should_be_efficient, unicast_should_be_efficient) = - find_most_efficient_config(model, node)?; - let op_is_quant = if let &[a, b] = &*model.node_input_facts(node.id)? { - let c_dt = self.output_datum_type(a.datum_type, b.datum_type)?; - c_dt.is_quantized() || a.datum_type.is_quantized() || b.datum_type.is_quantized() - } else { - false - }; - let can_eval_in_a = if let &[a, b] = &*model.node_input_facts(node.id)? { - let c_dt = self.output_datum_type(a.datum_type, b.datum_type)?; - let c_shape = crate::broadcast::multi_broadcast(&[a.shape.clone(), b.shape.clone()])?; - (c_shape == a.shape.to_tvec()) && (c_dt == a.datum_type) - } else { - false - }; + if let Some(linalg_bin_op) = self.0.as_linalg_binop() { + let (by_scalar_should_be_efficient, unicast_should_be_efficient) = + find_most_efficient_config(model, node)?; + let op_is_quant = if let &[a, b] = &*model.node_input_facts(node.id)? { + let c_dt = self.output_datum_type(a.datum_type, b.datum_type)?; + c_dt.is_quantized() || a.datum_type.is_quantized() || b.datum_type.is_quantized() + } else { + false + }; + let can_eval_in_a = if let &[a, b] = &*model.node_input_facts(node.id)? { + let c_dt = self.output_datum_type(a.datum_type, b.datum_type)?; + let c_shape = + crate::broadcast::multi_broadcast(&[a.shape.clone(), b.shape.clone()])?; + (c_shape == a.shape.to_tvec()) && (c_dt == a.datum_type) + } else { + false + }; - if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant { - return Ok(Some( - TypedModelPatch::replace_single_op( - model, - node, - &node.inputs, - BinOpByScalar(self.0.clone()), - )? - .with_context("ByScalar"), - )); - } + let dt = model.node_input_facts(node.id)?[0].datum_type().unwrap(); + if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant { + let Some(func) = tract_linalg::bin_by_scalar(dt, linalg_bin_op) else {return Ok(None)}; + let eval_fn = Arc::from(func); + return Ok(Some( + TypedModelPatch::replace_single_op( + model, + node, + &node.inputs, + OptBinByScalar { binop: self.0.clone(), eval_fn }, + )? + .with_context("ByScalar"), + )); + } - if unicast_should_be_efficient & can_eval_in_a & !op_is_quant { - return Ok(Some( - TypedModelPatch::replace_single_op( - model, - node, - &node.inputs, - BinOpUnicast(self.0.clone()), - )? - .with_context("Unicast"), - )); + if unicast_should_be_efficient & can_eval_in_a & !op_is_quant { + let Some(func) = tract_linalg::bin_unicast(dt, linalg_bin_op) else {return Ok(None)}; + let eval_fn = Arc::from(func); + return Ok(Some( + TypedModelPatch::replace_single_op( + model, + node, + &node.inputs, + OptBinUnicast { binop: self.0.clone(), eval_fn }, + )? + .with_context("Unicast"), + )); + } } Ok(None) @@ -333,7 +327,7 @@ fn find_most_efficient_config(model: &TypedModel, node: &TypedNode) -> TractResu let a_shape = a.shape.clone(); let b_shape = b.shape.clone(); - let by_scalar_is_possible = BinOpByScalar::check_input_shapes(&a_shape, &b_shape); + let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape); let num_by_scalar_elements = if by_scalar_is_possible { a_shape .iter() @@ -346,7 +340,7 @@ fn find_most_efficient_config(model: &TypedModel, node: &TypedNode) -> TractResu TDim::Val(0) }; - let unicast_is_possible = BinOpUnicast::check_input_shapes(&a_shape, &b_shape); + let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape); let num_unicast_elements = if unicast_is_possible { a_shape .iter() @@ -371,43 +365,46 @@ pub fn gt_tdim(x: TDim, min_val: i64) -> bool { TDim::Val(min_val).mini(x).to_i64().map_or(false, |v| v == min_val) } -#[derive(Debug, Clone)] -pub struct BinOpByScalar(pub Box); +#[derive(Clone)] +pub struct OptBinByScalar { + pub binop: Box, + eval_fn: Arc TractResult<()> + Send + Sync>, +} -impl BinOpByScalar { +impl Debug for OptBinByScalar { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + unimplemented!() + } +} + +impl OptBinByScalar { fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool { if a_shape.len() != b_shape.len() { return false; }; - let mut must_be_unary = false; - a_shape.iter().zip(b_shape.iter()).all(|(a_dim, b_dim)| { - // As soon as a and b dimensions differ, b dimensions must be 1 until the end. - if (a_dim != b_dim) && !must_be_unary { - must_be_unary = true - } - - // Leading dimensions: a_dim==b_dim condition - // Trailing dimensison: b_dim == 1 - ((a_dim == b_dim) & !must_be_unary) || ((*b_dim == 1.into()) & must_be_unary) - }) + a_shape + .iter() + .zip(b_shape.iter()) + .skip_while(|(a_dim, b_dim)| a_dim == b_dim) + .all(|(_, b_dim)| *b_dim == 1.to_dim()) } } -impl Op for BinOpByScalar { +impl Op for OptBinByScalar { fn name(&self) -> Cow { - format!("{}ByScalar", self.0.name()).into() + format!("Opt{}ByScalar", self.binop.name()).into() } fn same_as(&self, other: &dyn Op) -> bool { - let Some(other) = other.downcast_ref::() else { return false }; - self.0.same_as(&*other.0) + let Some(other) = other.downcast_ref::() else { return false }; + self.binop.same_as(&*other.binop) } op_as_typed_op!(); } -impl EvalOp for BinOpByScalar { +impl EvalOp for OptBinByScalar { fn is_stateless(&self) -> bool { true } @@ -434,22 +431,22 @@ impl EvalOp for BinOpByScalar { let mut view = TensorView::at_prefix(&a, &it_coords)?; let b_view = TensorView::at_prefix(&b, &it_coords)?; debug_assert_eq!(b_view.shape().iter().product::(), 1); - self.0.eval_by_scalar(&mut view, &b_view)?; + (self.eval_fn)(&mut view, &b_view)?; } } else { let mut view = a.view(); let b_view = b.view(); debug_assert_eq!(b_view.shape().iter().product::(), 1); - self.0.eval_by_scalar(&mut view, &b_view)?; + (self.eval_fn)(&mut view, &b_view)?; } Ok(tvec!(a.into_tvalue())) } } -impl TypedOp for BinOpByScalar { +impl TypedOp for OptBinByScalar { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape)); - let out_dt = self.0.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?; + let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?; let out_shape = inputs[0].shape.clone(); Ok(tvec!(out_dt.fact(out_shape))) } @@ -457,72 +454,55 @@ impl TypedOp for BinOpByScalar { fn cost(&self, inputs: &[&TypedFact]) -> TractResult> { let count: TDim = self.output_facts(inputs)?[0].shape.iter().product(); Ok(self - .0 + .binop .cost_per_element(inputs[0].datum_type) .into_iter() .map(|(c, n)| (c, count.clone() * n)) .collect()) } - fn declutter( - &self, - model: &TypedModel, - node: &TypedNode, - ) -> TractResult> { - if self.0.downcast_ref::().is_some() { - let dt = model.node_input_facts(node.id)?[0].datum_type().unwrap(); - let func = tract_linalg::bin_by_scalar(dt, tract_linalg::BinOp::Mul).unwrap(); - let eval = Arc::from(func); - return Ok(Some(TypedModelPatch::replace_single_op( - model, - node, - &node.inputs, - BinOpByScalar(Box::new(OptMul { eval, return_dt: dt })), - )?)); - } - Ok(None) - //self.0.declutter(model, node) - } - as_op!(); } -#[derive(Debug, Clone)] -pub struct BinOpUnicast(pub Box); +#[derive(Clone)] +pub struct OptBinUnicast { + pub binop: Box, + eval_fn: Arc TractResult<()> + Send + Sync>, +} + +impl Debug for OptBinUnicast { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + unimplemented!() + } +} -impl BinOpUnicast { +impl OptBinUnicast { fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool { if a_shape.len() != b_shape.len() { return false; }; - let mut must_be_equal = false; - a_shape.iter().zip(b_shape.iter()).all(|(a_dim, b_dim)| { - // As soon as b dimension not equal to one, a and b dimensions must be equal. - if (*b_dim != 1.into()) && !must_be_equal { - must_be_equal = true - } - - // Leading dimensions: b_dim==1 condition - // Trailing dimensison: a_dim == b_dim - ((*b_dim == 1.into()) & !must_be_equal) || ((a_dim == b_dim) & must_be_equal) - }) + a_shape + .iter() + .zip(b_shape.iter()) + .skip_while(|(_, b_dim)| **b_dim == 1.to_dim()) + .all(|(a_dim, b_dim)| a_dim == b_dim) } } -impl Op for BinOpUnicast { +impl Op for OptBinUnicast { fn name(&self) -> Cow { - format!("{}Unicast", self.0.name()).into() + format!("Opt{}Unicast", self.binop.name()).into() } fn same_as(&self, other: &dyn Op) -> bool { - let Some(other) = other.downcast_ref::() else { return false }; - self.0.same_as(&*other.0) + let Some(other) = other.downcast_ref::() else { return false }; + self.binop.same_as(&*other.binop) } op_as_typed_op!(); } -impl EvalOp for BinOpUnicast { +impl EvalOp for OptBinUnicast { fn is_stateless(&self) -> bool { true } @@ -543,22 +523,22 @@ impl EvalOp for BinOpUnicast { for it_coords in tract_data::internal::iter_indices(&iterating_shape) { let mut view = TensorView::at_prefix(&a, &it_coords)?; debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.len()..]); - self.0.eval_unicast(&mut view, &b_view)?; + (self.eval_fn)(&mut view, &b_view)?; } } else { let mut view = a.view(); debug_assert_eq!(view.shape(), b_view.shape()); - self.0.eval_unicast(&mut view, &b_view)?; + (self.eval_fn)(&mut view, &b_view)?; } Ok(tvec!(a.into_tvalue())) } } -impl TypedOp for BinOpUnicast { +impl TypedOp for OptBinUnicast { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape)); - let out_dt = self.0.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?; + let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?; let out_shape = inputs[0].shape.clone(); Ok(tvec!(out_dt.fact(out_shape))) } @@ -566,82 +546,16 @@ impl TypedOp for BinOpUnicast { fn cost(&self, inputs: &[&TypedFact]) -> TractResult> { let count: TDim = self.output_facts(inputs)?[0].shape.iter().product(); Ok(self - .0 + .binop .cost_per_element(inputs[0].datum_type) .into_iter() .map(|(c, n)| (c, count.clone() * n)) .collect()) } - fn declutter( - &self, - model: &TypedModel, - node: &TypedNode, - ) -> TractResult> { - if self.0.downcast_ref::().is_some() { - let dt = model.node_input_facts(node.id)?[0].datum_type().unwrap(); - let func = tract_linalg::bin_unicast(dt, tract_linalg::BinOp::Mul).unwrap(); - let eval = Arc::from(func); - return Ok(Some(TypedModelPatch::replace_single_op( - model, - node, - &node.inputs, - BinOpUnicast(Box::new(OptMul { eval, return_dt: dt })), - )?)); - } - Ok(None) - //self.0.declutter(model, node) - } - as_op!(); } -#[derive(Clone)] -pub struct OptMul { - return_dt: DatumType, - eval: Arc TractResult<()> + Send + Sync>, -} - -impl Debug for OptMul { - fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - unimplemented!() - } -} - -impl BinMiniOp for OptMul { - fn name(&self) -> &'static str { - "OptMul" - } - - fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult { - Ok(self.return_dt) - } - - fn eval_in_a(&self, _a: &mut Tensor, _b: &Tensor) -> TractResult<()> { - unimplemented!() - } - - fn eval_unicast(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> { - (self.eval)(a, b) - } - - fn eval_by_scalar(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> { - (self.eval)(a, b) - } - - fn eval_out_of_place(&self, _c: &mut Tensor, _a: &Tensor, _b: &Tensor) -> TractResult<()> { - unimplemented!() - } - - fn eval_unicast_in_place(&self, _a: &Tensor, _b: &mut Tensor) -> TractResult<()> { - unimplemented!() - } - - fn eval_uniform_in_place(&self, _a: &Tensor, _b: &mut Tensor) -> TractResult<()> { - unimplemented!() - } -} - #[macro_export] macro_rules! bin_to_super_type { ($func:ident, $Op:ident, @@ -652,10 +566,6 @@ macro_rules! bin_to_super_type { $(eval_override: $eval_override: expr,)? $(linalg: $linalg:ident,)? $(operating_datum_type: $operating_datum_type:expr,)? - $(uniform_in_place: $uniform_in_place:expr,)? - $(unicast_in_place: $unicast_in_place:expr,)? - $(eval_by_scalar: $eval_by_scalar:expr,)? - $(eval_unicast: $eval_unicast:expr,)? $(is_commutative: $is_commutative:expr,)? $(neutral_element: $neutral_element:expr,)? $(out_of_place: $out_of_place:expr,)? @@ -675,87 +585,6 @@ macro_rules! bin_to_super_type { other.downcast_ref::<$Op>().is_some() } - fn eval_uniform_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> { - $(if $uniform_in_place(a, b)? { return Ok(()) } )? - $( - $(if a.datum_type() == $typ::datum_type() { - let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab; - let a = &a.as_slice::<$typ>()?[0]; - let b = b.as_slice_mut::<$typ>()?; - unsafe { - for i in 0..b.len() { - let mut c = $typ::default(); - cab(&mut c, a, b.get_unchecked_mut(i)); - b[i] = c; - } - } - return Ok(()) - } - )* - )* - - $( - $( - $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() { - let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt; - let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.)); - let a = &a.as_slice::<$typ_dt>()?[0]; - let b = b.as_slice_mut::<$typ_dt>()?; - unsafe { - for i in 0..b.len() { - let mut c = $typ_dt::default(); - cab(&mut c, a, b.get_unchecked_mut(i), zp, scale); - b[i] = c; - } - } - return Ok(()) - } - )* - )* - )? - bail!("{} does not support {:?} (inplace uniform)", self.name(), a.datum_type()); - } - - fn eval_unicast_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> { - $(if $unicast_in_place(a, b)? { return Ok(()) } )? - $( - $(if a.datum_type() == $typ::datum_type() { - let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab; - let a = a.as_slice::<$typ>()?; - let b = b.as_slice_mut::<$typ>()?; - unsafe { - for i in 0..a.len() { - let mut c = $typ::default(); - cab(&mut c, &a[i], b.get_unchecked(i)); - *b.get_unchecked_mut(i) = c; - } - } - return Ok(()) - } - )* - )* - $( - $( - $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() { - let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt; - let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.)); - let a = a.as_slice::<$typ_dt>()?; - let b = b.as_slice_mut::<$typ_dt>()?; - unsafe { - for i in 0..a.len() { - let mut c = $typ_dt::default(); - cab(&mut c, &a[i], b.get_unchecked(i), zp, scale); - *b.get_unchecked_mut(i) = c; - } - } - return Ok(()) - } - )* - )* - )? - bail!("{} does not support {:?} (inplace)", self.name(), a.datum_type()); - } - fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> { $(if $out_of_place(c, a, b)? { return Ok(()) } )? $( @@ -784,38 +613,6 @@ macro_rules! bin_to_super_type { bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type()); } - fn eval_by_scalar(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> { - $(if $eval_by_scalar(a, b)? { return Ok(())} )? - $( - $(if b.datum_type() == $typ::datum_type() { - let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab; - let b = &b.as_slice::<$typ>()?[0]; - let a_slice = a.as_slice_mut::<$typ>()?; - a_slice.iter_mut().for_each(|a| cab(a, &a.clone(), b)); - return Ok(()) - })* - )* - bail!("{} does not support {:?} (eval by scalar)", self.name(), a.datum_type()); - } - fn eval_unicast(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> { - $(if $eval_unicast(a, b)? { return Ok(()) } )? - $( - $(if b.datum_type() == $typ::datum_type() { - let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab; - let b = &b.as_slice::<$typ>()?; - let a_slice = a.as_slice_mut::<$typ>()?; - unsafe { - for i in 0..b.len() { - let mut c = $typ::default(); - cab(&mut c, &a_slice[i], b.get_unchecked(i)); - *a_slice.get_unchecked_mut(i) = c; - } - } - return Ok(()) - })* - )* - bail!("{} does not support {:?} (eval unicast)", self.name(), a.datum_type()); - } $(fn is_commutative(&self) -> bool { $is_commutative })? @@ -897,8 +694,8 @@ macro_rules! bin_to_super_type { } )? $( - fn as_linalg_binop(&self) -> Option { - Some(tract_linalg::mmm::BinOp::$linalg) + fn as_linalg_binop(&self) -> Option { + Some(tract_linalg::BinOp::$linalg) } )? $( diff --git a/core/src/ops/cnn/conv/conv.rs b/core/src/ops/cnn/conv/conv.rs index 1abacac7ff..ca11848213 100644 --- a/core/src/ops/cnn/conv/conv.rs +++ b/core/src/ops/cnn/conv/conv.rs @@ -92,7 +92,7 @@ impl Conv { bias: OutletId, c_group_axis: usize, ) -> TractResult<(ProtoFusedSpec, OutletId)> { - use tract_linalg::mmm::BinOp::Add; + use tract_linalg::BinOp::Add; let fact = model.outlet_fact(bias)?; if fact.shape.volume().is_one() { Ok((ProtoFusedSpec::BinScalar(2, Add), bias)) diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index f7786f1b56..8c2b06bb20 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -11,10 +11,7 @@ use num_traits::{Float, Zero}; use tract_data::internal::ClampCast; use tract_data::itertools::Itertools; pub use tract_data::prelude::round_ties_to_even; -use tract_linalg::frame::unicast::Unicast; -use tract_linalg::frame::ElementWise; use tract_linalg::{ScaleShiftAndRound, Scaler}; -use tract_ndarray::Axis; use tract_num_traits::AsPrimitive; #[cfg(feature = "complex")] @@ -57,7 +54,6 @@ where bin_to_super_type!(mul, Mul, cost: |dt| tvec!((Cost::FMA(dt), 1)), declutter: declutter_mul, - eval_in_a: mul_eval_in_a, eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult { // we apply only if type is QU8 zp_scale datum type if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}), @@ -81,34 +77,6 @@ bin_to_super_type!(mul, Mul, } }, linalg: Mul, - uniform_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult { - let mut slice = b.view_mut(); - let scalar = a.view(); - let res = tract_linalg::bin_by_scalar(a.datum_type(), tract_linalg::BinOp::Mul) - .and_then(move |func| (func)(&mut slice, &scalar).ok()) - .is_some(); - Ok(res) - }, - unicast_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult { - let mut slice = b.view_mut(); - let other = a.view(); - let res = tract_linalg::bin_unicast(a.datum_type(), tract_linalg::BinOp::Mul) - .and_then(move |func| (func)(&mut slice, &other).ok()) - .is_some(); - Ok(res) - }, - eval_by_scalar: |a: &mut TensorView, b: &TensorView | -> TractResult { - let res = tract_linalg::bin_by_scalar(a.datum_type(), tract_linalg::BinOp::Mul) - .and_then(move |func| (func)(a, b).ok()) - .is_some(); - Ok(res) - }, - eval_unicast: |a: &mut TensorView, b: &TensorView | -> TractResult { - let res = tract_linalg::bin_unicast(a.datum_type(), tract_linalg::BinOp::Mul) - .and_then(move |func| (func)(a, b).ok()) - .is_some(); - Ok(res) - }, neutral_element: 1, out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult { if c.datum_type() == TDim::datum_type() && @@ -155,144 +123,6 @@ bin_to_super_type!(mul, Mul, [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() * b ); -fn check_uniform_is_possible(a_shape: &[usize], b_shape: &[usize]) -> bool { - if a_shape.len() != b_shape.len() { - return false; - }; - - a_shape - .iter() - .zip(b_shape.iter()) - .skip_while(|(a_dim, b_dim)| a_dim == b_dim) - .all(|(_, b_dim)| *b_dim == 1) -} - -fn check_unicast_is_possible(a_shape: &[usize], b_shape: &[usize]) -> bool { - if a_shape.len() != b_shape.len() { - return false; - }; - - a_shape - .iter() - .zip(b_shape.iter()) - .skip_while(|(_, b_dim)| **b_dim == 1) - .all(|(a_dim, b_dim)| a_dim == b_dim) -} - -fn mul_eval_in_a(a: &mut Tensor, b: &Tensor) -> TractResult { - let b_shape = b.shape(); - let leading_unary_dims: Vec = - b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i).collect(); - let trailing_unary_dims: Vec = b_shape - .iter() - .enumerate() - .rev() - .take_while(|&(_, &dim)| dim == 1) - .map(|(i, _)| i) - .collect(); - - let uniform_is_possible = check_uniform_is_possible(a.shape(), b.shape()); - let uniform_in_place_should_be_efficient = - trailing_unary_dims.iter().fold(1, |num_elements, it| num_elements * a.shape()[*it]) > 32; - let unicast_is_possible = check_unicast_is_possible(a.shape(), b.shape()); - let unicast_in_place_should_be_efficient = - leading_unary_dims.iter().fold(1, |num_elements, it| num_elements * a.shape()[*it]) > 32; - - // Better to try uniform in place first (should be more efficient) - if uniform_in_place_should_be_efficient && uniform_is_possible { - if b.datum_type() == f32::datum_type() { - mul_by_scalar::( - a, - b, - &trailing_unary_dims, - (tract_linalg::ops().mul_by_scalar_f32)(), - ) - } else if b.datum_type() == f16::datum_type() { - mul_by_scalar::( - a, - b, - &trailing_unary_dims, - (tract_linalg::ops().mul_by_scalar_f16)(), - ) - } else { - Ok(false) - } - } else if unicast_in_place_should_be_efficient && unicast_is_possible { - if b.datum_type() == f32::datum_type() { - mul_unicast::(a, b, &leading_unary_dims, (tract_linalg::ops().unicast_mul_f32)()) - } else if b.datum_type() == f16::datum_type() { - mul_unicast::(a, b, &leading_unary_dims, (tract_linalg::ops().unicast_mul_f16)()) - } else { - return Ok(false); - } - } else { - Ok(false) - } -} - -fn mul_unicast( - a: &mut Tensor, - b: &Tensor, - leading_unary_dims: &[usize], - eval: Box>, -) -> TractResult { - let mut a_view = a.to_array_view_mut::()?; - let b_view = b.to_array_view::()?; - let mut iterating_shape = a_view.shape().to_vec(); - iterating_shape.iter_mut().enumerate().for_each(|(idx, dim)| { - if !leading_unary_dims.contains(&idx) { - *dim = 1 - } - }); - for it_coords in tract_ndarray::indices(iterating_shape) { - let mut a_view = a_view.view_mut(); - for idx in 0..a_view.shape().len() { - if leading_unary_dims.contains(&idx) { - a_view.collapse_axis(Axis(idx), it_coords[idx]); - } - } - - if let Some((a_slice, b_slice)) = a_view.as_slice_mut().zip(b_view.as_slice()) { - eval.run(a_slice, b_slice)?; - } else { - return Ok(false); - } - } - Ok(true) -} - -fn mul_by_scalar( - a: &mut Tensor, - b: &Tensor, - trailing_unary_dims: &[usize], - eval: Box>, -) -> TractResult { - let mut view = a.to_array_view_mut::()?; - let b = b.to_array_view::()?; - for it_coords in tract_ndarray::indices(b.shape()) { - // Prepare array view to perform computation - // - view should be a slice - // - b should be a scalar - let mut view = view.view_mut(); - let mut b = b.view(); - for idx in 0..b.shape().len() { - if !trailing_unary_dims.contains(&idx) { - view.collapse_axis(Axis(idx), it_coords[idx]); - b.collapse_axis(Axis(idx), it_coords[idx]); - } - } - - // Perform computation on a slice on the view - let b = b.as_slice().unwrap()[0]; - if let Some(slice) = view.as_slice_mut() { - eval.run_with_params(slice, b)?; - } else { - view.iter_mut().for_each(|it| *it = *it * b) - } - } - Ok(true) -} - bin_to_super_type!(div, Div, cost: |dt| tvec!((Cost::Div(dt), 1)), declutter: declutter_div, diff --git a/core/src/ops/matmul/optimized.rs b/core/src/ops/matmul/optimized.rs index 6b1fb65a71..4b4c1fcf26 100644 --- a/core/src/ops/matmul/optimized.rs +++ b/core/src/ops/matmul/optimized.rs @@ -11,9 +11,9 @@ use tract_linalg::frame::block_quant::{ }; use tract_linalg::frame::PackedFormat; use tract_linalg::mmm::{ - AsInputValue, BinOp, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec, + AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec, }; -use tract_linalg::Scaler; +use tract_linalg::{Scaler, BinOp}; use tract_smallvec::ToSmallVec; #[derive(Clone, Debug)] @@ -552,8 +552,8 @@ impl TypedOp for OptMatMul { } } } - if let Some(op) = succ.op_as::() { - if op.0.is::() && self.mmm.len() == 1 { + if let Some(op) = succ.op_as::() { + if op.binop.is::() && self.mmm.len() == 1 { let other_slot = 1 - node.outputs[0].successors[0].slot; let other_input = succ.inputs[other_slot]; let other_input = patch.tap_model(model, other_input)?; diff --git a/core/src/ops/quant.rs b/core/src/ops/quant.rs index b04f37f132..c2cac1fc1b 100644 --- a/core/src/ops/quant.rs +++ b/core/src/ops/quant.rs @@ -277,13 +277,6 @@ impl crate::ops::binary::BinMiniOp for Scale { fn name(&self) -> &'static str { "Scale" } - fn eval_by_scalar(&self, _a: &mut TensorView, _b: &TensorView) -> TractResult<()> { - unimplemented!("Eval by scalar not implemented") - } - fn eval_unicast(&self, _a: &mut TensorView, _b: &TensorView) -> TractResult<()> { - unimplemented!("Eval unicast not implemented") - } - fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult { if !a.is_float() { bail!("Scale left operand must be float, got {:?}", a); @@ -298,34 +291,6 @@ impl crate::ops::binary::BinMiniOp for Scale { Ok(b) } - fn eval_uniform_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> { - let a = a.cast_to_scalar::()?; - unsafe fn eval_in_place_t>(a: f32, b: &mut Tensor) - where - f32: AsPrimitive, - { - b.as_slice_mut_unchecked::().iter_mut().for_each(|x| *x = scale_by(*x, a)); - } - unsafe { dispatch_numbers!(eval_in_place_t(b.datum_type())(a, b)) } - Ok(()) - } - - fn eval_unicast_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> { - let a = a.cast_to::()?; - let a = a.to_array_view::()?; - unsafe fn eval_in_place_t>( - a: &ndarray::ArrayViewD, - b: &mut Tensor, - ) where - f32: AsPrimitive, - { - let mut b = b.to_array_view_mut_unchecked::(); - ndarray::Zip::from(&mut b).and_broadcast(a).for_each(|b, a| *b = scale_by(*b, *a)) - } - unsafe { dispatch_numbers!(eval_in_place_t(b.datum_type())(&a, b)) } - Ok(()) - } - fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> { let a = a.cast_to::()?; let a = a.to_array_view::()?; diff --git a/linalg/src/frame/mmm/fuse.rs b/linalg/src/frame/mmm/fuse.rs index 3e1bf4a8f5..24aa60bfc7 100644 --- a/linalg/src/frame/mmm/fuse.rs +++ b/linalg/src/frame/mmm/fuse.rs @@ -2,6 +2,8 @@ use std::fmt::Debug; use std::ops::Deref; use super::pack::PackedFormat; +use crate::BinOp; + use super::{MMMInputValue, OutputStore, OutputStoreKer}; use tract_data::internal::*; @@ -17,27 +19,6 @@ pub enum RoundingPolicy { Odd, } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -pub enum BinOp { - Min, - Max, - Add, - Mul, - Sub, - SubF, -} - -impl BinOp { - pub fn flip(&self) -> BinOp { - use BinOp::*; - match self { - Sub => SubF, - SubF => Sub, - sym => *sym, - } - } -} - #[derive(Clone, Debug)] pub enum AsInputValue<'t> { Owned(Box), diff --git a/linalg/src/frame/mmm/scratch.rs b/linalg/src/frame/mmm/scratch.rs index 9438ca8eef..ac34fa6da4 100644 --- a/linalg/src/frame/mmm/scratch.rs +++ b/linalg/src/frame/mmm/scratch.rs @@ -1,5 +1,5 @@ -use super::{BinOp, FusedKerSpec, FusedSpec, MatMatMulKer, OutputStoreKer}; -use crate::LADatum; +use super::{FusedKerSpec, FusedSpec, MatMatMulKer, OutputStoreKer}; +use crate::{BinOp, LADatum}; use downcast_rs::{impl_downcast, Downcast}; use std::cell::RefCell; use std::fmt::Debug; diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index 65f474cf14..e7aee437f9 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -11,23 +11,23 @@ pub mod unicast; use tract_data::prelude::DatumType; -use crate::{LinalgRegistry, BinOp, UnicastKer, ByScalarKer}; +use crate::{BinOp, ByScalarKer, LinalgRegistry, UnicastKer}; pub use self::by_scalar::{HMulByScalar8, SMulByScalar4}; pub use self::erf::SErf4; pub use self::leaky_relu::{HLeakyRelu8, SLeakyRelu4}; pub use self::lut::GenericLut8; +pub use self::reduce::softmax_l2::SSoftMaxL2; pub use self::rounding::{ScaleShiftAndRound, Scaler}; pub use self::sigmoid::{HSigmoid8, SSigmoid4}; -pub use self::reduce::softmax_l2::SSoftMaxL2; pub use self::tanh::{HTanh8, STanh4}; -pub(crate)fn register_all_unicast(registry: &mut LinalgRegistry) { - registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| unicast::SUnicastMul4::bin_1())); - registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| unicast::HUnicastMul8::bin_1())); +pub(crate) fn register_all_unicast(registry: &mut LinalgRegistry) { + registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| unicast::SUnicastMul4::bin_1())); + registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| unicast::HUnicastMul8::bin_1())); } -pub(crate)fn register_all_by_scalar(registry: &mut LinalgRegistry) { - registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| by_scalar::SMulByScalar4::bin_1())); - registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| by_scalar::HMulByScalar8::bin_1())); +pub(crate) fn register_all_by_scalar(registry: &mut LinalgRegistry) { + registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| by_scalar::SMulByScalar4::bin_1())); + registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| by_scalar::HMulByScalar8::bin_1())); } diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 26e8dc3dba..af0f42bfc1 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -199,22 +199,18 @@ impl BinOp { } } - fn register_all_unicast(registry: &mut LinalgRegistry) { generic::register_all_unicast(registry); #[cfg(target_arch = "aarch64")] arm64::register_all_unicast(registry); - } fn register_all_by_scalar(registry: &mut LinalgRegistry) { generic::register_all_by_scalar(registry); #[cfg(target_arch = "aarch64")] arm64::register_all_by_scalar(registry); - } - pub type LinalgFn = Box TractResult<()> + Send + Sync>; type LinalgRegistry = HashMap<(BinOp, DatumType), Box LinalgFn + Send + Sync>>; lazy_static! { @@ -230,17 +226,16 @@ lazy_static! { }; } -pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option { +pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option { let map = BIN_BY_SCALAR_OPS.lock().unwrap(); - map.get(&(bin, dt)).map(|it| (it)()) + map.get(&(bin, dt)).map(|it| (it)()) } -pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option { +pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option { let map = BIN_UNICAST_OPS.lock().unwrap(); - map.get(&(bin, dt)).map(|it| (it)()) + map.get(&(bin, dt)).map(|it| (it)()) } - pub fn ops() -> &'static Ops { &OPS } diff --git a/tflite/src/ops/math.rs b/tflite/src/ops/math.rs index d92d19a4b8..0af3af8729 100644 --- a/tflite/src/ops/math.rs +++ b/tflite/src/ops/math.rs @@ -88,7 +88,7 @@ fn ser_bin( node: &TypedNode, op: &TypedBinOp, ) -> TractResult<()> { - use tract_linalg::mmm::BinOp; + use tract_linalg::BinOp; let inputs = builder.map_outlets(model, &node.inputs)?; let outputs = builder.map_outlets(model, [OutletId::from(node.id)])?; From 32a91ab7fd57b64151c392884c360043930d25b7 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Wed, 9 Oct 2024 15:37:45 -0400 Subject: [PATCH 16/32] Add more BinOp support in linalg (Add & Sub) --- linalg/src/arm64.rs | 8 ++ linalg/src/arm64/arm64fp16.rs | 14 ++- linalg/src/arm64/arm64fp16/by_scalar.rs | 70 ++++++++++++++ linalg/src/arm64/arm64fp16/mul.rs | 44 --------- linalg/src/arm64/arm64fp16/unicast.rs | 117 ++++++++++++++++++++++++ linalg/src/arm64/arm64simd.rs | 14 ++- linalg/src/arm64/arm64simd/by_scalar.rs | 66 +++++++++++++ linalg/src/arm64/arm64simd/mul.rs | 42 --------- linalg/src/arm64/arm64simd/unicast.rs | 113 +++++++++++++++++++++++ linalg/src/frame/by_scalar.rs | 12 ++- linalg/src/frame/mmm/tests/frame.rs | 2 +- linalg/src/frame/unicast/mul.rs | 26 +++--- linalg/src/generic.rs | 8 ++ linalg/src/generic/by_scalar.rs | 61 +++++++++++- linalg/src/generic/unicast.rs | 65 ++++++++++++- 15 files changed, 548 insertions(+), 114 deletions(-) delete mode 100644 linalg/src/arm64/arm64fp16/mul.rs create mode 100644 linalg/src/arm64/arm64fp16/unicast.rs delete mode 100644 linalg/src/arm64/arm64simd/mul.rs create mode 100644 linalg/src/arm64/arm64simd/unicast.rs diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index 4396d518a5..fb69166ec3 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -217,11 +217,19 @@ impl Kind { pub(crate)fn register_all_unicast(registry: &mut LinalgRegistry) { registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_unicast_mul_f32_16n::bin_1())); registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_unicast_mul_f16_32n::bin_1())); + registry.insert((BinOp::Add, DatumType::F32),Box::new(|| arm64simd_unicast_add_f32_16n::bin_1())); + registry.insert((BinOp::Add, DatumType::F16),Box::new(|| arm64fp16_unicast_add_f16_32n::bin_1())); + registry.insert((BinOp::Sub, DatumType::F32),Box::new(|| arm64simd_unicast_sub_f32_16n::bin_1())); + registry.insert((BinOp::Sub, DatumType::F16),Box::new(|| arm64fp16_unicast_sub_f16_32n::bin_1())); } pub(crate)fn register_all_by_scalar(registry: &mut LinalgRegistry) { registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_mul_by_scalar_f32_16n::bin_1())); registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_mul_by_scalar_f16_32n::bin_1())); + registry.insert((BinOp::Add, DatumType::F32),Box::new(|| arm64simd_add_by_scalar_f32_16n::bin_1())); + registry.insert((BinOp::Add, DatumType::F16),Box::new(|| arm64fp16_add_by_scalar_f16_32n::bin_1())); + registry.insert((BinOp::Sub, DatumType::F32),Box::new(|| arm64simd_sub_by_scalar_f32_16n::bin_1())); + registry.insert((BinOp::Sub, DatumType::F16),Box::new(|| arm64fp16_sub_by_scalar_f16_32n::bin_1())); } pub fn plug(ops: &mut Ops) { diff --git a/linalg/src/arm64/arm64fp16.rs b/linalg/src/arm64/arm64fp16.rs index 943981d93d..ef99502383 100644 --- a/linalg/src/arm64/arm64fp16.rs +++ b/linalg/src/arm64/arm64fp16.rs @@ -3,12 +3,20 @@ use tract_data::half::f16; mod by_scalar; mod leaky_relu; mod max; -mod mul; +mod unicast; mod sum; -pub use by_scalar::*; +pub use by_scalar::{ + arm64fp16_mul_by_scalar_f16_32n, + arm64fp16_add_by_scalar_f16_32n, + arm64fp16_sub_by_scalar_f16_32n +}; pub use leaky_relu::*; pub use max::*; -pub use mul::*; +pub use unicast::{ + arm64fp16_unicast_mul_f16_32n, + arm64fp16_unicast_add_f16_32n, + arm64fp16_unicast_sub_f16_32n +}; pub use sum::*; use crate::frame::block_quant::Q4_0; diff --git a/linalg/src/arm64/arm64fp16/by_scalar.rs b/linalg/src/arm64/arm64fp16/by_scalar.rs index f9c78ac2eb..f10c51ed6d 100644 --- a/linalg/src/arm64/arm64fp16/by_scalar.rs +++ b/linalg/src/arm64/arm64fp16/by_scalar.rs @@ -34,8 +34,78 @@ by_scalar_impl_wrap!( } ); +by_scalar_impl_wrap!( + f16, + arm64fp16_add_by_scalar_f16_32n, + 32, + 4, + f16, + fn run(buf: &mut [f16], s: f16) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(buf: &mut[f16], s: f16) { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.8h, v0.h[0] + 2: + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}] + fadd v4.8h, v4.8h, v0.8h + fadd v5.8h, v5.8h, v0.8h + fadd v6.8h, v6.8h, v0.8h + fadd v7.8h, v7.8h, v0.8h + st1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s.to_bits(), + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + unsafe { run(buf, s) } + } +); + +by_scalar_impl_wrap!( + f16, + arm64fp16_sub_by_scalar_f16_32n, + 32, + 4, + f16, + fn run(buf: &mut [f16], s: f16) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(buf: &mut[f16], s: f16) { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.8h, v0.h[0] + 2: + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}] + fsub v4.8h, v4.8h, v0.8h + fsub v5.8h, v5.8h, v0.8h + fsub v6.8h, v6.8h, v0.8h + fsub v7.8h, v7.8h, v0.8h + st1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s.to_bits(), + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + unsafe { run(buf, s) } + } +); + #[cfg(test)] mod test_arm64fp16_mul_by_scalar_f16_32n { use super::*; mul_by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_mul_by_scalar_f16_32n); + mul_by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_add_by_scalar_f16_32n); + mul_by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_sub_by_scalar_f16_32n); } diff --git a/linalg/src/arm64/arm64fp16/mul.rs b/linalg/src/arm64/arm64fp16/mul.rs deleted file mode 100644 index 50ec586d17..0000000000 --- a/linalg/src/arm64/arm64fp16/mul.rs +++ /dev/null @@ -1,44 +0,0 @@ -use tract_data::half::f16; - -unicast_impl_wrap!( - f16, - arm64fp16_unicast_mul_f16_32n, - 32, - 8, - #[inline(never)] - fn run(a: &mut [f16], b: &[f16]) { - assert!(a.len() == b.len()); - assert!(a.len() % 32 == 0); - assert!(a.len() > 0); - #[target_feature(enable = "fp16")] - unsafe fn run(a: &mut [f16], b: &[f16]) { - let len = a.len(); - let a_ptr = a.as_ptr(); - let b_ptr = b.as_ptr(); - std::arch::asm!(" - 2: - ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}] - ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64 - fmul v0.8h, v0.8h, v4.8h - fmul v1.8h, v1.8h, v5.8h - fmul v2.8h, v2.8h, v6.8h - fmul v3.8h, v3.8h, v7.8h - st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64 - subs {len}, {len}, 32 - bne 2b - ", - len = inout(reg) len => _, - a_ptr = inout(reg) a_ptr => _, - b_ptr = inout(reg) b_ptr => _, - out("v0") _, out("v1") _, out("v2") _, out("v3") _,); - } - unsafe { run(a, b) } - } -); - -#[cfg(test)] -mod test_arm64fp16_unicast_mul_f16_32n { - use super::*; - use proptest::strategy::Strategy; - crate::unicast_mul_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_mul_f16_32n); -} diff --git a/linalg/src/arm64/arm64fp16/unicast.rs b/linalg/src/arm64/arm64fp16/unicast.rs new file mode 100644 index 0000000000..40b8b995ce --- /dev/null +++ b/linalg/src/arm64/arm64fp16/unicast.rs @@ -0,0 +1,117 @@ +use tract_data::half::f16; + +unicast_impl_wrap!( + f16, + arm64fp16_unicast_mul_f16_32n, + 32, + 8, + #[inline(never)] + fn run(a: &mut [f16], b: &[f16]) { + assert!(a.len() == b.len()); + assert!(a.len() % 32 == 0); + assert!(a.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(a: &mut [f16], b: &[f16]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}] + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64 + fmul v0.8h, v0.8h, v4.8h + fmul v1.8h, v1.8h, v5.8h + fmul v2.8h, v2.8h, v6.8h + fmul v3.8h, v3.8h, v7.8h + st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + +unicast_impl_wrap!( + f16, + arm64fp16_unicast_add_f16_32n, + 32, + 8, + #[inline(never)] + fn run(a: &mut [f16], b: &[f16]) { + assert!(a.len() == b.len()); + assert!(a.len() % 32 == 0); + assert!(a.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(a: &mut [f16], b: &[f16]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}] + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64 + fadd v0.8h, v0.8h, v4.8h + fadd v1.8h, v1.8h, v5.8h + fadd v2.8h, v2.8h, v6.8h + fadd v3.8h, v3.8h, v7.8h + st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + + +unicast_impl_wrap!( + f16, + arm64fp16_unicast_sub_f16_32n, + 32, + 8, + #[inline(never)] + fn run(a: &mut [f16], b: &[f16]) { + assert!(a.len() == b.len()); + assert!(a.len() % 32 == 0); + assert!(a.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(a: &mut [f16], b: &[f16]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}] + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64 + fsub v0.8h, v0.8h, v4.8h + fsub v1.8h, v1.8h, v5.8h + fsub v2.8h, v2.8h, v6.8h + fsub v3.8h, v3.8h, v7.8h + st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + +#[cfg(test)] +mod test_arm64fp16_unicast_mul_f16_32n { + use super::*; + use proptest::strategy::Strategy; + crate::unicast_mul_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_mul_f16_32n); +} diff --git a/linalg/src/arm64/arm64simd.rs b/linalg/src/arm64/arm64simd.rs index 6e30fd8ca7..75bec6d500 100644 --- a/linalg/src/arm64/arm64simd.rs +++ b/linalg/src/arm64/arm64simd.rs @@ -1,16 +1,24 @@ mod by_scalar; mod leaky_relu; mod max; -mod mul; +mod unicast; mod softmax; mod sum; use crate::frame::PackedFormat; -pub use by_scalar::arm64simd_mul_by_scalar_f32_16n; +pub use by_scalar::{ + arm64simd_mul_by_scalar_f32_16n, + arm64simd_add_by_scalar_f32_16n, + arm64simd_sub_by_scalar_f32_16n +}; pub use leaky_relu::arm64simd_leaky_relu_f32_8n; pub use max::arm64simd_max_f32_16n; -pub use mul::arm64simd_unicast_mul_f32_16n; +pub use unicast::{ + arm64simd_unicast_mul_f32_16n, + arm64simd_unicast_add_f32_16n, + arm64simd_unicast_sub_f32_16n +}; pub use softmax::arm64simd_softmax2_fastcompact_f32_16n; pub use sum::arm64simd_sum_f32_16n; diff --git a/linalg/src/arm64/arm64simd/by_scalar.rs b/linalg/src/arm64/arm64simd/by_scalar.rs index ac6f0a06fe..db1928008a 100644 --- a/linalg/src/arm64/arm64simd/by_scalar.rs +++ b/linalg/src/arm64/arm64simd/by_scalar.rs @@ -30,8 +30,74 @@ by_scalar_impl_wrap!( } ); +by_scalar_impl_wrap!( + f32, + arm64simd_add_by_scalar_f32_16n, + 16, + 4, + f32, + fn run(buf: &mut [f32], s: f32) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.4s, v0.s[0] + 2: + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}] + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + fadd v6.4s, v6.4s, v0.4s + fadd v7.4s, v7.4s, v0.4s + st1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s, + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + } +); + +by_scalar_impl_wrap!( + f32, + arm64simd_sub_by_scalar_f32_16n, + 16, + 4, + f32, + fn run(buf: &mut [f32], s: f32) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.4s, v0.s[0] + 2: + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}] + fsub v4.4s, v4.4s, v0.4s + fsub v5.4s, v5.4s, v0.4s + fsub v6.4s, v6.4s, v0.4s + fsub v7.4s, v7.4s, v0.4s + st1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s, + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + } +); + #[cfg(test)] mod test_arm64simd_mul_by_scalar_f32_16n { use super::*; mul_by_scalar_frame_tests!(true, f32, arm64simd_mul_by_scalar_f32_16n); + mul_by_scalar_frame_tests!(true, f32, arm64simd_add_by_scalar_f32_16n); + mul_by_scalar_frame_tests!(true, f32, arm64simd_sub_by_scalar_f32_16n); } diff --git a/linalg/src/arm64/arm64simd/mul.rs b/linalg/src/arm64/arm64simd/mul.rs deleted file mode 100644 index 4de143230d..0000000000 --- a/linalg/src/arm64/arm64simd/mul.rs +++ /dev/null @@ -1,42 +0,0 @@ - -unicast_impl_wrap!( - f32, - arm64simd_unicast_mul_f32_16n, - 16, - 4, - #[inline(never)] - fn run(a: &mut [f32], b: &[f32]) { - assert!(a.len() == b.len()); - assert!(a.len() % 16 == 0); - assert!(a.len() > 0); - unsafe fn run(a: &mut [f32], b: &[f32]) { - let len = a.len(); - let a_ptr = a.as_ptr(); - let b_ptr = b.as_ptr(); - std::arch::asm!(" - 2: - ld1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}] - ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{b_ptr}], 64 - fmul v0.4s, v0.4s, v4.4s - fmul v1.4s, v1.4s, v5.4s - fmul v2.4s, v2.4s, v6.4s - fmul v3.4s, v3.4s, v7.4s - st1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}], 64 - subs {len}, {len}, 16 - bne 2b - ", - len = inout(reg) len => _, - a_ptr = inout(reg) a_ptr => _, - b_ptr = inout(reg) b_ptr => _, - out("v0") _, out("v1") _, out("v2") _, out("v3") _,); - } - unsafe { run(a, b) } - } -); - -#[cfg(test)] -mod test_arm64simd_unicast_mul_f32_16n { - use super::*; - use proptest::strategy::Strategy; - crate::unicast_mul_frame_tests!(true, f32, arm64simd_unicast_mul_f32_16n); -} diff --git a/linalg/src/arm64/arm64simd/unicast.rs b/linalg/src/arm64/arm64simd/unicast.rs new file mode 100644 index 0000000000..381e3592b5 --- /dev/null +++ b/linalg/src/arm64/arm64simd/unicast.rs @@ -0,0 +1,113 @@ + +unicast_impl_wrap!( + f32, + arm64simd_unicast_mul_f32_16n, + 16, + 4, + #[inline(never)] + fn run(a: &mut [f32], b: &[f32]) { + assert!(a.len() == b.len()); + assert!(a.len() % 16 == 0); + assert!(a.len() > 0); + unsafe fn run(a: &mut [f32], b: &[f32]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}] + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{b_ptr}], 64 + fmul v0.4s, v0.4s, v4.4s + fmul v1.4s, v1.4s, v5.4s + fmul v2.4s, v2.4s, v6.4s + fmul v3.4s, v3.4s, v7.4s + st1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + +unicast_impl_wrap!( + f32, + arm64simd_unicast_add_f32_16n, + 16, + 4, + #[inline(never)] + fn run(a: &mut [f32], b: &[f32]) { + assert!(a.len() == b.len()); + assert!(a.len() % 16 == 0); + assert!(a.len() > 0); + unsafe fn run(a: &mut [f32], b: &[f32]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}] + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{b_ptr}], 64 + fadd v0.4s, v0.4s, v4.4s + fadd v1.4s, v1.4s, v5.4s + fadd v2.4s, v2.4s, v6.4s + fadd v3.4s, v3.4s, v7.4s + st1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + + +unicast_impl_wrap!( + f32, + arm64simd_unicast_sub_f32_16n, + 16, + 4, + #[inline(never)] + fn run(a: &mut [f32], b: &[f32]) { + assert!(a.len() == b.len()); + assert!(a.len() % 16 == 0); + assert!(a.len() > 0); + unsafe fn run(a: &mut [f32], b: &[f32]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}] + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{b_ptr}], 64 + fsub v0.4s, v0.4s, v4.4s + fsub v1.4s, v1.4s, v5.4s + fsub v2.4s, v2.4s, v6.4s + fsub v3.4s, v3.4s, v7.4s + st1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + +#[cfg(test)] +mod test_arm64simd_unicast_mul_f32_16n { + use super::*; + use proptest::strategy::Strategy; + crate::unicast_mul_frame_tests!(true, f32, arm64simd_unicast_mul_f32_16n); +} diff --git a/linalg/src/frame/by_scalar.rs b/linalg/src/frame/by_scalar.rs index 3f8c0b3bea..9f42e782cd 100644 --- a/linalg/src/frame/by_scalar.rs +++ b/linalg/src/frame/by_scalar.rs @@ -67,11 +67,13 @@ pub mod test { #[macro_export] macro_rules! mul_by_scalar_frame_tests { ($cond:expr, $t: ty, $ker:ty) => { - proptest::proptest! { - #[test] - fn prop(xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) { - if $cond { - $crate::frame::by_scalar::test::test_mul_by_scalar::<$ker, $t>(&*xs, scalar).unwrap() + paste::paste! { + proptest::proptest! { + #[test] + fn [](xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) { + if $cond { + $crate::frame::by_scalar::test::test_mul_by_scalar::<$ker, $t>(&*xs, scalar).unwrap() + } } } } diff --git a/linalg/src/frame/mmm/tests/frame.rs b/linalg/src/frame/mmm/tests/frame.rs index 497090c7f8..acdf73efeb 100644 --- a/linalg/src/frame/mmm/tests/frame.rs +++ b/linalg/src/frame/mmm/tests/frame.rs @@ -1,5 +1,5 @@ use crate::frame::mmm::*; -use crate::LADatum; +use crate::{BinOp, LADatum}; use num_traits::AsPrimitive; use std::ops::Neg; use tests::display_error; diff --git a/linalg/src/frame/unicast/mul.rs b/linalg/src/frame/unicast/mul.rs index 2fdca15964..1598a26fe1 100644 --- a/linalg/src/frame/unicast/mul.rs +++ b/linalg/src/frame/unicast/mul.rs @@ -9,21 +9,23 @@ pub mod test { #[macro_export] macro_rules! unicast_mul_frame_tests { ($cond:expr, $t: ty, $ker:ty) => { - proptest::proptest! { - #[test] - fn prop( - (a, b) in (0..100_usize).prop_flat_map(|len| (vec![-25f32..25.0; len], vec![-25f32..25.0; len])) - ) { - if $cond { - $crate::frame::unicast::mul::test::test_unicast_mul::<$ker, $t>(&*a, &*b).unwrap() + paste::paste! { + proptest::proptest! { + #[test] + fn []( + (a, b) in (0..100_usize).prop_flat_map(|len| (vec![-25f32..25.0; len], vec![-25f32..25.0; len])) + ) { + if $cond { + $crate::frame::unicast::mul::test::test_unicast_mul::<$ker, $t>(&*a, &*b).unwrap() + } } } - } - #[test] - fn empty() { - if $cond { - $crate::frame::unicast::mul::test::test_unicast_mul::<$ker, $t>(&[], &[]).unwrap() + #[test] + fn []() { + if $cond { + $crate::frame::unicast::mul::test::test_unicast_mul::<$ker, $t>(&[], &[]).unwrap() + } } } }; diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index e7aee437f9..1a075caa14 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -25,9 +25,17 @@ pub use self::tanh::{HTanh8, STanh4}; pub(crate) fn register_all_unicast(registry: &mut LinalgRegistry) { registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| unicast::SUnicastMul4::bin_1())); registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| unicast::HUnicastMul8::bin_1())); + registry.insert((BinOp::Add, DatumType::F32), Box::new(|| unicast::SUnicastAdd4::bin_1())); + registry.insert((BinOp::Add, DatumType::F16), Box::new(|| unicast::HUnicastAdd8::bin_1())); + registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| unicast::SUnicastSub4::bin_1())); + registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| unicast::HUnicastSub8::bin_1())); } pub(crate) fn register_all_by_scalar(registry: &mut LinalgRegistry) { registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| by_scalar::SMulByScalar4::bin_1())); registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| by_scalar::HMulByScalar8::bin_1())); + registry.insert((BinOp::Add, DatumType::F32), Box::new(|| by_scalar::SAddByScalar4::bin_1())); + registry.insert((BinOp::Add, DatumType::F16), Box::new(|| by_scalar::HAddByScalar8::bin_1())); + registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| by_scalar::SSubByScalar4::bin_1())); + registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| by_scalar::HSubByScalar8::bin_1())); } diff --git a/linalg/src/generic/by_scalar.rs b/linalg/src/generic/by_scalar.rs index ef67b31dc5..6d5fddc622 100644 --- a/linalg/src/generic/by_scalar.rs +++ b/linalg/src/generic/by_scalar.rs @@ -13,10 +13,39 @@ by_scalar_impl_wrap!( } ); +by_scalar_impl_wrap!( + f32, + SAddByScalar4, + 4, + 4, + f32, + fn run(x: &mut [f32], s: f32) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px += s) + } +); + +by_scalar_impl_wrap!( + f32, + SSubByScalar4, + 4, + 4, + f32, + fn run(x: &mut [f32], s: f32) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px -= s) + } +); + #[cfg(test)] #[macro_use] pub mod mul_by_scalar_f32 { - mul_by_scalar_frame_tests!(true, f32, crate::generic::by_scalar::SMulByScalar4); + use super::*; + mul_by_scalar_frame_tests!(true, f32, SMulByScalar4); + mul_by_scalar_frame_tests!(true, f32, SAddByScalar4); + mul_by_scalar_frame_tests!(true, f32, SSubByScalar4); } by_scalar_impl_wrap!( @@ -32,9 +61,37 @@ by_scalar_impl_wrap!( } ); +by_scalar_impl_wrap!( + f16, + HAddByScalar8, + 8, + 8, + f16, + fn run(x: &mut [f16], s: f16) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px += s) + } +); + +by_scalar_impl_wrap!( + f16, + HSubByScalar8, + 8, + 8, + f16, + fn run(x: &mut [f16], s: f16) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px -= s) + } +); + #[cfg(test)] #[macro_use] pub mod mul_by_scalar_f16 { use super::*; - mul_by_scalar_frame_tests!(true, f16, crate::generic::by_scalar::HMulByScalar8); + mul_by_scalar_frame_tests!(true, f16, HMulByScalar8); + mul_by_scalar_frame_tests!(true, f16, HAddByScalar8); + mul_by_scalar_frame_tests!(true, f16, HSubByScalar8); } diff --git a/linalg/src/generic/unicast.rs b/linalg/src/generic/unicast.rs index ce12268cc6..439f089945 100644 --- a/linalg/src/generic/unicast.rs +++ b/linalg/src/generic/unicast.rs @@ -27,11 +27,70 @@ unicast_impl_wrap!( } ); +unicast_impl_wrap!( + f32, + SUnicastAdd4, + 4, + 4, + fn run(a: &mut [f32], b: &[f32]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a += b) + } +); + +unicast_impl_wrap!( + f16, + HUnicastAdd8, + 8, + 8, + fn run(a: &mut [f16], b: &[f16]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a += b) + } +); + +unicast_impl_wrap!( + f32, + SUnicastSub4, + 4, + 4, + fn run(a: &mut [f32], b: &[f32]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a -= b) + } +); + +unicast_impl_wrap!( + f16, + HUnicastSub8, + 8, + 8, + fn run(a: &mut [f16], b: &[f16]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a -= b) + } +); + #[cfg(test)] #[macro_use] pub mod s { + use super::*; use proptest::strategy::Strategy; - crate::unicast_mul_frame_tests!(true, f32, crate::generic::unicast::SUnicastMul4); + crate::unicast_mul_frame_tests!(true, f32, SUnicastMul4); + crate::unicast_mul_frame_tests!(true, f32, SUnicastAdd4); + crate::unicast_mul_frame_tests!(true, f32, SUnicastSub4); } #[cfg(test)] @@ -39,5 +98,7 @@ pub mod s { pub mod h { use super::*; use proptest::strategy::Strategy; - crate::unicast_mul_frame_tests!(true, f16, crate::generic::unicast::HUnicastMul8); + crate::unicast_mul_frame_tests!(true, f16, HUnicastMul8); + crate::unicast_mul_frame_tests!(true, f16, HUnicastAdd8); + crate::unicast_mul_frame_tests!(true, f16, HUnicastSub8); } From 497bb6f75a02ef36576ad516a1a1f22219cc4986 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Wed, 9 Oct 2024 17:26:57 -0400 Subject: [PATCH 17/32] Decluttering to swap operand --- core/src/ops/binary.rs | 51 +++++++++++++++++++++++++++++++--------- core/src/ops/math/mod.rs | 10 -------- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 78251f22b8..77cfe60596 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -155,7 +155,7 @@ impl TypedOp for TypedBinOp { if let AxisOp::Rm(rm) = change { let (inputs, outputs) = model.node_facts(node.id)?; if !inputs[0].shape[*rm].is_one() - || !inputs[0].shape[*rm].is_one() + || !inputs[1].shape[*rm].is_one() || !outputs[0].shape[*rm].is_one() { return Ok(None); @@ -211,7 +211,11 @@ impl TypedOp for TypedBinOp { { return Ok(Some(neutral_patch)); } - + if let Some(broadcast_patch) = + declutter_broadcasting_operand_1(model, node, self.0.clone())? + { + return Ok(Some(broadcast_patch)); + } self.0.declutter(model, node) } @@ -240,7 +244,9 @@ impl TypedOp for TypedBinOp { let dt = model.node_input_facts(node.id)?[0].datum_type().unwrap(); if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant { - let Some(func) = tract_linalg::bin_by_scalar(dt, linalg_bin_op) else {return Ok(None)}; + let Some(func) = tract_linalg::bin_by_scalar(dt, linalg_bin_op) else { + return Ok(None); + }; let eval_fn = Arc::from(func); return Ok(Some( TypedModelPatch::replace_single_op( @@ -254,7 +260,9 @@ impl TypedOp for TypedBinOp { } if unicast_should_be_efficient & can_eval_in_a & !op_is_quant { - let Some(func) = tract_linalg::bin_unicast(dt, linalg_bin_op) else {return Ok(None)}; + let Some(func) = tract_linalg::bin_unicast(dt, linalg_bin_op) else { + return Ok(None); + }; let eval_fn = Arc::from(func); return Ok(Some( TypedModelPatch::replace_single_op( @@ -273,6 +281,34 @@ impl TypedOp for TypedBinOp { as_op!(); } +fn declutter_broadcasting_operand_1( + model: &TypedModel, + node: &TypedNode, + mini_op: Box, +) -> TractResult> { + let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? { + (a.shape.clone(), b.shape.clone()) + } else { + unreachable!("TypedBinOp has two inputs.") + }; + + let a_num_elements = a_shape.iter().product::(); + let b_num_elements = b_shape.iter().product::(); + let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative(); + if a_should_be_broadcast & mini_op.is_commutative() { + let mut swap_input = node.inputs.clone(); + swap_input.swap(0, 1); + return Ok(Some(TypedModelPatch::replace_single_op( + model, + node, + &swap_input, + TypedBinOp(mini_op, None), + )?)); + } + + Ok(None) +} + fn declutter_neutral( model: &TypedModel, node: &TypedNode, @@ -280,13 +316,6 @@ fn declutter_neutral( out_dt: DatumType, ) -> TractResult> { if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? { - // Not sure to understand why this check was needed - //let integer = uniform.uni.cast_to_scalar::()?; - //let is_scalar = tensor0(integer) - // .cast_to_dt(uniform.uni.datum_type())? - // .close_enough(&uniform.uni, false) - // .is_ok(); - let is_neutral = mini_op .neutral_element() .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok()) diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index 8c2b06bb20..a53a1af82d 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -365,16 +365,6 @@ fn declutter_mul( }, )?)); } - if !uniform.left_is_uniform { - let mut swap_input = node.inputs.clone(); - swap_input.swap(0, 1); - return Ok(Some(TypedModelPatch::replace_single_op( - model, - node, - &swap_input, - mul(), - )?)); - } } } Ok(None) From e3d2e81944f0297e74d2ac0bee6290ff589e1a8c Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 10 Oct 2024 07:21:26 -0400 Subject: [PATCH 18/32] cargo clippy --- core/src/ops/binary.rs | 5 +++-- linalg/src/frame/by_scalar.rs | 2 +- linalg/src/frame/unicast/mod.rs | 22 +++++++++++++--------- linalg/src/lib.rs | 8 ++++---- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 77cfe60596..0cd10d7fe8 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -3,6 +3,7 @@ use downcast_rs::Downcast; use std::fmt::{self, Debug}; use tract_data::itertools::izip; use tract_itertools::Itertools; +use tract_linalg::LinalgFn; use super::cast::cast; @@ -397,7 +398,7 @@ pub fn gt_tdim(x: TDim, min_val: i64) -> bool { #[derive(Clone)] pub struct OptBinByScalar { pub binop: Box, - eval_fn: Arc TractResult<()> + Send + Sync>, + eval_fn: Arc, } impl Debug for OptBinByScalar { @@ -496,7 +497,7 @@ impl TypedOp for OptBinByScalar { #[derive(Clone)] pub struct OptBinUnicast { pub binop: Box, - eval_fn: Arc TractResult<()> + Send + Sync>, + eval_fn: Arc, } impl Debug for OptBinUnicast { diff --git a/linalg/src/frame/by_scalar.rs b/linalg/src/frame/by_scalar.rs index 9f42e782cd..5611b7f579 100644 --- a/linalg/src/frame/by_scalar.rs +++ b/linalg/src/frame/by_scalar.rs @@ -36,7 +36,7 @@ pub trait ByScalarKer: ElementWiseKer where T: LADatum { - fn bin_1() -> LinalgFn { + fn bin_1() -> Box { Box::new(|a: &mut TensorView, b: &TensorView| { let a_slice = a.as_slice_mut()?; let b = b.as_slice()?[0]; diff --git a/linalg/src/frame/unicast/mod.rs b/linalg/src/frame/unicast/mod.rs index 1c9a7ade7d..1dfd60cb2d 100644 --- a/linalg/src/frame/unicast/mod.rs +++ b/linalg/src/frame/unicast/mod.rs @@ -3,8 +3,8 @@ pub mod mul; use std::fmt::Debug; use std::marker::PhantomData; -use tract_data::TractResult; use tract_data::internal::TensorView; +use tract_data::TractResult; use crate::frame::element_wise_helper::TempBuffer; use crate::{LADatum, LinalgFn}; @@ -54,7 +54,6 @@ where phantom: PhantomData<(K, T)>, } - impl UnicastImpl where T: LADatum, @@ -88,7 +87,7 @@ where fn bin() -> Box> { Box::new(UnicastImpl::::new()) } - fn bin_1() -> LinalgFn { + fn bin_1() -> Box { Box::new(|a: &mut TensorView, b: &TensorView| { let a_slice = a.as_slice_mut()?; let b_slice = b.as_slice()?; @@ -101,7 +100,12 @@ std::thread_local! { static TMP: std::cell::RefCell<(TempBuffer, TempBuffer)> = std::cell::RefCell::new((TempBuffer::default(), TempBuffer::default())); } -fn create_incomplete_tile<'a, T: LADatum>(a: &'a mut [T], b: &'a [T], a_prefix_len: usize, b_prefix_len: usize) -> (&'a mut [T], &'a [T], usize) { +fn create_incomplete_tile<'a, T: LADatum>( + a: &'a mut [T], + b: &'a [T], + a_prefix_len: usize, + b_prefix_len: usize, +) -> (&'a mut [T], &'a [T], usize) { let effective_prefix = if (a_prefix_len == 0) || (b_prefix_len == 0) { // One of the two slice is aligned, the target size is the number of unaligned elements of // the other slice, the max value between the two. @@ -114,7 +118,6 @@ fn create_incomplete_tile<'a, T: LADatum>(a: &'a mut [T], b: &'a [T], a_prefix_l (&mut a[..effective_prefix], &b[..effective_prefix], effective_prefix) } - pub(crate) fn unicast_with_alignment( a: &mut [T], b: &[T], @@ -148,18 +151,19 @@ where let mut applied_prefix_len = 0; if (a_prefix_len > 0) || (b_prefix_len > 0) { // Incomplete tile needs to be created to process unaligned data. - let (mut sub_a, sub_b, applied_prefix) = create_incomplete_tile(a, b, a_prefix_len, b_prefix_len); + let (sub_a, sub_b, applied_prefix) = + create_incomplete_tile(a, b, a_prefix_len, b_prefix_len); applied_prefix_len = applied_prefix; - compute_via_temp_buffer(&mut sub_a, &sub_b); + compute_via_temp_buffer(sub_a, sub_b); num_element_processed += applied_prefix_len; } let num_complete_tiles = (a.len() - applied_prefix_len) / nr; if num_complete_tiles > 0 { // Process all tiles that are complete. - let mut sub_a = &mut a[applied_prefix_len..][..(num_complete_tiles * nr)]; + let sub_a = &mut a[applied_prefix_len..][..(num_complete_tiles * nr)]; let sub_b = &b[applied_prefix_len..][..(num_complete_tiles * nr)]; - f(&mut sub_a, &sub_b); + f(sub_a, sub_b); num_element_processed += num_complete_tiles * nr; } diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index af0f42bfc1..39195593ae 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -211,8 +211,8 @@ fn register_all_by_scalar(registry: &mut LinalgRegistry) { arm64::register_all_by_scalar(registry); } -pub type LinalgFn = Box TractResult<()> + Send + Sync>; -type LinalgRegistry = HashMap<(BinOp, DatumType), Box LinalgFn + Send + Sync>>; +pub type LinalgFn = dyn Fn(&mut TensorView, &TensorView) -> TractResult<()> + Send + Sync; +type LinalgRegistry = HashMap<(BinOp, DatumType), Box Box + Send + Sync>>; lazy_static! { static ref BIN_UNICAST_OPS: Mutex = { let mut registry = HashMap::default(); @@ -226,12 +226,12 @@ lazy_static! { }; } -pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option { +pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option> { let map = BIN_BY_SCALAR_OPS.lock().unwrap(); map.get(&(bin, dt)).map(|it| (it)()) } -pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option { +pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option> { let map = BIN_UNICAST_OPS.lock().unwrap(); map.get(&(bin, dt)).map(|it| (it)()) } From ff686fc871a08ab9759da04bf240e4b21fe5eb40 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 10 Oct 2024 07:24:37 -0400 Subject: [PATCH 19/32] Fix compilation x86 --- linalg/src/frame/unicast/mul.rs | 4 ++-- linalg/src/x86_64_fma/by_scalar.rs | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/linalg/src/frame/unicast/mul.rs b/linalg/src/frame/unicast/mul.rs index 1598a26fe1..9e5958de3c 100644 --- a/linalg/src/frame/unicast/mul.rs +++ b/linalg/src/frame/unicast/mul.rs @@ -12,7 +12,7 @@ pub mod test { paste::paste! { proptest::proptest! { #[test] - fn []( + fn []( (a, b) in (0..100_usize).prop_flat_map(|len| (vec![-25f32..25.0; len], vec![-25f32..25.0; len])) ) { if $cond { @@ -22,7 +22,7 @@ pub mod test { } #[test] - fn []() { + fn []() { if $cond { $crate::frame::unicast::mul::test::test_unicast_mul::<$ker, $t>(&[], &[]).unwrap() } diff --git a/linalg/src/x86_64_fma/by_scalar.rs b/linalg/src/x86_64_fma/by_scalar.rs index c2e7c9abda..ca461ab981 100644 --- a/linalg/src/x86_64_fma/by_scalar.rs +++ b/linalg/src/x86_64_fma/by_scalar.rs @@ -44,9 +44,10 @@ unsafe fn x86_64_avx_f32_mul_by_scalar_32n_run(buf: &mut [f32], scalar: f32) { #[cfg(test)] #[macro_use] pub mod test_x86_64_avx_f32_mul_by_scalar_32n { + use super::*; mul_by_scalar_frame_tests!( is_x86_feature_detected!("avx2"), f32, - crate::x86_64_fma::by_scalar::x86_64_avx_f32_mul_by_scalar_32n + x86_64_avx_f32_mul_by_scalar_32n ); } From ab407697372a7d89bf5410f99315f081e5a0198b Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 10 Oct 2024 11:45:02 -0400 Subject: [PATCH 20/32] Fix linalg tests --- linalg/src/arm64/arm64fp16/by_scalar.rs | 6 ++-- linalg/src/arm64/arm64fp16/unicast.rs | 4 ++- linalg/src/arm64/arm64simd/by_scalar.rs | 6 ++-- linalg/src/arm64/arm64simd/unicast.rs | 4 ++- linalg/src/frame/by_scalar.rs | 11 ++++--- linalg/src/frame/unicast/mod.rs | 39 ++++++++++++++++++++++ linalg/src/frame/unicast/mul.rs | 44 ------------------------- linalg/src/generic/by_scalar.rs | 12 +++---- linalg/src/generic/unicast.rs | 12 +++---- linalg/src/x86_64_fma/by_scalar.rs | 3 +- 10 files changed, 71 insertions(+), 70 deletions(-) diff --git a/linalg/src/arm64/arm64fp16/by_scalar.rs b/linalg/src/arm64/arm64fp16/by_scalar.rs index f10c51ed6d..36c1e0c4bc 100644 --- a/linalg/src/arm64/arm64fp16/by_scalar.rs +++ b/linalg/src/arm64/arm64fp16/by_scalar.rs @@ -105,7 +105,7 @@ by_scalar_impl_wrap!( #[cfg(test)] mod test_arm64fp16_mul_by_scalar_f16_32n { use super::*; - mul_by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_mul_by_scalar_f16_32n); - mul_by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_add_by_scalar_f16_32n); - mul_by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_sub_by_scalar_f16_32n); + by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_mul_by_scalar_f16_32n, |a, b| a * b); + by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_add_by_scalar_f16_32n, |a, b| a + b); + by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_sub_by_scalar_f16_32n, |a, b| a - b); } diff --git a/linalg/src/arm64/arm64fp16/unicast.rs b/linalg/src/arm64/arm64fp16/unicast.rs index 40b8b995ce..c266bb7b95 100644 --- a/linalg/src/arm64/arm64fp16/unicast.rs +++ b/linalg/src/arm64/arm64fp16/unicast.rs @@ -113,5 +113,7 @@ unicast_impl_wrap!( mod test_arm64fp16_unicast_mul_f16_32n { use super::*; use proptest::strategy::Strategy; - crate::unicast_mul_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_mul_f16_32n); + crate::unicast_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_mul_f16_32n, |a, b| a * b); + crate::unicast_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_add_f16_32n, |a, b| a + b); + crate::unicast_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_sub_f16_32n, |a, b| a - b); } diff --git a/linalg/src/arm64/arm64simd/by_scalar.rs b/linalg/src/arm64/arm64simd/by_scalar.rs index db1928008a..000745e5ea 100644 --- a/linalg/src/arm64/arm64simd/by_scalar.rs +++ b/linalg/src/arm64/arm64simd/by_scalar.rs @@ -97,7 +97,7 @@ by_scalar_impl_wrap!( #[cfg(test)] mod test_arm64simd_mul_by_scalar_f32_16n { use super::*; - mul_by_scalar_frame_tests!(true, f32, arm64simd_mul_by_scalar_f32_16n); - mul_by_scalar_frame_tests!(true, f32, arm64simd_add_by_scalar_f32_16n); - mul_by_scalar_frame_tests!(true, f32, arm64simd_sub_by_scalar_f32_16n); + by_scalar_frame_tests!(true, f32, arm64simd_mul_by_scalar_f32_16n, |a, b| a * b); + by_scalar_frame_tests!(true, f32, arm64simd_add_by_scalar_f32_16n, |a, b| a + b); + by_scalar_frame_tests!(true, f32, arm64simd_sub_by_scalar_f32_16n, |a, b| a - b); } diff --git a/linalg/src/arm64/arm64simd/unicast.rs b/linalg/src/arm64/arm64simd/unicast.rs index 381e3592b5..d225496f57 100644 --- a/linalg/src/arm64/arm64simd/unicast.rs +++ b/linalg/src/arm64/arm64simd/unicast.rs @@ -109,5 +109,7 @@ unicast_impl_wrap!( mod test_arm64simd_unicast_mul_f32_16n { use super::*; use proptest::strategy::Strategy; - crate::unicast_mul_frame_tests!(true, f32, arm64simd_unicast_mul_f32_16n); + crate::unicast_frame_tests!(true, f32, arm64simd_unicast_mul_f32_16n, |a, b| a * b); + crate::unicast_frame_tests!(true, f32, arm64simd_unicast_add_f32_16n, |a, b| a + b); + crate::unicast_frame_tests!(true, f32, arm64simd_unicast_sub_f32_16n, |a, b| a - b); } diff --git a/linalg/src/frame/by_scalar.rs b/linalg/src/frame/by_scalar.rs index 5611b7f579..50dcf9a45e 100644 --- a/linalg/src/frame/by_scalar.rs +++ b/linalg/src/frame/by_scalar.rs @@ -65,14 +65,14 @@ pub mod test { use proptest::test_runner::TestCaseResult; #[macro_export] - macro_rules! mul_by_scalar_frame_tests { - ($cond:expr, $t: ty, $ker:ty) => { + macro_rules! by_scalar_frame_tests { + ($cond:expr, $t: ty, $ker:ty, $func:expr) => { paste::paste! { proptest::proptest! { #[test] fn [](xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) { if $cond { - $crate::frame::by_scalar::test::test_mul_by_scalar::<$ker, $t>(&*xs, scalar).unwrap() + $crate::frame::by_scalar::test::test_by_scalar::<$ker, $t>(&*xs, scalar, $func).unwrap() } } } @@ -80,9 +80,10 @@ pub mod test { }; } - pub fn test_mul_by_scalar, T: LADatum + Float>( + pub fn test_by_scalar, T: LADatum + Float>( values: &[f32], scalar: f32, + func: impl Fn(T, T) -> T, ) -> TestCaseResult where f32: AsPrimitive, @@ -92,7 +93,7 @@ pub mod test { let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); crate::frame::element_wise::test::test_element_wise_params::( &values, - |a| a * scalar.as_(), + |a| (func)(a, scalar.as_()), scalar.as_(), ) } diff --git a/linalg/src/frame/unicast/mod.rs b/linalg/src/frame/unicast/mod.rs index 1dfd60cb2d..2c54bc49ca 100644 --- a/linalg/src/frame/unicast/mod.rs +++ b/linalg/src/frame/unicast/mod.rs @@ -180,11 +180,13 @@ where } #[cfg(test)] +#[macro_use] pub mod test { use super::*; use crate::LADatum; use proptest::test_runner::{TestCaseError, TestCaseResult}; use tract_data::internal::*; + use tract_num_traits::{AsPrimitive, Float}; pub fn test_unicast, T: LADatum>( a: &[T], @@ -201,4 +203,41 @@ pub mod test { .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; Ok(()) } + + pub fn test_unicast_t, T: LADatum + Float>(a: &[f32], b: &[f32], func: impl Fn(T, T) -> T) -> TestCaseResult + where + f32: AsPrimitive, + T: AsPrimitive, + { + crate::setup_test_logger(); + let a: Vec = a.iter().copied().map(|x| x.as_()).collect(); + let b: Vec = b.iter().copied().map(|x| x.as_()).collect(); + crate::frame::unicast::test::test_unicast::(&a, &b, func) + } + + #[macro_export] + macro_rules! unicast_frame_tests { + ($cond:expr, $t: ty, $ker:ty, $func:expr) => { + paste::paste! { + proptest::proptest! { + #[test] + fn []( + (a, b) in (0..100_usize).prop_flat_map(|len| (vec![-25f32..25.0; len], vec![-25f32..25.0; len])) + ) { + if $cond { + $crate::frame::unicast::test::test_unicast_t::<$ker, $t>(&*a, &*b, $func).unwrap() + } + } + } + + #[test] + fn []() { + if $cond { + $crate::frame::unicast::test::test_unicast_t::<$ker, $t>(&[], &[], $func).unwrap() + } + } + } + }; + } + } diff --git a/linalg/src/frame/unicast/mul.rs b/linalg/src/frame/unicast/mul.rs index 9e5958de3c..e69de29bb2 100644 --- a/linalg/src/frame/unicast/mul.rs +++ b/linalg/src/frame/unicast/mul.rs @@ -1,44 +0,0 @@ -#[cfg(test)] -#[macro_use] -pub mod test { - use crate::frame::unicast::UnicastKer; - use crate::LADatum; - use num_traits::{AsPrimitive, Float}; - use proptest::test_runner::TestCaseResult; - - #[macro_export] - macro_rules! unicast_mul_frame_tests { - ($cond:expr, $t: ty, $ker:ty) => { - paste::paste! { - proptest::proptest! { - #[test] - fn []( - (a, b) in (0..100_usize).prop_flat_map(|len| (vec![-25f32..25.0; len], vec![-25f32..25.0; len])) - ) { - if $cond { - $crate::frame::unicast::mul::test::test_unicast_mul::<$ker, $t>(&*a, &*b).unwrap() - } - } - } - - #[test] - fn []() { - if $cond { - $crate::frame::unicast::mul::test::test_unicast_mul::<$ker, $t>(&[], &[]).unwrap() - } - } - } - }; - } - - pub fn test_unicast_mul, T: LADatum + Float>(a: &[f32], b: &[f32]) -> TestCaseResult - where - f32: AsPrimitive, - T: AsPrimitive, - { - crate::setup_test_logger(); - let a: Vec = a.iter().copied().map(|x| x.as_()).collect(); - let b: Vec = b.iter().copied().map(|x| x.as_()).collect(); - crate::frame::unicast::test::test_unicast::(&a, &b, |a, b| a * b) - } -} diff --git a/linalg/src/generic/by_scalar.rs b/linalg/src/generic/by_scalar.rs index 6d5fddc622..743ec0b8b7 100644 --- a/linalg/src/generic/by_scalar.rs +++ b/linalg/src/generic/by_scalar.rs @@ -43,9 +43,9 @@ by_scalar_impl_wrap!( #[macro_use] pub mod mul_by_scalar_f32 { use super::*; - mul_by_scalar_frame_tests!(true, f32, SMulByScalar4); - mul_by_scalar_frame_tests!(true, f32, SAddByScalar4); - mul_by_scalar_frame_tests!(true, f32, SSubByScalar4); + by_scalar_frame_tests!(true, f32, SMulByScalar4, |a, b| a * b); + by_scalar_frame_tests!(true, f32, SAddByScalar4, |a, b| a + b ); + by_scalar_frame_tests!(true, f32, SSubByScalar4, |a, b| a - b); } by_scalar_impl_wrap!( @@ -91,7 +91,7 @@ by_scalar_impl_wrap!( #[macro_use] pub mod mul_by_scalar_f16 { use super::*; - mul_by_scalar_frame_tests!(true, f16, HMulByScalar8); - mul_by_scalar_frame_tests!(true, f16, HAddByScalar8); - mul_by_scalar_frame_tests!(true, f16, HSubByScalar8); + by_scalar_frame_tests!(true, f16, HMulByScalar8, |a, b| a * b); + by_scalar_frame_tests!(true, f16, HAddByScalar8, |a, b| a + b); + by_scalar_frame_tests!(true, f16, HSubByScalar8, |a, b| a - b); } diff --git a/linalg/src/generic/unicast.rs b/linalg/src/generic/unicast.rs index 439f089945..6e93161eec 100644 --- a/linalg/src/generic/unicast.rs +++ b/linalg/src/generic/unicast.rs @@ -88,9 +88,9 @@ unicast_impl_wrap!( pub mod s { use super::*; use proptest::strategy::Strategy; - crate::unicast_mul_frame_tests!(true, f32, SUnicastMul4); - crate::unicast_mul_frame_tests!(true, f32, SUnicastAdd4); - crate::unicast_mul_frame_tests!(true, f32, SUnicastSub4); + crate::unicast_frame_tests!(true, f32, SUnicastMul4, |a, b| a * b); + crate::unicast_frame_tests!(true, f32, SUnicastAdd4, |a, b| a + b); + crate::unicast_frame_tests!(true, f32, SUnicastSub4, |a, b| a - b); } #[cfg(test)] @@ -98,7 +98,7 @@ pub mod s { pub mod h { use super::*; use proptest::strategy::Strategy; - crate::unicast_mul_frame_tests!(true, f16, HUnicastMul8); - crate::unicast_mul_frame_tests!(true, f16, HUnicastAdd8); - crate::unicast_mul_frame_tests!(true, f16, HUnicastSub8); + crate::unicast_frame_tests!(true, f16, HUnicastMul8, |a, b| a * b); + crate::unicast_frame_tests!(true, f16, HUnicastAdd8, |a, b| a + b); + crate::unicast_frame_tests!(true, f16, HUnicastSub8, |a, b| a - b); } diff --git a/linalg/src/x86_64_fma/by_scalar.rs b/linalg/src/x86_64_fma/by_scalar.rs index ca461ab981..1764372b1d 100644 --- a/linalg/src/x86_64_fma/by_scalar.rs +++ b/linalg/src/x86_64_fma/by_scalar.rs @@ -45,9 +45,10 @@ unsafe fn x86_64_avx_f32_mul_by_scalar_32n_run(buf: &mut [f32], scalar: f32) { #[macro_use] pub mod test_x86_64_avx_f32_mul_by_scalar_32n { use super::*; - mul_by_scalar_frame_tests!( + by_scalar_frame_tests!( is_x86_feature_detected!("avx2"), f32, x86_64_avx_f32_mul_by_scalar_32n + |a, b| a * b ); } From 43e8d5614f6c569a37c02f6ac3b21edc0ac594c2 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 10 Oct 2024 11:49:14 -0400 Subject: [PATCH 21/32] File renaming --- linalg/src/frame/{unicast/mod.rs => unicast.rs} | 2 -- linalg/src/frame/unicast/mul.rs | 0 2 files changed, 2 deletions(-) rename linalg/src/frame/{unicast/mod.rs => unicast.rs} (99%) delete mode 100644 linalg/src/frame/unicast/mul.rs diff --git a/linalg/src/frame/unicast/mod.rs b/linalg/src/frame/unicast.rs similarity index 99% rename from linalg/src/frame/unicast/mod.rs rename to linalg/src/frame/unicast.rs index 2c54bc49ca..ba6c54c2b1 100644 --- a/linalg/src/frame/unicast/mod.rs +++ b/linalg/src/frame/unicast.rs @@ -1,5 +1,3 @@ -pub mod mul; - use std::fmt::Debug; use std::marker::PhantomData; diff --git a/linalg/src/frame/unicast/mul.rs b/linalg/src/frame/unicast/mul.rs deleted file mode 100644 index e69de29bb2..0000000000 From 6de56ca3b5e1466d1420f816212818cc28552b06 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 10 Oct 2024 12:07:10 -0400 Subject: [PATCH 22/32] Fix typo --- linalg/src/x86_64_fma/by_scalar.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linalg/src/x86_64_fma/by_scalar.rs b/linalg/src/x86_64_fma/by_scalar.rs index 1764372b1d..cd3322e2a1 100644 --- a/linalg/src/x86_64_fma/by_scalar.rs +++ b/linalg/src/x86_64_fma/by_scalar.rs @@ -48,7 +48,7 @@ pub mod test_x86_64_avx_f32_mul_by_scalar_32n { by_scalar_frame_tests!( is_x86_feature_detected!("avx2"), f32, - x86_64_avx_f32_mul_by_scalar_32n + x86_64_avx_f32_mul_by_scalar_32n, |a, b| a * b ); } From c660be8b298562a36089081b775f80742350a700 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 10 Oct 2024 12:13:22 -0400 Subject: [PATCH 23/32] Avoid axes swap for Scale --- core/src/ops/quant.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/ops/quant.rs b/core/src/ops/quant.rs index c2cac1fc1b..c78ef20b5c 100644 --- a/core/src/ops/quant.rs +++ b/core/src/ops/quant.rs @@ -319,6 +319,10 @@ impl crate::ops::binary::BinMiniOp for Scale { Ok(()) } + fn is_commutative(&self) -> bool { + false + } + fn declutter( &self, model: &TypedModel, From 7c1ec5b3a02f26f3781f5fe6e4a77b711bac8270 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 10 Oct 2024 12:43:22 -0400 Subject: [PATCH 24/32] Remove tmp bin_1 method --- linalg/src/arm64.rs | 26 ++++++++++++-------------- linalg/src/frame/by_scalar.rs | 2 +- linalg/src/frame/unicast.rs | 7 ++----- linalg/src/generic.rs | 24 ++++++++++++------------ linalg/src/lib.rs | 7 +------ 5 files changed, 28 insertions(+), 38 deletions(-) diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index fb69166ec3..3762ec92cf 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -215,21 +215,21 @@ impl Kind { } pub(crate)fn register_all_unicast(registry: &mut LinalgRegistry) { - registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_unicast_mul_f32_16n::bin_1())); - registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_unicast_mul_f16_32n::bin_1())); - registry.insert((BinOp::Add, DatumType::F32),Box::new(|| arm64simd_unicast_add_f32_16n::bin_1())); - registry.insert((BinOp::Add, DatumType::F16),Box::new(|| arm64fp16_unicast_add_f16_32n::bin_1())); - registry.insert((BinOp::Sub, DatumType::F32),Box::new(|| arm64simd_unicast_sub_f32_16n::bin_1())); - registry.insert((BinOp::Sub, DatumType::F16),Box::new(|| arm64fp16_unicast_sub_f16_32n::bin_1())); + registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_unicast_mul_f32_16n::bin())); + registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_unicast_mul_f16_32n::bin())); + registry.insert((BinOp::Add, DatumType::F32),Box::new(|| arm64simd_unicast_add_f32_16n::bin())); + registry.insert((BinOp::Add, DatumType::F16),Box::new(|| arm64fp16_unicast_add_f16_32n::bin())); + registry.insert((BinOp::Sub, DatumType::F32),Box::new(|| arm64simd_unicast_sub_f32_16n::bin())); + registry.insert((BinOp::Sub, DatumType::F16),Box::new(|| arm64fp16_unicast_sub_f16_32n::bin())); } pub(crate)fn register_all_by_scalar(registry: &mut LinalgRegistry) { - registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_mul_by_scalar_f32_16n::bin_1())); - registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_mul_by_scalar_f16_32n::bin_1())); - registry.insert((BinOp::Add, DatumType::F32),Box::new(|| arm64simd_add_by_scalar_f32_16n::bin_1())); - registry.insert((BinOp::Add, DatumType::F16),Box::new(|| arm64fp16_add_by_scalar_f16_32n::bin_1())); - registry.insert((BinOp::Sub, DatumType::F32),Box::new(|| arm64simd_sub_by_scalar_f32_16n::bin_1())); - registry.insert((BinOp::Sub, DatumType::F16),Box::new(|| arm64fp16_sub_by_scalar_f16_32n::bin_1())); + registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_mul_by_scalar_f32_16n::bin())); + registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_mul_by_scalar_f16_32n::bin())); + registry.insert((BinOp::Add, DatumType::F32),Box::new(|| arm64simd_add_by_scalar_f32_16n::bin())); + registry.insert((BinOp::Add, DatumType::F16),Box::new(|| arm64fp16_add_by_scalar_f16_32n::bin())); + registry.insert((BinOp::Sub, DatumType::F32),Box::new(|| arm64simd_sub_by_scalar_f32_16n::bin())); + registry.insert((BinOp::Sub, DatumType::F16),Box::new(|| arm64fp16_sub_by_scalar_f16_32n::bin())); } pub fn plug(ops: &mut Ops) { @@ -326,7 +326,6 @@ pub fn plug(ops: &mut Ops) { ops.tanh_f32 = Box::new(|| arm64simd_tanh_f32_4n::ew()); ops.max_f32 = Box::new(|| arm64simd_max_f32_16n::red()); ops.sum_f32 = Box::new(|| arm64simd_sum_f32_16n::red()); - ops.unicast_mul_f32 = Box::new(|| arm64simd_unicast_mul_f32_16n::bin()); ops.mul_by_scalar_f32 = Box::new(|| arm64simd_mul_by_scalar_f32_16n::ew()); ops.softmax2_fastcompact_f32 = Box::new(|| arm64simd_softmax2_fastcompact_f32_16n::red()); #[cfg(not(feature = "no_fp16"))] @@ -337,7 +336,6 @@ pub fn plug(ops: &mut Ops) { ops.sigmoid_f16 = Box::new(|| arm64fp16_sigmoid_f16_8n::ew()); ops.max_f16 = Box::new(|| arm64fp16_max_f16_32n::red()); ops.sum_f16 = Box::new(|| arm64fp16_sum_f16_32n::red()); - ops.unicast_mul_f16 = Box::new(|| arm64fp16_unicast_mul_f16_32n::bin()); ops.mul_by_scalar_f16 = Box::new(|| arm64fp16_mul_by_scalar_f16_32n::ew()); } else { log::info!("No native fp16 support"); diff --git a/linalg/src/frame/by_scalar.rs b/linalg/src/frame/by_scalar.rs index 50dcf9a45e..ad01cfc216 100644 --- a/linalg/src/frame/by_scalar.rs +++ b/linalg/src/frame/by_scalar.rs @@ -36,7 +36,7 @@ pub trait ByScalarKer: ElementWiseKer where T: LADatum { - fn bin_1() -> Box { + fn bin() -> Box { Box::new(|a: &mut TensorView, b: &TensorView| { let a_slice = a.as_slice_mut()?; let b = b.as_slice()?[0]; diff --git a/linalg/src/frame/unicast.rs b/linalg/src/frame/unicast.rs index ba6c54c2b1..c134df42f6 100644 --- a/linalg/src/frame/unicast.rs +++ b/linalg/src/frame/unicast.rs @@ -82,14 +82,11 @@ where fn alignment_items() -> usize; fn nr() -> usize; fn run(a: &mut [T], b: &[T]); - fn bin() -> Box> { - Box::new(UnicastImpl::::new()) - } - fn bin_1() -> Box { + fn bin() -> Box { Box::new(|a: &mut TensorView, b: &TensorView| { let a_slice = a.as_slice_mut()?; let b_slice = b.as_slice()?; - Self::bin().run(a_slice, b_slice) + UnicastImpl::::new().run(a_slice, b_slice) }) } } diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index 1a075caa14..5b7d158797 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -23,19 +23,19 @@ pub use self::sigmoid::{HSigmoid8, SSigmoid4}; pub use self::tanh::{HTanh8, STanh4}; pub(crate) fn register_all_unicast(registry: &mut LinalgRegistry) { - registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| unicast::SUnicastMul4::bin_1())); - registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| unicast::HUnicastMul8::bin_1())); - registry.insert((BinOp::Add, DatumType::F32), Box::new(|| unicast::SUnicastAdd4::bin_1())); - registry.insert((BinOp::Add, DatumType::F16), Box::new(|| unicast::HUnicastAdd8::bin_1())); - registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| unicast::SUnicastSub4::bin_1())); - registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| unicast::HUnicastSub8::bin_1())); + registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| unicast::SUnicastMul4::bin())); + registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| unicast::HUnicastMul8::bin())); + registry.insert((BinOp::Add, DatumType::F32), Box::new(|| unicast::SUnicastAdd4::bin())); + registry.insert((BinOp::Add, DatumType::F16), Box::new(|| unicast::HUnicastAdd8::bin())); + registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| unicast::SUnicastSub4::bin())); + registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| unicast::HUnicastSub8::bin())); } pub(crate) fn register_all_by_scalar(registry: &mut LinalgRegistry) { - registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| by_scalar::SMulByScalar4::bin_1())); - registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| by_scalar::HMulByScalar8::bin_1())); - registry.insert((BinOp::Add, DatumType::F32), Box::new(|| by_scalar::SAddByScalar4::bin_1())); - registry.insert((BinOp::Add, DatumType::F16), Box::new(|| by_scalar::HAddByScalar8::bin_1())); - registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| by_scalar::SSubByScalar4::bin_1())); - registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| by_scalar::HSubByScalar8::bin_1())); + registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| by_scalar::SMulByScalar4::bin())); + registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| by_scalar::HMulByScalar8::bin())); + registry.insert((BinOp::Add, DatumType::F32), Box::new(|| by_scalar::SAddByScalar4::bin())); + registry.insert((BinOp::Add, DatumType::F16), Box::new(|| by_scalar::HAddByScalar8::bin())); + registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| by_scalar::SSubByScalar4::bin())); + registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| by_scalar::HSubByScalar8::bin())); } diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 39195593ae..f66db9e117 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -25,7 +25,7 @@ use frame::by_scalar::ByScalarKer; use frame::element_wise::ElementWiseKer; use frame::reduce::{MapReduceKer, ReduceKer}; use frame::unicast::UnicastKer; -use frame::{reduce, unicast, MatMatMul}; +use frame::{reduce, MatMatMul}; pub use generic::{ScaleShiftAndRound, Scaler}; use lazy_static::lazy_static; use tract_data::internal::TensorView; @@ -87,9 +87,6 @@ pub struct Ops { pub sum_f16: Box Box> + Send + Sync>, pub sum_f32: Box Box> + Send + Sync>, - pub unicast_mul_f16: Box Box> + Send + Sync>, - pub unicast_mul_f32: Box Box> + Send + Sync>, - pub softmax2_fastcompact_f16: Box Box> + Send + Sync>, pub softmax2_fastcompact_f32: @@ -147,8 +144,6 @@ pub fn generic() -> Ops { max_f32: Box::new(|| generic::reduce::max::SMax4::red()), sum_f16: Box::new(|| generic::reduce::sum::HSum8::red()), sum_f32: Box::new(|| generic::reduce::sum::SSum4::red()), - unicast_mul_f16: Box::new(|| generic::unicast::HUnicastMul8::bin()), - unicast_mul_f32: Box::new(|| generic::unicast::SUnicastMul4::bin()), /* activation_f32: Box::new(|microcode| generic::SActivation::new(microcode)) */ From 00d977448fbd5ec9ef4a36b6f27d28cdec1710ba Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Thu, 10 Oct 2024 15:53:22 -0400 Subject: [PATCH 25/32] Add remaining BinOp kernels (Min, Max, SubF) --- linalg/src/arm64.rs | 12 +++ linalg/src/arm64/arm64fp16.rs | 12 +-- linalg/src/arm64/arm64fp16/by_scalar.rs | 105 ++++++++++++++++++++++ linalg/src/arm64/arm64fp16/unicast.rs | 111 ++++++++++++++++++++++++ linalg/src/arm64/arm64simd.rs | 12 +-- linalg/src/arm64/arm64simd/by_scalar.rs | 99 +++++++++++++++++++++ linalg/src/arm64/arm64simd/unicast.rs | 108 +++++++++++++++++++++++ linalg/src/generic.rs | 12 +++ linalg/src/generic/by_scalar.rs | 84 ++++++++++++++++++ linalg/src/generic/unicast.rs | 90 +++++++++++++++++++ 10 files changed, 625 insertions(+), 20 deletions(-) diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index 3762ec92cf..d9419dae8f 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -221,6 +221,12 @@ pub(crate)fn register_all_unicast(registry: &mut LinalgRegistry) { registry.insert((BinOp::Add, DatumType::F16),Box::new(|| arm64fp16_unicast_add_f16_32n::bin())); registry.insert((BinOp::Sub, DatumType::F32),Box::new(|| arm64simd_unicast_sub_f32_16n::bin())); registry.insert((BinOp::Sub, DatumType::F16),Box::new(|| arm64fp16_unicast_sub_f16_32n::bin())); + registry.insert((BinOp::SubF, DatumType::F32),Box::new(|| arm64simd_unicast_subf_f32_16n::bin())); + registry.insert((BinOp::SubF, DatumType::F16),Box::new(|| arm64fp16_unicast_subf_f16_32n::bin())); + registry.insert((BinOp::Min, DatumType::F32),Box::new(|| arm64simd_unicast_min_f32_16n::bin())); + registry.insert((BinOp::Min, DatumType::F16),Box::new(|| arm64fp16_unicast_min_f16_32n::bin())); + registry.insert((BinOp::Max, DatumType::F32),Box::new(|| arm64simd_unicast_max_f32_16n::bin())); + registry.insert((BinOp::Max, DatumType::F16),Box::new(|| arm64fp16_unicast_max_f16_32n::bin())); } pub(crate)fn register_all_by_scalar(registry: &mut LinalgRegistry) { @@ -230,6 +236,12 @@ pub(crate)fn register_all_by_scalar(registry: &mut LinalgRegistry) { registry.insert((BinOp::Add, DatumType::F16),Box::new(|| arm64fp16_add_by_scalar_f16_32n::bin())); registry.insert((BinOp::Sub, DatumType::F32),Box::new(|| arm64simd_sub_by_scalar_f32_16n::bin())); registry.insert((BinOp::Sub, DatumType::F16),Box::new(|| arm64fp16_sub_by_scalar_f16_32n::bin())); + registry.insert((BinOp::SubF, DatumType::F32),Box::new(|| arm64simd_subf_by_scalar_f32_16n::bin())); + registry.insert((BinOp::SubF, DatumType::F16),Box::new(|| arm64fp16_subf_by_scalar_f16_32n::bin())); + registry.insert((BinOp::Min, DatumType::F32),Box::new(|| arm64simd_min_by_scalar_f32_16n::bin())); + registry.insert((BinOp::Min, DatumType::F16),Box::new(|| arm64fp16_min_by_scalar_f16_32n::bin())); + registry.insert((BinOp::Max, DatumType::F32),Box::new(|| arm64simd_max_by_scalar_f32_16n::bin())); + registry.insert((BinOp::Max, DatumType::F16),Box::new(|| arm64fp16_max_by_scalar_f16_32n::bin())); } pub fn plug(ops: &mut Ops) { diff --git a/linalg/src/arm64/arm64fp16.rs b/linalg/src/arm64/arm64fp16.rs index ef99502383..59ff8ded80 100644 --- a/linalg/src/arm64/arm64fp16.rs +++ b/linalg/src/arm64/arm64fp16.rs @@ -5,18 +5,10 @@ mod leaky_relu; mod max; mod unicast; mod sum; -pub use by_scalar::{ - arm64fp16_mul_by_scalar_f16_32n, - arm64fp16_add_by_scalar_f16_32n, - arm64fp16_sub_by_scalar_f16_32n -}; +pub use by_scalar::*; pub use leaky_relu::*; pub use max::*; -pub use unicast::{ - arm64fp16_unicast_mul_f16_32n, - arm64fp16_unicast_add_f16_32n, - arm64fp16_unicast_sub_f16_32n -}; +pub use unicast::*; pub use sum::*; use crate::frame::block_quant::Q4_0; diff --git a/linalg/src/arm64/arm64fp16/by_scalar.rs b/linalg/src/arm64/arm64fp16/by_scalar.rs index 36c1e0c4bc..f7156dd27d 100644 --- a/linalg/src/arm64/arm64fp16/by_scalar.rs +++ b/linalg/src/arm64/arm64fp16/by_scalar.rs @@ -102,10 +102,115 @@ by_scalar_impl_wrap!( } ); +by_scalar_impl_wrap!( + f16, + arm64fp16_subf_by_scalar_f16_32n, + 32, + 4, + f16, + fn run(buf: &mut [f16], s: f16) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(buf: &mut[f16], s: f16) { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.8h, v0.h[0] + 2: + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}] + fsub v4.8h, v0.8h, v4.8h + fsub v5.8h, v0.8h, v5.8h + fsub v6.8h, v0.8h, v6.8h + fsub v7.8h, v0.8h, v7.8h + st1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s.to_bits(), + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + unsafe { run(buf, s) } + } +); + +by_scalar_impl_wrap!( + f16, + arm64fp16_min_by_scalar_f16_32n, + 32, + 4, + f16, + fn run(buf: &mut [f16], s: f16) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(buf: &mut[f16], s: f16) { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.8h, v0.h[0] + 2: + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}] + fmin v4.8h, v4.8h, v0.8h + fmin v5.8h, v5.8h, v0.8h + fmin v6.8h, v6.8h, v0.8h + fmin v7.8h, v7.8h, v0.8h + st1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s.to_bits(), + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + unsafe { run(buf, s) } + } +); + +by_scalar_impl_wrap!( + f16, + arm64fp16_max_by_scalar_f16_32n, + 32, + 4, + f16, + fn run(buf: &mut [f16], s: f16) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(buf: &mut[f16], s: f16) { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.8h, v0.h[0] + 2: + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}] + fmax v4.8h, v4.8h, v0.8h + fmax v5.8h, v5.8h, v0.8h + fmax v6.8h, v6.8h, v0.8h + fmax v7.8h, v7.8h, v0.8h + st1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s.to_bits(), + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + unsafe { run(buf, s) } + } +); + #[cfg(test)] mod test_arm64fp16_mul_by_scalar_f16_32n { use super::*; by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_mul_by_scalar_f16_32n, |a, b| a * b); by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_add_by_scalar_f16_32n, |a, b| a + b); by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_sub_by_scalar_f16_32n, |a, b| a - b); + by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_subf_by_scalar_f16_32n, |a, b| b - a); + by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_min_by_scalar_f16_32n, |a, b| a.min(b)); + by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_max_by_scalar_f16_32n, |a, b| a.max(b)); } diff --git a/linalg/src/arm64/arm64fp16/unicast.rs b/linalg/src/arm64/arm64fp16/unicast.rs index c266bb7b95..680d7245b0 100644 --- a/linalg/src/arm64/arm64fp16/unicast.rs +++ b/linalg/src/arm64/arm64fp16/unicast.rs @@ -109,6 +109,114 @@ unicast_impl_wrap!( } ); +unicast_impl_wrap!( + f16, + arm64fp16_unicast_subf_f16_32n, + 32, + 8, + #[inline(never)] + fn run(a: &mut [f16], b: &[f16]) { + assert!(a.len() == b.len()); + assert!(a.len() % 32 == 0); + assert!(a.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(a: &mut [f16], b: &[f16]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}] + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64 + fsub v0.8h, v4.8h, v0.8h + fsub v1.8h, v5.8h, v1.8h + fsub v2.8h, v6.8h, v2.8h + fsub v3.8h, v7.8h, v3.8h + st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + +unicast_impl_wrap!( + f16, + arm64fp16_unicast_min_f16_32n, + 32, + 8, + #[inline(never)] + fn run(a: &mut [f16], b: &[f16]) { + assert!(a.len() == b.len()); + assert!(a.len() % 32 == 0); + assert!(a.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(a: &mut [f16], b: &[f16]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}] + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64 + fmin v0.8h, v0.8h, v4.8h + fmin v1.8h, v1.8h, v5.8h + fmin v2.8h, v2.8h, v6.8h + fmin v3.8h, v3.8h, v7.8h + st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + +unicast_impl_wrap!( + f16, + arm64fp16_unicast_max_f16_32n, + 32, + 8, + #[inline(never)] + fn run(a: &mut [f16], b: &[f16]) { + assert!(a.len() == b.len()); + assert!(a.len() % 32 == 0); + assert!(a.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(a: &mut [f16], b: &[f16]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}] + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64 + fmax v0.8h, v0.8h, v4.8h + fmax v1.8h, v1.8h, v5.8h + fmax v2.8h, v2.8h, v6.8h + fmax v3.8h, v3.8h, v7.8h + st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64 + subs {len}, {len}, 32 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + #[cfg(test)] mod test_arm64fp16_unicast_mul_f16_32n { use super::*; @@ -116,4 +224,7 @@ mod test_arm64fp16_unicast_mul_f16_32n { crate::unicast_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_mul_f16_32n, |a, b| a * b); crate::unicast_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_add_f16_32n, |a, b| a + b); crate::unicast_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_sub_f16_32n, |a, b| a - b); + crate::unicast_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_subf_f16_32n, |a, b| b - a); + crate::unicast_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_min_f16_32n, |a, b| a.min(b)); + crate::unicast_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_unicast_max_f16_32n, |a, b| a.max(b)); } diff --git a/linalg/src/arm64/arm64simd.rs b/linalg/src/arm64/arm64simd.rs index 75bec6d500..f302b0341d 100644 --- a/linalg/src/arm64/arm64simd.rs +++ b/linalg/src/arm64/arm64simd.rs @@ -7,18 +7,10 @@ mod sum; use crate::frame::PackedFormat; -pub use by_scalar::{ - arm64simd_mul_by_scalar_f32_16n, - arm64simd_add_by_scalar_f32_16n, - arm64simd_sub_by_scalar_f32_16n -}; +pub use by_scalar::*; pub use leaky_relu::arm64simd_leaky_relu_f32_8n; pub use max::arm64simd_max_f32_16n; -pub use unicast::{ - arm64simd_unicast_mul_f32_16n, - arm64simd_unicast_add_f32_16n, - arm64simd_unicast_sub_f32_16n -}; +pub use unicast::*; pub use softmax::arm64simd_softmax2_fastcompact_f32_16n; pub use sum::arm64simd_sum_f32_16n; diff --git a/linalg/src/arm64/arm64simd/by_scalar.rs b/linalg/src/arm64/arm64simd/by_scalar.rs index 000745e5ea..49b2c7550b 100644 --- a/linalg/src/arm64/arm64simd/by_scalar.rs +++ b/linalg/src/arm64/arm64simd/by_scalar.rs @@ -94,10 +94,109 @@ by_scalar_impl_wrap!( } ); +by_scalar_impl_wrap!( + f32, + arm64simd_subf_by_scalar_f32_16n, + 16, + 4, + f32, + fn run(buf: &mut [f32], s: f32) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.4s, v0.s[0] + 2: + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}] + fsub v4.4s, v0.4s, v4.4s + fsub v5.4s, v0.4s, v5.4s + fsub v6.4s, v0.4s, v6.4s + fsub v7.4s, v0.4s, v7.4s + st1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s, + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + } +); + +by_scalar_impl_wrap!( + f32, + arm64simd_min_by_scalar_f32_16n, + 16, + 4, + f32, + fn run(buf: &mut [f32], s: f32) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.4s, v0.s[0] + 2: + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}] + fmin v4.4s, v4.4s, v0.4s + fmin v5.4s, v5.4s, v0.4s + fmin v6.4s, v6.4s, v0.4s + fmin v7.4s, v7.4s, v0.4s + st1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s, + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + } +); + +by_scalar_impl_wrap!( + f32, + arm64simd_max_by_scalar_f32_16n, + 16, + 4, + f32, + fn run(buf: &mut [f32], s: f32) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.4s, v0.s[0] + 2: + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}] + fmax v4.4s, v4.4s, v0.4s + fmax v5.4s, v5.4s, v0.4s + fmax v6.4s, v6.4s, v0.4s + fmax v7.4s, v7.4s, v0.4s + st1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s, + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + } +); + #[cfg(test)] mod test_arm64simd_mul_by_scalar_f32_16n { use super::*; by_scalar_frame_tests!(true, f32, arm64simd_mul_by_scalar_f32_16n, |a, b| a * b); by_scalar_frame_tests!(true, f32, arm64simd_add_by_scalar_f32_16n, |a, b| a + b); by_scalar_frame_tests!(true, f32, arm64simd_sub_by_scalar_f32_16n, |a, b| a - b); + by_scalar_frame_tests!(true, f32, arm64simd_subf_by_scalar_f32_16n, |a, b| b - a); + by_scalar_frame_tests!(true, f32, arm64simd_min_by_scalar_f32_16n, |a, b| a.min(b)); + by_scalar_frame_tests!(true, f32, arm64simd_max_by_scalar_f32_16n, |a, b| a.max(b)); } diff --git a/linalg/src/arm64/arm64simd/unicast.rs b/linalg/src/arm64/arm64simd/unicast.rs index d225496f57..228d701447 100644 --- a/linalg/src/arm64/arm64simd/unicast.rs +++ b/linalg/src/arm64/arm64simd/unicast.rs @@ -105,6 +105,111 @@ unicast_impl_wrap!( } ); +unicast_impl_wrap!( + f32, + arm64simd_unicast_subf_f32_16n, + 16, + 4, + #[inline(never)] + fn run(a: &mut [f32], b: &[f32]) { + assert!(a.len() == b.len()); + assert!(a.len() % 16 == 0); + assert!(a.len() > 0); + unsafe fn run(a: &mut [f32], b: &[f32]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}] + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{b_ptr}], 64 + fsub v0.4s, v4.4s, v0.4s + fsub v1.4s, v5.4s, v1.4s + fsub v2.4s, v6.4s, v2.4s + fsub v3.4s, v7.4s, v3.4s + st1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + +unicast_impl_wrap!( + f32, + arm64simd_unicast_max_f32_16n, + 16, + 4, + #[inline(never)] + fn run(a: &mut [f32], b: &[f32]) { + assert!(a.len() == b.len()); + assert!(a.len() % 16 == 0); + assert!(a.len() > 0); + unsafe fn run(a: &mut [f32], b: &[f32]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}] + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{b_ptr}], 64 + fmax v0.4s, v0.4s, v4.4s + fmax v1.4s, v1.4s, v5.4s + fmax v2.4s, v2.4s, v6.4s + fmax v3.4s, v3.4s, v7.4s + st1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + +unicast_impl_wrap!( + f32, + arm64simd_unicast_min_f32_16n, + 16, + 4, + #[inline(never)] + fn run(a: &mut [f32], b: &[f32]) { + assert!(a.len() == b.len()); + assert!(a.len() % 16 == 0); + assert!(a.len() > 0); + unsafe fn run(a: &mut [f32], b: &[f32]) { + let len = a.len(); + let a_ptr = a.as_ptr(); + let b_ptr = b.as_ptr(); + std::arch::asm!(" + 2: + ld1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}] + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{b_ptr}], 64 + fmin v0.4s, v0.4s, v4.4s + fmin v1.4s, v1.4s, v5.4s + fmin v2.4s, v2.4s, v6.4s + fmin v3.4s, v3.4s, v7.4s + st1 {{v0.4s, v1.4s, v2.4s, v3.4s}}, [{a_ptr}], 64 + subs {len}, {len}, 16 + bne 2b + ", + len = inout(reg) len => _, + a_ptr = inout(reg) a_ptr => _, + b_ptr = inout(reg) b_ptr => _, + out("v0") _, out("v1") _, out("v2") _, out("v3") _,); + } + unsafe { run(a, b) } + } +); + #[cfg(test)] mod test_arm64simd_unicast_mul_f32_16n { use super::*; @@ -112,4 +217,7 @@ mod test_arm64simd_unicast_mul_f32_16n { crate::unicast_frame_tests!(true, f32, arm64simd_unicast_mul_f32_16n, |a, b| a * b); crate::unicast_frame_tests!(true, f32, arm64simd_unicast_add_f32_16n, |a, b| a + b); crate::unicast_frame_tests!(true, f32, arm64simd_unicast_sub_f32_16n, |a, b| a - b); + crate::unicast_frame_tests!(true, f32, arm64simd_unicast_subf_f32_16n, |a, b| b - a); + crate::unicast_frame_tests!(true, f32, arm64simd_unicast_min_f32_16n, |a, b| a.min(b)); + crate::unicast_frame_tests!(true, f32, arm64simd_unicast_max_f32_16n, |a, b| a.max(b)); } diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index 5b7d158797..70ea15acd8 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -29,6 +29,12 @@ pub(crate) fn register_all_unicast(registry: &mut LinalgRegistry) { registry.insert((BinOp::Add, DatumType::F16), Box::new(|| unicast::HUnicastAdd8::bin())); registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| unicast::SUnicastSub4::bin())); registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| unicast::HUnicastSub8::bin())); + registry.insert((BinOp::SubF, DatumType::F32), Box::new(|| unicast::SUnicastSubF4::bin())); + registry.insert((BinOp::SubF, DatumType::F16), Box::new(|| unicast::HUnicastSubF8::bin())); + registry.insert((BinOp::Min, DatumType::F32), Box::new(|| unicast::SUnicastMin4::bin())); + registry.insert((BinOp::Min, DatumType::F16), Box::new(|| unicast::HUnicastMin8::bin())); + registry.insert((BinOp::Max, DatumType::F32), Box::new(|| unicast::SUnicastMax4::bin())); + registry.insert((BinOp::Max, DatumType::F16), Box::new(|| unicast::HUnicastMax8::bin())); } pub(crate) fn register_all_by_scalar(registry: &mut LinalgRegistry) { @@ -38,4 +44,10 @@ pub(crate) fn register_all_by_scalar(registry: &mut LinalgRegistry) { registry.insert((BinOp::Add, DatumType::F16), Box::new(|| by_scalar::HAddByScalar8::bin())); registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| by_scalar::SSubByScalar4::bin())); registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| by_scalar::HSubByScalar8::bin())); + registry.insert((BinOp::SubF, DatumType::F32), Box::new(|| by_scalar::SSubFByScalar4::bin())); + registry.insert((BinOp::SubF, DatumType::F16), Box::new(|| by_scalar::HSubFByScalar8::bin())); + registry.insert((BinOp::Min, DatumType::F32), Box::new(|| by_scalar::SMinByScalar4::bin())); + registry.insert((BinOp::Min, DatumType::F16), Box::new(|| by_scalar::HMinByScalar8::bin())); + registry.insert((BinOp::Max, DatumType::F32), Box::new(|| by_scalar::SMaxByScalar4::bin())); + registry.insert((BinOp::Max, DatumType::F16), Box::new(|| by_scalar::HMaxByScalar8::bin())); } diff --git a/linalg/src/generic/by_scalar.rs b/linalg/src/generic/by_scalar.rs index 743ec0b8b7..1aaab592fb 100644 --- a/linalg/src/generic/by_scalar.rs +++ b/linalg/src/generic/by_scalar.rs @@ -39,6 +39,45 @@ by_scalar_impl_wrap!( } ); +by_scalar_impl_wrap!( + f32, + SSubFByScalar4, + 4, + 4, + f32, + fn run(x: &mut [f32], s: f32) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px = s - *px) + } +); + +by_scalar_impl_wrap!( + f32, + SMinByScalar4, + 4, + 4, + f32, + fn run(x: &mut [f32], s: f32) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px = px.min(s)) + } +); + +by_scalar_impl_wrap!( + f32, + SMaxByScalar4, + 4, + 4, + f32, + fn run(x: &mut [f32], s: f32) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px = px.max(s)) + } +); + #[cfg(test)] #[macro_use] pub mod mul_by_scalar_f32 { @@ -46,6 +85,9 @@ pub mod mul_by_scalar_f32 { by_scalar_frame_tests!(true, f32, SMulByScalar4, |a, b| a * b); by_scalar_frame_tests!(true, f32, SAddByScalar4, |a, b| a + b ); by_scalar_frame_tests!(true, f32, SSubByScalar4, |a, b| a - b); + by_scalar_frame_tests!(true, f32, SSubFByScalar4, |a, b| b - a); + by_scalar_frame_tests!(true, f32, SMinByScalar4, |a, b| a.min(b)); + by_scalar_frame_tests!(true, f32, SMaxByScalar4, |a, b| a.max(b)); } by_scalar_impl_wrap!( @@ -87,6 +129,45 @@ by_scalar_impl_wrap!( } ); +by_scalar_impl_wrap!( + f16, + HSubFByScalar8, + 8, + 8, + f16, + fn run(x: &mut [f16], s: f16) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px = s - *px) + } +); + +by_scalar_impl_wrap!( + f16, + HMinByScalar8, + 8, + 8, + f16, + fn run(x: &mut [f16], s: f16) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px = px.min(s)) + } +); + +by_scalar_impl_wrap!( + f16, + HMaxByScalar8, + 8, + 8, + f16, + fn run(x: &mut [f16], s: f16) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px = px.max(s)) + } +); + #[cfg(test)] #[macro_use] pub mod mul_by_scalar_f16 { @@ -94,4 +175,7 @@ pub mod mul_by_scalar_f16 { by_scalar_frame_tests!(true, f16, HMulByScalar8, |a, b| a * b); by_scalar_frame_tests!(true, f16, HAddByScalar8, |a, b| a + b); by_scalar_frame_tests!(true, f16, HSubByScalar8, |a, b| a - b); + by_scalar_frame_tests!(true, f16, HSubFByScalar8, |a, b| b - a); + by_scalar_frame_tests!(true, f16, HMinByScalar8, |a, b| a.min(b)); + by_scalar_frame_tests!(true, f16, HMaxByScalar8, |a, b| a.max(b)); } diff --git a/linalg/src/generic/unicast.rs b/linalg/src/generic/unicast.rs index 6e93161eec..2d7d4875b8 100644 --- a/linalg/src/generic/unicast.rs +++ b/linalg/src/generic/unicast.rs @@ -83,6 +83,90 @@ unicast_impl_wrap!( } ); +unicast_impl_wrap!( + f32, + SUnicastSubF4, + 4, + 4, + fn run(a: &mut [f32], b: &[f32]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a = *b - *a) + } +); + +unicast_impl_wrap!( + f16, + HUnicastSubF8, + 8, + 8, + fn run(a: &mut [f16], b: &[f16]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a = *b - *a) + } +); + +unicast_impl_wrap!( + f32, + SUnicastMin4, + 4, + 4, + fn run(a: &mut [f32], b: &[f32]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a = a.min(*b)) + } +); + +unicast_impl_wrap!( + f16, + HUnicastMin8, + 8, + 8, + fn run(a: &mut [f16], b: &[f16]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a = a.min(*b)) + } +); + +unicast_impl_wrap!( + f32, + SUnicastMax4, + 4, + 4, + fn run(a: &mut [f32], b: &[f32]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a = a.max(*b)) + } +); + +unicast_impl_wrap!( + f16, + HUnicastMax8, + 8, + 8, + fn run(a: &mut [f16], b: &[f16]) { + debug_assert!(a.len() == b.len()); + debug_assert!(a.len() % Self::nr() == 0); + debug_assert!(a.as_ptr() as usize % Self::alignment_bytes() == 0); + debug_assert!(b.as_ptr() as usize % Self::alignment_bytes() == 0); + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a = a.max(*b)) + } +); + #[cfg(test)] #[macro_use] pub mod s { @@ -91,6 +175,9 @@ pub mod s { crate::unicast_frame_tests!(true, f32, SUnicastMul4, |a, b| a * b); crate::unicast_frame_tests!(true, f32, SUnicastAdd4, |a, b| a + b); crate::unicast_frame_tests!(true, f32, SUnicastSub4, |a, b| a - b); + crate::unicast_frame_tests!(true, f32, SUnicastSubF4, |a, b| b - a); + crate::unicast_frame_tests!(true, f32, SUnicastMin4, |a, b| a.min(b)); + crate::unicast_frame_tests!(true, f32, SUnicastMax4, |a, b| a.max(b)); } #[cfg(test)] @@ -101,4 +188,7 @@ pub mod h { crate::unicast_frame_tests!(true, f16, HUnicastMul8, |a, b| a * b); crate::unicast_frame_tests!(true, f16, HUnicastAdd8, |a, b| a + b); crate::unicast_frame_tests!(true, f16, HUnicastSub8, |a, b| a - b); + crate::unicast_frame_tests!(true, f16, HUnicastSubF8, |a, b| b - a); + crate::unicast_frame_tests!(true, f16, HUnicastMin8, |a, b| a.min(b)); + crate::unicast_frame_tests!(true, f16, HUnicastMax8, |a, b| a.max(b)); } From f4ea5c58069db657c87bb59b9a9c07d7f3812044 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Mon, 14 Oct 2024 16:53:15 +0200 Subject: [PATCH 26/32] Fix tensor alignement --- data/src/tensor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/src/tensor.rs b/data/src/tensor.rs index d54d02adc0..5d8ca07a2f 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -178,7 +178,7 @@ impl Tensor { /// Create an uninitialized tensor (dt as regular parameter). #[inline] pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult { - Self::uninitialized_aligned_dt(dt, shape, dt.alignment()) + Self::uninitialized_aligned_dt(dt, shape, Self::default_alignment(dt, shape)) } /// Create an uninitialized tensor with a given alignment (in bytes). From de83827086706b4b22a6c6495de373ad9c0f478b Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Mon, 14 Oct 2024 17:28:20 +0200 Subject: [PATCH 27/32] Fix unicast alignment issue --- core/src/ops/binary.rs | 12 ++++++++++- linalg/src/frame/unicast.rs | 43 ++++++++++++++----------------------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 0cd10d7fe8..736cee1bd2 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -371,7 +371,8 @@ fn find_most_efficient_config(model: &TypedModel, node: &TypedNode) -> TractResu }; let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape); - let num_unicast_elements = if unicast_is_possible { + let unicast_is_aligned = OptBinUnicast::check_b_alignement(&b_shape); + let num_unicast_elements = if unicast_is_possible & unicast_is_aligned { a_shape .iter() .zip(b_shape.iter()) @@ -507,6 +508,15 @@ impl Debug for OptBinUnicast { } impl OptBinUnicast { + fn check_b_alignement(b_shape: &[TDim]) -> bool { + let num_element = b_shape.iter().product::(); + if let Ok(num_element) = num_element.to_i64() { + let required_alignment = vector_size(); + (num_element as usize % required_alignment) == 0 + } else { + false + } + } fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool { if a_shape.len() != b_shape.len() { return false; diff --git a/linalg/src/frame/unicast.rs b/linalg/src/frame/unicast.rs index c134df42f6..c14cf368cf 100644 --- a/linalg/src/frame/unicast.rs +++ b/linalg/src/frame/unicast.rs @@ -95,24 +95,6 @@ std::thread_local! { static TMP: std::cell::RefCell<(TempBuffer, TempBuffer)> = std::cell::RefCell::new((TempBuffer::default(), TempBuffer::default())); } -fn create_incomplete_tile<'a, T: LADatum>( - a: &'a mut [T], - b: &'a [T], - a_prefix_len: usize, - b_prefix_len: usize, -) -> (&'a mut [T], &'a [T], usize) { - let effective_prefix = if (a_prefix_len == 0) || (b_prefix_len == 0) { - // One of the two slice is aligned, the target size is the number of unaligned elements of - // the other slice, the max value between the two. - a_prefix_len.max(b_prefix_len) - } else { - // Both are unaligned, the minimal common subset is the one including elements from a and b - // so it's the min value between the two. - a_prefix_len.min(b_prefix_len) - }; - (&mut a[..effective_prefix], &b[..effective_prefix], effective_prefix) -} - pub(crate) fn unicast_with_alignment( a: &mut [T], b: &[T], @@ -143,14 +125,18 @@ where let mut num_element_processed = 0; let a_prefix_len = a.as_ptr().align_offset(alignment_bytes).min(a.len()); let b_prefix_len = b.as_ptr().align_offset(alignment_bytes).min(b.len()); + assert!( + a_prefix_len == b_prefix_len, + "Both inputs should be of the same alignement, got {a_prefix_len:?}, {b_prefix_len:?}" + ); let mut applied_prefix_len = 0; - if (a_prefix_len > 0) || (b_prefix_len > 0) { + if a_prefix_len > 0 { // Incomplete tile needs to be created to process unaligned data. - let (sub_a, sub_b, applied_prefix) = - create_incomplete_tile(a, b, a_prefix_len, b_prefix_len); - applied_prefix_len = applied_prefix; + let sub_a = &mut a[..a_prefix_len]; + let sub_b = &b[..a_prefix_len]; compute_via_temp_buffer(sub_a, sub_b); - num_element_processed += applied_prefix_len; + num_element_processed += a_prefix_len; + applied_prefix_len = a_prefix_len; } let num_complete_tiles = (a.len() - applied_prefix_len) / nr; @@ -198,8 +184,12 @@ pub mod test { .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; Ok(()) } - - pub fn test_unicast_t, T: LADatum + Float>(a: &[f32], b: &[f32], func: impl Fn(T, T) -> T) -> TestCaseResult + + pub fn test_unicast_t, T: LADatum + Float>( + a: &[f32], + b: &[f32], + func: impl Fn(T, T) -> T, + ) -> TestCaseResult where f32: AsPrimitive, T: AsPrimitive, @@ -209,7 +199,7 @@ pub mod test { let b: Vec = b.iter().copied().map(|x| x.as_()).collect(); crate::frame::unicast::test::test_unicast::(&a, &b, func) } - + #[macro_export] macro_rules! unicast_frame_tests { ($cond:expr, $t: ty, $ker:ty, $func:expr) => { @@ -234,5 +224,4 @@ pub mod test { } }; } - } From fd1d5e4981b79a5197d4d880495bd4b84b04397a Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Mon, 14 Oct 2024 17:03:01 -0400 Subject: [PATCH 28/32] Add fusing for OptBinUnicast & OptBinByScalar --- core/src/ops/binary.rs | 12 ++++++++---- core/src/ops/matmul/optimized.rs | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 736cee1bd2..cb7b45e306 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -403,8 +403,10 @@ pub struct OptBinByScalar { } impl Debug for OptBinByScalar { - fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - unimplemented!() + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("OptBinByScalar") + .field("binop", &self.binop) + .finish() } } @@ -502,8 +504,10 @@ pub struct OptBinUnicast { } impl Debug for OptBinUnicast { - fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - unimplemented!() + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("OptBinUnicast") + .field("binop", &self.binop) + .finish() } } diff --git a/core/src/ops/matmul/optimized.rs b/core/src/ops/matmul/optimized.rs index 4b4c1fcf26..1fe6912178 100644 --- a/core/src/ops/matmul/optimized.rs +++ b/core/src/ops/matmul/optimized.rs @@ -448,6 +448,7 @@ impl TypedOp for OptMatMul { } let succ = model.node(node.outputs[0].successors[0].node); let mut patch = TypedModelPatch::new(format!("fusing {succ}")); + if let Some(op) = succ.op_as::() { let mut binop = if let Some(op) = op.0.as_linalg_binop() { op } else { return Ok(None) }; @@ -458,6 +459,26 @@ impl TypedOp for OptMatMul { let other_outlet = succ.inputs[flipped as usize]; return self.fuse_binary(model, node, patch, other_outlet, binop); } + if let Some(op) = succ.op_as::() { + let mut binop = + if let Some(op) = op.binop.as_linalg_binop() { op } else { return Ok(None) }; + let flipped = succ.inputs[0].node == node.id; + if flipped { + binop = binop.flip(); + } + let other_outlet = succ.inputs[flipped as usize]; + return self.fuse_binary(model, node, patch, other_outlet, binop); + } + if let Some(op) = succ.op_as::() { + let mut binop = + if let Some(op) = op.binop.as_linalg_binop() { op } else { return Ok(None) }; + let flipped = succ.inputs[0].node == node.id; + if flipped { + binop = binop.flip(); + } + let other_outlet = succ.inputs[flipped as usize]; + return self.fuse_binary(model, node, patch, other_outlet, binop); + } if let Some(op) = succ.op_as::().map(|ew| ew.0.as_ref()) { if let Some(op) = op.downcast_ref::() { From 0e887d12a60464270933a7c8b691ba71bac789bc Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Mon, 14 Oct 2024 17:10:21 -0400 Subject: [PATCH 29/32] Update expected for librispeech cli test --- .../mdl-en-2019-Q3-librispeech/expected | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected index ddeecddb92..8c5d7f19d1 100644 --- a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected +++ b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected @@ -188,7 +188,7 @@ graph network(input) -> (output) { i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1", transposeA = true, transposeB = false); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [1]); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0", axes = [1, 0, 2]); - i"tap.tap.fastlstm1.c_init.0-35/0-105/0" = variable(label = "tap.tap.fastlstm1.c_init.0-35/0-105/0", shape = [1, 256]); + i"tap.tap.fastlstm1.c_init.0-35/0-100/0" = variable(label = "tap.tap.fastlstm1.c_init.0-35/0-100/0", shape = [1, 256]); i"tap.fastlstm1.r_init.0-36/0" = variable(label = "tap.fastlstm1.r_init.0-36/0", shape = [1, 128, 1]); i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice" = variable(label = "fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", shape = [128, 256]); i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice" = variable(label = "fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", shape = [128, 256]); @@ -203,7 +203,7 @@ graph network(input) -> (output) { i"fastlstm1.peephole0.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole0.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm1.peephole1.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole1.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm1.peephole2.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole2.mul.fix-rank-0-1", shape = [1, 256]); - ( i"fastlstm1.c_final", i"fastlstm1.c_final_1" ) = tract_core_scan(body = "scan_body_0", scan = [("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1", 0, 1)], full = [("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm1.four_parts.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm1.h_new.W.split-1-over-1.0..128.slice", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm1.h_new.split-1-over-1.0..128.slice", i"fastlstm1.h_new.split-1-over-1.0..128.slice"), ("fastlstm1.peephole0.mul.fix-rank-0-1", i"fastlstm1.peephole0.mul.fix-rank-0-1"), ("fastlstm1.peephole1.mul.fix-rank-0-1", i"fastlstm1.peephole1.mul.fix-rank-0-1"), ("fastlstm1.peephole2.mul.fix-rank-0-1", i"fastlstm1.peephole2.mul.fix-rank-0-1")], state = [("fastlstm1.c", i"tap.tap.fastlstm1.c_init.0-35/0-105/0", "fastlstm1.c_new"), ("fastlstm1.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm1.r_new")], output = [("fastlstm1.r_new", "full", 2, 1), ("fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 2, reset_every_turn = false); + ( i"fastlstm1.c_final", i"fastlstm1.c_final_1" ) = tract_core_scan(body = "scan_body_0", scan = [("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1", 0, 1)], full = [("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm1.four_parts.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm1.h_new.W.split-1-over-1.0..128.slice", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm1.h_new.split-1-over-1.0..128.slice", i"fastlstm1.h_new.split-1-over-1.0..128.slice"), ("fastlstm1.peephole0.mul.fix-rank-0-1", i"fastlstm1.peephole0.mul.fix-rank-0-1"), ("fastlstm1.peephole1.mul.fix-rank-0-1", i"fastlstm1.peephole1.mul.fix-rank-0-1"), ("fastlstm1.peephole2.mul.fix-rank-0-1", i"fastlstm1.peephole2.mul.fix-rank-0-1")], state = [("fastlstm1.c", i"tap.tap.fastlstm1.c_init.0-35/0-100/0", "fastlstm1.c_new"), ("fastlstm1.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm1.r_new")], output = [("fastlstm1.r_new", "full", 2, 1), ("fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 2, reset_every_turn = false); i"fastlstm1.h_new.W.split-over-1.128..256.fix_a" = transpose(i"fastlstm1.c_final_1", axes = [1, 0, 2]); i"fastlstm1.h_new.W.split-over-1.128..256.fix_a.1" = unsqueeze(i"fastlstm1.h_new.W.split-over-1.128..256.fix_a", axes = [0]); i"fastlstm1.h_new.W.split-over-1.128..256.fix_b.1" = variable(label = "fastlstm1.h_new.W.split-over-1.128..256.fix_b.1", shape = [1, 1, 256, 128]); @@ -265,7 +265,7 @@ graph network(input) -> (output) { i"fastlstm2.peephole0.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole0.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm2.peephole1.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole1.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm2.peephole2.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole2.mul.fix-rank-0-1", shape = [1, 256]); - ( i"fastlstm2.c_final", i"fastlstm2.c_final_1" ) = tract_core_scan(body = "scan_body_1", scan = [("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1", 0, 1)], full = [("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm2.four_parts.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm2.h_new.W.split-1-over-1.0..128.slice", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm2.h_new.split-1-over-1.0..128.slice", i"fastlstm2.h_new.split-1-over-1.0..128.slice"), ("fastlstm2.peephole0.mul.fix-rank-0-1", i"fastlstm2.peephole0.mul.fix-rank-0-1"), ("fastlstm2.peephole1.mul.fix-rank-0-1", i"fastlstm2.peephole1.mul.fix-rank-0-1"), ("fastlstm2.peephole2.mul.fix-rank-0-1", i"fastlstm2.peephole2.mul.fix-rank-0-1")], state = [("fastlstm2.c", i"tap.tap.fastlstm1.c_init.0-35/0-105/0", "fastlstm2.c_new"), ("fastlstm2.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm2.r_new")], output = [("fastlstm2.r_new", "full", 2, 1), ("fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 6, reset_every_turn = false); + ( i"fastlstm2.c_final", i"fastlstm2.c_final_1" ) = tract_core_scan(body = "scan_body_1", scan = [("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1", 0, 1)], full = [("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm2.four_parts.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm2.h_new.W.split-1-over-1.0..128.slice", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm2.h_new.split-1-over-1.0..128.slice", i"fastlstm2.h_new.split-1-over-1.0..128.slice"), ("fastlstm2.peephole0.mul.fix-rank-0-1", i"fastlstm2.peephole0.mul.fix-rank-0-1"), ("fastlstm2.peephole1.mul.fix-rank-0-1", i"fastlstm2.peephole1.mul.fix-rank-0-1"), ("fastlstm2.peephole2.mul.fix-rank-0-1", i"fastlstm2.peephole2.mul.fix-rank-0-1")], state = [("fastlstm2.c", i"tap.tap.fastlstm1.c_init.0-35/0-100/0", "fastlstm2.c_new"), ("fastlstm2.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm2.r_new")], output = [("fastlstm2.r_new", "full", 2, 1), ("fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 6, reset_every_turn = false); i"output.affine.output.W.concat-einsum-k.0..128.fix_a" = variable(label = "output.affine.output.W.concat-einsum-k.0..128.fix_a", shape = [1, 1690, 128]); i"output.affine.output.W.concat-einsum-k.0..128" = matmul(i"fastlstm2.c_final", i"output.affine.output.W.concat-einsum-k.0..128.fix_a", transposeA = true, transposeB = true); i"output.affine.output.W.concat-einsum-k.0..128.fix_c.0" = squeeze(i"output.affine.output.W.concat-einsum-k.0..128", axes = [0]); From 1a11ff10c8fee2c8a5debef659f6c3b014daa6ba Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Wed, 16 Oct 2024 15:40:01 +0200 Subject: [PATCH 30/32] Remove dbg in test --- linalg/src/frame/mmm/tests/packed_packed.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/linalg/src/frame/mmm/tests/packed_packed.rs b/linalg/src/frame/mmm/tests/packed_packed.rs index 99fc177a62..8d190fd87f 100644 --- a/linalg/src/frame/mmm/tests/packed_packed.rs +++ b/linalg/src/frame/mmm/tests/packed_packed.rs @@ -363,11 +363,8 @@ impl PackedPackedProblem { if !self.ker.is_supported_here() { return Ok(()); } - dbg!(self); let expected = self.reference()?; - dbg!(&expected); let found = self.run()?; - dbg!(&found); let app = if K::Acc::datum_type() == f16::datum_type() { Approximation::SuperApproximate } else { From dd240fe97ce20063c844e607446ec7bfcda01222 Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Wed, 16 Oct 2024 17:15:45 +0200 Subject: [PATCH 31/32] Fix alignment issue in test --- linalg/src/frame/unicast.rs | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/linalg/src/frame/unicast.rs b/linalg/src/frame/unicast.rs index c14cf368cf..1c37450a24 100644 --- a/linalg/src/frame/unicast.rs +++ b/linalg/src/frame/unicast.rs @@ -170,16 +170,15 @@ pub mod test { use tract_num_traits::{AsPrimitive, Float}; pub fn test_unicast, T: LADatum>( - a: &[T], + a: &mut [T], b: &[T], reference: impl Fn(T, T) -> T, ) -> TestCaseResult { crate::setup_test_logger(); let op = UnicastImpl::::new(); let expected = a.iter().zip(b.iter()).map(|(a, b)| (reference)(*a, *b)).collect::>(); - let mut found = a.to_vec(); - op.run(&mut found, b).unwrap(); - tensor1(&found) + op.run(a, b).unwrap(); + tensor1(&a) .close_enough(&tensor1(&expected), true) .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; Ok(()) @@ -195,9 +194,17 @@ pub mod test { T: AsPrimitive, { crate::setup_test_logger(); - let a: Vec = a.iter().copied().map(|x| x.as_()).collect(); - let b: Vec = b.iter().copied().map(|x| x.as_()).collect(); - crate::frame::unicast::test::test_unicast::(&a, &b, func) + let vec_a: Vec = a.iter().copied().map(|x| x.as_()).collect(); + // We allocate a tensor to ensure allocation is done with alignement + let mut a = unsafe { Tensor::from_slice_align(vec_a.as_slice(), vector_size()).unwrap() }; + let vec_b: Vec = b.iter().copied().map(|x| x.as_()).collect(); + // We allocate a tensor to ensure allocation is done with alignement + let b = unsafe { Tensor::from_slice_align(vec_b.as_slice(), vector_size()).unwrap() }; + crate::frame::unicast::test::test_unicast::( + a.as_slice_mut::().unwrap(), + &b.as_slice::().unwrap(), + func, + ) } #[macro_export] From d7b99ab144a483ce00898e0fd1fdab3c7478e13e Mon Sep 17 00:00:00 2001 From: Emrick Sinitambirivoutin Date: Wed, 16 Oct 2024 15:41:18 -0400 Subject: [PATCH 32/32] Make check_b_aligment less strict --- core/src/ops/binary.rs | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index cb7b45e306..090e9bdb40 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -371,8 +371,7 @@ fn find_most_efficient_config(model: &TypedModel, node: &TypedNode) -> TractResu }; let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape); - let unicast_is_aligned = OptBinUnicast::check_b_alignement(&b_shape); - let num_unicast_elements = if unicast_is_possible & unicast_is_aligned { + let num_unicast_elements = if unicast_is_possible { a_shape .iter() .zip(b_shape.iter()) @@ -404,9 +403,7 @@ pub struct OptBinByScalar { impl Debug for OptBinByScalar { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - f.debug_struct("OptBinByScalar") - .field("binop", &self.binop) - .finish() + f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish() } } @@ -505,16 +502,31 @@ pub struct OptBinUnicast { impl Debug for OptBinUnicast { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - f.debug_struct("OptBinUnicast") - .field("binop", &self.binop) - .finish() + f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish() } } impl OptBinUnicast { - fn check_b_alignement(b_shape: &[TDim]) -> bool { - let num_element = b_shape.iter().product::(); - if let Ok(num_element) = num_element.to_i64() { + fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool { + let num_iterations: TDim = a_shape + .iter() + .zip(b_shape.iter()) + .take_while(|(_, b_dim)| **b_dim == 1.to_dim()) + .map(|(a_dim, _)| a_dim) + .product(); + + if num_iterations.is_one() { + return true; + } + + let elements_per_iteration: TDim = a_shape + .iter() + .zip(b_shape.iter()) + .skip_while(|(_, b_dim)| **b_dim == 1.to_dim()) + .map(|(_, b_dim)| b_dim) + .product(); + + if let Ok(num_element) = elements_per_iteration.to_i64() { let required_alignment = vector_size(); (num_element as usize % required_alignment) == 0 } else { @@ -526,11 +538,14 @@ impl OptBinUnicast { return false; }; - a_shape + let unicast_possible = a_shape .iter() .zip(b_shape.iter()) .skip_while(|(_, b_dim)| **b_dim == 1.to_dim()) - .all(|(a_dim, b_dim)| a_dim == b_dim) + .all(|(a_dim, b_dim)| a_dim == b_dim); + let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape); + + unicast_possible && unicast_is_aligned } }