Skip to content

Commit

Permalink
proptests against tflite too
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jul 12, 2023
1 parent f45b7b4 commit 3049aa7
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 52 deletions.
10 changes: 6 additions & 4 deletions test-rt/infra/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#![allow(clippy::len_zero)]
use std::collections::HashMap;
use std::io::Write;
use std::marker::PhantomData;

use dyn_clone::DynClone;
use itertools::Itertools;
use proptest::prelude::{any, any_with, Arbitrary};
use proptest::prelude::{any_with, Arbitrary};
use proptest::test_runner::{Config, FileFailurePersistence, TestRunner};
use tract_core::runtime::Runtime;
use tract_core::tract_data::TractResult;
Expand Down Expand Up @@ -51,8 +50,11 @@ impl TestSuite {
}
}

pub fn add_arbitrary<A: Arbitrary + Test + Clone>(&mut self, id: impl ToString, params: A::Parameters)
where
pub fn add_arbitrary<A: Arbitrary + Test + Clone>(
&mut self,
id: impl ToString,
params: A::Parameters,
) where
A::Parameters: Clone + Send + Sync,
{
self.add(id, ProptestWrapper::<A>(params));
Expand Down
92 changes: 59 additions & 33 deletions test-rt/suite-conv/src/conv_f32.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
use std::ops::Range;

use super::*;
use infra::*;
use proptest::collection::vec;
use proptest::test_runner::{Config, FileFailurePersistence, TestRunner};
use tract_itertools::izip;

#[derive(Debug, Clone, Default)]
pub struct ConvProblemParams {
pub no_group: bool,
pub no_arbitrary_grouping: bool,
pub geo_rank: Option<Range<usize>>,
}

#[derive(Debug, Clone)]
pub struct ConvProblem {
pub shape_in: DataShape,
Expand All @@ -30,21 +38,41 @@ impl ConvProblem {
KernelFormat::HWIO => self.kernel.shape()[self.kernel.ndim() - 1],
KernelFormat::OHWI => self.kernel.shape()[0],
};
let (shape_out, left_pads): (TVec<_>, TVec<_>) = if self.pad == PaddingSpec::Valid {
izip!(self.shape_in.hw_dims(), self.geo_ker(), &self.strides)
let (shape_out, left_pads): (TVec<_>, TVec<_>) = match &self.pad {
PaddingSpec::Valid => izip!(self.shape_in.hw_dims(), self.geo_ker(), &self.strides)
.map(|(i, k, s)| {
let out = (*i + 1).saturating_sub(*k).divceil(*s);
(out, 0)
})
.unzip()
} else {
izip!(self.shape_in.hw_dims(), self.geo_ker(), &self.strides)
.unzip(),
PaddingSpec::SameUpper => izip!(self.shape_in.hw_dims(), self.geo_ker(), &self.strides)
.map(|(input, k, stride)| {
let out = input.divceil(*stride);
let pad = ((out - 1) * stride + k).saturating_sub(*input);
(out, pad / 2)
})
.unzip()
.unzip(),
PaddingSpec::SameLower => izip!(self.shape_in.hw_dims(), self.geo_ker(), &self.strides)
.map(|(input, k, stride)| {
let out = input.divceil(*stride);
let pad = ((out - 1) * stride + k).saturating_sub(*input);
(out, pad.divceil(2))
})
.unzip(),
PaddingSpec::Explicit(l, r, ceil) => {
izip!(self.shape_in.hw_dims(), self.geo_ker(), &self.strides, l, r)
.map(|(input, k, stride, l, r)| {
let dil = 1;
let kf = (k - 1) * dil + 1;
let out = if *ceil {
(input + l + r).saturating_sub(kf).divceil(*stride) + 1
} else {
(input + l + r).saturating_sub(kf) / *stride + 1
};
(out, *l)
})
.unzip()
}
};
let shape_out = self
.shape_in
Expand Down Expand Up @@ -144,24 +172,20 @@ impl ConvProblem {
}
}

#[derive(Debug, Clone, Default)]
pub struct ConvProblemParams {
pub no_arbitrary_grouping: bool,
}

impl Arbitrary for ConvProblem {
type Parameters = ConvProblemParams;
type Strategy = BoxedStrategy<ConvProblem>;
fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
let geo_rank = params.geo_rank.unwrap_or(1..4);
(
data_format(),
kernel_format(),
prop_oneof![Just(PaddingSpec::Valid), Just(PaddingSpec::SameUpper)],
1usize..=3,
1usize..=4,
1usize..=4,
1usize..=3,
(1usize..=3).prop_flat_map(shapes),
1usize..=(if params.no_group { 1 } else { 3 }),
geo_rank.prop_flat_map(shapes),
)
.prop_flat_map(
move |(df, kf, pad, n, mut ci0, mut co0, group, (mut ker_shape, data_shape))| {
Expand Down Expand Up @@ -217,28 +241,15 @@ impl Test for ConvProblem {
let mut output =
runtime.prepare(self.tract()?)?.run(tvec![self.data.clone().into_tvalue()])?;
let output = output.remove(0).into_tensor();
reference.close_enough(&output, true)
}
}

#[derive(Clone)]
pub struct ConvProptest;

impl Test for ConvProptest {
fn run(&self, runtime: &dyn Runtime) -> TestResult {
let mut runner = TestRunner::new(Config {
failure_persistence: Some(Box::new(FileFailurePersistence::Off)),
..Config::default()
});
runner.run(&any::<ConvProblem>(), |v| Ok(v.run(runtime).unwrap()))?;
Ok(())
eprintln!("output: {output:?} reference: {reference:?}");
output.close_enough(&reference, true)
}
}

pub fn suite() -> TractResult<TestSuite> {
let mut suite = TestSuite::default();

suite.add("proptest", ConvProptest);
suite.add_arbitrary::<ConvProblem>("proptest", ConvProblemParams::default());

suite.add(
"trivial_0",
Expand Down Expand Up @@ -308,6 +319,21 @@ pub fn suite() -> TractResult<TestSuite> {
strides: tvec!(1),
},
);

suite.add(
"group_0",
ConvProblem {
shape_in: DataFormat::CHW.from_n_c_hw(1, 2, [1])?,
kernel_format: KernelFormat::OIHW,
group: 2,
data: arr2(&[[0.0f32], [0.0]]).into_dyn(),
kernel: arr3(&[[[0.0f32]], [[0.0]]]).into_dyn(),
bias: None,
pad: PaddingSpec::Valid,
strides: tvec!(1),
},
);

suite.add(
"group_1",
ConvProblem {
Expand Down Expand Up @@ -627,7 +653,7 @@ pub fn suite() -> TractResult<TestSuite> {
);

suite.add(
"same_0",
"same_1d_0",
ConvProblem {
shape_in: DataFormat::HWC.from_n_c_hw(1, 1, [1])?,
kernel_format: KernelFormat::OIHW,
Expand All @@ -641,13 +667,13 @@ pub fn suite() -> TractResult<TestSuite> {
);

suite.add(
"same_1",
"same_1d_1",
ConvProblem {
shape_in: DataFormat::HWC.from_n_c_hw(1, 1, [2])?,
kernel_format: KernelFormat::OIHW,
group: 1,
data: arr2(&[[0.0], [1.0]]).into_dyn(),
kernel: arr3(&[[[0.0, 1.0]]]).into_dyn(),
kernel: arr3(&[[[0.0, 2.0]]]).into_dyn(),
bias: None,
pad: PaddingSpec::SameUpper,
strides: tvec!(1),
Expand Down
7 changes: 0 additions & 7 deletions test-rt/suite-conv/tests/proptest.proptest-regressions

This file was deleted.

2 changes: 1 addition & 1 deletion test-rt/test-onnx-nnef-cycle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl Runtime for NnefCyclingRuntime {
self.0.write_to_tar(&model, &mut buffer)?;
info!("Reload from NNEF");
let reloaded = self.0.model_for_read(&mut &*buffer)?;
Ok(Box::new(reloaded.into_optimized()?.into_runnable()?))
Ok(Box::new(Arc::new(reloaded.into_optimized()?.into_runnable()?)))
}
}

Expand Down
3 changes: 2 additions & 1 deletion test-rt/test-tflite/src/tflite_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ impl Runtime for TfliteRuntime {
fn prepare(&self, model: TypedModel) -> TractResult<Box<dyn Runnable>> {
let mut buffer = vec![];
self.0.write(&model, &mut buffer)?;
// std::fs::write("foo.tflite", &buffer).unwrap();
Ok(Box::new(TfliteRunnable(buffer)))
}
}
Expand All @@ -39,7 +40,7 @@ impl State for TfliteState {
for (ix, input) in inputs.iter().enumerate() {
let input_ix = self.0.inputs()[ix];
let input_tensor = self.0.tensor_info(input_ix).unwrap();
assert_eq!(input_tensor.element_kind as u32, 1);
assert_eq!(input_tensor.element_kind as u32, 1); // 1 is f32
assert_eq!(input_tensor.dims, input.shape());
self.0.tensor_buffer_mut(input_ix).unwrap().copy_from_slice(unsafe { input.as_bytes() })
}
Expand Down
4 changes: 3 additions & 1 deletion test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub fn suite() -> infra::TestSuite {
conv.ignore(&ignore_conv);
conv.add_arbitrary::<ConvProblem>(
"proptest",
ConvProblemParams { no_arbitrary_grouping: true, ..ConvProblemParams::default() },
ConvProblemParams { no_group: true, geo_rank: Some(1..3), ..ConvProblemParams::default() },
);
infra::TestSuite::default().with("onnx", onnx).with("conv", conv)
}
Expand All @@ -29,4 +29,6 @@ fn ignore_conv(t: &[String]) -> bool {
|| unit == "lazy_im2col_big_2"
|| unit == "batch_3d"
|| unit == "bias_3d_1"
// nonsense. bug in tfl ? hole in the spec ?
|| unit == "same_1d_1"
}
10 changes: 5 additions & 5 deletions tflite/src/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ fn make_1d_2d(
conv: &ConvUnary,
) -> TractResult<Option<TypedModelPatch>> {
if conv.pool_spec.rank() == 1 {
let pos = conv.pool_spec.data_format.h_axis();
let pos = conv.pool_spec.data_format.h_axis() + 1;
let mut new = conv.clone();
new.pool_spec.kernel_shape.insert(0, 1);
new.pool_spec.dilations.iter_mut().for_each(|dil| dil.insert(0, 1));
new.pool_spec.strides.iter_mut().for_each(|dil| dil.insert(0, 1));
new.pool_spec.kernel_shape.insert(1, 1);
new.pool_spec.dilations.iter_mut().for_each(|dil| dil.insert(1, 1));
new.pool_spec.strides.iter_mut().for_each(|dil| dil.insert(1, 1));
let mut kernel = new.kernel.clone().into_tensor();
kernel.insert_axis(conv.kernel_fmt.h_axis())?;
kernel.insert_axis(conv.kernel_fmt.h_axis() + 1)?;
new.kernel = kernel.into_arc_tensor();
let mut patch = TypedModelPatch::default();
let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
Expand Down

0 comments on commit 3049aa7

Please sign in to comment.