Skip to content

Commit

Permalink
Merge pull request #48 from robertknight/model-metadata
Browse files Browse the repository at this point in the history
Add initial support for model metadata in RTen models
  • Loading branch information
robertknight authored Feb 4, 2024
2 parents 5562fd2 + b4adf51 commit 82eabb7
Show file tree
Hide file tree
Showing 9 changed files with 955 additions and 49 deletions.
103 changes: 70 additions & 33 deletions rten-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::error::Error;
use std::fs;
use std::time::Instant;

use rten::{Dimension, Input, Model, NodeId, Output, RunOptions};
use rten::{Dimension, Input, Model, ModelMetadata, NodeId, Output, RunOptions};
use rten_tensor::prelude::*;
use rten_tensor::Tensor;

Expand Down Expand Up @@ -37,6 +37,11 @@ fn parse_args() -> Result<Args, lexopt::Error> {
Usage: {bin_name} [OPTIONS] <model>
Args:
<model>
Path to '.rten' model to inspect and run.
Options:
-t, --timing Output timing info
-v, --verbose Enable verbose logging
-h, --help Print help
Expand Down Expand Up @@ -66,28 +71,27 @@ fn format_param_count(n: usize) -> String {
}
}

/// Tool for inspecting converted ONNX models and running them with randomly
/// generated inputs.
///
/// ```
/// tools/convert-onnx.py model.onnx output.rten
/// cargo run -p rten-cli --release output.rten
/// ```
///
/// To get detailed timing information set the `RTEN_TIMING` env var before
/// running. See `docs/profiling.md`.
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;
let model_bytes = fs::read(args.model)?;
let model = Model::load(&model_bytes)?;
fn print_metadata(metadata: &ModelMetadata) {
fn print_field<T: std::fmt::Display>(name: &str, value: Option<T>) {
if let Some(value) = value {
println!(" {}: {}", name, value);
}
}

println!(
"Model stats: {} inputs, {} outputs, {} params",
model.input_ids().len(),
model.output_ids().len(),
format_param_count(model.total_params()),
);
println!("Metadata:");
print_field("ONNX hash", metadata.onnx_hash());
print_field("Description", metadata.description());
print_field("License", metadata.license());
print_field("Commit", metadata.commit());
print_field("Repository", metadata.code_repository());
print_field("Model repository", metadata.model_repository());
print_field("Run ID", metadata.run_id());
print_field("Run URL", metadata.run_url());
}

/// Generate random inputs for `model` using shape metadata and heuristics,
/// run it, and print details of the output.
fn run_with_random_input(model: &Model, run_opts: RunOptions) -> Result<(), Box<dyn Error>> {
let mut rng = fastrand::Rng::new();

// Generate random ints that are likely to be valid token IDs in a language
Expand Down Expand Up @@ -165,27 +169,21 @@ fn main() -> Result<(), Box<dyn Error>> {
.as_ref()
.and_then(|ni| ni.name())
.unwrap_or("(unnamed)");
println!("Input \"{name}\" resolved shape {:?}", input.shape());
println!(" Input \"{name}\" generated shape {:?}", input.shape());
}

// Run model and summarize outputs.
let start = Instant::now();
let outputs = model.run(
&inputs,
model.output_ids(),
Some(RunOptions {
timing: args.timing,
verbose: args.verbose,
..Default::default()
}),
)?;
let outputs = model.run(&inputs, model.output_ids(), Some(run_opts))?;
let elapsed = start.elapsed().as_millis();

println!();
println!(
"Model returned {} outputs in {:.2}ms",
" Model returned {} outputs in {:.2}ms.",
outputs.len(),
elapsed
);
println!();

let output_names: Vec<String> = model
.output_ids()
Expand All @@ -204,11 +202,50 @@ fn main() -> Result<(), Box<dyn Error>> {
Output::IntTensor(_) => "i32",
};
println!(
"Output {i} \"{name}\" data type {} shape: {:?}",
" Output {i} \"{name}\" data type {} shape: {:?}",
dtype,
output.shape()
);
}

Ok(())
}

/// Tool for inspecting converted ONNX models and running them with randomly
/// generated inputs.
///
/// ```
/// tools/convert-onnx.py model.onnx output.rten
/// cargo run -p rten-cli --release output.rten
/// ```
///
/// To get detailed timing information set the `RTEN_TIMING` env var before
/// running. See `docs/profiling.md`.
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;
let model_bytes = fs::read(args.model)?;
let model = Model::load(&model_bytes)?;

println!(
"Model summary: {} inputs, {} outputs, {} params",
model.input_ids().len(),
model.output_ids().len(),
format_param_count(model.total_params()),
);
println!();

print_metadata(model.metadata());

println!();
println!("Running model with random inputs...");
run_with_random_input(
&model,
RunOptions {
timing: args.timing,
verbose: args.verbose,
..Default::default()
},
)?;

Ok(())
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mod gemm;
mod graph;
mod iter_util;
mod model;
mod model_metadata;
mod number;
mod slice_reductions;
mod timer;
Expand All @@ -44,6 +45,7 @@ pub mod ops;

