Skip to content

Commit

Permalink
wip simplifying scan (cumsum test borken)
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed May 30, 2023
1 parent 1530449 commit 1d573c3
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 20 deletions.
6 changes: 3 additions & 3 deletions core/src/ops/array/dyn_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub struct DynSlice {
pub axis: usize,
pub start_input: bool,
pub end_input: bool,
pub symbol: Symbol,
pub len: TDim,
}

impl DynSlice {
Expand Down Expand Up @@ -63,8 +63,8 @@ impl EvalOp for DynSlice {

impl TypedOp for DynSlice {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut fact = inputs[0].clone();
fact.shape.set(self.axis, self.symbol.clone().into());
let mut fact = inputs[0].without_value();
fact.shape.set(self.axis, self.len.clone().into());
Ok(tvec!(fact))
}

Expand Down
4 changes: 2 additions & 2 deletions hir/src/ops/array/strided_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,10 @@ impl Expansion for StridedSlice {
AxisOp::Rm(0),
&right,
)?[0];
let sym = target.symbol_table.new_with_prefix("l");
let len = target.symbol_table.new_with_prefix("len").to_dim();
wire = target.wire_node(
format!("{prefix}.slice-axis-{axis}"),
tract_core::ops::array::DynSlice::new(axis, true, true, sym),
tract_core::ops::array::DynSlice::new(axis, true, true, len),
&[wire, left, right],
)?[0];
}
Expand Down
37 changes: 32 additions & 5 deletions onnx/src/ops/cumsum.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use tract_hir::internal::*;
use tract_hir::tract_core::ops::array::DynSlice;
use tract_hir::tract_core::ops::scan::ScanInfo;

use crate::model::{OnnxOpRegister, ParsingContext};
Expand Down Expand Up @@ -54,11 +55,17 @@ impl Expansion for CumSum {
)?[0];
let chunk = if self.reverse { -1 } else { 1 };
let input_mapping =
vec![scan::InputMapping::Scan(ScanInfo { axis, chunk }), scan::InputMapping::State];
vec![scan::InputMapping::Full, scan::InputMapping::State, scan::InputMapping::State];
// outputs will be
// acc + x (!exclusive)
// acc input (exclusive)
let output_mapping = vec![
scan::OutputMapping {
scan: None,
full_dim_hint: None,
last_value_slot: None,
state: true,
},
scan::OutputMapping {
scan: Some((0, ScanInfo { axis, chunk })),
full_dim_hint: None,
Expand All @@ -74,12 +81,32 @@ impl Expansion for CumSum {
];
let mut body = TypedModel::default();
let var_fact = data.datum_type.fact(var_shape);
let x = body.add_source("scan_input", var_fact.clone())?;
let x = body.add_source("scan_input", data)?;

let i = body.add_source("i", i64::scalar_fact())?;
let one = body.add_const("one", tensor0(1i64))?;
let i_plus_one = body.wire_node("inc_i", tract_core::ops::math::add(), &[i, one])?[0];
let x_slice = body.wire_node(
"x",
DynSlice {
axis,
start_input: true,
end_input: true,
len: 1.to_dim(),
},
&[x, i, i_plus_one],
)?[0];

let acc = body.add_source("acc_input", var_fact)?;
let sum = body.wire_node("add", tract_core::ops::math::add(), &[x, acc])?[0];
body.set_output_outlets(&[sum, acc])?;
dbg!(axis);
dbg!(body.outlet_fact(x));
dbg!(body.outlet_fact(x_slice));
dbg!(body.outlet_fact(acc));
let sum = body.wire_node("add", tract_core::ops::math::add(), &[x_slice, acc])?[0];
body.set_output_outlets(&[i_plus_one, sum, acc])?;
let scan = scan::Scan::new(body, input_mapping, output_mapping, 0, iters)?;
let wires = model.wire_node(prefix, scan, &[inputs[0], init])?;
let zero = model.add_const(format!("{prefix}.zero"), tensor0(0i64))?;
let wires = model.wire_node(prefix, scan, &[inputs[0], zero, init])?;
let output = wires[self.exclusive as usize];
Ok(tvec![output])
}
Expand Down
41 changes: 31 additions & 10 deletions onnx/src/ops/rec/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt::Debug;
use crate::pb::*;
use tract_hir::internal::*;
use tract_hir::tract_core::dyn_clone::{clone_trait_object, DynClone};
use tract_hir::tract_core::ops::array::DynSlice;
use tract_hir::tract_core::ops::scan::ScanInfo;

pub trait WireBody: Debug + DynClone + Send + Sync {
Expand Down Expand Up @@ -117,12 +118,21 @@ impl CommonRec {
// scann inner interface: [chunk=1, batch_size, input_size]
// onnx inner interface: [batch_size, input_size]
outer_inputs.push(x_batch_first);
input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk }));
let mut x_source_fact = target.outlet_fact(x_batch_first)?.without_value();
// input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk }));
input_mapping.push(scan::InputMapping::Full);
let x_source_fact = target.outlet_fact(x_batch_first)?.without_value();
let iters = x_source_fact.shape[1].clone();
x_source_fact.shape.set(1, 1.to_dim());
let x_source = body.add_source("x_source", x_source_fact)?;
wire!(Xt = AxisOp::Rm(1), x_source);

input_mapping.push(scan::InputMapping::State);
let zero = target.add_const(format!("{prefix}.zero"), tensor0(0i64))?;
outer_inputs.push(zero);
let i = body.add_source("i", i64::scalar_fact())?;
let one = body.add_const("one", tensor0(1i64))?;
wire!(i_plus_one = tract_core::ops::math::add(), i, one);
let dyn_slice = DynSlice { axis: 1, start_input: true, end_input: true, len: 1.to_dim() };
wire!(x_slice = dyn_slice, x_source, i, i_plus_one);
wire!(Xt = AxisOp::Rm(1), x_slice);

// W: onnx interface: [num_directions, 3*hidden_size, input_size]
// scan interfaces: [3*hidden_size, input_size]
Expand Down Expand Up @@ -229,13 +239,24 @@ impl CommonRec {
};

self.body.wire_body(prefix, &mut body).context("Wiring body")?;
let mut outputs = body.outputs.clone();
outputs.insert(0, i_plus_one);
body.set_output_outlets(&*outputs)?;

let mut output_mapping = vec![scan::OutputMapping {
state: true,
full_dim_hint: None,
last_value_slot: self.optional_y_h_output,
scan: self.optional_y_output.map(|slot| (slot, ScanInfo { axis: 1, chunk })),
}];
let mut output_mapping = vec![
scan::OutputMapping {
state: true,
full_dim_hint: None,
last_value_slot: None,
scan: None,
},
scan::OutputMapping {
state: true,
full_dim_hint: None,
last_value_slot: self.optional_y_h_output,
scan: self.optional_y_output.map(|slot| (slot, ScanInfo { axis: 1, chunk })),
},
];
if self.body.have_extra_c_state() {
output_mapping.push(scan::OutputMapping {
state: true,
Expand Down

0 comments on commit 1d573c3

Please sign in to comment.