Skip to content

Commit

Permalink
allow a dilibert use case (#757)
Browse files Browse the repository at this point in the history
* allow a dilibert use case

* enable some onnx tests
  • Loading branch information
kali authored Jul 11, 2022
1 parent 032c95c commit 9d5deff
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 9 deletions.
3 changes: 3 additions & 0 deletions harness/onnx-test-suite/node-1.10.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ test_matmul_3d
test_matmul_4d
test_matmulinteger
test_max_example
test_max_float16
test_max_float32
test_max_float64
test_max_int16
Expand Down Expand Up @@ -328,6 +329,7 @@ test_mean_example
test_mean_one_input
test_mean_two_inputs
test_min_example
test_min_float16
test_min_float32
test_min_float64
test_min_int16
Expand Down Expand Up @@ -527,6 +529,7 @@ test_reshape_reordered_last_dims input:data
test_reshape_zero_and_negative_dim input:data
test_reshape_zero_dim input:data
test_resize_upsample_scales_linear_align_corners input:X not-nnef
test_resize_upsample_scales_nearest not-nnef not-typable
test_rnn_seq_length
test_round
test_scan9_sum
Expand Down
2 changes: 2 additions & 0 deletions harness/onnx-test-suite/node-1.7.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ test_matmul_3d
test_matmul_4d
test_matmulinteger
test_max_example
test_max_float16
test_max_float32
test_max_float64
test_max_int16
Expand Down Expand Up @@ -298,6 +299,7 @@ test_mean_example
test_mean_one_input
test_mean_two_inputs
test_min_example
test_min_float16
test_min_float32
test_min_float64
test_min_int16
Expand Down
2 changes: 2 additions & 0 deletions harness/onnx-test-suite/node-1.8.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ test_matmul_3d
test_matmul_4d
test_matmulinteger
test_max_example
test_max_float16
test_max_float32
test_max_float64
test_max_int16
Expand Down Expand Up @@ -313,6 +314,7 @@ test_mean_example
test_mean_one_input
test_mean_two_inputs
test_min_example
test_min_float16
test_min_float32
test_min_float64
test_min_int16
Expand Down
2 changes: 2 additions & 0 deletions harness/onnx-test-suite/node-1.9.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ test_matmul_3d
test_matmul_4d
test_matmulinteger
test_max_example
test_max_float16
test_max_float32
test_max_float64
test_max_int16
Expand Down Expand Up @@ -319,6 +320,7 @@ test_mean_example
test_mean_one_input
test_mean_two_inputs
test_min_example
test_min_float16
test_min_float32
test_min_float64
test_min_int16
Expand Down
22 changes: 13 additions & 9 deletions onnx/src/ops/cumsum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,23 @@ impl Expansion for CumSum {
let axis =
model.outlet_fact(inputs[1])?.konst.as_ref().context("Axis expected to be a const")?;
let axis = axis.cast_to_scalar::<i64>()?;
let data = model.outlet_fact(inputs[0])?;
let axis = if axis < 0 { (axis + data.rank() as i64) as usize } else { axis as usize };
let data = model.outlet_fact(inputs[0])?.clone();
let mut var_shape = data.shape.clone();
let axis = if axis < 0 { (axis + data.rank() as i64) as usize } else { axis as usize };
let zero = model.add_const(
format!("{}.zero", prefix),
Tensor::zero_dt(data.datum_type, &[])?.into_arc_tensor(),
)?;
var_shape.set(axis, 1.to_dim());
let var_shape = var_shape.as_concrete().context("Expect shapes to be known")?;
let init = model.wire_node(
format!("{}.init", prefix),
tract_core::ops::array::MultiBroadcastTo::new(var_shape.clone().into()),
&[zero],
)?[0];
let chunk = if self.reverse { -1 } else { 1 };
let input_mapping = vec![
scan::InputMapping::Scan { slot: 0, axis, chunk },
scan::InputMapping::State {
initializer: scan::StateInitializer::Value(
Tensor::zero_dt(data.datum_type, var_shape)?.into_arc_tensor(),
),
},
scan::InputMapping::State { initializer: scan::StateInitializer::FromInput(1) },
];
let output_mapping = vec![
scan::OutputMapping {
Expand Down Expand Up @@ -82,7 +86,7 @@ impl Expansion for CumSum {
body.set_output_outlets(&[sum, sum])?;
}
let scan = scan::Scan::new(body, input_mapping, output_mapping, None, 0)?;
model.wire_node(prefix, scan, &inputs[0..1])
model.wire_node(prefix, scan, &[inputs[0], init])
}

fn rules<'r, 'p: 'r, 's: 'r>(
Expand Down

0 comments on commit 9d5deff

Please sign in to comment.