From 9d5deff7117114e0ea263848435877948253b550 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 11 Jul 2022 13:09:39 +0200 Subject: [PATCH] allow a dilibert use case (#757) * allow a dilibert use case * enable some onnx tests --- harness/onnx-test-suite/node-1.10.1.txt | 3 +++ harness/onnx-test-suite/node-1.7.0.txt | 2 ++ harness/onnx-test-suite/node-1.8.1.txt | 2 ++ harness/onnx-test-suite/node-1.9.0.txt | 2 ++ onnx/src/ops/cumsum.rs | 22 +++++++++++++--------- 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/harness/onnx-test-suite/node-1.10.1.txt b/harness/onnx-test-suite/node-1.10.1.txt index 46c03344f0..176414276e 100644 --- a/harness/onnx-test-suite/node-1.10.1.txt +++ b/harness/onnx-test-suite/node-1.10.1.txt @@ -299,6 +299,7 @@ test_matmul_3d test_matmul_4d test_matmulinteger test_max_example +test_max_float16 test_max_float32 test_max_float64 test_max_int16 @@ -328,6 +329,7 @@ test_mean_example test_mean_one_input test_mean_two_inputs test_min_example +test_min_float16 test_min_float32 test_min_float64 test_min_int16 @@ -527,6 +529,7 @@ test_reshape_reordered_last_dims input:data test_reshape_zero_and_negative_dim input:data test_reshape_zero_dim input:data test_resize_upsample_scales_linear_align_corners input:X not-nnef +test_resize_upsample_scales_nearest not-nnef not-typable test_rnn_seq_length test_round test_scan9_sum diff --git a/harness/onnx-test-suite/node-1.7.0.txt b/harness/onnx-test-suite/node-1.7.0.txt index 352edfea42..b3ffa70a39 100644 --- a/harness/onnx-test-suite/node-1.7.0.txt +++ b/harness/onnx-test-suite/node-1.7.0.txt @@ -269,6 +269,7 @@ test_matmul_3d test_matmul_4d test_matmulinteger test_max_example +test_max_float16 test_max_float32 test_max_float64 test_max_int16 @@ -298,6 +299,7 @@ test_mean_example test_mean_one_input test_mean_two_inputs test_min_example +test_min_float16 test_min_float32 test_min_float64 test_min_int16 diff --git a/harness/onnx-test-suite/node-1.8.1.txt b/harness/onnx-test-suite/node-1.8.1.txt index 40b63657b8..a90b29d363 100644 --- a/harness/onnx-test-suite/node-1.8.1.txt +++ b/harness/onnx-test-suite/node-1.8.1.txt @@ -284,6 +284,7 @@ test_matmul_3d test_matmul_4d test_matmulinteger test_max_example +test_max_float16 test_max_float32 test_max_float64 test_max_int16 @@ -313,6 +314,7 @@ test_mean_example test_mean_one_input test_mean_two_inputs test_min_example +test_min_float16 test_min_float32 test_min_float64 test_min_int16 diff --git a/harness/onnx-test-suite/node-1.9.0.txt b/harness/onnx-test-suite/node-1.9.0.txt index 8120ff19df..bf02c4ce13 100644 --- a/harness/onnx-test-suite/node-1.9.0.txt +++ b/harness/onnx-test-suite/node-1.9.0.txt @@ -290,6 +290,7 @@ test_matmul_3d test_matmul_4d test_matmulinteger test_max_example +test_max_float16 test_max_float32 test_max_float64 test_max_int16 @@ -319,6 +320,7 @@ test_mean_example test_mean_one_input test_mean_two_inputs test_min_example +test_min_float16 test_min_float32 test_min_float64 test_min_int16 diff --git a/onnx/src/ops/cumsum.rs b/onnx/src/ops/cumsum.rs index 094b619395..c9eedde7f9 100644 --- a/onnx/src/ops/cumsum.rs +++ b/onnx/src/ops/cumsum.rs @@ -39,19 +39,23 @@ impl Expansion for CumSum { let axis = model.outlet_fact(inputs[1])?.konst.as_ref().context("Axis expected to be a const")?; let axis = axis.cast_to_scalar::()?; - let data = model.outlet_fact(inputs[0])?; - let axis = if axis < 0 { (axis + data.rank() as i64) as usize } else { axis as usize }; + let data = model.outlet_fact(inputs[0])?.clone(); let mut var_shape = data.shape.clone(); + let axis = if axis < 0 { (axis + data.rank() as i64) as usize } else { axis as usize }; + let zero = model.add_const( + format!("{}.zero", prefix), + Tensor::zero_dt(data.datum_type, &[])?.into_arc_tensor(), + )?; var_shape.set(axis, 1.to_dim()); - let var_shape = var_shape.as_concrete().context("Expect shapes to be known")?; + let init = model.wire_node( + format!("{}.init", prefix), + tract_core::ops::array::MultiBroadcastTo::new(var_shape.clone().into()), + &[zero], + )?[0]; let chunk = if self.reverse { -1 } else { 1 }; let input_mapping = vec![ scan::InputMapping::Scan { slot: 0, axis, chunk }, - scan::InputMapping::State { - initializer: scan::StateInitializer::Value( - Tensor::zero_dt(data.datum_type, var_shape)?.into_arc_tensor(), - ), - }, + scan::InputMapping::State { initializer: scan::StateInitializer::FromInput(1) }, ]; let output_mapping = vec![ scan::OutputMapping { @@ -82,7 +86,7 @@ impl Expansion for CumSum { body.set_output_outlets(&[sum, sum])?; } let scan = scan::Scan::new(body, input_mapping, output_mapping, None, 0)?; - model.wire_node(prefix, scan, &inputs[0..1]) + model.wire_node(prefix, scan, &[inputs[0], init]) } fn rules<'r, 'p: 'r, 's: 'r>(