From c6cc5f92173e020050faabd9dc1e6d6d3bd7eba5 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 31 Jul 2023 14:57:19 +0200 Subject: [PATCH] strided slice in core --- core/src/ops/array/strided_slice.rs | 315 ++++++++++++++++++++++++++++ 1 file changed, 315 insertions(+) create mode 100644 core/src/ops/array/strided_slice.rs diff --git a/core/src/ops/array/strided_slice.rs b/core/src/ops/array/strided_slice.rs new file mode 100644 index 0000000000..ac5ad6ab04 --- /dev/null +++ b/core/src/ops/array/strided_slice.rs @@ -0,0 +1,315 @@ +use crate::internal::*; + +#[derive(Debug, Clone, Hash)] +pub struct StridedSlice { + pub optional_axes_input: Option, + pub optional_steps_input: Option, + pub begin_mask: i64, + pub end_mask: i64, + pub shrink_axis_mask: i64, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Dim { + // position of the first element to return + pub begin: TDim, + // position of the first element not to return + pub end: TDim, + pub stride: i32, + pub shrink: bool, +} + +impl Dim { + pub fn soft_len(&self) -> TractResult { + if let Ok(len) = (self.end.clone() - &self.begin).to_isize() { + Ok((((self.stride.abs() - 1) + len.abs() as i32) / self.stride.abs()).to_dim()) + } else if self.stride == 1 { + Ok(self.end.clone() - &self.begin) + } else { + bail!("Streaming dimensions with strides are not supported for now") + } + } +} + +impl StridedSlice { + fn must_shrink(&self, ix: usize) -> bool { + self.shrink_axis_mask & (1 << ix) != 0 + } + fn ignore_begin(&self, ix: usize) -> bool { + self.begin_mask & (1 << ix) != 0 + } + fn ignore_end(&self, ix: usize) -> bool { + self.end_mask & (1 << ix) != 0 + } + pub fn prepare_one_dim( + &self, + ix: usize, + dim: &TDim, + begin: &Tensor, + end: &Tensor, + strides: &[i32], + ) -> TractResult { + // cast bouds to Option, dealing with ignore from mask, and spec shorted than dim + // also for end, magic values in onnx :/ + let mut begin: Option = if ix >= begin.len() { + None + } else { + let begin = begin.cast_to::()?; + begin.as_slice::()?.get(ix).cloned() + }; + + let mut end: Option = if self.ignore_end(ix) || ix >= end.len() { + None + } else if end.datum_type() == i64::datum_type() { + let end = *end.as_slice::()?.get(ix).unwrap(); + if end == std::i64::MAX || end == std::i64::MIN || end == std::i64::MIN + 1 { + None + } else { + Some(end.to_dim()) + } + } else { + let end = end.cast_to::()?; + end.as_slice::()?.get(ix).cloned() + }; + + let stride = strides.get(ix).cloned().unwrap_or(1); + + // deal with negative indexing + fn fix_negative(bound: &mut TDim, dim: &TDim) { + let neg = if let Ok(b) = bound.to_isize() { + b < 0 + } else { + let symbols = bound.symbols(); + if symbols.len() == 1 { + let sym = symbols.into_iter().next().unwrap(); + let values = SymbolValues::default().with(&sym, 100_000_000); + bound.eval(&values).to_isize().unwrap() < 0 + } else { + false + } + }; + if neg { + *bound = bound.clone() + dim; + } + } + if let Some(begin) = begin.as_mut() { + fix_negative(begin, dim) + } + if let Some(end) = end.as_mut() { + fix_negative(end, dim) + } + + if self.must_shrink(ix) { + return Ok(Dim { + begin: begin.clone().unwrap_or_else(|| 0.to_dim()), + end: begin.unwrap_or_else(|| 0.to_dim()) + 1, + stride: 1, + shrink: true, + }); + } + + // must happen after dealing with must_shrink :/ + if self.ignore_begin(ix) { + begin = None; + } + + let mut begin = + begin.unwrap_or_else(|| if stride > 0 { 0.to_dim() } else { dim.clone() - 1 }); + if begin.to_isize().map(|b| b < 0).unwrap_or(false) { + if stride < 0 { + return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false }); + } else { + begin = 0.to_dim(); + } + } + if let (Ok(b), Ok(d)) = (begin.to_isize(), dim.to_isize()) { + if b > d - 1 { + if stride > 0 { + return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false }); + } else { + begin = (d - 1).to_dim() + } + } + } + + let mut end = end.unwrap_or_else(|| if stride > 0 { dim.clone() } else { (-1).to_dim() }); + if end.to_isize().map(|e| e < 0).unwrap_or(false) { + if stride > 0 { + return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false }); + } else { + end = (-1).to_dim(); + } + } + if let (Ok(e), Ok(d)) = (end.to_isize(), dim.to_isize()) { + if e > d - 1 { + if stride > 0 { + end = d.to_dim() + } else { + return Ok(Dim { begin: 0.to_dim(), end: 0.to_dim(), stride, shrink: false }); + } + } + } + Ok(Dim { begin, end, stride, shrink: false }) + } + + fn wire( + &self, + prefix: &str, + target: &mut TypedModel, + inputs: &[OutletId], + ) -> TractResult> { + let params: TVec>> = inputs[1..] + .iter() + .map(|i| Ok(target.outlet_fact(*i)?.konst.clone())) + .collect::>()?; + let input_shape = target.outlet_fact(inputs[0])?.shape.clone(); + let strides: TVec = if let Some(i) = self.optional_steps_input { + let strides = params[i - 1] + .as_ref() + .context("StridedSlice is typable only if stride is a const")? + .cast_to::()?; + strides.as_slice::()?.into() + } else { + tvec![1; input_shape.rank()] + }; + let axes: TVec = if let Some(i) = self.optional_axes_input { + let axes = params[i - 1] + .as_ref() + .context("StridedSlice is typable only if axis is a const")? + .cast_to::()?; + axes.as_slice::()? + .iter() + .map(|&i| if i < 0 { input_shape.rank() as i32 + i } else { i } as usize) + .collect() + } else { + (0..input_shape.rank()).collect() + }; + let mut wire = inputs[0]; + let begin = params[0].as_ref(); + let end = params[1].as_ref(); + for (ix, &axis) in axes.iter().enumerate() { + if let (Some(begin), Some(end)) = (begin, end) { + let d = &input_shape[axis]; + let preped = self.prepare_one_dim(ix, d, begin, end, &strides)?; + let (left, right) = if preped.stride > 0 { + (preped.begin, preped.end) + } else { + (preped.end + 1, preped.begin + 1) + }; + wire = target.wire_node( + format!("{prefix}.slice-axis-{axis}"), + crate::ops::array::Slice::new(axis, left, right), + [wire].as_ref(), + )?[0]; + if preped.stride != 1 { + wire = target.wire_node( + format!("{prefix}.stride-axis-{axis}"), + crate::ops::downsample::Downsample::new(axis, preped.stride as isize, 0), + [wire].as_ref(), + )?[0]; + } + } else if strides[ix] == 1 { + let left = target.wire_node( + format!("{prefix}.slice-axis-{axis}-start"), + crate::ops::array::Slice::new(0, ix, ix + 1), + &[inputs[1]], + )?; + let left = target.wire_node( + format!("{prefix}.slice-axis-{axis}-start-rm-axis"), + AxisOp::Rm(0), + &left, + )?[0]; + let right = target.wire_node( + format!("{prefix}.slice-axis-{axis}-end"), + crate::ops::array::Slice::new(0, ix, ix + 1), + &[inputs[2]], + )?; + let right = target.wire_node( + format!("{prefix}.slice-axis-{axis}-end-rm-axis"), + AxisOp::Rm(0), + &right, + )?[0]; + let sym = target.symbol_table.new_with_prefix("l"); + wire = target.wire_node( + format!("{prefix}.slice-axis-{axis}"), + crate::ops::array::DynSlice::new(axis, sym.to_dim()), + &[wire, left, right], + )?[0]; + } + } + let mut shrink = input_shape + .iter() + .enumerate() + .filter(|(ix, _d)| self.must_shrink(*ix)) + .map(|pair| pair.0) + .collect::>(); + shrink.sort(); + for axis in shrink.iter().rev() { + wire = target.wire_node( + format!("{prefix}.RmDim-{axis}"), + AxisOp::Rm(*axis), + [wire].as_ref(), + )?[0]; + } + target.rename_node(wire.node, prefix)?; + Ok(tvec!(wire)) + } +} + +impl Op for StridedSlice { + fn name(&self) -> Cow { + "StridedSlice".into() + } + + op_as_typed_op!(); +} + +impl EvalOp for StridedSlice { + fn is_stateless(&self) -> bool { + true + } + + fn eval(&self, inputs: TVec) -> TractResult> { + let mut model = TypedModel::default(); + let mut source = tvec!(); + for (ix, input) in inputs.iter().enumerate() { + source.push(model.add_source( + format!("adhoc_input.{}", ix), + input.clone().into_arc_tensor().into(), + )?); + } + let output = self.wire("adhoc", &mut model, &source)?; + model.set_output_outlets(&output)?; + model.into_runnable()?.run(inputs) + } +} + +impl TypedOp for StridedSlice { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + let mut model = TypedModel::default(); + let mut source = tvec!(); + for (ix, input) in inputs.iter().enumerate() { + source.push(model.add_source(format!("adhoc_input.{}", ix), (*input).clone())?); + } + let output = self.wire("adhoc", &mut model, &source)?; + model.set_output_outlets(&output)?; + Ok(tvec!(model.outlet_fact(output[0])?.clone())) + } + + fn declutter( + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { + let mut patch = TypedModelPatch::default(); + let mut source = tvec!(); + for &input in &node.inputs { + source.push(patch.tap_model(model, input)?); + } + let output = self.wire(&node.name, &mut patch, &source)?; + patch.shunt_outside(model, node.id.into(), output[0])?; + Ok(Some(patch)) + } + + as_op!(); +}