Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce modes #1528

Merged
merged 20 commits into from
Sep 18, 2024
Merged
19 changes: 11 additions & 8 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ mod cost;
mod dump;
mod errors {}
mod params;
mod plan_options;
mod run;
#[cfg(feature = "pulse")]
mod stream_check;
mod plan_options;
mod tensor;
mod utils;

Expand Down Expand Up @@ -93,7 +93,8 @@ fn main() -> TractResult<()> {
.arg(Arg::new("constantize").long("constantize").multiple_occurrences(true).takes_value(true).long_help(
"Transorm an input into a Constant"))

.arg(arg!(--"assert").multiple_occurrences(true).takes_value(true).long_help("Adds a TDim pre-condition"))
.arg(arg!(--"assert").multiple_occurrences(true).takes_value(true).long_help("Adds a TDim pre-condition (prefix by optional \"scenario_name:\")"))
.arg(arg!(--"scenario").multiple_occurrences(true).takes_value(true).long_help("Adds a scenario"))

// deprecated
.arg(arg!(--"input-bundle" [input_bundle] "Path to an input container (.npz). This sets input facts and tensor values.").hide(true))
Expand Down Expand Up @@ -287,18 +288,20 @@ fn main() -> TractResult<()> {

let res = if matches.is_present("metal-gpu-trace") {
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
let gpu_trace_path = std::path::Path::new(matches.value_of("metal-gpu-trace").unwrap()).to_path_buf();
{
let gpu_trace_path =
std::path::Path::new(matches.value_of("metal-gpu-trace").unwrap()).to_path_buf();
ensure!(gpu_trace_path.is_absolute(), "Metal GPU trace file has to be absolute");
ensure!(!gpu_trace_path.exists(), format!("Given Metal GPU trace file {:?} already exists.", gpu_trace_path));
ensure!(
!gpu_trace_path.exists(),
format!("Given Metal GPU trace file {:?} already exists.", gpu_trace_path)
);
log::info!("Capturing Metal GPU trace at : {:?}", gpu_trace_path);
std::env::set_var("METAL_CAPTURE_ENABLED", "1");
std::env::set_var("METAL_DEVICE_WRAPPER_TYPE", "1");
let probe_ref = probe.as_ref();
tract_metal::METAL_CONTEXT.with_borrow(move |context| {
context.capture_trace(gpu_trace_path, move |_ctxt| {
handle(matches, probe_ref)
})
context.capture_trace(gpu_trace_path, move |_ctxt| handle(matches, probe_ref))
})
}
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
Expand Down
19 changes: 13 additions & 6 deletions cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ use std::io::Cursor;
use std::io::Read;
use std::path::PathBuf;
use std::str::FromStr;
use tract_core::internal::*;
use tract_core::model::TypedModel;
use tract_core::ops::konst::Const;
#[allow(unused_imports)]
use tract_core::transform::ModelTransform;
use tract_hir::internal::*;
#[allow(unused_imports)]
use tract_itertools::Itertools;
use tract_libcli::profile::BenchLimits;
use tract_nnef::tensors::read_tensor;
#[allow(unused_imports)]
use tract_core::transform::ModelTransform;
use tract_core::internal::*;
use tract_core::model::TypedModel;
use tract_hir::internal::*;
#[cfg(feature = "pulse")]
use tract_pulse::internal::*;
#[cfg(feature = "tf")]
Expand Down Expand Up @@ -860,8 +860,15 @@ impl Parameters {
/// Parses the command-line arguments.
pub fn from_clap(matches: &clap::ArgMatches, probe: Option<&Probe>) -> TractResult<Parameters> {
let symbols = SymbolScope::default();
for scenario in matches.values_of("scenario").unwrap_or_default() {
symbols.add_scenario(scenario)?;
}
for rule in matches.values_of("assert").unwrap_or_default() {
symbols.add_inequality(rule)?;
if let Some((scenario, assertion)) = rule.split_once(':') {
symbols.add_scenario_assertion(scenario, assertion)?;
} else {
symbols.add_assertion(rule)?;
}
}
let (filename, onnx_tc) = Self::disco_model(matches)?;
let tensors_values = Self::parse_tensors(matches, &filename, onnx_tc, &symbols)?;
Expand Down
31 changes: 13 additions & 18 deletions core/src/ops/logic/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ impl EvalOp for Comp {
if let (Ok(a), Ok(b)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
return Ok(tvec!(self.eval::<i64>(&a, &b)?.into_tvalue()));
}
let scope = a
.as_slice::<TDim>()?
.iter()
.chain(b.as_slice::<TDim>().unwrap().iter())
.find_map(|d| d.find_scope())
.unwrap();
let a = inputs[0].to_array_view::<TDim>()?;
let b = inputs[0].to_array_view::<TDim>()?;
let shape = multi_broadcast(&[a.shape(), b.shape()])?;
Expand All @@ -88,43 +82,44 @@ impl EvalOp for Comp {
let b = b.broadcast(&*shape).unwrap();
for ixs in tract_ndarray::indices(&*shape) {
let (a, b) = (&a[&ixs], &b[&ixs]);
let diff = a.clone() - b;
view[&ixs] = match *self {
Eq => a == b,
NE => a != b,
GTE => {
if scope.prove_positive_or_zero(&(a.clone() - b)) {
if diff.prove_positive_or_zero() {
true
} else if scope.prove_positive_or_zero(&(b.clone() - a - 1)) {
} else if diff.prove_strict_negative() {
false
} else {
bail!(UndeterminedSymbol(a.clone() - b));
bail!(UndeterminedSymbol(diff));
}
}
GT => {
if scope.prove_positive_or_zero(&(a.clone() - b - 1)) {
if diff.prove_strict_positive() {
true
} else if scope.prove_positive_or_zero(&(b.clone() - a)) {
} else if diff.prove_negative_or_zero() {
false
} else {
bail!(UndeterminedSymbol(a.clone() - b));
bail!(UndeterminedSymbol(diff));
}
}
LTE => {
if scope.prove_positive_or_zero(&(b.clone() - a)) {
if diff.prove_negative_or_zero() {
true
} else if scope.prove_positive_or_zero(&(a.clone() - b - 1)) {
} else if diff.prove_strict_positive() {
false
} else {
bail!(UndeterminedSymbol(a.clone() - b));
bail!(UndeterminedSymbol(diff));
}
}
LT => {
if scope.prove_positive_or_zero(&(b.clone() - a - 1)) {
if diff.prove_strict_negative() {
true
} else if scope.prove_positive_or_zero(&(a.clone() - b)) {
} else if diff.prove_negative_or_zero() {
false
} else {
bail!(UndeterminedSymbol(a.clone() - b));
bail!(UndeterminedSymbol(diff));
}
}
};
Expand Down
1 change: 1 addition & 0 deletions data/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ smallvec.workspace = true
lazy_static.workspace = true
scan_fmt.workspace = true
string-interner.workspace = true
parking_lot = "0.12.3"

[target.'cfg(not(target_family = "wasm"))'.dev-dependencies]
criterion.workspace = true
Expand Down
169 changes: 169 additions & 0 deletions data/src/dim/assertion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use fmt::Display;

use super::*;

#[derive(Debug, PartialEq, Clone, Hash)]
#[allow(clippy::upper_case_acronyms)]
pub enum Assertion {
Eq(TDim, TDim),
LT(TDim, TDim),
GT(TDim, TDim),
LTE(TDim, TDim),
GTE(TDim, TDim),
}

impl Display for Assertion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Assertion::*;
match self {
Eq(l, r) => write!(f, "{l} == {r}"),
LT(l, r) => write!(f, "{l} < {r}"),
GT(l, r) => write!(f, "{l} > {r}"),
LTE(l, r) => write!(f, "{l} <= {r}"),
GTE(l, r) => write!(f, "{l} >= {r}"),
}
}
}

impl Assertion {
pub fn as_known_positive(&self) -> Option<TDim> {
use Assertion::*;
match self {
Eq(left, right) => Some(left.clone() - right),
GTE(left, right) => Some(left.clone() - right),
GT(left, right) => Some(left.clone() - 1 - right),
LTE(left, right) => Some(right.clone() - left),
LT(left, right) => Some(right.clone() - 1 - left),
}
}
}


#[cfg(test)]
mod tests {
use super::*;
#[test]
fn use_equalities() {
let s = SymbolScope::default();
s.add_assertion("s==0").unwrap();
assert!(s.parse_tdim("s").unwrap().simplify().is_zero());
}

#[test]
fn prove_positive_with_axiom() {
let s = SymbolScope::default();
s.add_assertion("s>=0").unwrap();
assert!(s.parse_tdim("s").unwrap().prove_positive_or_zero());
}

#[test]
fn prove_positive_with_axiom_2() {
let s = SymbolScope::default();
s.add_assertion("s>=0").unwrap();
s.add_assertion("p>=0").unwrap();
s.add_assertion("p+s<4096").unwrap();
assert!(s.parse_tdim("4096-p").unwrap().prove_positive_or_zero());
}

#[test]
fn min_max_with_axiom() {
let symbols = SymbolScope::default();
symbols.add_assertion("a>=0").unwrap();
assert_eq!(symbols.parse_tdim("min(a,0)").unwrap().simplify(), 0.into());
assert_eq!(
symbols.parse_tdim("max(a,0)").unwrap().simplify(),
symbols.parse_tdim("a").unwrap()
);
}

#[test]
fn low_bound_0() -> TractResult<()> {
let symbols = SymbolScope::default().with_assertion("S>=0")?;
let s = symbols.parse_tdim("S").unwrap();
assert_eq!(s.low_inclusive_bound(), Some(0));
Ok(())
}

#[test]
fn low_bound_1() -> TractResult<()> {
let symbols = SymbolScope::default().with_assertion("S>0")?;
assert_eq!(symbols.parse_tdim("S").unwrap().low_inclusive_bound(), Some(1));
Ok(())
}

#[test]
fn low_bound_2() -> TractResult<()> {
let symbols = SymbolScope::default().with_assertion("S>0")?;
assert_eq!(symbols.parse_tdim("S + 1").unwrap().low_inclusive_bound(), Some(2));
Ok(())
}

#[test]
fn low_bound_3() -> TractResult<()> {
let symbols = SymbolScope::default().with_assertion("S>0")?;
assert_eq!(symbols.parse_tdim("4*S").unwrap().low_inclusive_bound(), Some(4));
Ok(())
}

#[test]
fn low_bound_4() -> TractResult<()> {
let symbols = SymbolScope::default().with_assertion("S>0")?.with_assertion("S>5")?;
assert_eq!(symbols.parse_tdim("S + 3").unwrap().low_inclusive_bound(), Some(9));
Ok(())
}

#[test]
fn max_bug_1() {
let symbols = SymbolScope::default();
symbols.add_assertion("S>8").unwrap();
assert_eq!(
symbols.parse_tdim("max(1,-1+(S+1)/4)").unwrap().simplify(),
symbols.parse_tdim("-1+(S+1)/4").unwrap(),
);
}

#[test]
fn min_bug_1() {
let symbols = SymbolScope::default();
symbols.add_assertion("S>8").unwrap();
assert_eq!(
symbols.parse_tdim("min(1,-1+(S+1)/4)").unwrap().simplify(),
symbols.parse_tdim("1").unwrap()
);
}

#[test]
fn min_bug_2() {
let symbols = SymbolScope::default();
symbols.add_assertion("S>50").unwrap();
assert_eq!(
symbols.parse_tdim("min(-3+2*(S+1)/4,-1+(S+1)/4)").unwrap().simplify(),
symbols.parse_tdim("-1+(S+1)/4").unwrap()
);
}

#[test]
fn min_bug_3() {
let symbols = SymbolScope::default();
symbols.add_assertion("S>=0").unwrap();
symbols.add_assertion("P>=0").unwrap();
assert_eq!(
symbols.parse_tdim("min(0,(S)#(P+S))").unwrap().simplify(),
symbols.parse_tdim("0").unwrap()
);
}

#[test]
fn min_llm_0() -> TractResult<()> {
let symbols = SymbolScope::default()
.with_assertion("S>=0")?
.with_assertion("P>=0")?
.with_scenario_assertion("tg", "S==1")?
.with_scenario_assertion("pp", "P==0")?;
assert_eq!(
symbols.parse_tdim("min(P,(S)#(P+S))").unwrap().simplify(),
symbols.parse_tdim("P").unwrap()
);
Ok(())
}
}
10 changes: 9 additions & 1 deletion data/src/dim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ use num_traits::Zero;
use std::fmt;
use std::ops;

mod assertion;
mod parse;
mod resolve;
mod sym;
mod tree;

pub use self::assertion::Assertion;
pub use self::parse::parse_tdim;
pub use self::resolve::solve_for;
pub use self::sym::{Symbol, SymbolScope, SymbolValues};
Expand Down Expand Up @@ -163,7 +165,13 @@ impl DimLike for TDim {
}

fn broadcast(self, other: Self) -> TractResult<Self> {
Ok(TDim::Broadcast(vec![self, other]).simplify())
if self.is_one() {
Ok(other)
} else if other.is_one() {
Ok(self)
} else {
Ok(TDim::Broadcast(vec![self, other]).simplify())
}
}

fn compatible_with(&self, other: &Self) -> bool {
Expand Down
Loading
Loading