Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor binary #1468

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
15c1418
shapes indicies iterator + collapse_axis in tensorview
emricksinisonos Jul 12, 2024
2c06fe7
Introduce BinOpByScalar & BinOpUnicast
emricksinisonos Jul 12, 2024
c7958c3
Add serialization of BinOpByScalar + BinOpUncast
emricksinisonos Jul 12, 2024
73cf50d
Fix unicast & avoid quant bin op declutter
emricksinisonos Jul 16, 2024
e40c04d
conversion in optimize instead of declutter
emricksinisonos Jul 16, 2024
6d6b13c
Add declutter neutral to typed op
emricksinisonos Jul 16, 2024
b8d4c1a
Fix clippy
emricksinisonos Jul 16, 2024
f7eaefb
Dirty plug in linalg
emricksinisonos Jul 16, 2024
9c7f3e1
Create by_scalar & unicast registries in linalg
emricksinisonos Jul 18, 2024
5eff926
Fix import
emricksinisonos Jul 18, 2024
a4a566f
BinOpX are slower ..
emricksinisonos Jul 19, 2024
5fac248
Replace collapse_axis with prefix_with
emricksinisonos Jul 19, 2024
9955a12
Introduce LirMul with predefined linalg method
emricksinisonos Jul 19, 2024
1a546f9
Change naming
emricksinisonos Oct 4, 2024
6834a92
Reorganize code & remove methods from BinMiniOp trait
emricksinisonos Oct 9, 2024
32a91ab
Add more BinOp support in linalg (Add & Sub)
emricksinisonos Oct 9, 2024
497bb6f
Decluttering to swap operand
emricksinisonos Oct 9, 2024
e3d2e81
cargo clippy
emricksinisonos Oct 10, 2024
ff686fc
Fix compilation x86
emricksinisonos Oct 10, 2024
ab40769
Fix linalg tests
emricksinisonos Oct 10, 2024
43e8d56
File renaming
emricksinisonos Oct 10, 2024
6de56ca
Fix typo
emricksinisonos Oct 10, 2024
c660be8
Avoid axes swap for Scale
emricksinisonos Oct 10, 2024
7c1ec5b
Remove tmp bin_1 method
emricksinisonos Oct 10, 2024
00d9774
Add remaining BinOp kernels (Min, Max, SubF)
emricksinisonos Oct 10, 2024
f4ea5c5
Fix tensor alignement
emricksinisonos Oct 14, 2024
de83827
Fix unicast alignment issue
emricksinisonos Oct 14, 2024
fd1d5e4
Add fusing for OptBinUnicast & OptBinByScalar
emricksinisonos Oct 14, 2024
0e887d1
Update expected for librispeech cli test
emricksinisonos Oct 14, 2024
1a11ff1
Remove dbg in test
emricksinisonos Oct 16, 2024
dd240fe
Fix alignment issue in test
emricksinisonos Oct 16, 2024
d7b99ab
Make check_b_aligment less strict
emricksinisonos Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
520 changes: 389 additions & 131 deletions core/src/ops/binary.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl Conv {
bias: OutletId,
c_group_axis: usize,
) -> TractResult<(ProtoFusedSpec, OutletId)> {
use tract_linalg::mmm::BinOp::Add;
use tract_linalg::BinOp::Add;
let fact = model.outlet_fact(bias)?;
if fact.shape.volume().is_one() {
Ok((ProtoFusedSpec::BinScalar(2, Add), bias))
Expand Down
250 changes: 11 additions & 239 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ use num_traits::{Float, Zero};
use tract_data::internal::ClampCast;
use tract_data::itertools::Itertools;
pub use tract_data::prelude::round_ties_to_even;
use tract_linalg::frame::unicast::Unicast;
use tract_linalg::frame::ElementWise;
use tract_linalg::{ScaleShiftAndRound, Scaler};
use tract_ndarray::Axis;
use tract_num_traits::AsPrimitive;

#[cfg(feature = "complex")]
Expand All @@ -23,8 +20,8 @@ mod complex;
pub use complex::{ComplexToInnerDim, InnerDimToComplex};

bin_to_super_type!(add, Add,
declutter: declutter_add,
linalg: Add,
neutral_element: 0,
validation: Validation::Rounding,
q: [i8, u8, i32, i32] => add_quant;
q_op_on_f32: |a: f32, b: f32| -> f32 {a+b},
Expand All @@ -39,8 +36,9 @@ where
}

bin_to_super_type!(sub, Sub,
declutter: declutter_sub,
linalg:Sub,
is_commutative: false,
neutral_element: 0,
q: [i8, u8, i32, i32] => sub_quant;
q_op_on_f32: |a: f32, b: f32| -> f32 {a-b},
[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() - b);
Expand All @@ -56,7 +54,6 @@ where
bin_to_super_type!(mul, Mul,
cost: |dt| tvec!((Cost::FMA(dt), 1)),
declutter: declutter_mul,
eval_in_a: mul_eval_in_a,
eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
// we apply only if type is QU8 zp_scale datum type
if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
Expand All @@ -80,36 +77,7 @@ bin_to_super_type!(mul, Mul,
}
},
linalg: Mul,
uniform_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult<bool> {
if b.datum_type() == f32::datum_type() {
let a = a.to_scalar::<f32>()?;
let slice = b.as_slice_mut::<f32>()?;
(tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, *a)?;
Ok(true)
} else if b.datum_type() == f16::datum_type() {
let a = a.to_scalar::<f16>()?;
let slice = b.as_slice_mut::<f16>()?;
(tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, *a)?;
Ok(true)
} else {
Ok(false)
}
},
unicast_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult<bool> {
if b.datum_type() == f32::datum_type() {
let a = a.as_slice::<f32>()?;
let slice = b.as_slice_mut::<f32>()?;
(tract_linalg::ops().unicast_mul_f32)().run(slice, a)?;
Ok(true)
} else if b.datum_type() == f16::datum_type() {
let a = a.as_slice::<f16>()?;
let slice = b.as_slice_mut::<f16>()?;
(tract_linalg::ops().unicast_mul_f16)().run(slice, a)?;
Ok(true)
} else {
Ok(false)
}
},
neutral_element: 1,
out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
if c.datum_type() == TDim::datum_type() &&
a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
Expand Down Expand Up @@ -155,144 +123,6 @@ bin_to_super_type!(mul, Mul,
[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() * b
);

fn check_uniform_is_possible(a_shape: &[usize], b_shape: &[usize]) -> bool {
if a_shape.len() != b_shape.len() {
return false;
};

a_shape
.iter()
.zip(b_shape.iter())
.skip_while(|(a_dim, b_dim)| a_dim == b_dim)
.all(|(_, b_dim)| *b_dim == 1)
}

fn check_unicast_is_possible(a_shape: &[usize], b_shape: &[usize]) -> bool {
if a_shape.len() != b_shape.len() {
return false;
};

a_shape
.iter()
.zip(b_shape.iter())
.skip_while(|(_, b_dim)| **b_dim == 1)
.all(|(a_dim, b_dim)| a_dim == b_dim)
}

fn mul_eval_in_a(a: &mut Tensor, b: &Tensor) -> TractResult<bool> {
let b_shape = b.shape();
let leading_unary_dims: Vec<usize> =
b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i).collect();
let trailing_unary_dims: Vec<usize> = b_shape
.iter()
.enumerate()
.rev()
.take_while(|&(_, &dim)| dim == 1)
.map(|(i, _)| i)
.collect();

let uniform_is_possible = check_uniform_is_possible(a.shape(), b.shape());
let uniform_in_place_should_be_efficient =
trailing_unary_dims.iter().fold(1, |num_elements, it| num_elements * a.shape()[*it]) > 32;
let unicast_is_possible = check_unicast_is_possible(a.shape(), b.shape());
let unicast_in_place_should_be_efficient =
leading_unary_dims.iter().fold(1, |num_elements, it| num_elements * a.shape()[*it]) > 32;

// Better to try uniform in place first (should be more efficient)
if uniform_in_place_should_be_efficient && uniform_is_possible {
if b.datum_type() == f32::datum_type() {
mul_by_scalar::<f32>(
a,
b,
&trailing_unary_dims,
(tract_linalg::ops().mul_by_scalar_f32)(),
)
} else if b.datum_type() == f16::datum_type() {
mul_by_scalar::<f16>(
a,
b,
&trailing_unary_dims,
(tract_linalg::ops().mul_by_scalar_f16)(),
)
} else {
Ok(false)
}
} else if unicast_in_place_should_be_efficient && unicast_is_possible {
if b.datum_type() == f32::datum_type() {
mul_unicast::<f32>(a, b, &leading_unary_dims, (tract_linalg::ops().unicast_mul_f32)())
} else if b.datum_type() == f16::datum_type() {
mul_unicast::<f16>(a, b, &leading_unary_dims, (tract_linalg::ops().unicast_mul_f16)())
} else {
return Ok(false);
}
} else {
Ok(false)
}
}

fn mul_unicast<T: Datum + Float>(
a: &mut Tensor,
b: &Tensor,
leading_unary_dims: &[usize],
eval: Box<dyn Unicast<T>>,
) -> TractResult<bool> {
let mut a_view = a.to_array_view_mut::<T>()?;
let b_view = b.to_array_view::<T>()?;
let mut iterating_shape = a_view.shape().to_vec();
iterating_shape.iter_mut().enumerate().for_each(|(idx, dim)| {
if !leading_unary_dims.contains(&idx) {
*dim = 1
}
});
for it_coords in tract_ndarray::indices(iterating_shape) {
let mut a_view = a_view.view_mut();
for idx in 0..a_view.shape().len() {
if leading_unary_dims.contains(&idx) {
a_view.collapse_axis(Axis(idx), it_coords[idx]);
}
}

if let Some((a_slice, b_slice)) = a_view.as_slice_mut().zip(b_view.as_slice()) {
eval.run(a_slice, b_slice)?;
} else {
return Ok(false);
}
}
Ok(true)
}

fn mul_by_scalar<T: Datum + Float>(
a: &mut Tensor,
b: &Tensor,
trailing_unary_dims: &[usize],
eval: Box<dyn ElementWise<T, T>>,
) -> TractResult<bool> {
let mut view = a.to_array_view_mut::<T>()?;
let b = b.to_array_view::<T>()?;
for it_coords in tract_ndarray::indices(b.shape()) {
// Prepare array view to perform computation
// - view should be a slice
// - b should be a scalar
let mut view = view.view_mut();
let mut b = b.view();
for idx in 0..b.shape().len() {
if !trailing_unary_dims.contains(&idx) {
view.collapse_axis(Axis(idx), it_coords[idx]);
b.collapse_axis(Axis(idx), it_coords[idx]);
}
}

// Perform computation on a slice on the view
let b = b.as_slice().unwrap()[0];
if let Some(slice) = view.as_slice_mut() {
eval.run_with_params(slice, b)?;
} else {
view.iter_mut().for_each(|it| *it = *it * b)
}
}
Ok(true)
}

bin_to_super_type!(div, Div,
cost: |dt| tvec!((Cost::Div(dt), 1)),
declutter: declutter_div,
Expand Down Expand Up @@ -338,6 +168,8 @@ eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
Div.generic_eval(a, b, c_dt)
}
},
is_commutative: false,
neutral_element: 1,
out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
if c.datum_type() == TDim::datum_type() &&
a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
Expand Down Expand Up @@ -452,61 +284,19 @@ bin_to_super_type!(max, Max,

bin_to_super_type!(pow, Pow,
declutter: declutter_pow,
is_commutative: false,
neutral_element: 1,
q_op_on_f32: |a: f32, b: f32| -> f32 {a.powf(b)},
[f16, f32, f64] => |c,a,b| *c = a.powf(*b),
[i32, i64] => |c,a,b| *c = a.pow(*b as u32));

bin_to_super_type!(shift_left, ShiftLeft,
is_commutative: false,
[i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a << *b);
bin_to_super_type!(shift_right, ShiftRight,
is_commutative: false,
[i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a >> *b);

fn declutter_neutral(
model: &TypedModel,
node: &TypedNode,
value: i64,
also_left: bool,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
// casting to i64 uni quantized type need to be avoided
if uniform.uni.datum_type().is_quantized() {
return Ok(None);
}
let Ok(integer) = uniform.uni.cast_to_scalar::<i64>() else { return Ok(None) };
if tensor0(integer)
.cast_to_dt(uniform.uni.datum_type())?
.close_enough(&uniform.uni, false)
.is_ok()
&& integer == value
&& (also_left || !uniform.left_is_uniform)
{
return Ok(Some(TypedModelPatch::rewire(
model,
&[uniform.var],
&[node.id.into()],
&|_, inputs| Ok(inputs.into()),
)?));
}
}
Ok(None)
}

fn declutter_add(
_op: &Add,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
declutter_neutral(model, node, 0, true)
}

fn declutter_sub(
_op: &Sub,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
declutter_neutral(model, node, 0, false)
}

fn declutter_mul(
_op: &Mul,
model: &TypedModel,
Expand All @@ -520,9 +310,7 @@ fn declutter_mul(
square(),
)?));
}
if let Some(p) = declutter_neutral(model, node, 1, true).context("decluttering neutral")? {
return Ok(Some(p));
}

if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
let var_fact = model.outlet_fact(uniform.var)?;
if uniform.uni.cast_to_scalar::<f64>()? == 0.0 {
Expand Down Expand Up @@ -577,16 +365,6 @@ fn declutter_mul(
},
)?));
}
if !uniform.left_is_uniform {
let mut swap_input = node.inputs.clone();
swap_input.swap(0, 1);
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&swap_input,
mul(),
)?));
}
}
}
Ok(None)
Expand All @@ -597,9 +375,6 @@ fn declutter_div(
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(p) = declutter_neutral(model, node, 1, false)? {
return Ok(Some(p));
}
if let &[p, q] = &*model.node_input_facts(node.id)? {
let dt = q.datum_type;
if let Some(q) = &q.uniform {
Expand Down Expand Up @@ -648,9 +423,6 @@ fn declutter_pow(
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(p) = declutter_neutral(model, node, 1, false)? {
return Ok(Some(p));
}
let b = model.outlet_fact(node.inputs[1])?;
if let Some(b) = &b.uniform {
let b = b.cast_to_scalar::<f32>()?;
Expand Down
Loading
Loading