Skip to content

Commit

Permalink
accept labels or node names npzed tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 14, 2024
1 parent 490b636 commit c9b2dbb
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions cli/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::iter::once;

use crate::params::Parameters;
use tract_hir::internal::*;
use tract_itertools::Itertools;
use tract_libcli::model::Model;

/// Compares the outputs of a node in tract and tensorflow.
Expand All @@ -8,25 +11,25 @@ pub fn check_outputs(got: &[Vec<TValue>], params: &Parameters) -> TractResult<()
// iter over all possible tract model outputs
for (ix, output) in params.tract_model.output_outlets().iter().enumerate() {
// get either name from outlet_label or from node_name
let name = if let Some(label) = params.tract_model.outlet_label(*output) {
label
} else {
params.tract_model.node_name(output.node)
};
// pick expected tensor values for this output
let exp = params.tensors_values.by_name(name);
let lookup_names = params
.tract_model
.outlet_label(*output)
.into_iter()
.chain(once(params.tract_model.node_name(output.node)))
.collect_vec();
let exp = lookup_names.iter().find_map(|name| params.tensors_values.by_name(name));
if exp.is_none() {
if params.assertions.allow_missing_outputs {
warn!("Missing reference output in bundle for {name}");
warn!("Missing reference output in bundle for {}", lookup_names.join(" or "));
continue;
} else {
bail!("Missing reference output in bundle for {name}");
bail!("Missing reference output in bundle for {}", lookup_names.join(" or "));
}
}
let exp = exp.unwrap();
debug!("Output {}, expects {:?}", ix, exp);
let mut exp: TValue = exp.values.as_ref().with_context(|| {
format!("Output {name:?}: found reference info without value: {exp:?}")
format!("Output {lookup_names:?}: found reference info without value: {exp:?}")
})?[0]
.clone();
let got: TValue = if got[ix].len() > 1 {
Expand Down

0 comments on commit c9b2dbb

Please sign in to comment.