Skip to content

Commit

Permalink
setup arbitrary test with params
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jul 12, 2023
1 parent f2b55ee commit f45b7b4
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 46 deletions.
1 change: 1 addition & 0 deletions test-rt/infra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ anyhow.workspace = true
itertools.workspace = true
dyn-clone.workspace = true
env_logger.workspace = true
proptest.workspace = true
tract-core = { path = "../../core", version = "=0.20.7-pre" }
30 changes: 29 additions & 1 deletion test-rt/infra/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#![allow(clippy::len_zero)]
use std::collections::HashMap;
use std::io::Write;
use std::marker::PhantomData;

use dyn_clone::DynClone;
use itertools::Itertools;
use proptest::prelude::{any, any_with, Arbitrary};
use proptest::test_runner::{Config, FileFailurePersistence, TestRunner};
use tract_core::runtime::Runtime;
use tract_core::tract_data::TractResult;

Expand All @@ -14,7 +17,6 @@ pub fn setup_test_logger() {
pub type TestResult = anyhow::Result<()>;

pub trait Test: 'static + Send + Sync + DynClone {
fn ignore(&self) -> bool;
fn run(&self, runtime: &dyn Runtime) -> TestResult;
}

Expand Down Expand Up @@ -49,6 +51,13 @@ impl TestSuite {
}
}

pub fn add_arbitrary<A: Arbitrary + Test + Clone>(&mut self, id: impl ToString, params: A::Parameters)
where
A::Parameters: Clone + Send + Sync,
{
self.add(id, ProptestWrapper::<A>(params));
}

pub fn with(mut self, id: impl ToString, test: impl Into<TestSuite>) -> Self {
self.add(id, test);
self
Expand Down Expand Up @@ -139,3 +148,22 @@ impl TestSuite {
self.dump(test_suite, runtime, "", "", &mut rs).unwrap();
}
}

#[derive(Clone, Debug)]
struct ProptestWrapper<A: Arbitrary + Test + Clone>(A::Parameters)
where
A::Parameters: Clone + Send + Sync;

impl<A: Arbitrary + Test + Clone> Test for ProptestWrapper<A>
where
A::Parameters: Clone + Send + Sync,
{
fn run(&self, runtime: &dyn Runtime) -> TestResult {
let mut runner = TestRunner::new(Config {
failure_persistence: Some(Box::new(FileFailurePersistence::Off)),
..Config::default()
});
runner.run(&any_with::<A>(self.0.clone()), |v| Ok(v.run(runtime).unwrap()))?;
Ok(())
}
}
73 changes: 38 additions & 35 deletions test-rt/suite-conv/src/conv_f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,15 @@ impl ConvProblem {
}
}

#[derive(Debug, Clone, Default)]
pub struct ConvProblemParams {
pub no_arbitrary_grouping: bool,
}

