Skip to content

Commit

Permalink
Add names and shapes of inputs and outputs to rten CLI output
Browse files Browse the repository at this point in the history
This allows quickly inspecting the inputs and outputs of an RTen model. Example
output for a BERT-like model:

```
Model summary: 2 inputs, 2 outputs, 81.5 M params

Inputs
  input_ids: [batch_size, sequence_length]
  attention_mask: [batch_size, sequence_length]

Outputs
  start_logits: [batch_size, sequence_length]
  end_logits: [batch_size, sequence_length]
```
  • Loading branch information
robertknight committed Feb 4, 2024
1 parent 873c85b commit 5e15d62
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions rten-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,36 @@ fn run_with_random_input(model: &Model, run_opts: RunOptions) -> Result<(), Box<
Ok(())
}

/// Format an input or output shape as a `[dim0, dim1, ...]` string, where each
/// dimension is represented by its fixed size or symbolic name.
fn format_shape(shape: &[Dimension]) -> String {
let dims = shape
.iter()
.map(|dim| match dim {
Dimension::Fixed(value) => value.to_string(),
Dimension::Symbolic(name) => name.clone(),
})
.collect::<Vec<_>>()
.join(", ");
format!("[{}]", dims)
}

/// Print a summary of the names and shapes of a list of input or output node IDs.
fn print_input_output_list(model: &Model, node_ids: &[NodeId]) {
for &node_id in node_ids {
let Some(info) = model.node_info(node_id) else {
continue;
};
println!(
" {}: {}",
info.name().unwrap_or("(unnamed)"),
info.shape()
.map(|dims| format_shape(&dims))
.unwrap_or("(unknown shape)".to_string())
);
}
}

/// Tool for inspecting converted ONNX models and running them with randomly
/// generated inputs.
///
Expand All @@ -234,6 +264,14 @@ fn main() -> Result<(), Box<dyn Error>> {
);
println!();

println!("Inputs");
print_input_output_list(&model, model.input_ids());
println!();

println!("Outputs");
print_input_output_list(&model, model.output_ids());
println!();

print_metadata(model.metadata());

println!();
Expand Down

0 comments on commit 5e15d62

Please sign in to comment.