Skip to content

Commit

Permalink
more fixes around einsum and opaue facts
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Aug 30, 2024
1 parent 30f2424 commit 16f3383
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 2 additions & 0 deletions core/src/model/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,14 @@ impl TypedModel {
"Inconsistent model, output types mismatch. Op says: {:?}, node says: {:?}. {} with inputs {:?}. {}",
output_facts, node.outputs.iter().map(|o| &o.fact).collect::<Vec<_>>(), node, input_facts, node)
}
/* this is not true for regularly packed values
if let Some(k) = node.op_as::<Const>() {
ensure!(
!k.0.datum_type().is_opaque() || k.1.is_some(),
"Node {node} is missing an opaque fact"
);
}
*/
}
for node in &self.nodes {
for (ix, output) in node.outputs.iter().enumerate() {
Expand Down
7 changes: 3 additions & 4 deletions core/src/ops/einsum/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@ pub fn eval_t<Acc: Datum + Zero + One>(
.iter()
.map(|t| {
if t.datum_type() == Opaque::datum_type() {
dbg!(t);
todo!();
bail!("Unoptimized einsum execution with BlockQuantized input is not implemented.");
} else {
t.shape()
Ok(t.shape())
}
})
.collect();
.collect::<TractResult<_>>()?;
let output_shape = output_shape(expr, &shapes);
let inputs: TVec<Cow<Tensor>> =
inputs.iter().map(|t| t.cast_to::<Acc>()).collect::<TractResult<_>>()?;
Expand Down

0 comments on commit 16f3383

Please sign in to comment.