Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start tflite dump #1118

Merged
merged 28 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .travis/onnx-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ export CACHEDIR

opset=onnx_"${1:-1_13_0}"

cargo -q test -p test-onnx-core $CARGO_EXTRA -q --no-default-features --profile opt-no-lto --features $opset

cargo -q test -p test-onnx-nnef-cycle $CARGO_EXTRA -q --no-default-features --profile opt-no-lto
cargo tree -p tflite

cargo -q test -p test-onnx-core $CARGO_EXTRA -q --no-default-features --features $opset
cargo -q test -p test-onnx-nnef-cycle $CARGO_EXTRA -q --no-default-features
cargo -q test -p test-conv-core $CARGO_EXTRA -q
cargo -q test -p test-tflite $CARGO_EXTRA -q
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ members = [
"test-rt/test-conv-core",
"test-rt/test-onnx-core",
"test-rt/test-onnx-nnef-cycle",
"test-rt/test-tflite",
]

[workspace.dependencies]
Expand All @@ -60,6 +61,7 @@ dinghy-test = "0.6"
downcast-rs = "1.2.0"
dyn-clone = "1.0.4"
env_logger = "0.10"
flatbuffers = "23.1.21"
flate2 = "1.0.20"
fs2 = "0.4.3"
getrandom = "0.2"
Expand Down
2 changes: 1 addition & 1 deletion core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ pub mod internal {
pub use {args_1, args_2, args_3, args_4, args_5, args_6, args_7, args_8};
pub use {as_op, impl_op_same_as, not_a_typed_op, op_as_typed_op};
pub use {bin_to_super_type, element_wise, element_wise_oop};
pub use crate::runtime::{Runtime, Runnable, DefaultRuntime};
pub use crate::runtime::{Runtime, Runnable, State, DefaultRuntime};
}

#[cfg(test)]
Expand Down
11 changes: 11 additions & 0 deletions core/src/model/fact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,17 @@ impl<'a> From<&'a TypedFact> for TypedFact {
}
}

impl<'a> From<&'a Arc<Tensor>> for TypedFact {
fn from(t: &'a Arc<Tensor>) -> TypedFact {
TypedFact {
datum_type: t.datum_type(),
shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
uniform: t.as_uniform().map(Arc::new),
konst: Some(t.clone()),
}
}
}

impl fmt::Debug for TypedFact {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self.konst {
Expand Down
38 changes: 24 additions & 14 deletions core/src/model/rewriter.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,47 @@
use std::any::{Any, TypeId};
use std::any::TypeId;

use crate::internal::*;

type RewriteRule<Ctx> =
type GenRewriteRule<Ctx> =
Box<dyn Fn(&Ctx, &TypedModel, &TypedNode) -> TractResult<Option<TypedModelPatch>>>;

#[derive(Default)]
#[allow(clippy::type_complexity)]
pub struct Rewriter<Ctx> {
rules: HashMap<TypeId, (Cow<'static, str>, RewriteRule<Ctx>)>,
rules: HashMap<TypeId, Vec<(Cow<'static, str>, GenRewriteRule<Ctx>)>>,
}

impl<Ctx> Rewriter<Ctx> {
pub fn with_rule_for<O: Any + 'static>(
pub fn with_rule_for<O: Op + 'static>(
mut self,
name: impl Into<Cow<'static, str>>,
rule: RewriteRule<Ctx>,
rule: impl Fn(&Ctx, &TypedModel, &TypedNode, &str, &O) -> TractResult<Option<TypedModelPatch>>
+ 'static,
) -> Self {
self.rules.insert(TypeId::of::<O>(), (name.into(), rule));
self.rules.entry(TypeId::of::<O>()).or_default().push((
name.into(),
Box::new(move |c: &Ctx, m: &TypedModel, n: &TypedNode| {
let o = n.op_as::<O>().unwrap();
rule(c, m, n, &n.name, o)
}),
));
self
}

pub fn rewrite(&self, ctx: &Ctx, model: &mut TypedModel) -> TractResult<()> {
loop {
let mut done_anything = false;
for n in model.eval_order()? {
if let Some((name, rule)) = self.rules.get(&(*model.node(n).op).type_id()) {
if let Some(patch) = (rule)(ctx, model, model.node(n)).with_context(|| {
format!("Matching rule {name} on {}", model.node(n).name)
})? {
patch.apply(model).with_context(|| {
format!("Applying patch for rule {name} on {}", model.node(n).name)
})?;
done_anything = true;
if let Some(rules) = self.rules.get(&(*model.node(n).op).type_id()) {
for (name, rule) in rules {
if let Some(patch) = (rule)(ctx, model, model.node(n))
.with_context(|| format!("Evaluating rule {name} on {}", model.node(n)))?
{
patch.apply(model).with_context(|| {
format!("Applying patch for rule {name} on {}", model.node(n))
})?;
done_anything = true;
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/ops/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod range;
mod reshape;
mod scatter_elements;
mod scatter_nd;
pub mod strided_slice;
mod slice;
mod tile;
mod topk;
Expand All @@ -28,6 +29,7 @@ pub use self::reshape::FiniteReshape;
pub use self::range::Range;
pub use self::scatter_elements::ScatterElements;
pub use self::scatter_nd::ScatterNd;
pub use self::strided_slice::StridedSlice;
pub use self::slice::Slice;
pub use self::tile::Tile;
pub use self::topk::Topk;
Expand Down
Loading