Skip to content

Commit

Permalink
Patch::taps()
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Aug 1, 2023
1 parent 3caac66 commit 5e7b245
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 52 deletions.
51 changes: 29 additions & 22 deletions core/src/model/patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use crate::model::*;
#[derive(Clone, Debug)]
pub struct ModelPatch<F, O>
where
F: Fact + Clone + 'static ,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static
F: Fact + Clone + 'static,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
/// patch label for auditing and debugging
pub context: Vec<String>,
Expand All @@ -36,8 +36,8 @@ where

impl<F, O> Default for ModelPatch<F, O>
where
F: Fact + Clone + 'static ,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static
F: Fact + Clone + 'static,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
fn default() -> ModelPatch<F, O> {
ModelPatch {
Expand All @@ -54,8 +54,8 @@ where

impl<F, O> Deref for ModelPatch<F, O>
where
F: Fact + Clone + 'static ,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static
F: Fact + Clone + 'static,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
type Target = Graph<F, O>;
fn deref(&self) -> &Graph<F, O> {
Expand All @@ -65,8 +65,8 @@ where

impl<F, O> DerefMut for ModelPatch<F, O>
where
F: Fact + Clone + 'static ,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static
F: Fact + Clone + 'static,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
fn deref_mut(&mut self) -> &mut Graph<F, O> {
&mut self.model
Expand All @@ -75,8 +75,8 @@ where

impl<F, O> ModelPatch<F, O>
where
F: Fact + Clone + 'static ,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static ,
F: Fact + Clone + 'static,
O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
Graph<F, O>: SpecialOps<F, O>,
{
pub fn new(s: impl Into<String>) -> Self {
Expand Down Expand Up @@ -109,6 +109,17 @@ where
Ok(id)
}

/// Draw taps from a preexisting node.
///
/// returns an OutletId usable in the little "patch" model
pub fn taps<'a>(
&mut self,
model: &Graph<F, O>,
outlets: impl IntoIterator<Item = &'a OutletId>,
) -> TractResult<TVec<OutletId>> {
outlets.into_iter().map(|o| self.tap_model(model, *o)).collect::<TractResult<TVec<_>>>()
}

pub unsafe fn shunt_outside_unchecked(
&mut self,
outlet: OutletId,
Expand Down Expand Up @@ -148,10 +159,7 @@ where
) -> TractResult<ModelPatch<F, O>> {
let mut patch = ModelPatch::default();
let new_op = new_op.into();
let inputs = inputs
.iter()
.map(|i| patch.tap_model(patched_model, *i))
.collect::<TractResult<TVec<_>>>()?;
let inputs = patch.taps(patched_model, inputs)?;
let wires = patch.wire_node(&node.name, new_op, &inputs)?;
for (ix, o) in wires.iter().enumerate() {
patch.shunt_outside(patched_model, OutletId::new(node.id, ix), *o)?;
Expand All @@ -172,8 +180,7 @@ where
} else {
bail!("Non single successor fuse attempt")
};
let inputs = node.inputs.iter().map(|o|
patch.tap_model(patched_model, *o)).collect::<TractResult<TVec<OutletId>>>()?;
let inputs = patch.taps(patched_model, &node.inputs)?;
let output = patch.wire_node(&node.name, new_op.into(), &inputs)?;
patch.shunt_outside(patched_model, succ.id.into(), output[0])?;
Ok(patch)
Expand All @@ -184,10 +191,13 @@ where
patched_model: &Graph<F, O>,
node: &Node<F, O>,
) -> TractResult<Option<ModelPatch<F, O>>> {
if patched_model.outputs.contains(&node.id.into()) && patched_model.outputs.contains(&node.inputs[0]) {
if patched_model.outputs.contains(&node.id.into())
&& patched_model.outputs.contains(&node.inputs[0])
{
Ok(None)
} else {
Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| Ok(xs.into())).map(Some)
Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| Ok(xs.into()))
.map(Some)
}
}

Expand All @@ -199,10 +209,7 @@ where
wiring: &dyn Fn(&mut Self, &[OutletId]) -> TractResult<TVec<OutletId>>,
) -> TractResult<ModelPatch<F, O>> {
let mut patch = ModelPatch::default();
let taps = from
.iter()
.map(|f| patch.tap_model(patched_model, *f))
.collect::<TractResult<TVec<_>>>()?;
let taps = patch.taps(patched_model, from)?;
let news = wiring(&mut patch, &taps)?;
if news.len() != to.len() {
bail!(
Expand Down
18 changes: 3 additions & 15 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,11 +575,7 @@ impl ConvUnary {
let a = ker.into_shape(&a_shape)?.into_arc_tensor();
let mut patch = TypedModelPatch::new("declutter_as_einsum");
let a = patch.add_const(format!("{name}.filters"), a)?;
let mut inputs = node
.inputs
.iter()
.map(|i| patch.tap_model(model, *i))
.collect::<TractResult<TVec<_>>>()?;
let mut inputs = patch.taps(model, &node.inputs)?;
inputs.insert(0, a);
let mut axes = self.axes_mapping(&input_facts, &output_facts)?.with_extra_input(0)?;
axes = axes.with_extra_axis('0', InOut::In(0), 0)?.with_extra_axis(
Expand Down Expand Up @@ -1043,11 +1039,7 @@ impl TypedOp for ConvUnary {
) -> TractResult<Option<TypedModelPatch>> {
if let DatumType::U8 = self.kernel.datum_type().unquantized() {
let mut patch = TypedModelPatch::default();
let mut inputs = node
.inputs
.iter()
.map(|w| patch.tap_model(model, *w))
.collect::<TractResult<TVec<_>>>()?;
let mut inputs = patch.taps(model, &node.inputs)?;
let new_op = self.kernel_offset_u8_as_i8(&mut inputs, &mut patch)?.unwrap();
let wire = patch.wire_node(&node.name, new_op, &inputs)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
Expand All @@ -1060,11 +1052,7 @@ impl TypedOp for ConvUnary {
let dt = input_fact.datum_type;
if self.q_params.is_some() {
let mut patch = TypedModelPatch::default();
let inputs = node
.inputs
.iter()
.map(|w| patch.tap_model(model, *w))
.collect::<TractResult<TVec<_>>>()?;
let inputs = patch.taps(model, &node.inputs)?;
let wire = self
.wire_as_quant_im2col(&mut patch, &node.name, &inputs)
.context("in wire_as_quant_im2col")?;
Expand Down
3 changes: 1 addition & 2 deletions core/src/ops/einsum/as_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ fn einsum_rules(
let prefix: String =
op.axes.iter_all_axes().filter(|a| ![m, k, n].contains(&a.repr)).map(|a| a.repr).collect();
let mut patch = TypedModelPatch::default();
let inputs =
node.inputs.iter().map(|i| patch.tap_model(model, *i)).collect::<TractResult<TVec<_>>>()?;
let inputs = patch.taps(model, &node.inputs)?;
let mut wire = tvec!(inputs[0], inputs[1]);

let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
Expand Down
9 changes: 3 additions & 6 deletions core/src/ops/einsum/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ pub(super) fn inject_k_axis(
let mut new_axes = op.axes.clone();
let name = &node.name;
let mut patch = TypedModelPatch::new("inject k axis");
let mut wire =
node.inputs.iter().map(|i| patch.tap_model(model, *i)).collect::<TractResult<TVec<_>>>()?;
let mut wire = patch.taps(model, &node.inputs)?;
let possible_k_axis =
new_axes.iter_all_axes().find(|a| a.outputs[0].len() == 0).map(|axis| axis.repr);
if let Some(axis) = possible_k_axis {
Expand Down Expand Up @@ -158,8 +157,7 @@ pub(super) fn inject_m_or_n_axis(
});
let name = &node.name;
let mut patch = TypedModelPatch::new("Injecting m or n axis");
let mut wire =
node.inputs.iter().map(|i| patch.tap_model(model, *i)).collect::<TractResult<TVec<_>>>()?;
let mut wire = patch.taps(model, &node.inputs)?;
if let Some(axis) = quasi_m_or_n_axis {
if axis.inputs[input_to_fix].len() == 1 {
let new_axes =
Expand Down Expand Up @@ -228,8 +226,7 @@ fn dequant_output(
) -> TractResult<Option<TypedModelPatch>> {
let name = &node.name;
let mut patch = TypedModelPatch::new("Dequantizing einsum");
let taps: Vec<OutletId> =
node.inputs.iter().map(|i| patch.tap_model(model, *i)).collect::<TractResult<Vec<_>>>()?;
let taps = patch.taps(model, &node.inputs)?;
let [a, b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
bail!("Expect exactly 9 inputs")
};
Expand Down
3 changes: 1 addition & 2 deletions core/src/ops/matmul/lir_unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,7 @@ impl LirMatMulUnary {
let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
new_op.micro_ops.splice(before_last, fused_micro_op);
new_op.update_trivial_path();
let mut inputs: TVec<OutletId> =
node.inputs.iter().map(|i| patch.tap_model(model, *i)).collect::<TractResult<_>>()?;
let mut inputs = patch.taps(model, &node.inputs)?;
inputs.extend(additional_inputs.iter().cloned());
let output = patch.wire_node(&node.name, new_op, &inputs)?;
patch.shunt_outside(model, succ.id.into(), output[0])?;
Expand Down
6 changes: 1 addition & 5 deletions core/src/optim/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ impl super::TypedPass for PushSliceUp {
if let Some(boundaries) = should_slice_output(model, node, axis, &eval_order)? {
let mut splits = tvec!();
let mut patch = TypedModelPatch::new(format!("Slice {node} by {boundaries:?}"));
let inputs = node
.inputs
.iter()
.map(|i| patch.tap_model(model, *i))
.collect::<TractResult<TVec<OutletId>>>()?;
let inputs = patch.taps(model, &node.inputs)?;
let mut start = 0;
let axis_info = invariants.axis((InOut::Out(0), axis)).unwrap();
for end in &boundaries {
Expand Down

0 comments on commit 5e7b245

Please sign in to comment.