Skip to content

Commit

Permalink
wip covering tflite (270 ok/3690 ign)
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jul 31, 2023
1 parent 6c9e5c2 commit 8118dc6
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 80 deletions.
11 changes: 11 additions & 0 deletions core/src/model/fact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,17 @@ impl<'a> From<&'a TypedFact> for TypedFact {
}
}

impl<'a> From<&'a Arc<Tensor>> for TypedFact {
fn from(t: &'a Arc<Tensor>) -> TypedFact {
TypedFact {
datum_type: t.datum_type(),
shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
uniform: t.as_uniform().map(Arc::new),
konst: Some(t.clone()),
}
}
}

impl fmt::Debug for TypedFact {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self.konst {
Expand Down
8 changes: 3 additions & 5 deletions core/src/model/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ impl<Ctx> Rewriter<Ctx> {
for n in model.eval_order()? {
if let Some(rules) = self.rules.get(&(*model.node(n).op).type_id()) {
for (name, rule) in rules {
if let Some(patch) =
(rule)(ctx, model, model.node(n)).with_context(|| {
format!("Matching rule {name} on {}", model.node(n).name)
})?
if let Some(patch) = (rule)(ctx, model, model.node(n))
.with_context(|| format!("Evaluating rule {name} on {}", model.node(n)))?
{
patch.apply(model).with_context(|| {
format!("Applying patch for rule {name} on {}", model.node(n).name)
format!("Applying patch for rule {name} on {}", model.node(n))
})?;
done_anything = true;
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,10 @@ impl TypedOp for ConvUnary {
self
);
}
if let PaddingSpec::Explicit(before, after, _) = &self.pool_spec.padding {
anyhow::ensure!(before.len() == self.pool_spec.rank());
anyhow::ensure!(after.len() == self.pool_spec.rank());
}
if let Some(bias) = &self.bias {
ensure!(
bias.rank() == 0 || (bias.rank() == 1 && bias.len() == self.output_channels()),
Expand Down
4 changes: 4 additions & 0 deletions core/src/ops/cnn/pools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ impl PoolSpec {
let mut after: TVec<usize> = after.clone();
op.change_shape_array(&mut before, false)?;
op.change_shape_array(&mut after, false)?;
if let AxisOp::Add(add) = op {
before[*add] = 0;
after[*add] = 0;
}
PaddingSpec::Explicit(before, after, *round)
} else {
self.padding.clone()
Expand Down
1 change: 0 additions & 1 deletion test-rt/suite-conv/src/conv_f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ impl Test for ConvProblem {
let mut output =
runtime.prepare(self.tract()?)?.run(tvec![self.data.clone().into_tvalue()])?;
let output = output.remove(0).into_tensor();
eprintln!("output: {output:?} reference: {reference:?}");
output.close_enough(&reference, true)
}
}
Expand Down
12 changes: 10 additions & 2 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,23 @@ pub fn suite() -> infra::TestSuite {

fn ignore_onnx(t: &[String]) -> bool {
let name = t.last().unwrap();
!name.contains("_conv_") || name == "test_conv_with_strides_and_asymmetric_padding"
let included = "_conv_ Conv1d Conv2d squeeze _transpose_ test_reshape test_flatten";
let excluded = "
test_Conv1d_groups
test_Conv2d_groups
test_Conv1d_depthwise_with_multiplier
test_Conv2d_depthwise_with_multiplier
test_Conv2d_groups_thnn
test_reshape_allowzero_reordered";
!included.split_whitespace().any(|s| name.contains(s))
|| excluded.split_whitespace().any(|s| s == name)
}

fn ignore_conv(t: &[String]) -> bool {
let unit: &str = t.last().map(|s| &**s).unwrap();
t[0] == "q"
|| unit == "proptest"
// grouping and depthwise
|| unit == "depthwise_0"
|| unit.starts_with("group")
// conv 3D
|| unit == "lazy_im2col_big"
Expand Down
53 changes: 42 additions & 11 deletions tflite/src/ops/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use tract_hir::internal::*;
use tract_hir::ops::array::TypedConcat;
use tract_hir::ops::binary::wire_cast;
use tract_hir::prelude::tract_itertools::Itertools;
use tract_hir::tract_ndarray::ArrayView2;

use crate::registry::{DeserOp, Registry};
use crate::ser::{BuiltinOp, SubgraphBuilder};
use crate::tflite::{
BuiltinOperator, BuiltinOptions, ExpandDimsOptions, ExpandDimsOptionsArgs, SqueezeOptions,
SqueezeOptionsArgs, TransposeOptions, TransposeOptionsArgs,
BuiltinOperator, BuiltinOptions, ExpandDimsOptions, ExpandDimsOptionsArgs, ReshapeOptions,
ReshapeOptionsArgs, SqueezeOptions, SqueezeOptionsArgs, TransposeOptions, TransposeOptionsArgs,
};

use super::wire_fused_activation;
Expand All @@ -16,6 +17,7 @@ pub fn register_all(reg: &mut Registry) {
reg.reg_to_tflite::<AxisOp>(ser_axisop);
reg.to_tract.insert(BuiltinOperator::CONCATENATION, de_concat);
reg.to_tract.insert(BuiltinOperator::EXPAND_DIMS, de_expand_dims);
reg.to_tract.insert(BuiltinOperator::PADV2, de_padv2);
reg.to_tract.insert(BuiltinOperator::RESHAPE, de_reshape);
reg.to_tract.insert(BuiltinOperator::SHAPE, de_shape);
reg.to_tract.insert(BuiltinOperator::SQUEEZE, de_squeeze);
Expand Down Expand Up @@ -47,9 +49,25 @@ fn de_expand_dims(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
Ok(wire)
}

fn de_padv2(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let (_input, pads, value) = args_3!(op.facts()?);
let pads = pads.konst.as_ref().context("Dynamic PADV2 is not supported")?;
let prefix = op.prefix;
let pads: ArrayView2<i32> = pads.to_array_view::<i32>()?.into_dimensionality()?;
let pads: Vec<(usize, usize)> =
pads.rows().into_iter().map(|row| (row[0] as usize, row[1] as usize)).collect();
let mode = tract_hir::ops::array::PadMode::Constant(value.konst.context("Constant expected")?);
op.ctx.target.wire_node(prefix, tract_core::ops::array::Pad { pads, mode }, &op.inputs[0..1])
}

fn de_reshape(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let input_shape: TVec<TDim> = op.ctx.target.outlet_fact(op.inputs[0])?.shape.to_tvec();
let shape = op.ctx.target.outlet_fact(op.inputs[1])?.konst.clone().unwrap();
let shape = if let Some(outlet) = op.inputs.get(1) {
op.ctx.target.outlet_fact(*outlet)?.konst.clone().unwrap()
} else {
let options = builtin!(op, builtin_options_as_reshape_options);
rctensor1(&options.new_shape().as_ref().unwrap().iter().collect::<Vec<i32>>())
};
let shape = shape.cast_to::<TDim>()?;
let shape = shape.as_slice::<TDim>()?;
let mut wire = tvec!(op.inputs[0]);
Expand Down Expand Up @@ -123,10 +141,7 @@ fn ser_axisop(
let mut permutation: Vec<i32> = (0..rank).map(|d| d as i32).collect();
permutation.remove(*from);
permutation.insert(*to, *from as _);
inputs.push(
builder
.write_fact(&format!("{}.perm", node.name), &rctensor1(&permutation).into())?,
);
inputs.push(builder.write_fact(&format!("{}.perm", node.name), tensor1(&permutation))?);
let options = TransposeOptions::create(builder.fb(), &TransposeOptionsArgs {});
builder.write_op_with_options(
&inputs,
Expand All @@ -136,9 +151,7 @@ fn ser_axisop(
)
}
AxisOp::Add(a) => {
inputs.push(
builder.write_fact(&format!("{}.axis", node.name), &rctensor0(*a as i32).into())?,
);
inputs.push(builder.write_fact(&format!("{}.axis", node.name), tensor0(*a as i32))?);
let options = ExpandDimsOptions::create(builder.fb(), &ExpandDimsOptionsArgs {});
builder.write_op_with_options(
&inputs,
Expand All @@ -165,6 +178,24 @@ fn ser_axisop(
options.as_union_value(),
)
}
_ => todo!("reshape translation"),
AxisOp::Reshape(_, _, _) => {
let new_shape = node.outputs[0]
.fact
.shape
.iter()
.map(|x| x.to_i32())
.collect::<TractResult<Vec<i32>>>()?;
let new_shape = builder.fb().create_vector(&new_shape);
let options = ReshapeOptions::create(
builder.fb(),
&ReshapeOptionsArgs { new_shape: Some(new_shape) },
);
builder.write_op_with_options(
&inputs,
&[output],
BuiltinOp::new(22, 1, BuiltinOperator::RESHAPE, BuiltinOptions::ReshapeOptions),
options.as_union_value(),
)
}
}
}
144 changes: 90 additions & 54 deletions tflite/src/ops/cnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use crate::registry::{DeserOp, Registry};
use crate::ser::{BuiltinOp, SubgraphBuilder};
use crate::tflite::{
ActivationFunctionType, BuiltinOperator, BuiltinOptions, Conv2DOptions, Conv2DOptionsArgs,
Padding,
DepthwiseConv2DOptions, DepthwiseConv2DOptionsArgs, PadOptions, PadOptionsArgs, Padding,
};
use tract_hir::internal::*;
use tract_hir::ops::array::{Pad, PadMode};
use tract_hir::ops::cnn::{ConvUnary, PaddingSpec};
use tract_hir::ops::nn::DataFormat;
use tract_hir::prelude::tract_itertools::Itertools;
Expand All @@ -17,6 +18,7 @@ pub fn register_all(reg: &mut Registry) {
reg.to_tract.insert(BuiltinOperator::CONV_2D, conv2d);
reg.reg_to_tflite::<ConvUnary>(ser_conv);
reg.to_tract.insert(BuiltinOperator::DEPTHWISE_CONV_2D, dw_conv2d);
reg.reg_to_tflite::<Pad>(ser_pad);
}

fn average_pool_2d(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
Expand Down Expand Up @@ -49,66 +51,71 @@ fn ser_conv(
let conv = node.op_as::<ConvUnary>().unwrap();
ensure!(conv.pool_spec.data_format == DataFormat::NHWC);
ensure!(model.node_input_facts(node.id)?[0].rank() == 4);
ensure!(conv.group == 1);
ensure!(conv.kernel_fmt == KernelFormat::OHWI);
ensure!(conv.group == 1 || conv.group.to_dim() == model.node_input_facts(node.id)?[0].shape[3]);
ensure!(
conv.pool_spec.padding == PaddingSpec::Valid
|| conv.pool_spec.padding == PaddingSpec::SameUpper
);
let node_name = &node.name;
let mut inputs = node.inputs.iter().map(|o| builder.outlets_to_tensors[o]).collect_vec();
let outputs = (0..node.outputs.len())
.map(|o| builder.outlets_to_tensors[&OutletId::new(node.id, o)])
.collect_vec();
let ohwi = match conv.kernel_fmt {
KernelFormat::OHWI => conv.kernel.clone(),
KernelFormat::HWIO => conv.kernel.clone().into_tensor().move_axis(3, 0)?.into_arc_tensor(),
KernelFormat::OIHW => conv.kernel.clone().into_tensor().move_axis(1, 3)?.into_arc_tensor(),
};
inputs.push(builder.write_fact(&format!("{node_name}.weights"), &ohwi.into())?);
inputs.push(
builder.write_fact(
&format!("{node_name}.bias"),
&conv
.bias
.clone()
.unwrap_or_else(|| {
rctensor1(&vec![0f32; conv.pool_spec.output_channel_override.unwrap()])
})
.into(),
)?,
);
let padding = &conv.pool_spec.padding;
let padding_h = (conv.pool_spec.kernel_shape[0] - 1) * conv.pool_spec.dilation(0);
let padding_w = (conv.pool_spec.kernel_shape[1] - 1) * conv.pool_spec.dilation(1);
let padding = if padding.valid_dim(0, true) && padding.valid_dim(1, true) {
Padding::VALID
} else if padding
== &PaddingSpec::Explicit(
tvec!(padding_h / 2, padding_w / 2),
tvec!(padding_h.divceil(2), padding_w.divceil(2)),
false,
inputs.push(builder.write_fact(&format!("{node_name}.weights"), &conv.kernel)?);
inputs.push(builder.write_fact(
&format!("{node_name}.bias"),
&conv.bias.clone().unwrap_or_else(|| {
rctensor1(&vec![0f32; conv.pool_spec.output_channel_override.unwrap()])
}),
)?);
let padding =
if conv.pool_spec.padding == PaddingSpec::Valid { Padding::VALID } else { Padding::SAME };
if conv.group == 1 {
let options = Conv2DOptions::create(
builder.fb(),
&Conv2DOptionsArgs {
padding,
stride_h: conv.pool_spec.stride(0) as _,
stride_w: conv.pool_spec.stride(1) as _,
dilation_h_factor: conv.pool_spec.dilation(0) as _,
dilation_w_factor: conv.pool_spec.dilation(1) as _,
fused_activation_function: ActivationFunctionType::NONE,
},
);
builder.write_op_with_options(
&inputs,
&outputs,
BuiltinOp::new(3, 2, BuiltinOperator::CONV_2D, BuiltinOptions::Conv2DOptions),
options.as_union_value(),
)
{
Padding::SAME
} else if padding == &PaddingSpec::SameUpper {
Padding::SAME
} else {
todo!();
};
let options = Conv2DOptions::create(
builder.fb(),
&Conv2DOptionsArgs {
padding,
stride_h: conv.pool_spec.stride(0) as _,
stride_w: conv.pool_spec.stride(1) as _,
dilation_h_factor: conv.pool_spec.dilation(0) as _,
dilation_w_factor: conv.pool_spec.dilation(1) as _,
fused_activation_function: ActivationFunctionType::NONE,
},
);
builder.write_op_with_options(
&inputs,
&outputs,
BuiltinOp::new(3, 2, BuiltinOperator::CONV_2D, BuiltinOptions::Conv2DOptions),
options.as_union_value(),
)?;
Ok(())
let depth_multiplier =
(conv.pool_spec.output_channel_override.unwrap() / conv.group) as i32;
let options = DepthwiseConv2DOptions::create(
builder.fb(),
&DepthwiseConv2DOptionsArgs {
padding,
depth_multiplier,
stride_h: conv.pool_spec.stride(0) as _,
stride_w: conv.pool_spec.stride(1) as _,
dilation_h_factor: conv.pool_spec.dilation(0) as _,
dilation_w_factor: conv.pool_spec.dilation(1) as _,
fused_activation_function: ActivationFunctionType::NONE,
},
);
builder.write_op_with_options(
&inputs,
&outputs,
BuiltinOp::new(
4,
2,
BuiltinOperator::DEPTHWISE_CONV_2D,
BuiltinOptions::DepthwiseConv2DOptions,
),
options.as_union_value(),
)
}
}

fn conv2d(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
Expand Down Expand Up @@ -182,3 +189,32 @@ fn dw_conv2d(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let wires = op.ctx.target.wire_node(op.prefix, conv, &op.inputs[0..1])?;
wire_fused_activation(op, &wires, &options.fused_activation_function())
}

fn ser_pad(
builder: &mut SubgraphBuilder,
_model: &TypedModel,
node: &TypedNode,
) -> TractResult<()> {
let pad = node.op_as::<Pad>().unwrap();
let node_name = &node.name;
let mut inputs = tvec!(builder.outlets_to_tensors[&node.inputs[0]]);
let outputs = (0..node.outputs.len())
.map(|o| builder.outlets_to_tensors[&OutletId::new(node.id, o)])
.collect_vec();
let paddings = tract_ndarray::Array2::<i32>::from_shape_fn((pad.pads.len(), 2), |(d, side)| {
(if side == 0 { pad.pads[d].0 } else { pad.pads[d].1 }) as i32
});
inputs.push(builder.write_fact(format!("{node_name}.paddings"), paddings.into_tensor())?);
let PadMode::Constant(pad_value) = &pad.mode else {
bail!("Only constant padding is supported by tflite");
};
inputs.push(builder.write_fact(format!("{node_name}.pad_value"), pad_value)?);
let options = PadOptions::create(builder.fb(), &PadOptionsArgs {});
builder.write_op_with_options(
&inputs,
&outputs,
BuiltinOp::new(60, 1, BuiltinOperator::PADV2, BuiltinOptions::PadV2Options),
options.as_union_value(),
)?;
Ok(())
}
Loading

0 comments on commit 8118dc6

Please sign in to comment.