From 4c60f18235f439b17aca55fd7ef2d7d200ec5835 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 9 Sep 2024 17:24:23 +0200 Subject: [PATCH 01/20] wip --- data/src/dim/parse.rs | 36 +++++++++++------------- data/src/dim/sym.rs | 59 +++++++++++++++------------------------ data/src/dim/tree.rs | 65 ++++++++++++++++++++++++------------------- 3 files changed, 76 insertions(+), 84 deletions(-) diff --git a/data/src/dim/parse.rs b/data/src/dim/parse.rs index d3e13ec430..68dbfbfa2a 100644 --- a/data/src/dim/parse.rs +++ b/data/src/dim/parse.rs @@ -4,9 +4,9 @@ use nom::bytes::complete::tag; use nom::character::complete::{alpha1, alphanumeric1, digit1, one_of}; use nom::combinator::{all_consuming, map, map_res, recognize}; use nom::multi::{many0, separated_list0}; -use nom::sequence::{delimited, pair, preceded, separated_pair, tuple}; +use nom::sequence::{delimited, pair, preceded, separated_pair}; use nom::IResult; -use sym::{Inequality, InequalitySign}; +use sym::Inequality; pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult { match all_consuming(|i| expr(symbol_table, i))(input) { @@ -23,17 +23,19 @@ pub fn parse_inequality(symbol_table: &SymbolScope, input: &str) -> TractResult< } fn inequality<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Inequality> { - map(tuple((|i| expr(s, i), inequality_sign, |i| expr(s, i))), |(left, sign, right)| { - Inequality { left, sign, right } - })(i) -} - -fn inequality_sign(i: &str) -> IResult<&str, InequalitySign> { alt(( - map(stag("<="), |_| InequalitySign::LTE), - map(stag("<"), |_| InequalitySign::LT), - map(stag(">="), |_| InequalitySign::GTE), - map(stag(">"), |_| InequalitySign::GT), + map(separated_pair(|i| expr(s, i), stag("<="), |i| expr(s, i)), |(a, b)| { + Inequality::LTE(a, b) + }), + map(separated_pair(|i| expr(s, i), stag(">="), |i| expr(s, i)), |(a, b)| { + Inequality::GTE(a, b) + }), + map(separated_pair(|i| expr(s, i), stag("<"), |i| expr(s, i)), |(a, b)| { + Inequality::LT(a, b) + }), + map(separated_pair(|i| expr(s, i), stag(">"), |i| expr(s, i)), |(a, b)| { + Inequality::GT(a, b) + }), ))(i) } @@ -58,9 +60,7 @@ bin!(mul, div, "*", |(a, b)| a * b); fn broadcast<'i>(symbol_table: &SymbolScope, input: &'i str) -> IResult<&'i str, TDim> { let s = symbol_table; alt(( - map_res(separated_pair(|i| add(s, i), stag("#"), |i| add(s, i)), |(a, b)| { - a.broadcast(b) - }), + map_res(separated_pair(|i| add(s, i), stag("#"), |i| add(s, i)), |(a, b)| a.broadcast(b)), |i| add(s, i), ))(input) } @@ -170,11 +170,7 @@ mod test { let table = SymbolScope::default(); assert_eq!( parse_inequality(&table, "P+S<4096").unwrap(), - Inequality { - left: parse_tdim(&table, "P+S").unwrap(), - sign: InequalitySign::LT, - right: 4096.to_dim() - } + Inequality::LT(parse_tdim(&table, "P+S").unwrap(), 4096.to_dim()) ); } } diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index f60ce53b30..041291496b 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -87,7 +87,7 @@ impl SymbolScope { let ineqs = self.0.lock().unwrap().inequalities.clone(); let positives = ineqs .iter() - .map(|i| parse_inequality(self, i).unwrap().as_known_positive()) + .filter_map(|i| parse_inequality(self, i).unwrap().as_known_positive()) .collect_vec(); let mut visited = vec![]; let mut todo = vec![t.clone()]; @@ -140,50 +140,37 @@ impl fmt::Debug for SymbolScope { #[derive(Debug, PartialEq, Clone, Hash)] #[allow(clippy::upper_case_acronyms)] -pub enum InequalitySign { - LT, - GT, - LTE, - GTE, +pub enum Inequality { + LT(TDim, TDim), + GT(TDim, TDim), + LTE(TDim, TDim), + GTE(TDim, TDim), } -impl Display for InequalitySign { +impl Display for Inequality { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use InequalitySign::*; + use Inequality::*; match self { - LT => write!(f, "<"), - GT => write!(f, ">"), - LTE => write!(f, "<="), - GTE => write!(f, ">="), + LT(l, r) => write!(f, "{l} < {r}"), + GT(l, r) => write!(f, "{l} > {r}"), + LTE(l, r) => write!(f, "{l} <= {r}"), + GTE(l, r) => write!(f, "{l} >= {r}"), } } } -#[derive(Debug, PartialEq, Clone, Hash)] -pub struct Inequality { - pub left: TDim, - pub sign: InequalitySign, - pub right: TDim, -} - impl Inequality { - pub fn as_known_positive(&self) -> TDim { - use InequalitySign::*; - match self.sign { - GTE => self.left.clone() - &self.right, - GT => self.left.clone() - 1 - &self.right, - LTE => self.right.clone() - &self.left, - LT => self.right.clone() - 1 - &self.left, + pub fn as_known_positive(&self) -> Option { + use Inequality::*; + match self { + GTE(left, right) => Some(left.clone() - right), + GT(left, right) => Some(left.clone() - 1 - right), + LTE(left, right) => Some(right.clone() - left), + LT(left, right) => Some(right.clone() - 1 - left), } } } -impl Display for Inequality { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} {} {}", self.left, self.sign, self.right) - } -} - #[derive(Clone, PartialEq, Eq)] pub struct Symbol(SymbolScope, string_interner::DefaultSymbol); @@ -254,7 +241,7 @@ mod tests { let s = SymbolScope::default(); assert_eq!( parse_inequality(&s, "S>=0").unwrap().as_known_positive(), - s.parse_tdim("S").unwrap() + Some(s.parse_tdim("S").unwrap()) ); } @@ -263,7 +250,7 @@ mod tests { let s = SymbolScope::default(); assert_eq!( parse_inequality(&s, "S>0").unwrap().as_known_positive(), - s.parse_tdim("S-1").unwrap() + Some(s.parse_tdim("S-1").unwrap()) ); } @@ -272,7 +259,7 @@ mod tests { let s = SymbolScope::default(); assert_eq!( parse_inequality(&s, "S<=0").unwrap().as_known_positive(), - s.parse_tdim("-S").unwrap() + Some(s.parse_tdim("-S").unwrap()) ); } @@ -281,7 +268,7 @@ mod tests { let s = SymbolScope::default(); assert_eq!( parse_inequality(&s, "S<0").unwrap().as_known_positive(), - s.parse_tdim("-S - 1").unwrap() + Some(s.parse_tdim("-S - 1").unwrap()) ); } diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index b88e034ecd..a7365957e6 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -550,20 +550,19 @@ impl TDim { .all_assertions() .iter() .filter_map(|assert| { - let Inequality { left, sign, right } = - parse_inequality(scope, assert).unwrap(); - if &left == self - && sign == InequalitySign::LT - && right.as_i64().is_some() - { - Some(right.as_i64().unwrap() - 1) - } else if &left == self - && sign == InequalitySign::LTE - && right.as_i64().is_some() - { - Some(right.as_i64().unwrap()) - } else { - None + let ineq = parse_inequality(scope, assert).unwrap(); + match &ineq { + Inequality::LT(left, right) + if left == self && right.as_i64().is_some() => + { + Some(right.as_i64().unwrap() - 1) + } + Inequality::LTE(left, right) + if left == self && right.as_i64().is_some() => + { + Some(right.as_i64().unwrap()) + } + _ => None, } }) .min() @@ -572,20 +571,19 @@ impl TDim { .all_assertions() .iter() .filter_map(|assert| { - let Inequality { left, sign, right } = - parse_inequality(scope, assert).unwrap(); - if &left == self - && sign == InequalitySign::GT - && right.as_i64().is_some() - { - Some(right.as_i64().unwrap() + 1) - } else if &left == self - && sign == InequalitySign::GTE - && right.as_i64().is_some() - { - Some(right.as_i64().unwrap()) - } else { - None + let ineq = parse_inequality(scope, assert).unwrap(); + match &ineq { + Inequality::GT(left, right) + if left == self && right.as_i64().is_some() => + { + Some(right.as_i64().unwrap() + 1) + } + Inequality::GTE(left, right) + if left == self && right.as_i64().is_some() => + { + Some(right.as_i64().unwrap()) + } + _ => None, } }) .max() @@ -1379,4 +1377,15 @@ mod tests { symbols.parse_tdim("0").unwrap() ); } + + #[test] + fn min_llm_0() { + let symbols = SymbolScope::default(); + symbols.add_inequality("S>=0").unwrap(); + symbols.add_inequality("P>=0").unwrap(); + assert_eq!( + symbols.parse_tdim("min(P,(S)#(P+S))").unwrap().simplify(), + symbols.parse_tdim("P").unwrap() + ); + } } From 93321fb9407ca18e5669f9f8f7cb0d49c7128eb7 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 10 Sep 2024 08:55:47 +0200 Subject: [PATCH 02/20] wip --- data/src/dim/parse.rs | 16 ++++++++-------- data/src/dim/sym.rs | 10 +++++----- data/src/dim/tree.rs | 9 +++++---- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/data/src/dim/parse.rs b/data/src/dim/parse.rs index 68dbfbfa2a..4eba764b77 100644 --- a/data/src/dim/parse.rs +++ b/data/src/dim/parse.rs @@ -6,7 +6,7 @@ use nom::combinator::{all_consuming, map, map_res, recognize}; use nom::multi::{many0, separated_list0}; use nom::sequence::{delimited, pair, preceded, separated_pair}; use nom::IResult; -use sym::Inequality; +use sym::Assertions; pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult { match all_consuming(|i| expr(symbol_table, i))(input) { @@ -15,26 +15,26 @@ pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult } } -pub fn parse_inequality(symbol_table: &SymbolScope, input: &str) -> TractResult { +pub fn parse_inequality(symbol_table: &SymbolScope, input: &str) -> TractResult { match all_consuming(|i| inequality(symbol_table, i))(input) { Ok(pair) => Ok(pair.1), Err(e) => bail!("Failed to parse {:?}, {:?}", input, e), } } -fn inequality<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Inequality> { +fn inequality<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Assertions> { alt(( map(separated_pair(|i| expr(s, i), stag("<="), |i| expr(s, i)), |(a, b)| { - Inequality::LTE(a, b) + Assertions::LTE(a, b) }), map(separated_pair(|i| expr(s, i), stag(">="), |i| expr(s, i)), |(a, b)| { - Inequality::GTE(a, b) + Assertions::GTE(a, b) }), map(separated_pair(|i| expr(s, i), stag("<"), |i| expr(s, i)), |(a, b)| { - Inequality::LT(a, b) + Assertions::LT(a, b) }), map(separated_pair(|i| expr(s, i), stag(">"), |i| expr(s, i)), |(a, b)| { - Inequality::GT(a, b) + Assertions::GT(a, b) }), ))(i) } @@ -170,7 +170,7 @@ mod test { let table = SymbolScope::default(); assert_eq!( parse_inequality(&table, "P+S<4096").unwrap(), - Inequality::LT(parse_tdim(&table, "P+S").unwrap(), 4096.to_dim()) + Assertions::LT(parse_tdim(&table, "P+S").unwrap(), 4096.to_dim()) ); } } diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index 041291496b..92a2d3b4a4 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -140,16 +140,16 @@ impl fmt::Debug for SymbolScope { #[derive(Debug, PartialEq, Clone, Hash)] #[allow(clippy::upper_case_acronyms)] -pub enum Inequality { +pub enum Assertions { LT(TDim, TDim), GT(TDim, TDim), LTE(TDim, TDim), GTE(TDim, TDim), } -impl Display for Inequality { +impl Display for Assertions { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use Inequality::*; + use Assertions::*; match self { LT(l, r) => write!(f, "{l} < {r}"), GT(l, r) => write!(f, "{l} > {r}"), @@ -159,9 +159,9 @@ impl Display for Inequality { } } -impl Inequality { +impl Assertions { pub fn as_known_positive(&self) -> Option { - use Inequality::*; + use Assertions::*; match self { GTE(left, right) => Some(left.clone() - right), GT(left, right) => Some(left.clone() - 1 - right), diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index a7365957e6..359f5231d5 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -552,12 +552,12 @@ impl TDim { .filter_map(|assert| { let ineq = parse_inequality(scope, assert).unwrap(); match &ineq { - Inequality::LT(left, right) + Assertions::LT(left, right) if left == self && right.as_i64().is_some() => { Some(right.as_i64().unwrap() - 1) } - Inequality::LTE(left, right) + Assertions::LTE(left, right) if left == self && right.as_i64().is_some() => { Some(right.as_i64().unwrap()) @@ -573,12 +573,12 @@ impl TDim { .filter_map(|assert| { let ineq = parse_inequality(scope, assert).unwrap(); match &ineq { - Inequality::GT(left, right) + Assertions::GT(left, right) if left == self && right.as_i64().is_some() => { Some(right.as_i64().unwrap() + 1) } - Inequality::GTE(left, right) + Assertions::GTE(left, right) if left == self && right.as_i64().is_some() => { Some(right.as_i64().unwrap()) @@ -1379,6 +1379,7 @@ mod tests { } #[test] + #[ignore] fn min_llm_0() { let symbols = SymbolScope::default(); symbols.add_inequality("S>=0").unwrap(); From 2e47d60eee656c097be4f1b5d7bd2366d5beaefa Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 10 Sep 2024 13:20:59 +0200 Subject: [PATCH 03/20] wip, using rwlock --- data/src/dim/sym.rs | 109 +++++++++++++++++++++++------------------ data/src/dim/tree.rs | 113 +++++++++++++++++++++++++------------------ 2 files changed, 128 insertions(+), 94 deletions(-) diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index 92a2d3b4a4..b7dd5b81cb 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -1,7 +1,7 @@ use itertools::Itertools; use std::collections::HashMap; use std::fmt::{self, Display}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard, Weak}; use string_interner::DefaultStringInterner; use string_interner::Symbol as _; @@ -11,7 +11,7 @@ use super::parse::parse_inequality; use super::{parse_tdim, TDim}; #[derive(Clone, Default)] -pub struct SymbolScope(Arc>); +pub struct SymbolScope(Arc>); impl PartialEq for SymbolScope { fn eq(&self, other: &Self) -> bool { @@ -24,23 +24,23 @@ impl Eq for SymbolScope {} #[derive(Default)] pub struct SymbolScopeData { table: DefaultStringInterner, - inequalities: Vec, + inequalities: Vec, } impl SymbolScope { pub fn get(&self, name: &str) -> Option { - let locked = self.0.lock().unwrap(); - locked.table.get(name).map(|sym| Symbol(self.clone(), sym)) + let locked = self.0.read().unwrap(); + locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym)) } pub fn sym(&self, name: &str) -> Symbol { - let mut locked = self.0.lock().unwrap(); + let mut locked = self.0.write().unwrap(); let sym = locked.table.get_or_intern(name); - Symbol(self.clone(), sym) + Symbol(Arc::downgrade(&self.0), sym) } pub fn new_with_prefix(&self, prefix: &str) -> Symbol { - let mut locked = self.0.lock().unwrap(); + let mut locked = self.0.write().unwrap(); let sym = if locked.table.get(prefix).is_none() { locked.table.get_or_intern(prefix) } else { @@ -53,14 +53,7 @@ impl SymbolScope { i += 1; } }; - Symbol(self.clone(), sym) - } - - pub fn resolving(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option { - match self.0.try_lock() { - Ok(lock) => lock.table.resolve(sym.1).map(f), - Err(_) => None, - } + Symbol(Arc::downgrade(&self.0), sym) } pub fn parse_tdim(&self, input: impl AsRef) -> TractResult { @@ -69,8 +62,8 @@ impl SymbolScope { pub fn add_inequality(&self, ineq: impl Into) -> TractResult<()> { let ineq = ineq.into(); - parse_inequality(self, &ineq)?; - self.0.lock().unwrap().inequalities.push(ineq); + let ineq = parse_inequality(self, &ineq)?; + self.0.write().unwrap().inequalities.push(ineq); Ok(()) } @@ -79,23 +72,41 @@ impl SymbolScope { Ok(self) } + pub fn all_symbols(&self) -> Vec { + self.0.read().unwrap().table.into_iter().map(|is| Symbol(Arc::downgrade(&self.0), is.0)).collect() + } + + pub fn all_assertions(&self) -> Vec { + self.0.read().unwrap().inequalities.clone() + } + + pub fn lock(&self) -> Option> { + self.0.read().ok() + } +} + +impl SymbolScopeData { + pub fn all_assertions(&self) -> &[Assertions] { + &self.inequalities + } + + pub fn resolving(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option { + self.table.resolve(sym.1).map(f) + } + #[allow(clippy::mutable_key_type)] pub fn prove_positive_or_zero(&self, t: &TDim) -> bool { if let TDim::Val(v) = t { return *v >= 0; } - let ineqs = self.0.lock().unwrap().inequalities.clone(); - let positives = ineqs - .iter() - .filter_map(|i| parse_inequality(self, i).unwrap().as_known_positive()) - .collect_vec(); + let positives = self.inequalities.iter().filter_map(|i| i.as_known_positive()).collect_vec(); let mut visited = vec![]; let mut todo = vec![t.clone()]; while let Some(t) = todo.pop() { if t.to_i64().is_ok_and(|i| i >= 0) { return true; } - if t.low_inclusive_bound(self).is_some_and(|l| l >= 0) { + if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) { return true; } let syms = t.symbols(); @@ -122,18 +133,11 @@ impl SymbolScope { false } - pub fn all_symbols(&self) -> Vec { - self.0.lock().unwrap().table.into_iter().map(|is| Symbol(self.clone(), is.0)).collect() - } - - pub fn all_assertions(&self) -> Vec { - self.0.lock().unwrap().inequalities.clone() - } } impl fmt::Debug for SymbolScope { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let locked = self.0.lock().unwrap(); + let locked = self.0.read().unwrap(); write!(f, "{}", locked.table.into_iter().map(|(_, s)| s).join(" ")) } } @@ -171,18 +175,26 @@ impl Assertions { } } -#[derive(Clone, PartialEq, Eq)] -pub struct Symbol(SymbolScope, string_interner::DefaultSymbol); +#[derive(Clone)] +pub struct Symbol(Weak>, string_interner::DefaultSymbol); + +impl Eq for Symbol {} + +impl PartialEq for Symbol { + fn eq(&self, other: &Self) -> bool { + self.1 == other.1 + } +} impl Symbol { - pub fn scope(&self) -> &SymbolScope { - &self.0 + pub fn scope(&self) -> Option { + self.0.upgrade().map(SymbolScope) } } impl PartialOrd for Symbol { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + Some(self.1.cmp(&other.1)) } } @@ -200,9 +212,14 @@ impl std::hash::Hash for Symbol { impl std::fmt::Display for Symbol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0 - .resolving(self, |s| write!(f, "{s}")) - .unwrap_or_else(|| write!(f, "", self.1.to_usize())) + if let Some(scope) = self.scope() { + if let Ok(lock) = scope.0.read() { + if let Some(s) = lock.table.resolve(self.1) { + return write!(f, "{}", s); + } + } + } + write!(f, "", self.1.to_usize()) } } @@ -275,32 +292,32 @@ mod tests { #[test] fn prove_positive_0() { let s = SymbolScope::default(); - assert!(s.prove_positive_or_zero(&s.parse_tdim("0").unwrap())); + assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero()); } #[test] fn prove_positive_1() { let s = SymbolScope::default(); - assert!(s.prove_positive_or_zero(&s.parse_tdim("1").unwrap())); + assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero()); } #[test] fn prove_positive_neg1() { let s = SymbolScope::default(); - assert!(!s.prove_positive_or_zero(&s.parse_tdim("-1").unwrap())); + assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero()); } #[test] fn prove_positive_add_0() { let s = SymbolScope::default(); - assert!(!s.prove_positive_or_zero(&s.parse_tdim("s+1").unwrap())); + assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero()); } #[test] fn prove_positive_with_axiom() { let s = SymbolScope::default(); s.add_inequality("s>=0").unwrap(); - assert!(s.prove_positive_or_zero(&s.parse_tdim("s").unwrap())); + assert!(s.parse_tdim("s").unwrap().prove_positive_or_zero()); } #[test] @@ -309,6 +326,6 @@ mod tests { s.add_inequality("s>=0").unwrap(); s.add_inequality("p>=0").unwrap(); s.add_inequality("p+s<4096").unwrap(); - assert!(s.prove_positive_or_zero(&s.parse_tdim("4096-p").unwrap())); + assert!(s.parse_tdim("4096-p").unwrap().prove_positive_or_zero()); } } diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index 359f5231d5..2257cbcae7 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -1,4 +1,3 @@ -use crate::dim::parse::parse_inequality; use crate::internal::*; use super::{sym::*, DimLike}; @@ -284,7 +283,7 @@ impl TDim { } pub fn find_scope(&self) -> Option { - Self::find_any_sym(self).map(|s| s.scope().clone()) + Self::find_any_sym(self).and_then(|s| s.scope().clone()) } pub fn simplify(self) -> TDim { @@ -292,12 +291,12 @@ impl TDim { if let Val(v) = self { return Val(v); } - - let scope = Self::find_any_sym(&self).map(|s| s.scope().clone()); - self.simplify_rec(scope.as_ref()) + let scope = self.find_scope(); + let data = scope.as_ref().and_then(|scope| scope.lock()); + self.simplify_rec(data.as_deref()) } - fn simplify_rec(self, scope: Option<&SymbolScope>) -> TDim { + fn simplify_rec(self, scope: Option<&SymbolScopeData>) -> TDim { match self { Add(mut terms) => { #[allow(clippy::mutable_key_type)] @@ -540,7 +539,7 @@ impl TDim { } } - pub fn inclusive_bound(&self, scope: &SymbolScope, upper: bool) -> Option { + pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option { use self::TDim::*; match self { Val(n) => Some(*n), @@ -549,42 +548,36 @@ impl TDim { scope .all_assertions() .iter() - .filter_map(|assert| { - let ineq = parse_inequality(scope, assert).unwrap(); - match &ineq { - Assertions::LT(left, right) - if left == self && right.as_i64().is_some() => - { - Some(right.as_i64().unwrap() - 1) - } - Assertions::LTE(left, right) - if left == self && right.as_i64().is_some() => - { - Some(right.as_i64().unwrap()) - } - _ => None, + .filter_map(|assert| match &assert { + Assertions::LT(left, right) + if left == self && right.as_i64().is_some() => + { + Some(right.as_i64().unwrap() - 1) + } + Assertions::LTE(left, right) + if left == self && right.as_i64().is_some() => + { + Some(right.as_i64().unwrap()) } + _ => None, }) .min() } else { scope .all_assertions() .iter() - .filter_map(|assert| { - let ineq = parse_inequality(scope, assert).unwrap(); - match &ineq { - Assertions::GT(left, right) - if left == self && right.as_i64().is_some() => - { - Some(right.as_i64().unwrap() + 1) - } - Assertions::GTE(left, right) - if left == self && right.as_i64().is_some() => - { - Some(right.as_i64().unwrap()) - } - _ => None, + .filter_map(|assert| match &assert { + Assertions::GT(left, right) + if left == self && right.as_i64().is_some() => + { + Some(right.as_i64().unwrap() + 1) + } + Assertions::GTE(left, right) + if left == self && right.as_i64().is_some() => + { + Some(right.as_i64().unwrap()) } + _ => None, }) .max() } @@ -606,26 +599,49 @@ impl TDim { Ordering::Less => a.inclusive_bound(scope, !upper).map(|x| x * p), }, Mul(_) => None, - Min(terms) if !upper => terms.iter().filter_map(|t| t.low_inclusive_bound(scope)).min(), - Max(terms) if upper => terms.iter().filter_map(|t| t.high_inclusive_bound(scope)).max(), + Min(terms) if !upper => { + terms.iter().filter_map(|t| t.inclusive_bound(scope, false)).min() + } + Max(terms) if upper => { + terms.iter().filter_map(|t| t.inclusive_bound(scope, true)).max() + } Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)), Broadcast(terms) => { if upper { - Max(terms.clone()).high_inclusive_bound(scope) + Max(terms.clone()).inclusive_bound(scope, true) } else { - Min(terms.clone()).low_inclusive_bound(scope) + Min(terms.clone()).inclusive_bound(scope, false) } } _ => None, } } - pub fn low_inclusive_bound(&self, scope: &SymbolScope) -> Option { - self.inclusive_bound(scope, false) + pub fn low_inclusive_bound(&self) -> Option { + if let TDim::Val(v) = self { + return Some(*v); + } + let Some(scope) = self.find_scope() else { return None }; + let Some(data) = scope.lock() else { return None }; + self.inclusive_bound(&*data, false) } - pub fn high_inclusive_bound(&self, scope: &SymbolScope) -> Option { - self.inclusive_bound(scope, true) + pub fn high_inclusive_bound(&self) -> Option { + if let TDim::Val(v) = self { + return Some(*v); + } + let Some(scope) = self.find_scope() else { return None }; + let Some(data) = scope.lock() else { return None }; + self.inclusive_bound(&*data, true) + } + + pub fn prove_positive_or_zero(&self) -> bool { + if let TDim::Val(v) = self { + return *v >= 0; + } + let Some(scope) = self.find_scope() else { return false }; + let Some(data) = scope.lock() else { return false }; + data.prove_positive_or_zero(&self) } pub fn gcd(&self) -> u64 { @@ -1295,35 +1311,36 @@ mod tests { #[test] fn low_bound_0() -> TractResult<()> { let symbols = SymbolScope::default().with_inequality("S>=0")?; - assert_eq!(symbols.parse_tdim("S").unwrap().low_inclusive_bound(&symbols), Some(0)); + let s = symbols.parse_tdim("S").unwrap(); + assert_eq!(s.low_inclusive_bound(), Some(0)); Ok(()) } #[test] fn low_bound_1() -> TractResult<()> { let symbols = SymbolScope::default().with_inequality("S>0")?; - assert_eq!(symbols.parse_tdim("S").unwrap().low_inclusive_bound(&symbols), Some(1)); + assert_eq!(symbols.parse_tdim("S").unwrap().low_inclusive_bound(), Some(1)); Ok(()) } #[test] fn low_bound_2() -> TractResult<()> { let symbols = SymbolScope::default().with_inequality("S>0")?; - assert_eq!(symbols.parse_tdim("S + 1").unwrap().low_inclusive_bound(&symbols), Some(2)); + assert_eq!(symbols.parse_tdim("S + 1").unwrap().low_inclusive_bound(), Some(2)); Ok(()) } #[test] fn low_bound_3() -> TractResult<()> { let symbols = SymbolScope::default().with_inequality("S>0")?; - assert_eq!(symbols.parse_tdim("4*S").unwrap().low_inclusive_bound(&symbols), Some(4)); + assert_eq!(symbols.parse_tdim("4*S").unwrap().low_inclusive_bound(), Some(4)); Ok(()) } #[test] fn low_bound_4() -> TractResult<()> { let symbols = SymbolScope::default().with_inequality("S>0")?.with_inequality("S>5")?; - assert_eq!(symbols.parse_tdim("S + 3").unwrap().low_inclusive_bound(&symbols), Some(9)); + assert_eq!(symbols.parse_tdim("S + 3").unwrap().low_inclusive_bound(), Some(9)); Ok(()) } From 20d331bfa014319fbbfe38e74899f0e3e91e0b4a Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 12 Sep 2024 15:58:25 +0200 Subject: [PATCH 04/20] terminology: inequalities->assertion --- cli/src/params.rs | 2 +- data/src/dim/parse.rs | 18 +++++++-------- data/src/dim/sym.rs | 52 +++++++++++++++++++++---------------------- data/src/dim/tree.rs | 34 ++++++++++++++-------------- nnef/src/deser.rs | 2 +- 5 files changed, 54 insertions(+), 54 deletions(-) diff --git a/cli/src/params.rs b/cli/src/params.rs index 36e13d7a53..a31a452e5d 100644 --- a/cli/src/params.rs +++ b/cli/src/params.rs @@ -861,7 +861,7 @@ impl Parameters { pub fn from_clap(matches: &clap::ArgMatches, probe: Option<&Probe>) -> TractResult { let symbols = SymbolScope::default(); for rule in matches.values_of("assert").unwrap_or_default() { - symbols.add_inequality(rule)?; + symbols.add_assertion(rule)?; } let (filename, onnx_tc) = Self::disco_model(matches)?; let tensors_values = Self::parse_tensors(matches, &filename, onnx_tc, &symbols)?; diff --git a/data/src/dim/parse.rs b/data/src/dim/parse.rs index 4eba764b77..1a3e1b14d6 100644 --- a/data/src/dim/parse.rs +++ b/data/src/dim/parse.rs @@ -6,7 +6,7 @@ use nom::combinator::{all_consuming, map, map_res, recognize}; use nom::multi::{many0, separated_list0}; use nom::sequence::{delimited, pair, preceded, separated_pair}; use nom::IResult; -use sym::Assertions; +use sym::Assertion; pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult { match all_consuming(|i| expr(symbol_table, i))(input) { @@ -15,26 +15,26 @@ pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult } } -pub fn parse_inequality(symbol_table: &SymbolScope, input: &str) -> TractResult { +pub fn parse_assertion(symbol_table: &SymbolScope, input: &str) -> TractResult { match all_consuming(|i| inequality(symbol_table, i))(input) { Ok(pair) => Ok(pair.1), Err(e) => bail!("Failed to parse {:?}, {:?}", input, e), } } -fn inequality<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Assertions> { +fn inequality<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Assertion> { alt(( map(separated_pair(|i| expr(s, i), stag("<="), |i| expr(s, i)), |(a, b)| { - Assertions::LTE(a, b) + Assertion::LTE(a, b) }), map(separated_pair(|i| expr(s, i), stag(">="), |i| expr(s, i)), |(a, b)| { - Assertions::GTE(a, b) + Assertion::GTE(a, b) }), map(separated_pair(|i| expr(s, i), stag("<"), |i| expr(s, i)), |(a, b)| { - Assertions::LT(a, b) + Assertion::LT(a, b) }), map(separated_pair(|i| expr(s, i), stag(">"), |i| expr(s, i)), |(a, b)| { - Assertions::GT(a, b) + Assertion::GT(a, b) }), ))(i) } @@ -169,8 +169,8 @@ mod test { fn parse_inequality_0() { let table = SymbolScope::default(); assert_eq!( - parse_inequality(&table, "P+S<4096").unwrap(), - Assertions::LT(parse_tdim(&table, "P+S").unwrap(), 4096.to_dim()) + parse_assertion(&table, "P+S<4096").unwrap(), + Assertion::LT(parse_tdim(&table, "P+S").unwrap(), 4096.to_dim()) ); } } diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index b7dd5b81cb..6b2dadb6c2 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -7,7 +7,7 @@ use string_interner::Symbol as _; use crate::TractResult; -use super::parse::parse_inequality; +use super::parse::parse_assertion; use super::{parse_tdim, TDim}; #[derive(Clone, Default)] @@ -24,7 +24,7 @@ impl Eq for SymbolScope {} #[derive(Default)] pub struct SymbolScopeData { table: DefaultStringInterner, - inequalities: Vec, + assertions: Vec, } impl SymbolScope { @@ -60,15 +60,15 @@ impl SymbolScope { parse_tdim(self, input.as_ref()) } - pub fn add_inequality(&self, ineq: impl Into) -> TractResult<()> { - let ineq = ineq.into(); - let ineq = parse_inequality(self, &ineq)?; - self.0.write().unwrap().inequalities.push(ineq); + pub fn add_assertion(&self, assert: impl Into) -> TractResult<()> { + let assert = assert.into(); + let assert = parse_assertion(self, &assert)?; + self.0.write().unwrap().assertions.push(assert); Ok(()) } - pub fn with_inequality(self, ineq: impl Into) -> TractResult { - self.add_inequality(ineq)?; + pub fn with_assertion(self, assert: impl Into) -> TractResult { + self.add_assertion(assert)?; Ok(self) } @@ -76,8 +76,8 @@ impl SymbolScope { self.0.read().unwrap().table.into_iter().map(|is| Symbol(Arc::downgrade(&self.0), is.0)).collect() } - pub fn all_assertions(&self) -> Vec { - self.0.read().unwrap().inequalities.clone() + pub fn all_assertions(&self) -> Vec { + self.0.read().unwrap().assertions.clone() } pub fn lock(&self) -> Option> { @@ -86,8 +86,8 @@ impl SymbolScope { } impl SymbolScopeData { - pub fn all_assertions(&self) -> &[Assertions] { - &self.inequalities + pub fn all_assertions(&self) -> &[Assertion] { + &self.assertions } pub fn resolving(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option { @@ -99,7 +99,7 @@ impl SymbolScopeData { if let TDim::Val(v) = t { return *v >= 0; } - let positives = self.inequalities.iter().filter_map(|i| i.as_known_positive()).collect_vec(); + let positives = self.assertions.iter().filter_map(|i| i.as_known_positive()).collect_vec(); let mut visited = vec![]; let mut todo = vec![t.clone()]; while let Some(t) = todo.pop() { @@ -144,16 +144,16 @@ impl fmt::Debug for SymbolScope { #[derive(Debug, PartialEq, Clone, Hash)] #[allow(clippy::upper_case_acronyms)] -pub enum Assertions { +pub enum Assertion { LT(TDim, TDim), GT(TDim, TDim), LTE(TDim, TDim), GTE(TDim, TDim), } -impl Display for Assertions { +impl Display for Assertion { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use Assertions::*; + use Assertion::*; match self { LT(l, r) => write!(f, "{l} < {r}"), GT(l, r) => write!(f, "{l} > {r}"), @@ -163,9 +163,9 @@ impl Display for Assertions { } } -impl Assertions { +impl Assertion { pub fn as_known_positive(&self) -> Option { - use Assertions::*; + use Assertion::*; match self { GTE(left, right) => Some(left.clone() - right), GT(left, right) => Some(left.clone() - 1 - right), @@ -257,7 +257,7 @@ mod tests { fn as_known_positive_gte() { let s = SymbolScope::default(); assert_eq!( - parse_inequality(&s, "S>=0").unwrap().as_known_positive(), + parse_assertion(&s, "S>=0").unwrap().as_known_positive(), Some(s.parse_tdim("S").unwrap()) ); } @@ -266,7 +266,7 @@ mod tests { fn as_known_positive_gt() { let s = SymbolScope::default(); assert_eq!( - parse_inequality(&s, "S>0").unwrap().as_known_positive(), + parse_assertion(&s, "S>0").unwrap().as_known_positive(), Some(s.parse_tdim("S-1").unwrap()) ); } @@ -275,7 +275,7 @@ mod tests { fn as_known_positive_lte() { let s = SymbolScope::default(); assert_eq!( - parse_inequality(&s, "S<=0").unwrap().as_known_positive(), + parse_assertion(&s, "S<=0").unwrap().as_known_positive(), Some(s.parse_tdim("-S").unwrap()) ); } @@ -284,7 +284,7 @@ mod tests { fn as_known_positive_lt() { let s = SymbolScope::default(); assert_eq!( - parse_inequality(&s, "S<0").unwrap().as_known_positive(), + parse_assertion(&s, "S<0").unwrap().as_known_positive(), Some(s.parse_tdim("-S - 1").unwrap()) ); } @@ -316,16 +316,16 @@ mod tests { #[test] fn prove_positive_with_axiom() { let s = SymbolScope::default(); - s.add_inequality("s>=0").unwrap(); + s.add_assertion("s>=0").unwrap(); assert!(s.parse_tdim("s").unwrap().prove_positive_or_zero()); } #[test] fn prove_positive_with_axiom_2() { let s = SymbolScope::default(); - s.add_inequality("s>=0").unwrap(); - s.add_inequality("p>=0").unwrap(); - s.add_inequality("p+s<4096").unwrap(); + s.add_assertion("s>=0").unwrap(); + s.add_assertion("p>=0").unwrap(); + s.add_assertion("p+s<4096").unwrap(); assert!(s.parse_tdim("4096-p").unwrap().prove_positive_or_zero()); } } diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index 2257cbcae7..fe2437cf2e 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -549,12 +549,12 @@ impl TDim { .all_assertions() .iter() .filter_map(|assert| match &assert { - Assertions::LT(left, right) + Assertion::LT(left, right) if left == self && right.as_i64().is_some() => { Some(right.as_i64().unwrap() - 1) } - Assertions::LTE(left, right) + Assertion::LTE(left, right) if left == self && right.as_i64().is_some() => { Some(right.as_i64().unwrap()) @@ -567,12 +567,12 @@ impl TDim { .all_assertions() .iter() .filter_map(|assert| match &assert { - Assertions::GT(left, right) + Assertion::GT(left, right) if left == self && right.as_i64().is_some() => { Some(right.as_i64().unwrap() + 1) } - Assertions::GTE(left, right) + Assertion::GTE(left, right) if left == self && right.as_i64().is_some() => { Some(right.as_i64().unwrap()) @@ -1300,7 +1300,7 @@ mod tests { #[test] fn min_max_with_axiom() { let symbols = SymbolScope::default(); - symbols.add_inequality("a>=0").unwrap(); + symbols.add_assertion("a>=0").unwrap(); assert_eq!(symbols.parse_tdim("min(a,0)").unwrap().simplify(), 0.into()); assert_eq!( symbols.parse_tdim("max(a,0)").unwrap().simplify(), @@ -1310,7 +1310,7 @@ mod tests { #[test] fn low_bound_0() -> TractResult<()> { - let symbols = SymbolScope::default().with_inequality("S>=0")?; + let symbols = SymbolScope::default().with_assertion("S>=0")?; let s = symbols.parse_tdim("S").unwrap(); assert_eq!(s.low_inclusive_bound(), Some(0)); Ok(()) @@ -1318,28 +1318,28 @@ mod tests { #[test] fn low_bound_1() -> TractResult<()> { - let symbols = SymbolScope::default().with_inequality("S>0")?; + let symbols = SymbolScope::default().with_assertion("S>0")?; assert_eq!(symbols.parse_tdim("S").unwrap().low_inclusive_bound(), Some(1)); Ok(()) } #[test] fn low_bound_2() -> TractResult<()> { - let symbols = SymbolScope::default().with_inequality("S>0")?; + let symbols = SymbolScope::default().with_assertion("S>0")?; assert_eq!(symbols.parse_tdim("S + 1").unwrap().low_inclusive_bound(), Some(2)); Ok(()) } #[test] fn low_bound_3() -> TractResult<()> { - let symbols = SymbolScope::default().with_inequality("S>0")?; + let symbols = SymbolScope::default().with_assertion("S>0")?; assert_eq!(symbols.parse_tdim("4*S").unwrap().low_inclusive_bound(), Some(4)); Ok(()) } #[test] fn low_bound_4() -> TractResult<()> { - let symbols = SymbolScope::default().with_inequality("S>0")?.with_inequality("S>5")?; + let symbols = SymbolScope::default().with_assertion("S>0")?.with_assertion("S>5")?; assert_eq!(symbols.parse_tdim("S + 3").unwrap().low_inclusive_bound(), Some(9)); Ok(()) } @@ -1357,7 +1357,7 @@ mod tests { #[test] fn max_bug_1() { let symbols = SymbolScope::default(); - symbols.add_inequality("S>8").unwrap(); + symbols.add_assertion("S>8").unwrap(); assert_eq!( symbols.parse_tdim("max(1,-1+(S+1)/4)").unwrap().simplify(), symbols.parse_tdim("-1+(S+1)/4").unwrap(), @@ -1367,7 +1367,7 @@ mod tests { #[test] fn min_bug_1() { let symbols = SymbolScope::default(); - symbols.add_inequality("S>8").unwrap(); + symbols.add_assertion("S>8").unwrap(); assert_eq!( symbols.parse_tdim("min(1,-1+(S+1)/4)").unwrap().simplify(), symbols.parse_tdim("1").unwrap() @@ -1377,7 +1377,7 @@ mod tests { #[test] fn min_bug_2() { let symbols = SymbolScope::default(); - symbols.add_inequality("S>50").unwrap(); + symbols.add_assertion("S>50").unwrap(); assert_eq!( symbols.parse_tdim("min(-3+2*(S+1)/4,-1+(S+1)/4)").unwrap().simplify(), symbols.parse_tdim("-1+(S+1)/4").unwrap() @@ -1387,8 +1387,8 @@ mod tests { #[test] fn min_bug_3() { let symbols = SymbolScope::default(); - symbols.add_inequality("S>=0").unwrap(); - symbols.add_inequality("P>=0").unwrap(); + symbols.add_assertion("S>=0").unwrap(); + symbols.add_assertion("P>=0").unwrap(); assert_eq!( symbols.parse_tdim("min(0,(S)#(P+S))").unwrap().simplify(), symbols.parse_tdim("0").unwrap() @@ -1399,8 +1399,8 @@ mod tests { #[ignore] fn min_llm_0() { let symbols = SymbolScope::default(); - symbols.add_inequality("S>=0").unwrap(); - symbols.add_inequality("P>=0").unwrap(); + symbols.add_assertion("S>=0").unwrap(); + symbols.add_assertion("P>=0").unwrap(); assert_eq!( symbols.parse_tdim("min(P,(S)#(P+S))").unwrap().simplify(), symbols.parse_tdim("P").unwrap() diff --git a/nnef/src/deser.rs b/nnef/src/deser.rs index beee27379b..f7b9f452fa 100644 --- a/nnef/src/deser.rs +++ b/nnef/src/deser.rs @@ -62,7 +62,7 @@ impl<'mb> ModelBuilder<'mb> { self.symbols.push(symbol); } "tract_assert" => { - self.model.symbols.add_inequality(&ext.1)?; + self.model.symbols.add_assertion(&ext.1)?; } "KHR_enable_fragment_definitions" | "KHR_enable_operator_expressions" => (), _ => { From 11734b7b2d5d2298cd41ed9af6537f2c29a1f550 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 12 Sep 2024 16:31:37 +0200 Subject: [PATCH 05/20] introduce scenarios --- data/src/dim/sym.rs | 53 +++++++++++++++++++++++++++++++++++++------- data/src/dim/tree.rs | 19 +++++++++------- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index 6b2dadb6c2..ec52d2f794 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -1,7 +1,8 @@ use itertools::Itertools; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::fmt::{self, Display}; -use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard, Weak}; +use std::ops::Deref; +use std::sync::{Arc, RwLock, Weak}; use string_interner::DefaultStringInterner; use string_interner::Symbol as _; @@ -25,6 +26,7 @@ impl Eq for SymbolScope {} pub struct SymbolScopeData { table: DefaultStringInterner, assertions: Vec, + scenarios: BTreeMap>, } impl SymbolScope { @@ -72,15 +74,51 @@ impl SymbolScope { Ok(self) } - pub fn all_symbols(&self) -> Vec { - self.0.read().unwrap().table.into_iter().map(|is| Symbol(Arc::downgrade(&self.0), is.0)).collect() - } - pub fn all_assertions(&self) -> Vec { self.0.read().unwrap().assertions.clone() } - pub fn lock(&self) -> Option> { + pub fn add_scenario(&self, scenario: impl Into) -> TractResult<()> { + self.0.write().unwrap().scenarios.insert(scenario.into(), vec![]); + Ok(()) + } + + pub fn add_scenario_assertion( + &self, + scenario: impl Into, + assertion: impl Into, + ) -> TractResult<()> { + let assert = parse_assertion(self, &assertion.into())?; + let s = scenario.into(); + self.0.write().unwrap().scenarios.entry(s).or_default().push(assert); + Ok(()) + } + + pub fn with_scenario_assertion( + self, + scenario: impl Into, + assertion: impl Into, + ) -> TractResult { + self.add_scenario_assertion(scenario, assertion)?; + Ok(self) + } + + pub fn with_scenario(self, scenario: impl Into) -> TractResult { + self.add_scenario(scenario)?; + Ok(self) + } + + pub fn all_symbols(&self) -> Vec { + self.0 + .read() + .unwrap() + .table + .into_iter() + .map(|is| Symbol(Arc::downgrade(&self.0), is.0)) + .collect() + } + + pub fn read(&self) -> Option + '_> { self.0.read().ok() } } @@ -132,7 +170,6 @@ impl SymbolScopeData { } false } - } impl fmt::Debug for SymbolScope { diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index fe2437cf2e..cbaaedd2ee 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -292,7 +292,7 @@ impl TDim { return Val(v); } let scope = self.find_scope(); - let data = scope.as_ref().and_then(|scope| scope.lock()); + let data = scope.as_ref().and_then(|scope| scope.read()); self.simplify_rec(data.as_deref()) } @@ -622,7 +622,7 @@ impl TDim { return Some(*v); } let Some(scope) = self.find_scope() else { return None }; - let Some(data) = scope.lock() else { return None }; + let Some(data) = scope.read() else { return None }; self.inclusive_bound(&*data, false) } @@ -631,7 +631,7 @@ impl TDim { return Some(*v); } let Some(scope) = self.find_scope() else { return None }; - let Some(data) = scope.lock() else { return None }; + let Some(data) = scope.read() else { return None }; self.inclusive_bound(&*data, true) } @@ -640,7 +640,7 @@ impl TDim { return *v >= 0; } let Some(scope) = self.find_scope() else { return false }; - let Some(data) = scope.lock() else { return false }; + let Some(data) = scope.read() else { return false }; data.prove_positive_or_zero(&self) } @@ -1397,13 +1397,16 @@ mod tests { #[test] #[ignore] - fn min_llm_0() { - let symbols = SymbolScope::default(); - symbols.add_assertion("S>=0").unwrap(); - symbols.add_assertion("P>=0").unwrap(); + fn min_llm_0() -> TractResult<()> { + let symbols = SymbolScope::default() + .with_assertion("S>=0")? + .with_assertion("P>=0")? + .with_scenario_assertion("tg", "S=1")? + .with_scenario_assertion("pp", "P=0")?; assert_eq!( symbols.parse_tdim("min(P,(S)#(P+S))").unwrap().simplify(), symbols.parse_tdim("P").unwrap() ); + Ok(()) } } From 5d8ed66777b5bcd8e09b027ab9e46fa94c3836be Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 13 Sep 2024 11:09:44 +0200 Subject: [PATCH 06/20] wip, test deadlock --- data/src/dim/sym.rs | 2 +- data/src/dim/tree.rs | 132 ++++++++++++++++++++----------------------- 2 files changed, 61 insertions(+), 73 deletions(-) diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index ec52d2f794..f77ce70c87 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -118,7 +118,7 @@ impl SymbolScope { .collect() } - pub fn read(&self) -> Option + '_> { + fn read(&self) -> Option + '_> { self.0.read().ok() } } diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index cbaaedd2ee..f48b552c5c 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -993,23 +993,9 @@ mod tests { macro_rules! b( ($e:expr) => { Box::new($e) } ); lazy_static::lazy_static! { - static ref S: (SymbolScope, Symbol) = { - let table = SymbolScope::default(); - let s = table.new_with_prefix("S"); - (table, s) - }; - } - - fn a() -> Symbol { - S.0.sym("a") - } - - fn b() -> Symbol { - S.0.sym("b") - } - - fn s() -> TDim { - S.1.clone().into() + static ref table: SymbolScope = SymbolScope::default(); + static ref A: Symbol = table.sym("a"); + static ref B: Symbol = table.sym("b"); } fn neg(a: &TDim) -> TDim { @@ -1030,50 +1016,54 @@ mod tests { #[test] fn reduce_add() { - assert_eq!(add(&s(), &neg(&s())).reduce(), Val(0)) + assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0)) } #[test] fn reduce_neg_mul() { - assert_eq!(neg(&mul(2, &s())).reduce(), mul(-2, &s())) + assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim())) } #[test] fn reduce_cplx_ex_2() { assert_eq!( - add(&add(&Val(-4), &mul(-2, &div(&s(), 4))), &mul(-2, &mul(-1, &div(&s(), 4)))) - .reduce(), + add( + &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))), + &mul(-2, &mul(-1, &div(&A.to_dim(), 4))) + ) + .reduce(), Val(-4) ) } #[test] fn reduce_cplx_ex_3() { - assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(s())))), 4).reduce(), s()) + assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim()) } #[test] fn reduce_cplx_ex_4() { // (S+1)/2 + (1-S)/2 == 1 assert_eq!( - add(&div(&add(&s(), &Val(1)), 2), &div(&add(&neg(&s()), &Val(1)), 2)).reduce(), + add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2)) + .reduce(), 1.into() ); } #[test] fn reduce_mul_mul_1() { - assert_eq!(mul(3, &mul(2, &s())).reduce(), mul(6, &s())) + assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim())) } #[test] fn reduce_mul_mul_2() { - assert_eq!(mul(-2, &mul(-1, &s())).reduce(), mul(2, &s())) + assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim())) } #[test] fn reduce_mul_div_1() { - assert_eq!(mul(2, &div(&mul(-1, &s()), 3)).reduce(), mul(-2, &div(&s(), 3))) + assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3))) } #[test] @@ -1090,11 +1080,10 @@ mod tests { #[test] fn substitution() { - let x = S.0.sym("x"); - let e: TDim = x.clone().into(); - assert_eq!(e.eval(&SymbolValues::default().with(&x, 2)).to_i64().unwrap(), 2); - let e = e + 3; - assert_eq!(e.eval(&SymbolValues::default().with(&x, 2)).to_i64().unwrap(), 5); + let a: TDim = A.to_dim(); + assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2); + let e = a + 3; + assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5); } #[test] @@ -1111,11 +1100,10 @@ mod tests { #[test] fn reduce_muls() { - let e: TDim = Val(1) * s(); - assert_eq!(e, s()); - let b = S.0.sym("b"); - let e: TDim = s() * &b * 1; - assert_eq!(e, s() * &b); + let e: TDim = Val(1) * A.to_dim(); + assert_eq!(e, A.to_dim()); + let e: TDim = A.to_dim() * &B.to_dim() * 1; + assert_eq!(e, A.to_dim() * &B.to_dim()); } #[test] @@ -1134,69 +1122,69 @@ mod tests { #[test] fn reduce_div_bug_0() { - let e1: TDim = (s() + 23) / 2 - 1; - let e2: TDim = (s() + 21) / 2; + let e1: TDim = (A.to_dim() + 23) / 2 - 1; + let e2: TDim = (A.to_dim() + 21) / 2; assert_eq!(e1, e2); } #[test] fn reduce_div_bug_1() { - let e1: TDim = (s() + -1) / 2; - let e2: TDim = (s() + 1) / 2 - 1; + let e1: TDim = (A.to_dim() + -1) / 2; + let e2: TDim = (A.to_dim() + 1) / 2 - 1; assert_eq!(e1, e2); } #[test] fn reduce_div_bug_2() { - let e1: TDim = ((s() + 1) / 2 + 1) / 2; - let e2: TDim = (s() + 3) / 4; + let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2; + let e2: TDim = (A.to_dim() + 3) / 4; assert_eq!(e1, e2); } #[test] fn reduce_div_bug_3() { - let e1: TDim = (s() / 2) * -4; - let e2: TDim = (s() / 2) * -4 / 1; + let e1: TDim = (A.to_dim() / 2) * -4; + let e2: TDim = (A.to_dim() / 2) * -4 / 1; assert_eq!(e1, e2); } #[test] fn reduce_mul_div() { - let e: TDim = s() * 2 / 2; - assert_eq!(e, s()); + let e: TDim = A.to_dim() * 2 / 2; + assert_eq!(e, A.to_dim()); } #[test] fn reduce_div_mul() { - let e: TDim = s() / 2 * 2; - assert_ne!(e, s()); + let e: TDim = A.to_dim() / 2 * 2; + assert_ne!(e, A.to_dim()); } #[test] fn reduce_add_div() { - let e: TDim = s() / 2 + 1; - assert_eq!(e, ((s() + 2) / 2)); + let e: TDim = A.to_dim() / 2 + 1; + assert_eq!(e, ((A.to_dim() + 2) / 2)); } #[test] fn reduce_neg_mul_() { - let e: TDim = TDim::from(1) - s() * 2; - assert_eq!(e, TDim::from(1) + s() * -2); + let e: TDim = TDim::from(1) - A.to_dim() * 2; + assert_eq!(e, TDim::from(1) + A.to_dim() * -2); } #[test] fn reduce_add_rem_1() { - assert_eq!(((s() + 4) % 2), (s() % 2)); + assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2)); } #[test] fn reduce_add_rem_2() { - assert_eq!(((s() - 4) % 2), (s() % 2)); + assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2)); } #[test] fn reduce_rem_div() { - let e: TDim = s() % 2 / 2; + let e: TDim = A.to_dim() % 2 / 2; assert_eq!(e, TDim::from(0)); } @@ -1208,13 +1196,13 @@ mod tests { #[test] fn conv2d_ex_2() { - let e = (s() - 3 + 1).div_ceil(1); - assert_eq!(e, s() + -2); + let e = (A.to_dim() - 3 + 1).div_ceil(1); + assert_eq!(e, A.to_dim() + -2); } #[test] fn extract_int_gcd_from_muls() { - let term = (s() + 1) / 4; + let term = (A.to_dim() + 1) / 4; let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2); let target = (term.clone() - 1) * (term.clone() - 1) * 48; assert_eq!(mul, target); @@ -1222,7 +1210,7 @@ mod tests { #[test] fn equality_of_muls() { - let term = (s() + 1) / 4; + let term = (A.to_dim() + 1) / 4; let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1); let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3); assert_eq!(mul1, mul2); @@ -1230,7 +1218,7 @@ mod tests { #[test] fn factorize_complex_expr_times_int() { - let term = (s() + 1) / 4; + let term = (A.to_dim() + 1) / 4; let e = term.clone() * 2 - &term - 1; assert_eq!(e, term - 1); } @@ -1247,54 +1235,54 @@ mod tests { #[test] fn min_same() { - assert_eq!(s().mini(s()), s()); + assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim()); } #[test] fn min_noop() { - assert_eq!(s().mini(1.to_dim()), s().mini(1.to_dim())); + assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim())); } #[test] fn min_diff_1() { - assert_eq!((s() + 1).mini(s() + 2), s() + 1); + assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1); } #[test] fn slope_0() { - assert_eq!(12.to_dim().guess_slope(&S.1), (0, 1)); + assert_eq!(12.to_dim().guess_slope(&A), (0, 1)); } #[test] fn slope_1() { - assert_eq!(s().guess_slope(&S.1), (1, 1)); + assert_eq!(A.to_dim().guess_slope(&A), (1, 1)); } #[test] fn slope_2() { - assert_eq!((s() * 2).guess_slope(&S.1), (2, 1)); + assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1)); } #[test] fn slope_3() { - assert_eq!((s() * 2 + s() / 2).guess_slope(&S.1), (5, 2)); + assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2)); } #[test] fn slope_4() { - assert_eq!((a().to_dim()).guess_slope(&b()), (0, 1)); + assert_eq!((A.to_dim()).guess_slope(&B), (0, 1)); } #[test] fn slope_5() { - assert_eq!((a().to_dim() + 1).guess_slope(&a()), (1, 1)); - assert_eq!((a().to_dim() + 1).guess_slope(&b()), (0, 1)); + assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1)); + assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1)); } #[test] fn slope_6() { - assert_eq!((a().to_dim() + 1).guess_slope(&a()), (1, 1)); - assert_eq!((a().to_dim() + b().to_dim()).guess_slope(&b()), (1, 1)); + assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1)); + assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1)); } #[test] From 3e46cf324be8ced9f081109a54bddfef0178788c Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 08:32:48 +0200 Subject: [PATCH 07/20] wip --- data/Cargo.toml | 1 + data/src/dim/sym.rs | 52 ++++++++++++++++++++++++++------------------ data/src/dim/tree.rs | 16 +++++++++----- 3 files changed, 42 insertions(+), 27 deletions(-) diff --git a/data/Cargo.toml b/data/Cargo.toml index 8d5cfdd37b..8008f76e93 100644 --- a/data/Cargo.toml +++ b/data/Cargo.toml @@ -31,6 +31,7 @@ smallvec.workspace = true lazy_static.workspace = true scan_fmt.workspace = true string-interner.workspace = true +parking_lot = "0.12.3" [target.'cfg(not(target_family = "wasm"))'.dev-dependencies] criterion.workspace = true diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index f77ce70c87..05f087aef9 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -1,8 +1,10 @@ use itertools::Itertools; +use parking_lot::ReentrantMutex; +use std::cell::RefCell; use std::collections::{BTreeMap, HashMap}; use std::fmt::{self, Display}; use std::ops::Deref; -use std::sync::{Arc, RwLock, Weak}; +use std::sync::{Arc, Mutex, RwLock, Weak}; use string_interner::DefaultStringInterner; use string_interner::Symbol as _; @@ -12,7 +14,7 @@ use super::parse::parse_assertion; use super::{parse_tdim, TDim}; #[derive(Clone, Default)] -pub struct SymbolScope(Arc>); +pub struct SymbolScope(pub Arc>>); impl PartialEq for SymbolScope { fn eq(&self, other: &Self) -> bool { @@ -31,18 +33,21 @@ pub struct SymbolScopeData { impl SymbolScope { pub fn get(&self, name: &str) -> Option { - let locked = self.0.read().unwrap(); + let locked = self.0.lock(); + let locked = locked.borrow(); locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym)) } pub fn sym(&self, name: &str) -> Symbol { - let mut locked = self.0.write().unwrap(); + let locked = self.0.lock(); + let mut locked = locked.borrow_mut(); let sym = locked.table.get_or_intern(name); Symbol(Arc::downgrade(&self.0), sym) } pub fn new_with_prefix(&self, prefix: &str) -> Symbol { - let mut locked = self.0.write().unwrap(); + let locked = self.0.lock(); + let mut locked = locked.borrow_mut(); let sym = if locked.table.get(prefix).is_none() { locked.table.get_or_intern(prefix) } else { @@ -65,7 +70,9 @@ impl SymbolScope { pub fn add_assertion(&self, assert: impl Into) -> TractResult<()> { let assert = assert.into(); let assert = parse_assertion(self, &assert)?; - self.0.write().unwrap().assertions.push(assert); + let locked = self.0.lock(); + let mut locked = locked.borrow_mut(); + locked.assertions.push(assert); Ok(()) } @@ -75,11 +82,15 @@ impl SymbolScope { } pub fn all_assertions(&self) -> Vec { - self.0.read().unwrap().assertions.clone() + let locked = self.0.lock(); + let locked = locked.borrow(); + locked.assertions.clone() } pub fn add_scenario(&self, scenario: impl Into) -> TractResult<()> { - self.0.write().unwrap().scenarios.insert(scenario.into(), vec![]); + let locked = self.0.lock(); + let mut locked = locked.borrow_mut(); + locked.scenarios.insert(scenario.into(), vec![]); Ok(()) } @@ -90,7 +101,9 @@ impl SymbolScope { ) -> TractResult<()> { let assert = parse_assertion(self, &assertion.into())?; let s = scenario.into(); - self.0.write().unwrap().scenarios.entry(s).or_default().push(assert); + let locked = self.0.lock(); + let mut locked = locked.borrow_mut(); + locked.scenarios.entry(s).or_default().push(assert); Ok(()) } @@ -110,17 +123,13 @@ impl SymbolScope { pub fn all_symbols(&self) -> Vec { self.0 - .read() - .unwrap() + .lock() + .borrow() .table .into_iter() .map(|is| Symbol(Arc::downgrade(&self.0), is.0)) .collect() } - - fn read(&self) -> Option + '_> { - self.0.read().ok() - } } impl SymbolScopeData { @@ -174,7 +183,8 @@ impl SymbolScopeData { impl fmt::Debug for SymbolScope { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let locked = self.0.read().unwrap(); + let locked = self.0.lock(); + let locked = locked.borrow(); write!(f, "{}", locked.table.into_iter().map(|(_, s)| s).join(" ")) } } @@ -213,7 +223,7 @@ impl Assertion { } #[derive(Clone)] -pub struct Symbol(Weak>, string_interner::DefaultSymbol); +pub struct Symbol(Weak>>, string_interner::DefaultSymbol); impl Eq for Symbol {} @@ -250,10 +260,10 @@ impl std::hash::Hash for Symbol { impl std::fmt::Display for Symbol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(scope) = self.scope() { - if let Ok(lock) = scope.0.read() { - if let Some(s) = lock.table.resolve(self.1) { - return write!(f, "{}", s); - } + let lock = scope.0.lock(); + let lock = lock.borrow(); + if let Some(s) = lock.table.resolve(self.1) { + return write!(f, "{}", s); } } write!(f, "", self.1.to_usize()) diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index f48b552c5c..cb3cd5d50e 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -291,9 +291,13 @@ impl TDim { if let Val(v) = self { return Val(v); } - let scope = self.find_scope(); - let data = scope.as_ref().and_then(|scope| scope.read()); - self.simplify_rec(data.as_deref()) + if let Some(scope) = self.find_scope() { + let locked = scope.0.lock(); + let borrow = locked.borrow(); + self.simplify_rec(Some(&borrow)) + } else { + self + } } fn simplify_rec(self, scope: Option<&SymbolScopeData>) -> TDim { @@ -622,7 +626,7 @@ impl TDim { return Some(*v); } let Some(scope) = self.find_scope() else { return None }; - let Some(data) = scope.read() else { return None }; + let Some(data) = scope.0.lock().ok() else { return None }; self.inclusive_bound(&*data, false) } @@ -631,7 +635,7 @@ impl TDim { return Some(*v); } let Some(scope) = self.find_scope() else { return None }; - let Some(data) = scope.read() else { return None }; + let Some(data) = scope.0.lock().ok() else { return None }; self.inclusive_bound(&*data, true) } @@ -640,7 +644,7 @@ impl TDim { return *v >= 0; } let Some(scope) = self.find_scope() else { return false }; - let Some(data) = scope.read() else { return false }; + let Some(data) = scope.0.lock().ok() else { return false }; data.prove_positive_or_zero(&self) } From ae6cd6d803dbe57961c99ce9461ebc0f5a7f5907 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 08:52:31 +0200 Subject: [PATCH 08/20] reentrant mutex --- data/src/dim/sym.rs | 3 +-- data/src/dim/tree.rs | 13 ++++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index 05f087aef9..621e4b7f9c 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -3,8 +3,7 @@ use parking_lot::ReentrantMutex; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap}; use std::fmt::{self, Display}; -use std::ops::Deref; -use std::sync::{Arc, Mutex, RwLock, Weak}; +use std::sync::{Arc, Weak}; use string_interner::DefaultStringInterner; use string_interner::Symbol as _; diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index cb3cd5d50e..9621303bba 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -288,7 +288,7 @@ impl TDim { pub fn simplify(self) -> TDim { use self::TDim::*; - if let Val(v) = self { + if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) { return Val(v); } if let Some(scope) = self.find_scope() { @@ -626,8 +626,9 @@ impl TDim { return Some(*v); } let Some(scope) = self.find_scope() else { return None }; - let Some(data) = scope.0.lock().ok() else { return None }; - self.inclusive_bound(&*data, false) + let data = scope.0.lock(); + let data = data.borrow(); + self.inclusive_bound(&data, false) } pub fn high_inclusive_bound(&self) -> Option { @@ -635,7 +636,8 @@ impl TDim { return Some(*v); } let Some(scope) = self.find_scope() else { return None }; - let Some(data) = scope.0.lock().ok() else { return None }; + let data = scope.0.lock(); + let data = data.borrow(); self.inclusive_bound(&*data, true) } @@ -644,7 +646,8 @@ impl TDim { return *v >= 0; } let Some(scope) = self.find_scope() else { return false }; - let Some(data) = scope.0.lock().ok() else { return false }; + let data = scope.0.lock(); + let data = data.borrow(); data.prove_positive_or_zero(&self) } From 22c900cba27ff0123ae2b0437c188093d8b1ffe8 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 08:56:58 +0200 Subject: [PATCH 09/20] scope is not optional --- data/src/dim/tree.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index 9621303bba..dbbde7f01f 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -289,18 +289,17 @@ impl TDim { pub fn simplify(self) -> TDim { use self::TDim::*; if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) { - return Val(v); - } - if let Some(scope) = self.find_scope() { + Val(v) + } else if let Some(scope) = self.find_scope() { let locked = scope.0.lock(); let borrow = locked.borrow(); - self.simplify_rec(Some(&borrow)) + self.simplify_rec(&borrow) } else { self } } - fn simplify_rec(self, scope: Option<&SymbolScopeData>) -> TDim { + fn simplify_rec(self, scope: &SymbolScopeData) -> TDim { match self { Add(mut terms) => { #[allow(clippy::mutable_key_type)] @@ -493,8 +492,7 @@ impl TDim { continue; } let diff = a.clone() - b; - if diff.as_i64().is_some_and(|i| i >= 0) - || scope.is_some_and(|scope| scope.prove_positive_or_zero(&diff)) + if diff.as_i64().is_some_and(|i| i >= 0) || scope.prove_positive_or_zero(&diff) { redundant.insert(a.clone()); } @@ -524,8 +522,7 @@ impl TDim { continue; } let diff = a.clone() - b; - if diff.as_i64().is_some_and(|i| i >= 0) - || scope.is_some_and(|scope| scope.prove_positive_or_zero(&diff)) + if diff.as_i64().is_some_and(|i| i >= 0) || scope.prove_positive_or_zero(&diff) { redundant.insert(b.clone()); } From 13aff17935abe5ef4575d343139c57375e409a0b Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 09:55:30 +0200 Subject: [PATCH 10/20] adjust api --- core/src/ops/logic/comparison.rs | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/core/src/ops/logic/comparison.rs b/core/src/ops/logic/comparison.rs index 11d409b8ec..ccf87582b1 100644 --- a/core/src/ops/logic/comparison.rs +++ b/core/src/ops/logic/comparison.rs @@ -73,12 +73,6 @@ impl EvalOp for Comp { if let (Ok(a), Ok(b)) = (a.cast_to::(), b.cast_to::()) { return Ok(tvec!(self.eval::(&a, &b)?.into_tvalue())); } - let scope = a - .as_slice::()? - .iter() - .chain(b.as_slice::().unwrap().iter()) - .find_map(|d| d.find_scope()) - .unwrap(); let a = inputs[0].to_array_view::()?; let b = inputs[0].to_array_view::()?; let shape = multi_broadcast(&[a.shape(), b.shape()])?; @@ -92,36 +86,36 @@ impl EvalOp for Comp { Eq => a == b, NE => a != b, GTE => { - if scope.prove_positive_or_zero(&(a.clone() - b)) { + if (a.clone() - b).prove_positive_or_zero() { true - } else if scope.prove_positive_or_zero(&(b.clone() - a - 1)) { + } else if (b.clone() - a - 1).prove_positive_or_zero() { false } else { bail!(UndeterminedSymbol(a.clone() - b)); } } GT => { - if scope.prove_positive_or_zero(&(a.clone() - b - 1)) { + if (a.clone() - b - 1).prove_positive_or_zero() { true - } else if scope.prove_positive_or_zero(&(b.clone() - a)) { + } else if (b.clone() - a).prove_positive_or_zero() { false } else { bail!(UndeterminedSymbol(a.clone() - b)); } } LTE => { - if scope.prove_positive_or_zero(&(b.clone() - a)) { + if (b.clone() - a).prove_positive_or_zero() { true - } else if scope.prove_positive_or_zero(&(a.clone() - b - 1)) { + } else if (a.clone() - b - 1).prove_positive_or_zero() { false } else { bail!(UndeterminedSymbol(a.clone() - b)); } } LT => { - if scope.prove_positive_or_zero(&(b.clone() - a - 1)) { + if (b.clone() - a - 1).prove_positive_or_zero() { true - } else if scope.prove_positive_or_zero(&(a.clone() - b)) { + } else if (a.clone() - b).prove_positive_or_zero() { false } else { bail!(UndeterminedSymbol(a.clone() - b)); From 08ebe839aed7af5173352d79b308bff971d27bcb Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 10:05:19 +0200 Subject: [PATCH 11/20] derive other sign proovers in data --- core/src/ops/logic/comparison.rs | 25 +++++++++++++------------ data/src/dim/tree.rs | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/core/src/ops/logic/comparison.rs b/core/src/ops/logic/comparison.rs index ccf87582b1..b2dae05760 100644 --- a/core/src/ops/logic/comparison.rs +++ b/core/src/ops/logic/comparison.rs @@ -82,43 +82,44 @@ impl EvalOp for Comp { let b = b.broadcast(&*shape).unwrap(); for ixs in tract_ndarray::indices(&*shape) { let (a, b) = (&a[&ixs], &b[&ixs]); + let diff = a.clone() - b; view[&ixs] = match *self { Eq => a == b, NE => a != b, GTE => { - if (a.clone() - b).prove_positive_or_zero() { + if diff.prove_positive_or_zero() { true - } else if (b.clone() - a - 1).prove_positive_or_zero() { + } else if diff.prove_strict_negative() { false } else { - bail!(UndeterminedSymbol(a.clone() - b)); + bail!(UndeterminedSymbol(diff)); } } GT => { - if (a.clone() - b - 1).prove_positive_or_zero() { + if diff.prove_strict_positive() { true - } else if (b.clone() - a).prove_positive_or_zero() { + } else if diff.prove_negative_or_zero() { false } else { - bail!(UndeterminedSymbol(a.clone() - b)); + bail!(UndeterminedSymbol(diff)); } } LTE => { - if (b.clone() - a).prove_positive_or_zero() { + if diff.prove_negative_or_zero() { true - } else if (a.clone() - b - 1).prove_positive_or_zero() { + } else if diff.prove_strict_positive() { false } else { - bail!(UndeterminedSymbol(a.clone() - b)); + bail!(UndeterminedSymbol(diff)); } } LT => { - if (b.clone() - a - 1).prove_positive_or_zero() { + if diff.prove_strict_negative() { true - } else if (a.clone() - b).prove_positive_or_zero() { + } else if diff.prove_negative_or_zero() { false } else { - bail!(UndeterminedSymbol(a.clone() - b)); + bail!(UndeterminedSymbol(diff)); } } }; diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index dbbde7f01f..eb193671b4 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -7,6 +7,7 @@ use num_traits::{AsPrimitive, PrimInt, Zero}; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; +use std::ops::Neg; use std::{fmt, ops}; #[derive(Debug)] @@ -648,6 +649,27 @@ impl TDim { data.prove_positive_or_zero(&self) } + pub fn prove_strict_positive(&self) -> bool { + if let TDim::Val(v) = self { + return *v > 0; + } + (self.clone() - 1).prove_positive_or_zero() + } + + pub fn prove_negative_or_zero(&self) -> bool { + if let TDim::Val(v) = self { + return *v <= 0; + } + self.clone().neg().prove_positive_or_zero() + } + + pub fn prove_strict_negative(&self) -> bool { + if let TDim::Val(v) = self { + return *v < 0; + } + self.clone().neg().prove_strict_positive() + } + pub fn gcd(&self) -> u64 { use self::TDim::*; match self { From e98f504425ab2e79fd2f067a75c3b4004e43db94 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 13:57:45 +0200 Subject: [PATCH 12/20] create fast paths for trivial tdims --- data/src/dim/mod.rs | 8 +++++- data/src/dim/tree.rs | 66 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/data/src/dim/mod.rs b/data/src/dim/mod.rs index 7ef9e4e03e..34a89ac692 100644 --- a/data/src/dim/mod.rs +++ b/data/src/dim/mod.rs @@ -163,7 +163,13 @@ impl DimLike for TDim { } fn broadcast(self, other: Self) -> TractResult { - Ok(TDim::Broadcast(vec![self, other]).simplify()) + if self.is_one() { + Ok(other) + } else if other.is_one() { + Ok(self) + } else { + Ok(TDim::Broadcast(vec![self, other]).simplify()) + } } fn compatible_with(&self, other: &Self) -> bool { diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index eb193671b4..26af9f09c0 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -91,7 +91,7 @@ impl fmt::Display for TDim { impl TDim { #[inline] pub fn is_one(&self) -> bool { - self == &Val(1) + matches!(self, Val(1)) } #[inline] @@ -881,13 +881,25 @@ impl<'a> From<&'a Symbol> for TDim { impl ops::Neg for TDim { type Output = Self; fn neg(self) -> Self { - TDim::MulInt(-1, Box::new(self)).reduce() + if let Val(v) = self { + Val(-v) + } else { + TDim::MulInt(-1, Box::new(self)).reduce() + } } } impl<'a> ops::AddAssign<&'a TDim> for TDim { fn add_assign(&mut self, rhs: &'a TDim) { - *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce() + if rhs.is_zero() { + () + } else if self.is_zero() { + *self = rhs.clone(); + } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { + *s += o; + } else { + *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce() + } } } @@ -897,7 +909,15 @@ where { fn add_assign(&mut self, rhs: I) { let rhs = rhs.into(); - *self += &rhs + if rhs.is_zero() { + () + } else if self.is_zero() { + *self = rhs; + } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { + *s += o; + } else { + *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce() + } } } @@ -924,7 +944,15 @@ impl<'a> ops::Add<&'a TDim> for TDim { impl<'a> ops::SubAssign<&'a TDim> for TDim { fn sub_assign(&mut self, rhs: &'a TDim) { use std::ops::Neg; - *self += rhs.clone().neg() + if rhs.is_zero() { + () + } else if self.is_zero() { + *self = rhs.clone().neg(); + } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { + *s -= o; + } else { + *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce() + } } } @@ -933,7 +961,16 @@ where I: Into, { fn sub_assign(&mut self, rhs: I) { - *self -= &rhs.into() + let rhs = rhs.into(); + if rhs.is_zero() { + () + } else if self.is_zero() { + *self = rhs.neg(); + } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { + *s -= o; + } else { + *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce() + } } } @@ -958,13 +995,26 @@ impl<'a> ops::Sub<&'a TDim> for TDim { impl> ops::MulAssign for TDim { fn mul_assign(&mut self, rhs: I) { - *self = TDim::Mul(vec![rhs.into(), std::mem::take(self)]).reduce() + let rhs = rhs.into(); + if self.is_one() { + *self = rhs + } else if rhs.is_one() { + () + } else { + *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce() + } } } impl<'a> ops::MulAssign<&'a TDim> for TDim { fn mul_assign(&mut self, rhs: &'a TDim) { - *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce() + if self.is_one() { + *self = rhs.clone() + } else if rhs.is_one() { + () + } else { + *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce() + } } } From 001d7b04ee354bc026a3973f80bef2e2f5c2c409 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 14:11:41 +0200 Subject: [PATCH 13/20] clips --- data/src/dim/tree.rs | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index 26af9f09c0..cbafa60698 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -623,7 +623,7 @@ impl TDim { if let TDim::Val(v) = self { return Some(*v); } - let Some(scope) = self.find_scope() else { return None }; + let scope = self.find_scope()?; let data = scope.0.lock(); let data = data.borrow(); self.inclusive_bound(&data, false) @@ -633,10 +633,10 @@ impl TDim { if let TDim::Val(v) = self { return Some(*v); } - let Some(scope) = self.find_scope() else { return None }; + let scope = self.find_scope()?; let data = scope.0.lock(); let data = data.borrow(); - self.inclusive_bound(&*data, true) + self.inclusive_bound(&data, true) } pub fn prove_positive_or_zero(&self) -> bool { @@ -646,7 +646,7 @@ impl TDim { let Some(scope) = self.find_scope() else { return false }; let data = scope.0.lock(); let data = data.borrow(); - data.prove_positive_or_zero(&self) + data.prove_positive_or_zero(self) } pub fn prove_strict_positive(&self) -> bool { @@ -792,16 +792,16 @@ pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) { impl Zero for TDim { fn zero() -> Self { - Self::from(0) + Val(0) } fn is_zero(&self) -> bool { - *self == Self::zero() + matches!(self, Val(0)) } } impl Default for TDim { fn default() -> TDim { - TDim::zero() + Val(0) } } @@ -892,7 +892,6 @@ impl ops::Neg for TDim { impl<'a> ops::AddAssign<&'a TDim> for TDim { fn add_assign(&mut self, rhs: &'a TDim) { if rhs.is_zero() { - () } else if self.is_zero() { *self = rhs.clone(); } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { @@ -910,7 +909,6 @@ where fn add_assign(&mut self, rhs: I) { let rhs = rhs.into(); if rhs.is_zero() { - () } else if self.is_zero() { *self = rhs; } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { @@ -945,7 +943,6 @@ impl<'a> ops::SubAssign<&'a TDim> for TDim { fn sub_assign(&mut self, rhs: &'a TDim) { use std::ops::Neg; if rhs.is_zero() { - () } else if self.is_zero() { *self = rhs.clone().neg(); } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { @@ -963,7 +960,6 @@ where fn sub_assign(&mut self, rhs: I) { let rhs = rhs.into(); if rhs.is_zero() { - () } else if self.is_zero() { *self = rhs.neg(); } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { @@ -999,7 +995,6 @@ impl> ops::MulAssign for TDim { if self.is_one() { *self = rhs } else if rhs.is_one() { - () } else { *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce() } @@ -1011,7 +1006,6 @@ impl<'a> ops::MulAssign<&'a TDim> for TDim { if self.is_one() { *self = rhs.clone() } else if rhs.is_one() { - () } else { *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce() } From 60f92dd44cd1fdc985060bb6019787081f257b6f Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 14:25:09 +0200 Subject: [PATCH 14/20] redundant import --- data/src/dim/tree.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index cbafa60698..f0c25c0031 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -941,7 +941,6 @@ impl<'a> ops::Add<&'a TDim> for TDim { #[allow(clippy::suspicious_op_assign_impl)] impl<'a> ops::SubAssign<&'a TDim> for TDim { fn sub_assign(&mut self, rhs: &'a TDim) { - use std::ops::Neg; if rhs.is_zero() { } else if self.is_zero() { *self = rhs.clone().neg(); From 0b0733e3f16d334aa60ee70fc0edbcd2923bd581 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 15:20:15 +0200 Subject: [PATCH 15/20] split assertions --- data/src/dim/assertion.rs | 57 +++++++++++++++++++++++++++++++++++++++ data/src/dim/mod.rs | 2 ++ data/src/dim/parse.rs | 1 - data/src/dim/sym.rs | 51 +---------------------------------- data/src/dim/tree.rs | 2 +- 5 files changed, 61 insertions(+), 52 deletions(-) create mode 100644 data/src/dim/assertion.rs diff --git a/data/src/dim/assertion.rs b/data/src/dim/assertion.rs new file mode 100644 index 0000000000..28984da134 --- /dev/null +++ b/data/src/dim/assertion.rs @@ -0,0 +1,57 @@ +use fmt::Display; + +use super::*; + +#[derive(Debug, PartialEq, Clone, Hash)] +#[allow(clippy::upper_case_acronyms)] +pub enum Assertion { + LT(TDim, TDim), + GT(TDim, TDim), + LTE(TDim, TDim), + GTE(TDim, TDim), +} + +impl Display for Assertion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use Assertion::*; + match self { + LT(l, r) => write!(f, "{l} < {r}"), + GT(l, r) => write!(f, "{l} > {r}"), + LTE(l, r) => write!(f, "{l} <= {r}"), + GTE(l, r) => write!(f, "{l} >= {r}"), + } + } +} + +impl Assertion { + pub fn as_known_positive(&self) -> Option { + use Assertion::*; + match self { + GTE(left, right) => Some(left.clone() - right), + GT(left, right) => Some(left.clone() - 1 - right), + LTE(left, right) => Some(right.clone() - left), + LT(left, right) => Some(right.clone() - 1 - left), + } + } +} + + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn prove_positive_with_axiom() { + let s = SymbolScope::default(); + s.add_assertion("s>=0").unwrap(); + assert!(s.parse_tdim("s").unwrap().prove_positive_or_zero()); + } + + #[test] + fn prove_positive_with_axiom_2() { + let s = SymbolScope::default(); + s.add_assertion("s>=0").unwrap(); + s.add_assertion("p>=0").unwrap(); + s.add_assertion("p+s<4096").unwrap(); + assert!(s.parse_tdim("4096-p").unwrap().prove_positive_or_zero()); + } +} diff --git a/data/src/dim/mod.rs b/data/src/dim/mod.rs index 34a89ac692..db61e89c49 100644 --- a/data/src/dim/mod.rs +++ b/data/src/dim/mod.rs @@ -4,11 +4,13 @@ use num_traits::Zero; use std::fmt; use std::ops; +mod assertion; mod parse; mod resolve; mod sym; mod tree; +pub use self::assertion::Assertion; pub use self::parse::parse_tdim; pub use self::resolve::solve_for; pub use self::sym::{Symbol, SymbolScope, SymbolValues}; diff --git a/data/src/dim/parse.rs b/data/src/dim/parse.rs index 1a3e1b14d6..0900876047 100644 --- a/data/src/dim/parse.rs +++ b/data/src/dim/parse.rs @@ -6,7 +6,6 @@ use nom::combinator::{all_consuming, map, map_res, recognize}; use nom::multi::{many0, separated_list0}; use nom::sequence::{delimited, pair, preceded, separated_pair}; use nom::IResult; -use sym::Assertion; pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult { match all_consuming(|i| expr(symbol_table, i))(input) { diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index 621e4b7f9c..24c9f013e4 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -10,7 +10,7 @@ use string_interner::Symbol as _; use crate::TractResult; use super::parse::parse_assertion; -use super::{parse_tdim, TDim}; +use super::{parse_tdim, Assertion, TDim}; #[derive(Clone, Default)] pub struct SymbolScope(pub Arc>>); @@ -188,39 +188,6 @@ impl fmt::Debug for SymbolScope { } } -#[derive(Debug, PartialEq, Clone, Hash)] -#[allow(clippy::upper_case_acronyms)] -pub enum Assertion { - LT(TDim, TDim), - GT(TDim, TDim), - LTE(TDim, TDim), - GTE(TDim, TDim), -} - -impl Display for Assertion { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use Assertion::*; - match self { - LT(l, r) => write!(f, "{l} < {r}"), - GT(l, r) => write!(f, "{l} > {r}"), - LTE(l, r) => write!(f, "{l} <= {r}"), - GTE(l, r) => write!(f, "{l} >= {r}"), - } - } -} - -impl Assertion { - pub fn as_known_positive(&self) -> Option { - use Assertion::*; - match self { - GTE(left, right) => Some(left.clone() - right), - GT(left, right) => Some(left.clone() - 1 - right), - LTE(left, right) => Some(right.clone() - left), - LT(left, right) => Some(right.clone() - 1 - left), - } - } -} - #[derive(Clone)] pub struct Symbol(Weak>>, string_interner::DefaultSymbol); @@ -358,20 +325,4 @@ mod tests { let s = SymbolScope::default(); assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero()); } - - #[test] - fn prove_positive_with_axiom() { - let s = SymbolScope::default(); - s.add_assertion("s>=0").unwrap(); - assert!(s.parse_tdim("s").unwrap().prove_positive_or_zero()); - } - - #[test] - fn prove_positive_with_axiom_2() { - let s = SymbolScope::default(); - s.add_assertion("s>=0").unwrap(); - s.add_assertion("p>=0").unwrap(); - s.add_assertion("p+s<4096").unwrap(); - assert!(s.parse_tdim("4096-p").unwrap().prove_positive_or_zero()); - } } diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index f0c25c0031..38b0055ab5 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -1,3 +1,4 @@ +use crate::dim::Assertion; use crate::internal::*; use super::{sym::*, DimLike}; @@ -1453,7 +1454,6 @@ mod tests { } #[test] - #[ignore] fn min_llm_0() -> TractResult<()> { let symbols = SymbolScope::default() .with_assertion("S>=0")? From 4d2d29db2240a9422d454e65624e2c67a53075f1 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 15:25:01 +0200 Subject: [PATCH 16/20] parse equal in assertions --- data/src/dim/assertion.rs | 105 ++++++++++++++++++++++++++++++++++++++ data/src/dim/parse.rs | 7 ++- data/src/dim/tree.rs | 100 ------------------------------------ 3 files changed, 110 insertions(+), 102 deletions(-) diff --git a/data/src/dim/assertion.rs b/data/src/dim/assertion.rs index 28984da134..4bf21c1b33 100644 --- a/data/src/dim/assertion.rs +++ b/data/src/dim/assertion.rs @@ -5,6 +5,7 @@ use super::*; #[derive(Debug, PartialEq, Clone, Hash)] #[allow(clippy::upper_case_acronyms)] pub enum Assertion { + Eq(TDim, TDim), LT(TDim, TDim), GT(TDim, TDim), LTE(TDim, TDim), @@ -15,6 +16,7 @@ impl Display for Assertion { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use Assertion::*; match self { + Eq(l, r) => write!(f, "{l} == {r}"), LT(l, r) => write!(f, "{l} < {r}"), GT(l, r) => write!(f, "{l} > {r}"), LTE(l, r) => write!(f, "{l} <= {r}"), @@ -27,6 +29,7 @@ impl Assertion { pub fn as_known_positive(&self) -> Option { use Assertion::*; match self { + Eq(left, right) => Some(left.clone() - right), GTE(left, right) => Some(left.clone() - right), GT(left, right) => Some(left.clone() - 1 - right), LTE(left, right) => Some(right.clone() - left), @@ -54,4 +57,106 @@ mod tests { s.add_assertion("p+s<4096").unwrap(); assert!(s.parse_tdim("4096-p").unwrap().prove_positive_or_zero()); } + + #[test] + fn min_max_with_axiom() { + let symbols = SymbolScope::default(); + symbols.add_assertion("a>=0").unwrap(); + assert_eq!(symbols.parse_tdim("min(a,0)").unwrap().simplify(), 0.into()); + assert_eq!( + symbols.parse_tdim("max(a,0)").unwrap().simplify(), + symbols.parse_tdim("a").unwrap() + ); + } + + #[test] + fn low_bound_0() -> TractResult<()> { + let symbols = SymbolScope::default().with_assertion("S>=0")?; + let s = symbols.parse_tdim("S").unwrap(); + assert_eq!(s.low_inclusive_bound(), Some(0)); + Ok(()) + } + + #[test] + fn low_bound_1() -> TractResult<()> { + let symbols = SymbolScope::default().with_assertion("S>0")?; + assert_eq!(symbols.parse_tdim("S").unwrap().low_inclusive_bound(), Some(1)); + Ok(()) + } + + #[test] + fn low_bound_2() -> TractResult<()> { + let symbols = SymbolScope::default().with_assertion("S>0")?; + assert_eq!(symbols.parse_tdim("S + 1").unwrap().low_inclusive_bound(), Some(2)); + Ok(()) + } + + #[test] + fn low_bound_3() -> TractResult<()> { + let symbols = SymbolScope::default().with_assertion("S>0")?; + assert_eq!(symbols.parse_tdim("4*S").unwrap().low_inclusive_bound(), Some(4)); + Ok(()) + } + + #[test] + fn low_bound_4() -> TractResult<()> { + let symbols = SymbolScope::default().with_assertion("S>0")?.with_assertion("S>5")?; + assert_eq!(symbols.parse_tdim("S + 3").unwrap().low_inclusive_bound(), Some(9)); + Ok(()) + } + + #[test] + fn max_bug_1() { + let symbols = SymbolScope::default(); + symbols.add_assertion("S>8").unwrap(); + assert_eq!( + symbols.parse_tdim("max(1,-1+(S+1)/4)").unwrap().simplify(), + symbols.parse_tdim("-1+(S+1)/4").unwrap(), + ); + } + + #[test] + fn min_bug_1() { + let symbols = SymbolScope::default(); + symbols.add_assertion("S>8").unwrap(); + assert_eq!( + symbols.parse_tdim("min(1,-1+(S+1)/4)").unwrap().simplify(), + symbols.parse_tdim("1").unwrap() + ); + } + + #[test] + fn min_bug_2() { + let symbols = SymbolScope::default(); + symbols.add_assertion("S>50").unwrap(); + assert_eq!( + symbols.parse_tdim("min(-3+2*(S+1)/4,-1+(S+1)/4)").unwrap().simplify(), + symbols.parse_tdim("-1+(S+1)/4").unwrap() + ); + } + + #[test] + fn min_bug_3() { + let symbols = SymbolScope::default(); + symbols.add_assertion("S>=0").unwrap(); + symbols.add_assertion("P>=0").unwrap(); + assert_eq!( + symbols.parse_tdim("min(0,(S)#(P+S))").unwrap().simplify(), + symbols.parse_tdim("0").unwrap() + ); + } + + #[test] + fn min_llm_0() -> TractResult<()> { + let symbols = SymbolScope::default() + .with_assertion("S>=0")? + .with_assertion("P>=0")? + .with_scenario_assertion("tg", "S==1")? + .with_scenario_assertion("pp", "P==0")?; + assert_eq!( + symbols.parse_tdim("min(P,(S)#(P+S))").unwrap().simplify(), + symbols.parse_tdim("P").unwrap() + ); + Ok(()) + } } diff --git a/data/src/dim/parse.rs b/data/src/dim/parse.rs index 0900876047..f23f7c1820 100644 --- a/data/src/dim/parse.rs +++ b/data/src/dim/parse.rs @@ -15,14 +15,17 @@ pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult } pub fn parse_assertion(symbol_table: &SymbolScope, input: &str) -> TractResult { - match all_consuming(|i| inequality(symbol_table, i))(input) { + match all_consuming(|i| assertion(symbol_table, i))(input) { Ok(pair) => Ok(pair.1), Err(e) => bail!("Failed to parse {:?}, {:?}", input, e), } } -fn inequality<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Assertion> { +fn assertion<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Assertion> { alt(( + map(separated_pair(|i| expr(s, i), stag("=="), |i| expr(s, i)), |(a, b)| { + Assertion::Eq(a, b) + }), map(separated_pair(|i| expr(s, i), stag("<="), |i| expr(s, i)), |(a, b)| { Assertion::LTE(a, b) }), diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index 38b0055ab5..81122b8d72 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -1355,52 +1355,6 @@ mod tests { assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1)); } - #[test] - fn min_max_with_axiom() { - let symbols = SymbolScope::default(); - symbols.add_assertion("a>=0").unwrap(); - assert_eq!(symbols.parse_tdim("min(a,0)").unwrap().simplify(), 0.into()); - assert_eq!( - symbols.parse_tdim("max(a,0)").unwrap().simplify(), - symbols.parse_tdim("a").unwrap() - ); - } - - #[test] - fn low_bound_0() -> TractResult<()> { - let symbols = SymbolScope::default().with_assertion("S>=0")?; - let s = symbols.parse_tdim("S").unwrap(); - assert_eq!(s.low_inclusive_bound(), Some(0)); - Ok(()) - } - - #[test] - fn low_bound_1() -> TractResult<()> { - let symbols = SymbolScope::default().with_assertion("S>0")?; - assert_eq!(symbols.parse_tdim("S").unwrap().low_inclusive_bound(), Some(1)); - Ok(()) - } - - #[test] - fn low_bound_2() -> TractResult<()> { - let symbols = SymbolScope::default().with_assertion("S>0")?; - assert_eq!(symbols.parse_tdim("S + 1").unwrap().low_inclusive_bound(), Some(2)); - Ok(()) - } - - #[test] - fn low_bound_3() -> TractResult<()> { - let symbols = SymbolScope::default().with_assertion("S>0")?; - assert_eq!(symbols.parse_tdim("4*S").unwrap().low_inclusive_bound(), Some(4)); - Ok(()) - } - - #[test] - fn low_bound_4() -> TractResult<()> { - let symbols = SymbolScope::default().with_assertion("S>0")?.with_assertion("S>5")?; - assert_eq!(symbols.parse_tdim("S + 3").unwrap().low_inclusive_bound(), Some(9)); - Ok(()) - } #[test] fn min_0() -> TractResult<()> { @@ -1412,58 +1366,4 @@ mod tests { Ok(()) } - #[test] - fn max_bug_1() { - let symbols = SymbolScope::default(); - symbols.add_assertion("S>8").unwrap(); - assert_eq!( - symbols.parse_tdim("max(1,-1+(S+1)/4)").unwrap().simplify(), - symbols.parse_tdim("-1+(S+1)/4").unwrap(), - ); - } - - #[test] - fn min_bug_1() { - let symbols = SymbolScope::default(); - symbols.add_assertion("S>8").unwrap(); - assert_eq!( - symbols.parse_tdim("min(1,-1+(S+1)/4)").unwrap().simplify(), - symbols.parse_tdim("1").unwrap() - ); - } - - #[test] - fn min_bug_2() { - let symbols = SymbolScope::default(); - symbols.add_assertion("S>50").unwrap(); - assert_eq!( - symbols.parse_tdim("min(-3+2*(S+1)/4,-1+(S+1)/4)").unwrap().simplify(), - symbols.parse_tdim("-1+(S+1)/4").unwrap() - ); - } - - #[test] - fn min_bug_3() { - let symbols = SymbolScope::default(); - symbols.add_assertion("S>=0").unwrap(); - symbols.add_assertion("P>=0").unwrap(); - assert_eq!( - symbols.parse_tdim("min(0,(S)#(P+S))").unwrap().simplify(), - symbols.parse_tdim("0").unwrap() - ); - } - - #[test] - fn min_llm_0() -> TractResult<()> { - let symbols = SymbolScope::default() - .with_assertion("S>=0")? - .with_assertion("P>=0")? - .with_scenario_assertion("tg", "S=1")? - .with_scenario_assertion("pp", "P=0")?; - assert_eq!( - symbols.parse_tdim("min(P,(S)#(P+S))").unwrap().simplify(), - symbols.parse_tdim("P").unwrap() - ); - Ok(()) - } } From 238be30917bbe97c7ceac268ae761dad37e7e32f Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 16:19:17 +0200 Subject: [PATCH 17/20] multiple scenario simplification --- data/src/dim/assertion.rs | 7 ++++ data/src/dim/sym.rs | 12 +++++++ data/src/dim/tree.rs | 72 +++++++++++++++++++++++++-------------- 3 files changed, 66 insertions(+), 25 deletions(-) diff --git a/data/src/dim/assertion.rs b/data/src/dim/assertion.rs index 4bf21c1b33..2f262df4ab 100644 --- a/data/src/dim/assertion.rs +++ b/data/src/dim/assertion.rs @@ -42,6 +42,13 @@ impl Assertion { #[cfg(test)] mod tests { use super::*; + #[test] + fn use_equalities() { + let s = SymbolScope::default(); + s.add_assertion("s==0").unwrap(); + assert!(s.parse_tdim("s").unwrap().simplify().is_zero()); + } + #[test] fn prove_positive_with_axiom() { let s = SymbolScope::default(); diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index 24c9f013e4..63212a2d70 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -136,6 +136,18 @@ impl SymbolScopeData { &self.assertions } + pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator { + self.assertions.iter().chain(if let Some(s) = scenario { + self.scenarios[s].iter() + } else { + [].iter() + }) + } + + pub fn scenarios(&self) -> impl Iterator { + self.scenarios.keys().map(|s| s.as_ref()) + } + pub fn resolving(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option { self.table.resolve(sym.1).map(f) } diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index 81122b8d72..3807406432 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -291,24 +291,35 @@ impl TDim { pub fn simplify(self) -> TDim { use self::TDim::*; if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) { - Val(v) - } else if let Some(scope) = self.find_scope() { - let locked = scope.0.lock(); - let borrow = locked.borrow(); - self.simplify_rec(&borrow) - } else { - self + return Val(v); + } + let Some(scope) = self.find_scope() else { + return self; + }; + let scope = scope.0; + let locked = scope.lock(); + let scope = locked.borrow(); + let it = self.simplify_rec(&scope, None); + let mut current: Option = None; + for scenario in scope.scenarios() { + let v = it.clone().simplify_rec(&scope, Some(scenario)); + if current.is_some_and(|c| c != v) { + return it; + } else { + current = Some(v); + } } + current.unwrap_or(it) } - fn simplify_rec(self, scope: &SymbolScopeData) -> TDim { + fn simplify_rec(self, scope: &SymbolScopeData, scenario: Option<&str>) -> TDim { match self { Add(mut terms) => { #[allow(clippy::mutable_key_type)] let mut simplified_terms: HashMap = HashMap::new(); // factorize common sub-expr while let Some(term) = terms.pop() { - let simplified = term.simplify_rec(scope); + let simplified = term.simplify_rec(scope, scenario); match simplified { Val(0) => {} // ignore Add(members) => { @@ -354,7 +365,7 @@ impl TDim { .into_iter() .map(|t| { let gcd = t.gcd(); - (t / gcd).simplify_rec(scope) + (t / gcd).simplify_rec(scope, scenario) }) .collect() } else { @@ -377,18 +388,20 @@ impl TDim { } MulInt(coef, expr) => { match *expr { - MulInt(c2, inner) => return MulInt(coef * c2, inner).simplify_rec(scope), + MulInt(c2, inner) => { + return MulInt(coef * c2, inner).simplify_rec(scope, scenario) + } Val(v) => return Val(coef * v), _ => {} } - let simplified = expr.simplify_rec(scope); + let simplified = expr.simplify_rec(scope, scenario); match (coef, simplified) { (0, _) => Val(0), // Case #1: If coef is 0, return 0 (1, s) => s, // Case #2: If coef is 1, return the simplified expression (_, Add(terms)) => Add(terms .into_iter() - .map(|term| MulInt(coef, Box::new(term)).simplify_rec(scope)) + .map(|term| MulInt(coef, Box::new(term)).simplify_rec(scope, scenario)) .collect()), // Case #3: If expression is an addition, distribute the coef (c, Val(v)) => Val(c * v), // Case #4: If expression is a value, combine coefs (c, MulInt(v, inner)) => MulInt(c * v, inner), // Case #5: If expression is a MulInt, combine coefs @@ -397,11 +410,11 @@ impl TDim { } Div(a, q) => { if q == 1 { - return a.simplify_rec(scope); + return a.simplify_rec(scope, scenario); } else if let Div(a, q2) = *a { - return Div(a, q * q2).simplify_rec(scope); + return Div(a, q * q2).simplify_rec(scope, scenario); } - let a = a.simplify_rec(scope); + let a = a.simplify_rec(scope, scenario); if let Val(a) = a { Val(a / q as i64) } else if let MulInt(-1, a) = a { @@ -418,7 +431,7 @@ impl TDim { -1, b!(Div( b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect()) - .simplify_rec(scope)), + .simplify_rec(scope, scenario)), q )), ) @@ -434,7 +447,10 @@ impl TDim { }; if let Some(val) = offset { terms.push(Val(-val * q as i64)); - Add(vec![Val(val), Div(b!(Add(terms).simplify_rec(scope)), q)]) + Add(vec![ + Val(val), + Div(b!(Add(terms).simplify_rec(scope, scenario)), q), + ]) } else { Div(b!(Add(terms)), q) } @@ -451,7 +467,8 @@ impl TDim { } else if gcd == q as i64 { MulInt(p / gcd, a) } else if gcd > 1 { - Div(b!(MulInt(p / gcd, a)), q / gcd as u64).simplify_rec(scope) + Div(b!(MulInt(p / gcd, a)), q / gcd as u64) + .simplify_rec(scope, scenario) } else { Div(b!(MulInt(p, a)), q) } @@ -463,7 +480,7 @@ impl TDim { Broadcast(terms) => { let mut terms: Vec = terms .iter() - .map(|s| s.clone().simplify_rec(scope)) + .map(|s| s.clone().simplify_rec(scope, scenario)) .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] }) .filter(|t| !t.is_one()) .sorted_by(tdim_lexi_order) @@ -481,7 +498,7 @@ impl TDim { Min(terms) => { let mut flatten: Vec = terms .into_iter() - .map(|t| t.simplify_rec(scope)) + .map(|t| t.simplify_rec(scope, scenario)) .flat_map(|t| if let Min(t) = t { t } else { vec![t] }) .sorted_by(tdim_lexi_order) .dedup() @@ -511,7 +528,7 @@ impl TDim { Max(terms) => { let mut flatten: Vec = terms .into_iter() - .map(|t| t.simplify_rec(scope)) + .map(|t| t.simplify_rec(scope, scenario)) .flat_map(|t| if let Max(t) = t { t } else { vec![t] }) .sorted_by(tdim_lexi_order) .dedup() @@ -538,7 +555,14 @@ impl TDim { Max(flatten) } } - Val(_) | Sym(_) => self, + Sym(s) => scope + .assertions(scenario) + .find_map(|a| match a { + Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()), + _ => None, + }) + .unwrap_or(Sym(s)), + Val(_) => self, } } @@ -1355,7 +1379,6 @@ mod tests { assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1)); } - #[test] fn min_0() -> TractResult<()> { let symbols = SymbolScope::default(); @@ -1365,5 +1388,4 @@ mod tests { ); Ok(()) } - } From 370f17513eceeffb4d41fff1ecee8a374f05f554 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 16:56:32 +0200 Subject: [PATCH 18/20] scenario and asserts in cli --- cli/src/main.rs | 19 +++++++++++-------- cli/src/params.rs | 19 +++++++++++++------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/cli/src/main.rs b/cli/src/main.rs index 9996a9553c..daf92bb2e9 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -26,10 +26,10 @@ mod cost; mod dump; mod errors {} mod params; +mod plan_options; mod run; #[cfg(feature = "pulse")] mod stream_check; -mod plan_options; mod tensor; mod utils; @@ -93,7 +93,8 @@ fn main() -> TractResult<()> { .arg(Arg::new("constantize").long("constantize").multiple_occurrences(true).takes_value(true).long_help( "Transorm an input into a Constant")) - .arg(arg!(--"assert").multiple_occurrences(true).takes_value(true).long_help("Adds a TDim pre-condition")) + .arg(arg!(--"assert").multiple_occurrences(true).takes_value(true).long_help("Adds a TDim pre-condition (prefix by optional \"scenario_name:\")")) + .arg(arg!(--"scenario").multiple_occurrences(true).takes_value(true).long_help("Adds a scenario")) // deprecated .arg(arg!(--"input-bundle" [input_bundle] "Path to an input container (.npz). This sets input facts and tensor values.").hide(true)) @@ -287,18 +288,20 @@ fn main() -> TractResult<()> { let res = if matches.is_present("metal-gpu-trace") { #[cfg(any(target_os = "macos", target_os = "ios"))] - { - let gpu_trace_path = std::path::Path::new(matches.value_of("metal-gpu-trace").unwrap()).to_path_buf(); + { + let gpu_trace_path = + std::path::Path::new(matches.value_of("metal-gpu-trace").unwrap()).to_path_buf(); ensure!(gpu_trace_path.is_absolute(), "Metal GPU trace file has to be absolute"); - ensure!(!gpu_trace_path.exists(), format!("Given Metal GPU trace file {:?} already exists.", gpu_trace_path)); + ensure!( + !gpu_trace_path.exists(), + format!("Given Metal GPU trace file {:?} already exists.", gpu_trace_path) + ); log::info!("Capturing Metal GPU trace at : {:?}", gpu_trace_path); std::env::set_var("METAL_CAPTURE_ENABLED", "1"); std::env::set_var("METAL_DEVICE_WRAPPER_TYPE", "1"); let probe_ref = probe.as_ref(); tract_metal::METAL_CONTEXT.with_borrow(move |context| { - context.capture_trace(gpu_trace_path, move |_ctxt| { - handle(matches, probe_ref) - }) + context.capture_trace(gpu_trace_path, move |_ctxt| handle(matches, probe_ref)) }) } #[cfg(not(any(target_os = "macos", target_os = "ios")))] diff --git a/cli/src/params.rs b/cli/src/params.rs index a31a452e5d..32a0f98a61 100644 --- a/cli/src/params.rs +++ b/cli/src/params.rs @@ -5,16 +5,16 @@ use std::io::Cursor; use std::io::Read; use std::path::PathBuf; use std::str::FromStr; +use tract_core::internal::*; +use tract_core::model::TypedModel; use tract_core::ops::konst::Const; #[allow(unused_imports)] +use tract_core::transform::ModelTransform; +use tract_hir::internal::*; +#[allow(unused_imports)] use tract_itertools::Itertools; use tract_libcli::profile::BenchLimits; use tract_nnef::tensors::read_tensor; -#[allow(unused_imports)] -use tract_core::transform::ModelTransform; -use tract_core::internal::*; -use tract_core::model::TypedModel; -use tract_hir::internal::*; #[cfg(feature = "pulse")] use tract_pulse::internal::*; #[cfg(feature = "tf")] @@ -860,8 +860,15 @@ impl Parameters { /// Parses the command-line arguments. pub fn from_clap(matches: &clap::ArgMatches, probe: Option<&Probe>) -> TractResult { let symbols = SymbolScope::default(); + for scenario in matches.values_of("scenario").unwrap_or_default() { + symbols.add_scenario(scenario)?; + } for rule in matches.values_of("assert").unwrap_or_default() { - symbols.add_assertion(rule)?; + if let Some((scenario, assertion)) = rule.split_once(':') { + symbols.add_scenario_assertion(scenario, assertion)?; + } else { + symbols.add_assertion(rule)?; + } } let (filename, onnx_tc) = Self::disco_model(matches)?; let tensors_values = Self::parse_tensors(matches, &filename, onnx_tc, &symbols)?; From 9d06ad1fc9fb535d7041af0302164b7aedec9512 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 16 Sep 2024 17:25:22 +0200 Subject: [PATCH 19/20] more work on scanarios --- data/src/dim/parse.rs | 4 ++-- data/src/dim/sym.rs | 4 ++++ nnef/src/deser.rs | 6 +++++- nnef/src/ser.rs | 8 +++++++- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/data/src/dim/parse.rs b/data/src/dim/parse.rs index f23f7c1820..6426834199 100644 --- a/data/src/dim/parse.rs +++ b/data/src/dim/parse.rs @@ -22,7 +22,7 @@ pub fn parse_assertion(symbol_table: &SymbolScope, input: &str) -> TractResult(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Assertion> { - alt(( + delimited(spaces, alt(( map(separated_pair(|i| expr(s, i), stag("=="), |i| expr(s, i)), |(a, b)| { Assertion::Eq(a, b) }), @@ -38,7 +38,7 @@ fn assertion<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Assertion> { map(separated_pair(|i| expr(s, i), stag(">"), |i| expr(s, i)), |(a, b)| { Assertion::GT(a, b) }), - ))(i) + )), spaces)(i) } fn expr<'i>(symbol_table: &SymbolScope, i: &'i str) -> IResult<&'i str, TDim> { diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index 63212a2d70..8965fd1986 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -148,6 +148,10 @@ impl SymbolScopeData { self.scenarios.keys().map(|s| s.as_ref()) } + pub fn scenario(&self, s: &str) -> impl Iterator { + self.scenarios[s].iter() + } + pub fn resolving(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option { self.table.resolve(sym.1).map(f) } diff --git a/nnef/src/deser.rs b/nnef/src/deser.rs index f7b9f452fa..22e843c147 100644 --- a/nnef/src/deser.rs +++ b/nnef/src/deser.rs @@ -62,7 +62,11 @@ impl<'mb> ModelBuilder<'mb> { self.symbols.push(symbol); } "tract_assert" => { - self.model.symbols.add_assertion(&ext.1)?; + if let Some((scen, rule)) = ext.1.split_once(':') { + self.model.symbols.add_scenario_assertion(scen, rule)?; + } else { + self.model.symbols.add_assertion(&ext.1)?; + } } "KHR_enable_fragment_definitions" | "KHR_enable_operator_expressions" => (), _ => { diff --git a/nnef/src/ser.rs b/nnef/src/ser.rs index a0935cdada..e13c119da2 100644 --- a/nnef/src/ser.rs +++ b/nnef/src/ser.rs @@ -204,9 +204,15 @@ impl<'a> IntoAst<'a> { for sym in self.model.symbols.all_symbols() { extension.push(("tract_symbol".into(), sym.to_string())); } - for assert in self.model.symbols.all_assertions() { + let locked = self.model.symbols.0.lock(); + for assert in locked.borrow().all_assertions() { extension.push(("tract_assert".into(), assert.to_string())); } + for scenario in locked.borrow().scenarios() { + for assert in locked.borrow().scenario(scenario) { + extension.push(("tract_assert".into(), format!("{scenario}: {assert}"))); + } + } let properties = FragmentDef { decl: FragmentDecl { id: Identifier("tract_core_properties".to_string()), From 736d887e3723a6e84849281933fe73c6648f6a14 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 18 Sep 2024 14:40:37 +0200 Subject: [PATCH 20/20] display assertions and scenarios in dump --- data/src/dim/sym.rs | 6 ++++++ libcli/src/model.rs | 5 +++++ libcli/src/terminal.rs | 15 +++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/data/src/dim/sym.rs b/data/src/dim/sym.rs index 8965fd1986..1275900649 100644 --- a/data/src/dim/sym.rs +++ b/data/src/dim/sym.rs @@ -86,6 +86,12 @@ impl SymbolScope { locked.assertions.clone() } + pub fn all_scenarios(&self) -> impl IntoIterator)> { + let locked = self.0.lock(); + let locked = locked.borrow(); + locked.scenarios.clone() + } + pub fn add_scenario(&self, scenario: impl Into) -> TractResult<()> { let locked = self.0.lock(); let mut locked = locked.borrow_mut(); diff --git a/libcli/src/model.rs b/libcli/src/model.rs index 96ae110a8f..1f10cfd3f8 100644 --- a/libcli/src/model.rs +++ b/libcli/src/model.rs @@ -109,6 +109,8 @@ pub trait Model: fn properties(&self) -> &HashMap>; + fn symbols(&self) -> &SymbolScope; + fn get_or_intern_symbol(&self, name: &str) -> Symbol; fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()>; @@ -222,6 +224,9 @@ where &self.properties } + fn symbols(&self) -> &SymbolScope { + &self.symbols + } fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> { self.rename_node(id, name) } diff --git a/libcli/src/terminal.rs b/libcli/src/terminal.rs index 23e44678f6..8521009c91 100644 --- a/libcli/src/terminal.rs +++ b/libcli/src/terminal.rs @@ -26,6 +26,21 @@ pub fn render( for (k, v) in model.properties().iter().sorted_by_key(|(k, _)| k.to_string()) { println!("* {}: {:?}", White.paint(k), v) } + let symbols = model.symbols(); + if !symbols.all_assertions().is_empty() { + println!("{}", White.bold().paint("# Assertions")); + for a in symbols.all_assertions() { + println!(" * {a}"); + } + } + for (ix, scenario) in symbols.all_scenarios().into_iter().enumerate() { + if ix == 0 { + println!("{}", White.bold().paint("# Scenarios")); + } + for a in scenario.1 { + println!(" * {}: {}", scenario.0, a); + } + } Ok(()) }