diff --git a/test-rt/test-onnx-tflite-cycle/Cargo.toml b/test-rt/test-onnx-tflite-cycle/Cargo.toml new file mode 100644 index 0000000000..d36f74b2df --- /dev/null +++ b/test-rt/test-onnx-tflite-cycle/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "test-onnx-tflit-cycle" +version = "0.1.0" +edition = "2021" + +[dependencies] + +[build-dependencies] +suite-onnx = { path = "../suite-onnx" } + +[dev-dependencies] +lazy_static.workspace = true +log.workspace = true +tract-tflite = { path = "../../tflite", version = "=0.20.7-pre" } +tract-onnx-opl = { path = "../../onnx-opl", version = "=0.20.7-pre" } +suite-onnx = { path = "../suite-onnx" } diff --git a/test-rt/test-onnx-tflite-cycle/build.rs b/test-rt/test-onnx-tflite-cycle/build.rs new file mode 100644 index 0000000000..0492a6cc17 --- /dev/null +++ b/test-rt/test-onnx-tflite-cycle/build.rs @@ -0,0 +1,10 @@ +fn main() { + let mut suite = suite_onnx::suite().clone(); + suite.ignore(&ignore); + suite.test_runtime("tflite_cycle", "suite_onnx::suite()", "tflite_cycle()"); +} + +fn ignore(t: &[String]) -> bool { + let name = t.last().unwrap(); + !name.contains("_conv_") +} diff --git a/test-rt/test-onnx-tflite-cycle/src/lib.rs b/test-rt/test-onnx-tflite-cycle/src/lib.rs new file mode 100644 index 0000000000..ce87cd6b41 --- /dev/null +++ b/test-rt/test-onnx-tflite-cycle/src/lib.rs @@ -0,0 +1,31 @@ +#![cfg(test)] +use std::borrow::Cow; + +use log::*; +use tract_tflite::{internal::*, Tflite}; + +struct TfliteCyclingRuntime(Tflite); + +impl Runtime for TfliteCyclingRuntime { + fn name(&self) -> Cow { + "nnef_cycle".into() + } + + fn prepare(&self, model: TypedModel) -> TractResult> { + info!("Store to Tflite"); + let mut buffer = vec![]; + self.0.write(&model, &mut buffer)?; + info!("Reload from Tflite"); + let reloaded = self.0.model_for_read(&mut &*buffer)?; + Ok(Box::new(reloaded.into_optimized()?.into_runnable()?)) + } +} + +fn tflite_cycle() -> &'static TfliteCyclingRuntime { + lazy_static::lazy_static! { + static ref RT: TfliteCyclingRuntime = TfliteCyclingRuntime(Tflite::default()); + }; + &RT +} + +include!(concat!(env!("OUT_DIR"), "/tests/tflite_cycle.rs"));