impl Arbitrary for ConvProblem {
type Parameters = ();
type Parameters = ConvProblemParams;
type Strategy = BoxedStrategy<ConvProblem>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
(
data_format(),
kernel_format(),
Expand All @@ -158,32 +163,38 @@ impl Arbitrary for ConvProblem {
1usize..=3,
(1usize..=3).prop_flat_map(shapes),
)
.prop_flat_map(|(df, kf, pad, n, mut ci0, co0, group, (mut ker_shape, data_shape))| {
// FIXME in HWIO order, only regular and depthwise are supported
if kf == KernelFormat::HWIO && group > 1 {
ci0 = 1;
}
let shape_in = df.from_n_c_hw(n, ci0 * group, data_shape).unwrap();
let data_in = tensor(shape_in.shape.iter().cloned().collect());
match kf {
KernelFormat::HWIO => {
ker_shape.push(ci0 * group);
ker_shape.push(co0)
.prop_flat_map(
move |(df, kf, pad, n, mut ci0, mut co0, group, (mut ker_shape, data_shape))| {
// FIXME in HWIO order, only regular and depthwise are supported
if params.no_arbitrary_grouping && group > 1 {
ci0 = 1;
co0 = 1;
}
KernelFormat::OIHW => {
ker_shape.insert(0, ci0);
ker_shape.insert(0, co0 * group)
if kf == KernelFormat::HWIO && group > 1 {
ci0 = 1;
}
KernelFormat::OHWI => {
ker_shape.insert(0, co0);
ker_shape.push(ci0 * group);
}
};
let strides = vec(1usize..=3, shape_in.hw_rank()..=shape_in.hw_rank());
let kernel = tensor(ker_shape);
let bias = proptest::option::of(tensor(vec![co0 * group]));
(Just((kf, pad, shape_in, group)), data_in, kernel, bias, strides)
})
let shape_in = df.from_n_c_hw(n, ci0 * group, data_shape).unwrap();
let data_in = tensor(shape_in.shape.iter().cloned().collect());
match kf {
KernelFormat::HWIO => {
ker_shape.push(ci0 * group);
ker_shape.push(co0)
}
KernelFormat::OIHW => {
ker_shape.insert(0, ci0);
ker_shape.insert(0, co0 * group)
}
KernelFormat::OHWI => {
ker_shape.insert(0, co0);
ker_shape.push(ci0 * group);
}
};
let strides = vec(1usize..=3, shape_in.hw_rank()..=shape_in.hw_rank());
let kernel = tensor(ker_shape);
let bias = proptest::option::of(tensor(vec![co0 * group]));
(Just((kf, pad, shape_in, group)), data_in, kernel, bias, strides)
},
)
.prop_map(|((kernel_format, pad, shape_in, group), data, kernel, bias, strides)| {
ConvProblem {
shape_in,
Expand All @@ -201,10 +212,6 @@ impl Arbitrary for ConvProblem {
}

impl Test for ConvProblem {
fn ignore(&self) -> bool {
false
}

fn run(&self, runtime: &dyn Runtime) -> TestResult {
let reference = self.reference().into_tensor();
let mut output =
Expand All @@ -215,13 +222,9 @@ impl Test for ConvProblem {
}

#[derive(Clone)]
struct ConvProptest;
pub struct ConvProptest;

impl Test for ConvProptest {
fn ignore(&self) -> bool {
false
}

fn run(&self, runtime: &dyn Runtime) -> TestResult {
let mut runner = TestRunner::new(Config {
failure_persistence: Some(Box::new(FileFailurePersistence::Off)),
Expand Down
3 changes: 0 additions & 3 deletions test-rt/suite-conv/src/conv_q.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ impl QConvProblem {
}

impl Test for QConvProblem {
fn ignore(&self) -> bool {
false
}
fn run(&self, runtime: &dyn Runtime) -> TractResult<()> {
let model = runtime.prepare(self.tract()?)?;
let output = model.run(tvec!(self.data.clone().into_tensor().into_tvalue()))?.remove(0);
Expand Down
6 changes: 0 additions & 6 deletions test-rt/suite-onnx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,13 @@ const MANIFEST_PYTORCH_OPERATOR: &str = include_str!("../pytorch-operator.txt");

#[derive(Clone, Debug)]
struct OnnxTestCase {
skipped: bool,
path: PathBuf,
ignore_output_shapes: bool,
ignore_output_type: bool,
input: Option<String>,
}

impl Test for OnnxTestCase {
fn ignore(&self) -> bool {
self.skipped
}

fn run(&self, runtime: &dyn Runtime) -> TractResult<()> {
setup_test_logger();
let model_file = self.path.join("model.onnx");
Expand Down Expand Up @@ -232,7 +227,6 @@ fn full() -> TestSuite {
t,
OnnxTestCase {
path: node_tests.join(t),
skipped,
ignore_output_type,
ignore_output_shapes,
input,
Expand Down
2 changes: 1 addition & 1 deletion test-rt/test-onnx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mod unoptimized {
Cow::Borrowed("unoptimized")
}
fn prepare(&self, model: TypedModel) -> TractResult<Box<dyn Runnable>> {
Ok(Box::new(model.into_runnable()?))
Ok(Box::new(Arc::new(model.into_runnable()?)))
}
}

Expand Down

0 comments on commit f45b7b4

Please sign in to comment.