Skip to content

Commit

Permalink
adjust opaque axes removal
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieupoumeyrolsonos committed Aug 30, 2024
1 parent a2987e1 commit 30f2424
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
21 changes: 16 additions & 5 deletions core/src/axes/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,17 @@ impl AxisTracking {
let emiter_node = model.node(wire.node);
let mut nodes = vec![];
let (input_facts, output_facts) = model.node_facts(emiter_node.id)?;
let invs = emiter_node
let map = emiter_node
.op
.axes_mapping(&input_facts, &output_facts)
.with_context(|| format!("Computing axes mapping for {emiter_node}"))?;
let info = invs.axis((InOut::Out(wire.slot), axis)).unwrap();
let info = map.axis((InOut::Out(wire.slot), axis)).with_context(|| {
format!(
"Axes mapping for {} is {map}, need output axis {:?} from slot {}",
emiter_node, axis, wire.slot,
)
})?;

if info.inputs.iter().any(|i| i.len() > 0) {
nodes.push((wire.node, info.clone()));
} else {
Expand All @@ -108,8 +114,13 @@ impl AxisTracking {
for succ in &emiter_node.outputs[wire.slot].successors {
let succ_node = model.node(succ.node);
let (input_facts, output_facts) = model.node_facts(succ_node.id)?;
let invs = succ_node.op.axes_mapping(&input_facts, &output_facts)?;
let info = invs.axis((InOut::In(succ.slot), axis)).unwrap();
let map = succ_node.op.axes_mapping(&input_facts, &output_facts)?;
let info = map.axis((InOut::In(succ.slot), axis)).with_context(|| {
format!(
"Axes mapping for {succ_node} is {map}, need input axis {:?} from slot {}",
axis, succ.slot,
)
})?;
if info.outputs.iter().any(|o| o.len() > 0) {
nodes.push((succ_node.id, info.clone()));
} else {
Expand Down Expand Up @@ -177,7 +188,7 @@ pub fn for_model(model: &TypedModel) -> TractResult<AxesMapping> {
.collect::<TractResult<TVec<usize>>>()?;
let mut result = AxesMapping::disconnected_for_ranks(&input_ranks, &output_ranks)?;
for tracking in full_axis_tracking(model)? {
let mut reprs:Vec<char> = vec![];
let mut reprs: Vec<char> = vec![];
for (ix, outlet) in model.input_outlets()?.iter().enumerate() {
if let Some(appearance) = tracking.outlets.get(outlet) {
reprs.push(result.axis((InOut::In(ix), *appearance)).unwrap().repr);
Expand Down
2 changes: 1 addition & 1 deletion core/src/model/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl TypedModel {
})?
}
}
self.axes_mapping()?;
self.axes_mapping().context("Checking model axes mapping")?;
Ok(())
}

Expand Down
12 changes: 7 additions & 5 deletions core/src/ops/einsum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,13 @@ impl TypedOp for EinSum {
_outputs: &[&TypedFact],
) -> TractResult<AxesMapping> {
let mut axes = self.axes.clone();
for (ix, input) in inputs.iter().enumerate().take(2) {
if input.datum_type.is_opaque() {
while axes.axes(InOut::In(ix)).next().is_some() {
axes = axes.remove_axis_occurency(InOut::In(ix), 0)?;
}
for (slot, i) in inputs.iter().enumerate() {
if i.datum_type.is_opaque()
&& i.opaque_fact.as_ref().is_some_and(|of| of.is::<BlockQuantFact>())
{
axes = axes
.remove_axis_occurency(InOut::In(slot), i.rank())?
.remove_axis_occurency(InOut::In(slot), i.rank())?;
}
}
Ok(axes)
Expand Down

0 comments on commit 30f2424

Please sign in to comment.