diff --git a/core/src/ops/array/gather.rs b/core/src/ops/array/gather.rs index eb86244fc5..7d12405583 100644 --- a/core/src/ops/array/gather.rs +++ b/core/src/ops/array/gather.rs @@ -20,6 +20,7 @@ impl Gather { input_shape: &[D], indices_shape: &[D], ) -> TractResult> { + ensure!(input_shape.len() > self.axis); let mut output_shape: TVec = input_shape[..self.axis].into(); output_shape.extend(indices_shape.iter().cloned()); output_shape.extend(input_shape[self.axis + 1..].iter().cloned()); diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 66800433f9..0e3ebb6c24 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -594,124 +594,6 @@ macro_rules! bin_to_super_type { }; } -macro_rules! bin_to_bool { - ($func:ident, $Op:ident, - $( codegen: $codegen:expr, )? - $( cost: $cost:expr, )? - $( declutter: $declutter:expr, )? - $( operating_datum_type: $operating_datum_type:expr, )? - $( [$($typ:ident),*] => $cab:expr),*) => { - #[derive(Debug, Clone, Hash)] - pub struct $Op; - impl $crate::ops::binary::BinMiniOp for $Op { - fn name(&self) -> &'static str { - stringify!($Op) - } - - fn eval_uniform_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> { - $( - $(if a.datum_type() == $typ::datum_type() { - let cab: fn(&mut bool, &bool, &bool) -> () = $cab; - let a = &a.as_slice::()?[0]; - let b = b.as_slice_mut::()?; - unsafe { - for i in 0..b.len() { - let mut c = bool::default(); - cab(&mut c, a, b.get_unchecked(i)); - *b.get_unchecked_mut(i) = c; - } - } - return Ok(()) - } - )* - )* - bail!("{} does not support {:?} (inplace uniform)", self.name(), a.datum_type()); - } - - #[allow(unreachable_code)] - fn eval_unicast_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> { - $( - $(if a.datum_type() == $typ::datum_type() { - let cab: fn(&mut bool, &bool, &bool) -> () = $cab; - let a = a.as_slice::()?; - let b = b.as_slice_mut::()?; - unsafe { - for i in 0..a.len() { - let mut c = bool::default(); - cab(&mut c, a.get_unchecked(i), b.get_unchecked(i)); - *b.get_unchecked_mut(i) = c; - } - } - return Ok(()) - } - )* - )* - bail!("{} does not support {:?}", self.name(), a.datum_type()); - } - - fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> { - $( - $(if a.datum_type() == $typ::datum_type() { - let cab: fn(&mut bool, &$typ, &$typ) -> () = $cab; - let a = a.to_array_view::<$typ>()?; - let b = b.to_array_view::<$typ>()?; - let mut c = c.to_array_view_mut::()?; - ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(cab); - return Ok(()) - } - )* - )* - bail!("{} does not support {:?}", self.name(), a.datum_type()); - } - - fn eval_in_a(&self, a: &mut Tensor, _b: &Tensor) -> TractResult<()> { - bail!("{} does not support {:?}", self.name(), a.datum_type()); - } - - fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult { - Ok(bool::datum_type()) - } - - $( - fn codegen( - &self, - model: &TypedModel, - node: &TypedNode, - ) -> TractResult> { - ($codegen)(self, model, node) - } - )? - - - $( - fn declutter( - &self, - model: &TypedModel, - node: &TypedNode, - ) -> TractResult> { - ($declutter)(self, model, node) - } - )? - - $( - fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> { - ($cost)(dt) - } - )? - - $( - fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult { - ($operating_datum_type)(a, b) - })? - - } - - pub fn $func() -> $crate::ops::binary::TypedBinOp { - $crate::ops::binary::TypedBinOp(Box::new($Op), None) - } - }; -} - #[derive(Debug)] pub(crate) struct OneUniformInput { pub uni: Arc, diff --git a/core/src/ops/logic.rs b/core/src/ops/logic.rs index ed29487017..84394d157d 100644 --- a/core/src/ops/logic.rs +++ b/core/src/ops/logic.rs @@ -1,111 +1,21 @@ #![allow(clippy::bool_comparison)] #![allow(clippy::unnecessary_cast)] +mod comparison; mod ite; pub use ite::IfThenElse; +pub use comparison::Comp; use ndarray::*; use crate::broadcast::multi_broadcast; use crate::internal::*; -use super::binary::BinMiniOp; -use super::element_wise::ElementWiseOp; - bin_to_super_type!(and, And, [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 && b as i64 != 0) as _); bin_to_super_type!(or, Or, [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 || b as i64 != 0) as _); bin_to_super_type!(xor, Xor, /*flip: commute, */ [bool] => |c, &a, &b| *c = a ^ b); -bin_to_bool!(equals, Equals, - [bool, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64, TDim] => |c, a, b | *c = a == b -); -bin_to_bool!(not_equals, NotEquals, /* flip: commute, */ - [bool, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64, TDim] => |c, a, b | *c = a != b -); - -bin_to_bool!(less, Less, - codegen: codegen_compare_to_zero, - operating_datum_type: operating_datum_type_for_cmp, - [bool, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64] => |c, &a, &b | *c = a < b); -bin_to_bool!(less_equal, LessEqual, - codegen: codegen_compare_to_zero, - operating_datum_type: operating_datum_type_for_cmp, - [bool, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64] => |c, &a, &b | *c = a <= b); -bin_to_bool!(greater, Greater, - codegen: codegen_compare_to_zero, - operating_datum_type: operating_datum_type_for_cmp, - [bool, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64] => |c, &a, &b | *c = a > b); -bin_to_bool!(greater_equal, GreaterEqual, - codegen: codegen_compare_to_zero, - operating_datum_type: operating_datum_type_for_cmp, - [bool, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64] => |c, &a, &b | *c = a >= b); - -pub fn operating_datum_type_for_cmp(a: DatumType, b: DatumType) -> TractResult { - let dt = a - .common_super_type(b) - .with_context(|| format_err!("No super type for {:?} and {:?}", a, b))?; - if dt == DatumType::TDim { - Ok(DatumType::I64) - } else { - Ok(dt) - } -} - -fn codegen_compare_to_zero( - op: &dyn BinMiniOp, - model: &TypedModel, - node: &TypedNode, -) -> TractResult> { - let facts = model.node_input_facts(node.id)?; - if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? { - let dt = facts[0].datum_type; - if (dt.is_signed() || dt.is_float()) && *uniform.uni == Tensor::zero_scalar_dt(dt)? { - let reversed = uniform.left_is_uniform; - let mapped = || -> Box { - macro_rules! m { - ($bin: ty, $same: expr, $other: expr) => { - if op.is::<$bin>() { - return if reversed { Box::new($other) } else { Box::new($same) }; - }; - }; - } - m!(Less, LessThanZero {}, GreaterEqualThanZero {}); - m!(LessEqual, LessEqualThanZero {}, GreaterThanZero {}); - m!(Greater, GreaterThanZero {}, LessEqualThanZero {}); - m!(GreaterEqual, GreaterEqualThanZero {}, LessThanZero {}); - unreachable!(); - }; - return Ok(Some(TypedModelPatch::replace_single_op( - model, - node, - &[uniform.var], - ElementWiseOp(mapped(), None), - )?)); - } - } - Ok(None) -} - -element_wise_oop!(less_than_zero, LessThanZero, [f16, f32, f64, i8, i16, i32, i64] => bool |_op, xs, ys| { - xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = *x < num_traits::Zero::zero()); - Ok(()) -}); - -element_wise_oop!(less_equal_than_zero, LessEqualThanZero, [f16, f32, f64, i8, i16, i32, i64] => bool |_op, xs, ys| { - xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = *x <= num_traits::Zero::zero()); - Ok(()) -}); - -element_wise_oop!(greater_than_zero, GreaterThanZero, [f16, f32, f64, i8, i16, i32, i64] => bool |_op, xs, ys| { - xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = *x > num_traits::Zero::zero()); - Ok(()) -}); - -element_wise_oop!(greater_equal_than_zero, GreaterEqualThanZero, [f16, f32, f64, i8, i16, i32, i64] => bool |_op, xs, ys| { - xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = *x >= num_traits::Zero::zero()); - Ok(()) -}); element_wise!(not, Not, [bool] => |_, vs| { vs.iter_mut().for_each(|a| *a = !*a); diff --git a/core/src/ops/logic/comparison.rs b/core/src/ops/logic/comparison.rs new file mode 100644 index 0000000000..11d409b8ec --- /dev/null +++ b/core/src/ops/logic/comparison.rs @@ -0,0 +1,190 @@ +use crate::broadcast::multi_broadcast; +use crate::internal::*; +use crate::ndarray::Zip; + +#[derive(Clone, Copy, Debug, Hash)] +pub enum Comp { + Eq, + NE, + LT, + GT, + GTE, + LTE, +} + +use tract_data::UndeterminedSymbol; +use Comp::*; + +impl Op for Comp { + fn name(&self) -> Cow { + match *self { + Eq => "==", + NE => "!=", + LT => "<", + GT => ">", + LTE => "<=", + GTE => ">=", + } + .into() + } + + op_as_typed_op!(); +} + +impl Comp { + fn eval(&self, a: &Tensor, b: &Tensor) -> TractResult { + let a = a.to_array_view::()?; + let b = b.to_array_view::()?; + let shape = multi_broadcast(&[a.shape(), b.shape()])?; + let mut c = unsafe { Tensor::uninitialized::(&shape)? }; + let mut view = c.to_array_view_mut::()?; + let zipped = Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b); + match *self { + Eq => zipped.for_each(|c, a, b| *c = a == b), + NE => zipped.for_each(|c, a, b| *c = a != b), + LT => zipped.for_each(|c, a, b| *c = a < b), + GT => zipped.for_each(|c, a, b| *c = a > b), + LTE => zipped.for_each(|c, a, b| *c = a <= b), + GTE => zipped.for_each(|c, a, b| *c = a >= b), + } + Ok(c) + } +} + +impl EvalOp for Comp { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + session: &SessionState, + inputs: TVec, + ) -> TractResult> { + if inputs[0].datum_type() == TDim::datum_type() { + let mut a = inputs[0].clone().into_tensor(); + let mut b = inputs[1].clone().into_tensor(); + for a in a.as_slice_mut::()? { + *a = a.eval(&session.resolved_symbols); + } + for b in b.as_slice_mut::()? { + *b = b.eval(&session.resolved_symbols); + } + if let (Ok(a), Ok(b)) = (a.cast_to::(), b.cast_to::()) { + return Ok(tvec!(self.eval::(&a, &b)?.into_tvalue())); + } + let scope = a + .as_slice::()? + .iter() + .chain(b.as_slice::().unwrap().iter()) + .find_map(|d| d.find_scope()) + .unwrap(); + let a = inputs[0].to_array_view::()?; + let b = inputs[0].to_array_view::()?; + let shape = multi_broadcast(&[a.shape(), b.shape()])?; + let mut c = unsafe { Tensor::uninitialized::(&shape)? }; + let mut view = c.to_array_view_mut::()?; + let a = a.broadcast(&*shape).unwrap(); + let b = b.broadcast(&*shape).unwrap(); + for ixs in tract_ndarray::indices(&*shape) { + let (a, b) = (&a[&ixs], &b[&ixs]); + view[&ixs] = match *self { + Eq => a == b, + NE => a != b, + GTE => { + if scope.prove_positive_or_zero(&(a.clone() - b)) { + true + } else if scope.prove_positive_or_zero(&(b.clone() - a - 1)) { + false + } else { + bail!(UndeterminedSymbol(a.clone() - b)); + } + } + GT => { + if scope.prove_positive_or_zero(&(a.clone() - b - 1)) { + true + } else if scope.prove_positive_or_zero(&(b.clone() - a)) { + false + } else { + bail!(UndeterminedSymbol(a.clone() - b)); + } + } + LTE => { + if scope.prove_positive_or_zero(&(b.clone() - a)) { + true + } else if scope.prove_positive_or_zero(&(a.clone() - b - 1)) { + false + } else { + bail!(UndeterminedSymbol(a.clone() - b)); + } + } + LT => { + if scope.prove_positive_or_zero(&(b.clone() - a - 1)) { + true + } else if scope.prove_positive_or_zero(&(a.clone() - b)) { + false + } else { + bail!(UndeterminedSymbol(a.clone() - b)); + } + } + }; + } + Ok(tvec!(c.into_tvalue())) + } else { + let t = dispatch_numbers!(Self::eval(inputs[0].datum_type())( + self, &inputs[0], &inputs[1] + ))?; + Ok(tvec!(t.into_tvalue())) + } + } +} + +impl TypedOp for Comp { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + let shape = multi_broadcast(&[&inputs[0].shape, &inputs[1].shape])?; + Ok(tvec!(bool::datum_type().fact(shape))) + } + + fn change_axes( + &self, + model: &TypedModel, + node: &TypedNode, + _io: InOut, + change: &AxisOp, + ) -> TractResult> { + if let AxisOp::Rm(rm) = change { + let (inputs, outputs) = model.node_facts(node.id)?; + if !inputs[0].shape[*rm].is_one() + || !inputs[0].shape[*rm].is_one() + || !outputs[0].shape[*rm].is_one() + { + return Ok(None); + } + } + Ok(Some(AxisChangeConsequence::new(model, node, None, change))) + } + + fn slice( + &self, + patch: &mut TypedModelPatch, + _model: &TypedModel, + _node: &TypedNode, + prefix: &str, + inputs: &[OutletId], + _output_axis: usize, + _start: usize, + _end: usize, + ) -> TractResult>> { + Ok(Some(patch.wire_node(prefix, *self, inputs)?)) + } + + fn axes_mapping( + &self, + inputs: &[&TypedFact], + outputs: &[&TypedFact], + ) -> TractResult { + AxesMapping::natural(inputs, outputs) + } + + as_op!(); +} diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index 9fe3745d25..24dc9db68e 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -397,7 +397,6 @@ bin_to_super_type!(rem, Rem, [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() % b); bin_to_super_type!(min, Min, linalg:Min, - operating_datum_type: super::logic::operating_datum_type_for_cmp, q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *a } else { *b }; q_op_on_f32: |a: f32, b: f32| a.min(b), [f16, f32, f64] => |c,a,b| *c = a.min(*b), @@ -445,7 +444,6 @@ bin_to_super_type!(max, Max, Max.generic_eval(a, b, c_dt) }, linalg:Max, - operating_datum_type: super::logic::operating_datum_type_for_cmp, q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *b } else { *a }; q_op_on_f32: |a: f32, b: f32| -> f32 {a.max(b)}, [f16, f32, f64] => |c,a,b| *c = a.max(*b), diff --git a/data/src/dim/mod.rs b/data/src/dim/mod.rs index 22d6b67e0b..7ef9e4e03e 100644 --- a/data/src/dim/mod.rs +++ b/data/src/dim/mod.rs @@ -175,7 +175,7 @@ impl DimLike for TDim { } fn maxi(self, other: Self) -> Self { - TDim::Min(vec![self, other]).simplify() + TDim::Max(vec![self, other]).simplify() } } diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index 52c22ecd1d..b832abf926 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -75,7 +75,7 @@ impl SymbolScope { } #[allow(clippy::mutable_key_type)] - pub fn prove_positive(&self, t: &TDim) -> bool { + pub fn prove_positive_or_zero(&self, t: &TDim) -> bool { if let TDim::Val(v) = t { return *v >= 0; } @@ -280,32 +280,32 @@ mod tests { #[test] fn prove_positive_0() { let s = SymbolScope::default(); - assert!(s.prove_positive(&s.parse_tdim("0").unwrap())); + assert!(s.prove_positive_or_zero(&s.parse_tdim("0").unwrap())); } #[test] fn prove_positive_1() { let s = SymbolScope::default(); - assert!(s.prove_positive(&s.parse_tdim("1").unwrap())); + assert!(s.prove_positive_or_zero(&s.parse_tdim("1").unwrap())); } #[test] fn prove_positive_neg1() { let s = SymbolScope::default(); - assert!(!s.prove_positive(&s.parse_tdim("-1").unwrap())); + assert!(!s.prove_positive_or_zero(&s.parse_tdim("-1").unwrap())); } #[test] fn prove_positive_add_0() { let s = SymbolScope::default(); - assert!(!s.prove_positive(&s.parse_tdim("s+1").unwrap())); + assert!(!s.prove_positive_or_zero(&s.parse_tdim("s+1").unwrap())); } #[test] fn prove_positive_with_axiom() { let s = SymbolScope::default(); s.add_inequality("s>=0").unwrap(); - assert!(s.prove_positive(&s.parse_tdim("s").unwrap())); + assert!(s.prove_positive_or_zero(&s.parse_tdim("s").unwrap())); } #[test] @@ -314,6 +314,6 @@ mod tests { s.add_inequality("s>=0").unwrap(); s.add_inequality("p>=0").unwrap(); s.add_inequality("p+s<4096").unwrap(); - assert!(s.prove_positive(&s.parse_tdim("4096-p").unwrap())); + assert!(s.prove_positive_or_zero(&s.parse_tdim("4096-p").unwrap())); } } diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index d3d16e1535..6d0427d9c1 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -10,7 +10,7 @@ use std::fmt::Debug; use std::{fmt, ops}; #[derive(Debug)] -pub struct UndeterminedSymbol(TDim); +pub struct UndeterminedSymbol(pub TDim); impl std::fmt::Display for UndeterminedSymbol { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -37,7 +37,7 @@ pub enum TDim { use TDim::*; -fn tdim_compare(a: &TDim, b: &TDim) -> Ordering { +fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering { match (a, b) { (Sym(a), Sym(b)) => a.cmp(b), (Val(a), Val(b)) => a.cmp(b), @@ -48,10 +48,10 @@ fn tdim_compare(a: &TDim, b: &TDim) -> Ordering { | (Max(a), Max(b)) => a.len().cmp(&b.len()).then( a.iter() .zip(b.iter()) - .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_compare(a, b))), + .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))), ), - (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_compare(d, e)), - (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_compare(d, e)), + (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)), + (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)), (Sym(_), _) => Ordering::Less, (_, Sym(_)) => Ordering::Greater, (Val(_), _) => Ordering::Less, @@ -187,7 +187,7 @@ impl TDim { self.simplify() .wiggle() .into_iter() - .sorted_by(tdim_compare) + .sorted_by(tdim_lexi_order) .unique() .map(|e| e.simplify()) .min_by_key(|e| e.cost()) @@ -271,24 +271,28 @@ impl TDim { } } + fn find_any_sym(tdim: &TDim) -> Option<&Symbol> { + match tdim { + Val(_) => None, + Sym(s) => Some(s), + Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => { + terms.iter().find_map(Self::find_any_sym) + } + MulInt(_, t) | Div(t, _) => Self::find_any_sym(t), + } + } + + pub fn find_scope(&self) -> Option { + Self::find_any_sym(self).map(|s| s.scope().clone()) + } + pub fn simplify(self) -> TDim { use self::TDim::*; if let Val(v) = self { return Val(v); } - fn find_any_sym(tdim: &TDim) -> Option<&Symbol> { - match tdim { - Val(_) => None, - Sym(s) => Some(s), - Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => { - terms.iter().find_map(find_any_sym) - } - MulInt(_, t) | Div(t, _) => find_any_sym(t), - } - } - - let scope = find_any_sym(&self).map(|s| s.scope().clone()); + let scope = Self::find_any_sym(&self).map(|s| s.scope().clone()); self.simplify_rec(scope.as_ref()) } @@ -327,7 +331,7 @@ impl TDim { .into_iter() .filter_map(|(term, count)| evaluate_count(term, count)) .collect(); - members.sort_by(tdim_compare); + members.sort_by(tdim_lexi_order); match members.len() { 0 => TDim::Val(0), @@ -355,7 +359,7 @@ impl TDim { gcd = -gcd; } terms.retain(|t| !t.is_one() && t != &Val(-1)); - terms.sort_by(tdim_compare); + terms.sort_by(tdim_lexi_order); match (gcd, terms.len()) { (_, 0) => Val(gcd), // Case #1: If 0 variables, return product (0, _) => Val(0), // Case #2: Result is 0 if coef is 0 (actually @@ -457,7 +461,7 @@ impl TDim { .map(|s| s.clone().simplify_rec(scope)) .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] }) .filter(|t| !t.is_one()) - .sorted_by(tdim_compare) + .sorted_by(tdim_lexi_order) .dedup() .collect_vec(); if terms.len() == 0 { @@ -473,7 +477,7 @@ impl TDim { .into_iter() .map(|t| t.simplify_rec(scope)) .flat_map(|t| if let Min(t) = t { t } else { vec![t] }) - .sorted_by(tdim_compare) + .sorted_by(tdim_lexi_order) .dedup() .collect(); let new_terms: Vec = flatten @@ -483,7 +487,8 @@ impl TDim { && !flatten.iter().filter(|other| other != &t).any(|other| { let diff = t.clone() - other; diff.to_i64().is_ok_and(|i| i >= 0) - || scope.is_some_and(|scope| scope.prove_positive(&diff)) + || scope + .is_some_and(|scope| scope.prove_positive_or_zero(&diff)) }) }) .cloned() @@ -501,7 +506,7 @@ impl TDim { .into_iter() .map(|t| t.simplify_rec(scope)) .flat_map(|t| if let Max(t) = t { t } else { vec![t] }) - .sorted_by(tdim_compare) + .sorted_by(tdim_lexi_order) .dedup() .collect(); let new_terms: Vec = flatten @@ -511,7 +516,8 @@ impl TDim { && !flatten.iter().filter(|other| other != &t).any(|other| { let diff = other.clone() - t; diff.to_i64().is_ok_and(|i| i >= 0) - || scope.is_some_and(|scope| scope.prove_positive(&diff)) + || scope + .is_some_and(|scope| scope.prove_positive_or_zero(&diff)) }) }) .cloned() diff --git a/hir/src/ops/activations.rs b/hir/src/ops/activations.rs index 44fc10bd38..b8f1be6d9b 100644 --- a/hir/src/ops/activations.rs +++ b/hir/src/ops/activations.rs @@ -1,4 +1,5 @@ use crate::internal::*; +use tract_core::ops::logic::Comp; use tract_core::ops::math::*; macro_rules! activation { @@ -89,9 +90,11 @@ activation!(Celu, |op, name: &str, model: &mut TypedModel, inputs| { cst!(model, inputs, name, zero, 0.0); cst!(model, inputs, name, one, 1.0); cst!(model, inputs, name, alpha, op.0); - let x_over_alpha = model.wire_node(name.to_string() + ".x_over_alpha", div(), &[inputs[0], alpha])?; + let x_over_alpha = + model.wire_node(name.to_string() + ".x_over_alpha", div(), &[inputs[0], alpha])?; let x_over_alpha_exp = model.wire_node(name.to_string() + ".exp", exp(), &[x_over_alpha[0]])?; - let minus_one = model.wire_node(name.to_string() + ".minus_one", sub(), &[x_over_alpha_exp[0], one])?; + let minus_one = + model.wire_node(name.to_string() + ".minus_one", sub(), &[x_over_alpha_exp[0], one])?; let wire = model.wire_node(name.to_string() + ".sat-zero", min(), &[zero, minus_one[0]])?; let relu = model.wire_node(name.to_string() + ".relu", max(), &[zero, inputs[0]])?; let wire = model.wire_node(name.to_string(), add(), &[relu[0], wire[0]])?; @@ -108,11 +111,7 @@ activation!(Elu, |op, name: &str, model: &mut TypedModel, inputs| { let x_exp = model.wire_node(name.to_string() + ".exp", exp(), inputs)?; let minus_one = model.wire_node(name.to_string() + ".minus_one", sub(), &[x_exp[0], one])?; let neg = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, minus_one[0]])?; - let test = model.wire_node( - name.to_string() + ".test", - tract_core::ops::logic::less(), - &[zero, inputs[0]], - )?; + let test = model.wire_node(name.to_string() + ".test", Comp::LT, &[zero, inputs[0]])?; let wire = model.wire_node( name.to_string() + ".iff", tract_core::ops::logic::Iff, @@ -171,10 +170,7 @@ activation!(ScaledTanh, |op, name: &str, model: &mut TypedModel, inputs| { }); #[derive(Debug, Clone, new)] -pub struct Selu( - pub f32, - pub f32, -); +pub struct Selu(pub f32, pub f32); activation!(Selu, |op, name: &str, model: &mut TypedModel, inputs| { cst!(model, inputs, name, zero, 0.0); @@ -183,11 +179,7 @@ activation!(Selu, |op, name: &str, model: &mut TypedModel, inputs| { let wire = model.wire_node(name.to_string() + ".exp", exp(), inputs)?; let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[wire[0], alpha])?; let wire = model.wire_node(name.to_string() + ".sub_alpha", sub(), &[wire[0], alpha])?; - let test = model.wire_node( - name.to_string() + ".test", - tract_core::ops::logic::less(), - &[zero, inputs[0]], - )?; + let test = model.wire_node(name.to_string() + ".test", Comp::LT, &[zero, inputs[0]])?; let wire = model.wire_node( name.to_string() + ".iff", tract_core::ops::logic::Iff, @@ -198,10 +190,7 @@ activation!(Selu, |op, name: &str, model: &mut TypedModel, inputs| { }); #[derive(Debug, Clone, new)] -pub struct Shrink( - pub f32, - pub f32, -); +pub struct Shrink(pub f32, pub f32); activation!(Shrink, |op, name: &str, model: &mut TypedModel, inputs| { cst!(model, inputs, name, bias, op.0); @@ -209,21 +198,15 @@ activation!(Shrink, |op, name: &str, model: &mut TypedModel, inputs| { cst!(model, inputs, name, minus_lambda, -op.1); let zero = broadcast_scalar(0.0, model, inputs)?; let zero = model.add_const(name.to_string() + ".zero", zero)?; - let test_pos = model.wire_node( - name.to_string() + ".test_pos", - tract_core::ops::logic::less(), - &[lambda, inputs[0]], - )?; + let test_pos = + model.wire_node(name.to_string() + ".test_pos", Comp::LT, &[lambda, inputs[0]])?; let pos = model.wire_node( name.to_string() + ".pos", tract_core::ops::math::sub(), &[inputs[0], bias], )?; - let test_neg = model.wire_node( - name.to_string() + ".test_neg", - tract_core::ops::logic::greater(), - &[minus_lambda, inputs[0]], - )?; + let test_neg = + model.wire_node(name.to_string() + ".test_neg", Comp::GT, &[minus_lambda, inputs[0]])?; let neg = model.wire_node( name.to_string() + ".neg", tract_core::ops::math::add(), @@ -248,11 +231,7 @@ pub struct ThresholdRelu(pub f32); activation!(ThresholdRelu, |op, name: &str, model: &mut TypedModel, inputs| { cst!(model, inputs, name, zero, 0.0); cst!(model, inputs, name, alpha, op.0); - let test = model.wire_node( - name.to_string() + ".test", - tract_core::ops::logic::less(), - &[alpha, inputs[0]], - )?; + let test = model.wire_node(name.to_string() + ".test", Comp::LT, &[alpha, inputs[0]])?; let wire = model.wire_node( name.to_string() + ".iff", tract_core::ops::logic::Iff, diff --git a/hir/src/ops/logic.rs b/hir/src/ops/logic.rs index d66229cf51..822af3a5d2 100644 --- a/hir/src/ops/logic.rs +++ b/hir/src/ops/logic.rs @@ -6,6 +6,38 @@ use tract_core::ops::cast::wire_cast; pub use tract_core::ops::change_axes::wire_with_rank_broadcast; pub use tract_core::ops::logic::*; +impl Expansion for Comp { + fn name(&self) -> Cow { + ::name(self) + } + + fn rules<'r, 'p: 'r, 's: 'r>( + &'s self, + s: &mut Solver<'r>, + inputs: &'p [TensorProxy], + outputs: &'p [TensorProxy], + ) -> InferenceResult { + super::binary::rules(s, inputs, outputs, |_, _| Ok(bool::datum_type())) + } + + fn wire( + &self, + prefix: &str, + target: &mut TypedModel, + inputs: &[OutletId], + ) -> TractResult> { + let a = target.outlet_fact(inputs[0])?; + let b = target.outlet_fact(inputs[1])?; + let operating_datum_type = a + .datum_type + .common_super_type(b.datum_type) + .context("No super type for {a:?} and {b:?}")?; + let wires = wire_rank_broadcast(prefix, target, inputs)?; + let wires = wire_cast(prefix, target, &wires, operating_datum_type)?; + target.wire_node(prefix, *self, &wires) + } +} + #[derive(Debug, Clone, Hash)] pub struct Iff; diff --git a/nnef/src/ops/nnef/deser.rs b/nnef/src/ops/nnef/deser.rs index bd1ac57dff..9bad6aedbf 100644 --- a/nnef/src/ops/nnef/deser.rs +++ b/nnef/src/ops/nnef/deser.rs @@ -9,6 +9,7 @@ use tract_core::ops::cnn::deconv::adjustments; use tract_core::ops::cnn::PaddingSpec; use tract_core::ops::cnn::PoolSpec; use tract_core::ops::konst::Const; +use tract_core::ops::logic::Comp; use tract_core::ops::math::min; use tract_core::ops::matmul::de_block_quant::BlockQuantValue; use tract_core::ops::nn::{DataFormat, Softmax, SoftmaxExp}; @@ -260,7 +261,7 @@ pub fn tile(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Trac let repeats: ShapeFact = invocation.named_arg_as(builder, "repeats")?; let wire = invocation.named_arg_as(builder, "input")?; ensure!(builder.model.outlet_fact(wire)?.rank() == repeats.len()); - builder.wire(ops::array::Tile { multipliers: repeats.to_tvec()} , &[wire]) + builder.wire(ops::array::Tile { multipliers: repeats.to_tvec() }, &[wire]) } pub fn pad_mode(border: &str, value: Tensor) -> TractResult { @@ -649,6 +650,37 @@ pub fn matmul(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tr } } +/* +fragment lt( x: tensor, y: tensor ) -> ( z: tensor ) +fragment gt( x: tensor, y: tensor ) -> ( z: tensor ) +fragment le( x: tensor, y: tensor ) -> ( z: tensor ) +fragment ge( x: tensor, y: tensor ) -> ( z: tensor ) +fragment eq( x: tensor, y: tensor ) -> ( z: tensor ) +fragment ne( x: tensor, y: tensor ) -> ( z: tensor ) +*/ +pub fn comp(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult { + let op = match &*invocation.invocation.id.0 { + "eq" => Comp::Eq, + "ne" => Comp::NE, + "lt" => Comp::LT, + "gt" => Comp::GT, + "le" => Comp::LTE, + "ge" => Comp::GTE, + _ => bail!("Unexpected comparing operator"), + }; + let mut a = + invocation.invocation.arguments[0].rvalue.resolve(builder, &[])?.to::(builder)?; + let mut b = + invocation.invocation.arguments[1].rvalue.resolve(builder, &[])?.to::(builder)?; + let a_dt = builder.model.outlet_fact(a)?.datum_type; + let b_dt = builder.model.outlet_fact(b)?.datum_type; + let dt = a_dt.common_super_type(b_dt).context("no supertype found")?; + a = builder.wire_as_outlets(tract_core::ops::cast::cast(dt), &[a])?[0]; + b = builder.wire_as_outlets(tract_core::ops::cast::cast(dt), &[b])?[0]; + let inputs = crate::registry::multi_rank_broadcast(builder, &[a, b])?; + builder.wire(op, &inputs) +} + /* * fragment select( condition: tensor, # the condition for selecting the result diff --git a/nnef/src/ops/nnef/mod.rs b/nnef/src/ops/nnef/mod.rs index 10fa0b55cc..c7b7fa254c 100644 --- a/nnef/src/ops/nnef/mod.rs +++ b/nnef/src/ops/nnef/mod.rs @@ -89,12 +89,10 @@ pub fn tract_nnef() -> Registry { deser::leaky_relu, ); - registry.register_binary("lt", &ops::logic::Less {}); - registry.register_binary("gt", &ops::logic::Greater {}); - registry.register_binary("le", &ops::logic::LessEqual {}); - registry.register_binary("ge", &ops::logic::GreaterEqual {}); - registry.register_binary("eq", &ops::logic::Equals {}); - registry.register_binary("ne", &ops::logic::NotEquals {}); + registry.register_dumper(ser::comp); + for c in ["eq", "ne", "ge", "gt", "le", "lt"] { + primitive(&mut registry, c, deser::comp); + } registry.register_binary("and", &ops::logic::And {}); registry.register_binary("or", &ops::logic::Or {}); @@ -138,3 +136,4 @@ pub fn tract_nnef() -> Registry { } registry } + diff --git a/nnef/src/ops/nnef/ser.rs b/nnef/src/ops/nnef/ser.rs index 1ebe0c2a2d..f650d24531 100644 --- a/nnef/src/ops/nnef/ser.rs +++ b/nnef/src/ops/nnef/ser.rs @@ -32,7 +32,11 @@ pub fn source( Ok(None) } -pub fn basic_matmul(ast: &mut IntoAst, node: &TypedNode, op: &BasicMatMul) -> TractResult>> { +pub fn basic_matmul( + ast: &mut IntoAst, + node: &TypedNode, + op: &BasicMatMul, +) -> TractResult>> { let inputs = node.inputs.iter().map(|i| (*ast.mapping[i]).clone()).collect_vec(); if op.transpose_c { Ok(Some(invocation( @@ -57,6 +61,24 @@ pub fn konst( Ok(Some(ast.konst(&node.name, &op.0)?)) } +pub fn comp( + ast: &mut IntoAst, + node: &TypedNode, + op: &ops::logic::Comp, +) -> TractResult>> { + use ops::logic::Comp::*; + let inputs = node.inputs.iter().map(|i| Arc::clone(&ast.mapping[i])).collect_vec(); + let name = match *op { + Eq => "eq", + NE => "ne", + LT => "lt", + GT => "gt", + LTE => "le", + GTE => "ge", + }; + Ok(Some(invocation(name, &inputs, &[]))) +} + pub fn concat( ast: &mut IntoAst, node: &TypedNode, @@ -455,7 +477,7 @@ pub fn softmax( op: &ops::nn::Softmax, ) -> TractResult>> { if op.exp != SoftmaxExp::default() { - return Ok(None) + return Ok(None); } let litteral_axes: Vec<_> = op.axes.iter().map(|&it| (it as i64).into()).collect(); Ok(Some(invocation( diff --git a/onnx/src/ops/logic.rs b/onnx/src/ops/logic.rs index a3a926da64..f8d3ad02b3 100644 --- a/onnx/src/ops/logic.rs +++ b/onnx/src/ops/logic.rs @@ -4,6 +4,7 @@ use crate::model::ParsingContext; use crate::pb::NodeProto; use tract_core::ops; use tract_hir::internal::*; +use tract_hir::ops::logic::Comp; use tract_itertools::Itertools; pub fn register_all_ops(reg: &mut OnnxOpRegister) { @@ -12,11 +13,11 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) { reg.insert("Or", |_, _| Ok((ops::logic::Or.into_hir(), vec![]))); reg.insert("Xor", |_, _| Ok((ops::logic::Xor.into_hir(), vec![]))); - reg.insert("Equal", |_, _| Ok((ops::logic::Equals.into_hir(), vec![]))); - reg.insert("Greater", |_, _| Ok((ops::logic::Greater.into_hir(), vec![]))); - reg.insert("Less", |_, _| Ok((ops::logic::Less.into_hir(), vec![]))); - reg.insert("LessOrEqual", |_, _| Ok((ops::logic::LessEqual.into_hir(), vec![]))); - reg.insert("GreaterOrEqual", |_, _| Ok((ops::logic::GreaterEqual.into_hir(), vec![]))); + reg.insert("Equal", |_, _| Ok((expand(Comp::Eq), vec![]))); + reg.insert("Greater", |_, _| Ok((expand(Comp::GT), vec![]))); + reg.insert("Less", |_, _| Ok((expand(Comp::LT), vec![]))); + reg.insert("LessOrEqual", |_, _| Ok((expand(Comp::LTE), vec![]))); + reg.insert("GreaterOrEqual", |_, _| Ok((expand(Comp::GTE), vec![]))); reg.insert("Where", |_, _| Ok((expand(tract_hir::ops::logic::Iff), vec![]))); diff --git a/onnx/src/ops/math/rem.rs b/onnx/src/ops/math/rem.rs index 4319128f7a..0e41945f9d 100644 --- a/onnx/src/ops/math/rem.rs +++ b/onnx/src/ops/math/rem.rs @@ -2,6 +2,7 @@ use crate::model::ParsingContext; use crate::pb::*; use tract_hir::internal::*; use tract_hir::ops; +use tract_hir::ops::logic::Comp; pub fn rem( _ctx: &ParsingContext, @@ -17,8 +18,6 @@ pub fn rem( #[derive(Debug, Clone, new, Hash)] pub struct RemInt; - - impl Expansion for RemInt { fn name(&self) -> Cow { "Remint".into() @@ -44,9 +43,8 @@ impl Expansion for RemInt { let zero = tract_hir::ops::activations::broadcast_scalar(0.0, model, inputs)?; let a = model.outlet_fact(inputs[0])?.datum_type; let b = model.outlet_fact(inputs[1])?.datum_type; - let dt = a - .common_super_type(b) - .with_context(|| format!("No super type for {a:?} and {b:?}"))?; + let dt = + a.common_super_type(b).with_context(|| format!("No super type for {a:?} and {b:?}"))?; let wires = tract_hir::ops::binary::wire_rank_broadcast(name, model, inputs)?; let wires = tract_hir::ops::binary::wire_cast(name, model, &wires, dt)?; if dt.is_unsigned() || dt == DatumType::TDim { @@ -58,26 +56,14 @@ impl Expansion for RemInt { let zero = model.add_const(name.to_string() + ".zero", zero)?; let rem = model.wire_node(name.to_string() + ".rem", tract_hir::ops::math::rem(), &wires)?[0]; - let rem_is_neg = model.wire_node( - name.to_string() + ".rem_is_neg", - tract_hir::ops::logic::greater(), - &[zero, rem], - )?; - let rem_is_pos = model.wire_node( - name.to_string() + ".rem_is_pos", - tract_hir::ops::logic::less(), - &[zero, rem], - )?; - let b_is_neg = model.wire_node( - name.to_string() + ".b_is_neg", - tract_hir::ops::logic::greater(), - &[zero, wires[1]], - )?; - let b_is_pos = model.wire_node( - name.to_string() + ".b_is_pos", - tract_hir::ops::logic::less(), - &[zero, wires[1]], - )?; + let rem_is_neg = + model.wire_node(name.to_string() + ".rem_is_neg", Comp::GT, &[zero, rem])?; + let rem_is_pos = + model.wire_node(name.to_string() + ".rem_is_pos", Comp::LT, &[zero, rem])?; + let b_is_neg = + model.wire_node(name.to_string() + ".b_is_neg", Comp::GT, &[zero, wires[1]])?; + let b_is_pos = + model.wire_node(name.to_string() + ".b_is_pos", Comp::LT, &[zero, wires[1]])?; let rem_is_neg_b_is_pos = model.wire_node( name.to_string() + ".rem_is_neg_b_is_pos", tract_hir::ops::logic::and(), diff --git a/onnx/src/ops/nn/mod.rs b/onnx/src/ops/nn/mod.rs index 1b8a2784c2..7eb96ea7dd 100644 --- a/onnx/src/ops/nn/mod.rs +++ b/onnx/src/ops/nn/mod.rs @@ -1,5 +1,6 @@ use tract_hir::internal::*; use tract_hir::ops; +use tract_hir::ops::logic::Comp; use tract_hir::ops::{cnn, nn}; use crate::model::{OnnxOpRegister, ParsingContext}; @@ -364,11 +365,7 @@ impl Expansion for Prelu { .broadcast_into_rank(rank)?; let ab = model.wire_node(format!("{name}.mul"), tract_hir::ops::math::mul(), &[a, b])?[0]; let zero = model.add_const(name.to_string() + ".zero", zero)?; - let test = model.wire_node( - name.to_string() + ".test", - tract_hir::ops::logic::greater(), - &[zero, a], - )?; + let test = model.wire_node(name.to_string() + ".test", Comp::GT, &[zero, a])?; model.wire_node(name.to_string() + ".iff", tract_core::ops::logic::Iff, &[test[0], ab, a]) } } diff --git a/tensorflow/src/ops/logic.rs b/tensorflow/src/ops/logic.rs index 986e0c85a7..996bc9a8e1 100644 --- a/tensorflow/src/ops/logic.rs +++ b/tensorflow/src/ops/logic.rs @@ -1,5 +1,6 @@ use tract_hir::internal::*; use tract_hir::ops; +use tract_hir::ops::logic::Comp; use crate::model::ParsingContext; use crate::model::TfOpRegister; @@ -7,11 +8,11 @@ use crate::tfpb::tensorflow::NodeDef; use std::collections::HashSet; pub fn register_all_ops(reg: &mut TfOpRegister) { - reg.insert("Equal", |_, _| Ok(ops::logic::Equals.into_hir())); - reg.insert("Greater", |_, _| Ok(ops::logic::Greater.into_hir())); - reg.insert("GreaterEqual", |_, _| Ok(ops::logic::GreaterEqual.into_hir())); - reg.insert("Less", |_, _| Ok(ops::logic::Less.into_hir())); - reg.insert("LessEqual", |_, _| Ok(ops::logic::LessEqual.into_hir())); + reg.insert("Equal", |_, _| Ok(expand(Comp::Eq))); + reg.insert("Greater", |_, _| Ok(expand(Comp::GT))); + reg.insert("GreaterEqual", |_, _| Ok(expand(Comp::GTE))); + reg.insert("Less", |_, _| Ok(expand(Comp::LT))); + reg.insert("LessEqual", |_, _| Ok(expand(Comp::LTE))); reg.insert("LogicalAnd", |_, _| Ok(ops::logic::And.into_hir())); reg.insert("LogicalOr", |_, _| Ok(ops::logic::Or.into_hir())); reg.insert("Merge", merge); @@ -21,8 +22,6 @@ pub fn register_all_ops(reg: &mut TfOpRegister) { #[derive(Debug, Clone, new, Hash)] pub struct Switch; - - impl Op for Switch { fn name(&self) -> Cow { "Switch".into() @@ -118,8 +117,6 @@ pub struct Merge { n: usize, } - - impl Op for Merge { fn name(&self) -> Cow { "Merge".into() diff --git a/tensorflow/src/ops/math.rs b/tensorflow/src/ops/math.rs index fac71fd810..5a39cf3f02 100644 --- a/tensorflow/src/ops/math.rs +++ b/tensorflow/src/ops/math.rs @@ -24,7 +24,6 @@ pub fn register_all_ops(reg: &mut TfOpRegister) { reg.insert("Sum", reduce::sum); reg.insert("Maximum", |_, _| Ok(ops::math::Max.into_hir())); reg.insert("Minimum", |_, _| Ok(ops::math::Min.into_hir())); - reg.insert("Less", |_, _| Ok(ops::logic::Less.into_hir())); reg.insert("Log", |_, _| Ok(ops::math::ln().into_hir())); reg.insert("Mul", |_, _| Ok(ops::math::Mul.into_hir())); reg.insert("Pow", |_, _| Ok(ops::math::Pow.into_hir())); diff --git a/tflite/src/ops/math.rs b/tflite/src/ops/math.rs index 3143b42b44..d92d19a4b8 100644 --- a/tflite/src/ops/math.rs +++ b/tflite/src/ops/math.rs @@ -10,10 +10,11 @@ use tract_core::internal::*; use tract_core::ops::binary::TypedBinOp; use tract_core::ops::cast::wire_cast; use tract_core::ops::change_axes::wire_rank_broadcast; -use tract_core::ops::logic; +use tract_core::ops::logic::{self, Comp}; pub fn register_all(reg: &mut Registry) { reg.reg_to_tflite(ser_bin); + reg.reg_to_tflite(ser_comp); reg.reg_to_tract(BuiltinOperator::ADD, deser_add); reg.reg_to_tract(BuiltinOperator::SUB, deser_sub); @@ -22,12 +23,12 @@ pub fn register_all(reg: &mut Registry) { reg.reg_to_tract(BuiltinOperator::MAXIMUM, |op| deser_bin(op, tract_core::ops::math::max())); reg.reg_to_tract(BuiltinOperator::MINIMUM, |op| deser_bin(op, tract_core::ops::math::min())); - reg.reg_to_tract(BuiltinOperator::EQUAL, |op| deser_bin(op, tract_core::ops::logic::equals())); - reg.reg_to_tract(BuiltinOperator::NOT_EQUAL, |op| deser_bin(op, logic::not_equals())); - reg.reg_to_tract(BuiltinOperator::LESS, |op| deser_bin(op, tract_core::ops::logic::less())); - reg.reg_to_tract(BuiltinOperator::LESS_EQUAL, |op| deser_bin(op, logic::less_equal())); - reg.reg_to_tract(BuiltinOperator::GREATER, |op| deser_bin(op, logic::greater())); - reg.reg_to_tract(BuiltinOperator::GREATER_EQUAL, |op| deser_bin(op, logic::greater_equal())); + reg.reg_to_tract(BuiltinOperator::EQUAL, |op| deser_comp(op, Comp::Eq)); + reg.reg_to_tract(BuiltinOperator::NOT_EQUAL, |op| deser_comp(op, Comp::NE)); + reg.reg_to_tract(BuiltinOperator::LESS, |op| deser_comp(op, Comp::LT)); + reg.reg_to_tract(BuiltinOperator::LESS_EQUAL, |op| deser_comp(op, Comp::LTE)); + reg.reg_to_tract(BuiltinOperator::GREATER, |op| deser_comp(op, Comp::GT)); + reg.reg_to_tract(BuiltinOperator::GREATER_EQUAL, |op| deser_comp(op, Comp::GTE)); reg.reg_to_tract(BuiltinOperator::LOGICAL_OR, |op| deser_bin(op, logic::or())); reg.reg_to_tract(BuiltinOperator::LOGICAL_AND, |op| deser_bin(op, logic::and())); } @@ -48,6 +49,11 @@ fn deser_bin(op: &mut DeserOp, mini: TypedBinOp) -> TractResult> op.ctx.target.wire_node(op.prefix, mini, &wires) } +fn deser_comp(op: &mut DeserOp, comp: Comp) -> TractResult> { + let wires = wire_cast_and_rank_broadcast(op)?; + op.ctx.target.wire_node(op.prefix, comp, &wires) +} + fn deser_add(op: &mut DeserOp) -> TractResult> { let options = builtin!(op, builtin_options_as_add_options); let wires = wire_cast_and_rank_broadcast(op)?; @@ -99,12 +105,6 @@ fn ser_bin( } }; } - ser_logic!(logic::Less, 58, 1, LESS); - ser_logic!(logic::Greater, 61, 1, GREATER); - ser_logic!(logic::GreaterEqual, 62, 1, GREATER_EQUAL); - ser_logic!(logic::LessEqual, 63, 1, LESS_EQUAL); - ser_logic!(logic::Equals, 71, 1, EQUAL); - ser_logic!(logic::NotEquals, 72, 1, NOT_EQUAL); ser_logic!(logic::Or, 84, 1, LOGICAL_OR); ser_logic!(logic::And, 86, 1, LOGICAL_AND); @@ -198,3 +198,23 @@ fn ser_bin( it => todo!("Missing iplementation for binary {it:?} serialization"), } } + +fn ser_comp( + builder: &mut SubgraphBuilder, + model: &TypedModel, + node: &TypedNode, + op: &Comp, +) -> TractResult<()> { + use Comp::*; + let (code, version, builtin) = match *op { + LT => (58, 1, BuiltinOperator::LESS), + GT => (61, 1, BuiltinOperator::GREATER), + GTE => (62, 1, BuiltinOperator::GREATER_EQUAL), + LTE => (63, 1, BuiltinOperator::LESS_EQUAL), + Eq => (71, 1, BuiltinOperator::EQUAL), + NE => (72, 1, BuiltinOperator::NOT_EQUAL), + }; + let inputs = builder.map_outlets(model, &node.inputs)?; + let outputs = builder.map_outlets(model, [OutletId::from(node.id)])?; + builder.write_op(&inputs, &outputs, code, version, builtin) +}