pub use graph::{Dimension, NodeId, RunOptions};
pub use model::{DefaultOperatorFactory, Model, ModelLoadError, NodeInfo, OpRegistry};
pub use model_metadata::ModelMetadata;
pub use ops::{FloatOperators, Input, Operators, Output};
pub use timer::Timer;
pub use timing::TimingSort;
Expand Down
27 changes: 26 additions & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use rten_tensor::Tensor;
use smallvec::smallvec;

use crate::graph::{Dimension, Graph, Node, NodeId, RunError, RunOptions};
use crate::model_metadata::ModelMetadata;
use crate::ops;
use crate::ops::{
BoxOrder, CoordTransformMode, DataType, Direction, Input, NearestMode, Operator, Output,
Expand Down Expand Up @@ -52,6 +53,7 @@ pub struct Model {
input_ids: Vec<NodeId>,
output_ids: Vec<NodeId>,
graph: Graph,
metadata: ModelMetadata,
}

/// Provides access to metadata about a graph node.
Expand Down Expand Up @@ -144,6 +146,11 @@ impl Model {
self.graph.get_node(id).map(|node| NodeInfo { node })
}

/// Return metadata about the model.
pub fn metadata(&self) -> &ModelMetadata {
&self.metadata
}

/// Return the IDs of input nodes.
pub fn input_ids(&self) -> &[NodeId] {
&self.input_ids
Expand Down Expand Up @@ -1108,11 +1115,17 @@ fn load_model(data: &[u8], registry: &OpRegistry) -> Result<Model, ModelLoadErro
}
}

let metadata = model
.metadata()
.map(ModelMetadata::deserialize)
.unwrap_or_default();

let model = Model {
node_ids: node_id_from_name,
input_ids,
output_ids,
graph,
metadata,
};
Ok(model)
}
Expand All @@ -1126,7 +1139,7 @@ mod tests {

use crate::graph::{Dimension, RunError};
use crate::model::Model;
use crate::model_builder::{ModelBuilder, OpType};
use crate::model_builder::{MetadataArgs, ModelBuilder, OpType};
use crate::ops;
use crate::ops::{BoxOrder, CoordTransformMode, NearestMode, OpError, ResizeMode, Scalar};

Expand Down Expand Up @@ -1157,6 +1170,10 @@ mod tests {
);
builder.add_operator("relu", OpType::Relu, &[Some(concat_out)], &[output_node]);

builder.add_metadata(MetadataArgs {
onnx_hash: Some("abc".to_string()),
});

builder.finish()
}

Expand Down Expand Up @@ -1198,6 +1215,14 @@ mod tests {
assert_eq!(shape, &[1, 2, 2].map(Dimension::Fixed));
}

#[test]
fn test_metadata() {
let buffer = generate_model_buffer();
let model = Model::load(&buffer).unwrap();
assert_eq!(model.metadata().onnx_hash(), Some("abc"));
assert_eq!(model.metadata().description(), None);
}

#[test]
fn test_input_shape() {
let buffer = generate_model_buffer();
Expand Down
21 changes: 21 additions & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ pub struct ModelBuilder<'a> {
nodes: Vec<WIPOffset<sg::Node<'a>>>,
input_ids: Vec<u32>,
output_ids: Vec<u32>,
metadata: Option<WIPOffset<sg::Metadata<'a>>>,
}

enum NodeData<'a> {
Expand All @@ -126,6 +127,11 @@ enum NodeData<'a> {
Operator(WIPOffset<sg::OperatorNode<'a>>),
}

/// Arguments for [ModelBuilder::add_metadata].
pub struct MetadataArgs {
pub onnx_hash: Option<String>,
}

struct PadArgs {
pad_mode: sg::PadMode,
pads: Option<Vec<usize>>,
Expand All @@ -152,6 +158,7 @@ impl<'a> ModelBuilder<'a> {
nodes: Vec::new(),
input_ids: Vec::new(),
output_ids: Vec::new(),
metadata: None,
}
}

Expand Down Expand Up @@ -683,6 +690,19 @@ impl<'a> ModelBuilder<'a> {
self.output_ids.push(node_id);
}

/// Add model metadata
pub fn add_metadata(&mut self, metadata: MetadataArgs) {
let hash = metadata
.onnx_hash
.as_ref()
.map(|hash| self.builder.create_string(hash));
let mut meta_builder = sg::MetadataBuilder::new(&mut self.builder);
if let Some(hash) = hash {
meta_builder.add_onnx_hash(hash);
}
self.metadata = Some(meta_builder.finish());
}

/// Finish writing the model data to the buffer and return the buffer's contents.
pub fn finish(mut self) -> Vec<u8> {
let inputs_vec = self.builder.create_vector(&self.input_ids[..]);
Expand All @@ -703,6 +723,7 @@ impl<'a> ModelBuilder<'a> {
&sg::ModelArgs {
schema_version: 1,
graph: Some(graph),
metadata: self.metadata,
},
);

Expand Down
Loading

0 comments on commit 82eabb7

Please sign in to comment.