diff --git a/.travis/cli-tests.sh b/.travis/cli-tests.sh index 7281819110..c209c6f6c9 100755 --- a/.travis/cli-tests.sh +++ b/.travis/cli-tests.sh @@ -105,6 +105,7 @@ $TRACT_RUN $MODELS/en_libri_real/model.onnx \ -O \ run \ --input-from-bundle $MODELS/en_libri_real/io.npz \ + --approx approximate \ --allow-random-input \ --assert-output-bundle $MODELS/en_libri_real/io.npz diff --git a/cli/src/main.rs b/cli/src/main.rs index daf92bb2e9..66a326a0e4 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -405,6 +405,14 @@ fn dump_subcommand<'a>() -> clap::Command<'a> { fn assertions_options(command: clap::Command) -> clap::Command { use clap::*; command + .arg( + Arg::new("approx") + .takes_value(true) + .possible_values(["exact", "close", "approximate", "super"]) + .default_value("close") + .long("approx") + .help("Approximation level used in assertions."), + ) .arg( Arg::new("assert-output") .takes_value(true) diff --git a/cli/src/params.rs b/cli/src/params.rs index 32a0f98a61..280950c066 100644 --- a/cli/src/params.rs +++ b/cli/src/params.rs @@ -1107,6 +1107,7 @@ pub struct Assertions { pub assert_outputs: bool, pub assert_output_facts: Option>, pub assert_op_count: Option>, + pub approximation: Approximation } impl Assertions { @@ -1123,7 +1124,13 @@ impl Assertions { .map(|mut args| Some((args.next()?.to_string(), args.next()?.parse().ok()?))) .collect() }); - - Ok(Assertions { assert_outputs, assert_output_facts, assert_op_count }) + let approximation = match sub.value_of("approx").unwrap() { + "exact" => Approximation::Exact, + "close" => Approximation::Close, + "approximate" => Approximation::Approximate, + "super" => Approximation::SuperApproximate, + _ => panic!() + }; + Ok(Assertions { assert_outputs, assert_output_facts, assert_op_count, approximation }) } } diff --git a/cli/src/utils.rs b/cli/src/utils.rs index 8899c49420..04b4fd82a2 100644 --- a/cli/src/utils.rs +++ b/cli/src/utils.rs @@ -45,7 +45,10 @@ pub fn check_outputs(got: &[Vec], params: &Parameters) -> TractResult<() { exp = exp.cast_to_dt(got.datum_type())?.into_owned().into_tvalue(); } - if let Err(e) = exp.close_enough(&got, true).context(format!("Checking output {ix}")) { + if let Err(e) = exp + .close_enough(&got, params.assertions.approximation) + .context(format!("Checking output {ix}")) + { if error.is_some() { error!("{:?}", e); } else { diff --git a/data/src/tensor.rs b/data/src/tensor.rs index da960c2c78..ea1204ebf8 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -20,9 +20,10 @@ use std::sync::Arc; pub mod litteral; pub mod view; -#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)] pub enum Approximation { Exact, + #[default] Close, Approximate, SuperApproximate,