diff --git a/rten-generate/src/generator.rs b/rten-generate/src/generator.rs index c4205c89..d15dd2e9 100644 --- a/rten-generate/src/generator.rs +++ b/rten-generate/src/generator.rs @@ -764,8 +764,12 @@ mod tests { /// Return a model with a given set of inputs and outputs. fn with_inputs_and_outputs(inputs: &[NodeInfo], outputs: &[NodeInfo]) -> FakeModel { let node_infos = [inputs, outputs].concat(); - let input_ids = (0..inputs.len()).collect(); - let output_ids = (inputs.len()..(inputs.len() + outputs.len())).collect(); + let input_ids = (0..inputs.len()) + .map(|id| NodeId::from_u32(id as u32)) + .collect(); + let output_ids = (inputs.len()..(inputs.len() + outputs.len())) + .map(|id| NodeId::from_u32(id as u32)) + .collect(); FakeModel { input_ids, @@ -796,11 +800,14 @@ mod tests { impl Model for FakeModel { fn find_node(&self, name: &str) -> Option { - self.nodes.iter().position(|info| info.name() == name) + self.nodes + .iter() + .position(|info| info.name() == name) + .map(|pos| NodeId::from_u32(pos as u32)) } fn node_info(&self, id: NodeId) -> Option { - self.nodes.get(id).cloned() + self.nodes.get(id.as_u32() as usize).cloned() } fn input_ids(&self) -> &[NodeId] { diff --git a/src/graph.rs b/src/graph.rs index 82143ac9..d766c6a0 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::error::Error; use std::fmt; +use std::num::NonZero; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -281,8 +282,45 @@ impl Node { } } -/// ID of a node in a [Model](crate::Model) graph. -pub type NodeId = usize; +/// ID of a node in a [`Model`](crate::Model) graph. +/// +/// This is used to identify input and output values as well as internal nodes. +/// +/// Node IDs are u32 values <= `i32::MAX`. +#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct NodeId(NonZero); + +impl NodeId { + /// Return the underlying u32 value of the ID. + pub fn as_u32(self) -> u32 { + self.0.get() - 1 + } + + /// Construct a node ID from a u32 value. + /// + /// Panics if the value exceeds `i32::MAX`. + pub fn from_u32(value: u32) -> NodeId { + // Node IDs are limited to `i32::MAX` because the `OperatorNode` type + // in the FlatBuffers schema represents operator input and output IDs + // as `i32`. Negative values are used as a niche to represent missing + // optional inputs. + assert!(value <= i32::MAX as u32); + + // Valid node IDs are in the range `[0, i32::MAX]`, so we store them as + // values in `[1, i32::MAX + 1]` internally and reserve 0 as a niche to + // make `Option` the same size as `NodeId`. + NodeId(unsafe { + // Safety: `value + 1` cannot be zero + NonZero::new_unchecked(value + 1) + }) + } +} + +impl std::fmt::Display for NodeId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_u32().fmt(f) + } +} /// Reasons why a graph execution failed #[derive(Eq, PartialEq, Debug)] @@ -368,14 +406,14 @@ impl NodeRefCount { /// Increment ref count of node. If the refcount reaches `u8::MAX` it /// will become "sticky" and never decrement. fn inc(&mut self, id: NodeId) { - let rc = &mut self.rc[id]; + let rc = &mut self.rc[id.as_u32() as usize]; *rc = rc.saturating_add(1); } /// Decrement ref count of node and return new count, or `None` if the /// ref count was already zero. fn dec(&mut self, id: NodeId) -> Option { - let rc = &mut self.rc[id]; + let rc = &mut self.rc[id.as_u32() as usize]; // If the refcount reaches the max value, it becomes sticky. if *rc == u8::MAX { @@ -389,7 +427,7 @@ impl NodeRefCount { } fn count(&self, id: NodeId) -> usize { - self.rc[id] as usize + self.rc[id.as_u32() as usize] as usize } } @@ -674,10 +712,10 @@ impl Graph { } pub fn add_node(&mut self, node: Node) -> NodeId { + let node_id = NodeId::from_u32(self.nodes.len() as u32); self.nodes.push(node); - let node_id = self.nodes.len() - 1; - if let Some(name) = self.nodes[node_id].name() { + if let Some(name) = self.nodes[node_id.as_u32() as usize].name() { self.node_id_from_name.insert(name.to_string(), node_id); } @@ -775,7 +813,10 @@ impl Graph { /// Return an iterator over nodes in the graph. pub fn iter(&self) -> impl Iterator { - self.nodes.iter().enumerate() + self.nodes + .iter() + .enumerate() + .map(|(i, node)| (NodeId::from_u32(i as u32), node)) } /// Return the debug name for a node. @@ -788,7 +829,7 @@ impl Graph { /// Retrieve a node by ID pub fn get_node(&self, id: NodeId) -> Option<&Node> { - self.nodes.get(id) + self.nodes.get(id.as_u32() as usize) } /// Look up a node ID given its unique name @@ -808,7 +849,7 @@ impl Graph { /// Retrieve a node by ID pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Node> { - self.nodes.get_mut(id) + self.nodes.get_mut(id.as_u32() as usize) } /// Return the total number of parameters in all constant nodes in this @@ -928,7 +969,7 @@ impl Graph { let inputs_by_id: FxHashMap = inputs.iter().cloned().collect(); let get_value_from_constant_or_input = |node_id: NodeId| -> Option { - match self.nodes.get(node_id) { + match self.nodes.get(node_id.as_u32() as usize) { Some(Node::Constant(constant)) => Some(constant.as_input()), Some(Node::Value(_)) => inputs_by_id.get(&node_id).map(|input| input.as_input()), _ => { @@ -938,7 +979,10 @@ impl Graph { }; let get_value_from_capture = |node_id: NodeId| -> Option { - let name = self.nodes.get(node_id).and_then(|n| n.name())?; + let name = self + .nodes + .get(node_id.as_u32() as usize) + .and_then(|n| n.name())?; captures.as_ref().and_then(|cap| cap.get_input(name)) }; @@ -946,13 +990,13 @@ impl Graph { // when no longer needed. let mut temp_value_refcount = NodeRefCount::with_capacity(self.nodes.len()); for &op_node_id in plan.iter() { - let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id) else { + let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_u32() as usize) else { return Err(RunError::PlanningError( "operator node not found".to_string(), )); }; for node_id in self.operator_dependencies(op_node) { - if let Some(Node::Value(_)) = self.nodes.get(node_id) { + if let Some(Node::Value(_)) = self.nodes.get(node_id.as_u32() as usize) { temp_value_refcount.inc(node_id); } } @@ -984,7 +1028,7 @@ impl Graph { let mut op_start = Instant::now(); for (step, &op_node_id) in plan.iter().enumerate() { - let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id) else { + let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_u32() as usize) else { return Err(RunError::PlanningError( "operator node not found".to_string(), )); @@ -1305,7 +1349,7 @@ impl Graph { // Walk forwards through the plan and prune away steps that cannot be // computed due to missing inputs. for &node_id in plan { - let Some(Node::Operator(op_node)) = self.nodes.get(node_id) else { + let Some(Node::Operator(op_node)) = self.nodes.get(node_id.as_u32() as usize) else { continue; }; @@ -1354,12 +1398,11 @@ impl Graph { inputs: I, include_captures: bool, ) -> FxHashSet { - let mut resolved: FxHashSet = - inputs - .chain(self.nodes.iter().enumerate().filter_map(|(node_id, node)| { - matches!(node, Node::Constant(_)).then_some(node_id) - })) - .collect(); + let mut resolved: FxHashSet = inputs + .chain(self.nodes.iter().enumerate().filter_map(|(node_id, node)| { + matches!(node, Node::Constant(_)).then_some(NodeId::from_u32(node_id as u32)) + })) + .collect(); if include_captures { resolved.extend(self.captures().iter().copied()); @@ -1514,7 +1557,7 @@ mod tests { use smallvec::{smallvec, SmallVec}; use super::{CachedPlan, CaptureEnv}; - use crate::graph::{Dimension, Graph, Node, RunError, RunOptions, TypedConstant}; + use crate::graph::{Dimension, Graph, Node, NodeId, RunError, RunOptions, TypedConstant}; use crate::ops::{ Add, Concat, Conv, Identity, If, InputList, IntoOpResult, Mul, OpError, Operator, Output, OutputList, Relu, Shape, @@ -1943,7 +1986,7 @@ mod tests { #[test] fn test_err_if_invalid_output() { let g = Graph::new(); - let result = g.run(vec![], &[123], None); + let result = g.run(vec![], &[NodeId::from_u32(123)], None); assert_eq!( result.err(), Some(RunError::PlanningError("Missing output 123".to_string())) @@ -1953,7 +1996,7 @@ mod tests { #[test] fn test_err_if_missing_operator_input() { let mut g = Graph::new(); - let (_, output) = g.add_simple_op("op", Relu {}, &[42]); + let (_, output) = g.add_simple_op("op", Relu {}, &[NodeId::from_u32(42)]); let result = g.run(vec![], &[output], None); assert_eq!( result.err(), @@ -2268,21 +2311,27 @@ mod tests { #[test] fn test_cached_plan_matches() { - let input_ids = &[3, 1, 2]; - let output_ids = &[6, 4, 5]; - let op_ids = &[10, 11, 12]; + let input_ids = &[3, 1, 2].map(NodeId::from_u32); + let output_ids = &[6, 4, 5].map(NodeId::from_u32); + let op_ids = &[10, 11, 12].map(NodeId::from_u32); let plan = CachedPlan::new(input_ids, output_ids, op_ids.to_vec()); assert!(plan.matches(input_ids, output_ids)); // Same input and output IDs, different orders. - assert!(plan.matches(&[1, 2, 3], &[4, 5, 6])); - assert!(plan.matches(&[3, 2, 1], &[6, 5, 4])); + assert!(plan.matches( + &[1, 2, 3].map(NodeId::from_u32), + &[4, 5, 6].map(NodeId::from_u32) + )); + assert!(plan.matches( + &[3, 2, 1].map(NodeId::from_u32), + &[6, 5, 4].map(NodeId::from_u32) + )); // Different input and output IDs - assert!(!plan.matches(&[20, 21, 22], output_ids)); - assert!(!plan.matches(input_ids, &[20, 21, 22])); + assert!(!plan.matches(&[20, 21, 22].map(NodeId::from_u32), output_ids)); + assert!(!plan.matches(input_ids, &[20, 21, 22].map(NodeId::from_u32))); } /// A trivial control flow operator which just forwards inputs to a subgraph diff --git a/src/model.rs b/src/model.rs index 18dbed2c..a8492df4 100644 --- a/src/model.rs +++ b/src/model.rs @@ -356,12 +356,12 @@ impl Model { let input_ids: Vec = serialized_graph .inputs() - .map(|ids| ids.iter().map(|id| id as NodeId).collect()) + .map(|ids| ids.iter().map(NodeId::from_u32).collect()) .unwrap_or_default(); let output_ids: Vec = serialized_graph .outputs() - .map(|ids| ids.iter().map(|id| id as NodeId).collect()) + .map(|ids| ids.iter().map(NodeId::from_u32).collect()) .unwrap_or_default(); let mut graph = Graph::with_capacity(node_count); @@ -369,7 +369,7 @@ impl Model { graph.set_output_ids(&output_ids); if let Some(captures) = serialized_graph.captures() { - let captures: Vec = captures.iter().map(|id| id as NodeId).collect(); + let captures: Vec = captures.iter().map(NodeId::from_u32).collect(); graph.set_captures(&captures); } @@ -839,7 +839,7 @@ mod tests { use rten_tensor::prelude::*; use rten_tensor::Tensor; - use crate::graph::{Dimension, RunError}; + use crate::graph::{Dimension, NodeId, RunError}; use crate::model::{Model, ModelOptions}; use crate::model_builder::{ GraphBuilder, IfArgs, MetadataArgs, ModelBuilder, ModelFormat, OpType, @@ -1147,7 +1147,7 @@ mod tests { .load(buffer) .unwrap(); - let result = model.run(vec![], &[output_node as usize], None); + let result = model.run(vec![], &[output_node], None); assert_eq!( result.err(), @@ -1181,7 +1181,7 @@ mod tests { let mut op_outputs = Vec::new(); let mut add_operator = - |builder: &mut GraphBuilder, name: &str, op: OpType, input_nodes: &[Option]| { + |builder: &mut GraphBuilder, name: &str, op: OpType, input_nodes: &[Option]| { let output_name = format!("{}_out", name); let op_output_node = builder.add_value(&output_name, None); builder.add_operator(name, op, input_nodes, &[op_output_node]); @@ -1605,8 +1605,8 @@ mod tests { let result = model .run( vec![ - (input_node as usize, input.view().into()), - (input_bool as usize, input_bool_data.view().into()), + (input_node, input.view().into()), + (input_bool, input_bool_data.view().into()), ], &[output_id], None, @@ -1629,11 +1629,7 @@ mod tests { for output in outputs { let output_id = model.find_node(output).unwrap(); let result = model - .run( - vec![(input_2d as usize, input.view().into())], - &[output_id], - None, - ) + .run(vec![(input_2d, input.view().into())], &[output_id], None) .unwrap(); assert_eq!(result.len(), 1); } @@ -1645,11 +1641,11 @@ mod tests { let result = model .run( vec![ - (range_start_node as usize, start.into()), - (range_limit_node as usize, limit.into()), - (range_delta_node as usize, delta.into()), + (range_start_node, start.into()), + (range_limit_node, limit.into()), + (range_delta_node, delta.into()), ], - &[range_out as usize], + &[range_out], None, ) .unwrap(); @@ -1662,11 +1658,11 @@ mod tests { let result = model .run( vec![ - (where_cond as usize, cond.into()), - (where_x as usize, x.into()), - (where_y as usize, y.into()), + (where_cond, cond.into()), + (where_x, x.into()), + (where_y, y.into()), ], - &[where_out as usize], + &[where_out], None, ) .unwrap(); diff --git a/src/model_builder.rs b/src/model_builder.rs index 0b45d5a6..fd0d98f0 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -2,7 +2,7 @@ use flatbuffers::{FlatBufferBuilder, UnionWIPOffset, Vector, WIPOffset}; use rten_tensor::prelude::*; use rten_tensor::TensorView; -use crate::graph::Dimension; +use crate::graph::{Dimension, NodeId}; use crate::header::Header; use crate::number::LeBytes; use crate::ops::{ @@ -229,8 +229,8 @@ pub struct GraphBuilder<'mb, 'a> { tensor_data_builder: Option<&'mb mut TensorDataBuilder>, nodes: Vec>>, - input_ids: Vec, - output_ids: Vec, + input_ids: Vec, + output_ids: Vec, } impl<'mb, 'a> GraphBuilder<'mb, 'a> { @@ -248,7 +248,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { } } - fn add_node(&mut self, name: Option<&str>, data: NodeData) -> u32 { + fn add_node(&mut self, name: Option<&str>, data: NodeData) -> NodeId { let (data_type, union_val) = match data { NodeData::Constant(offset) => (sg::NodeKind::ConstantNode, offset.as_union_value()), NodeData::Value(offset) => (sg::NodeKind::ValueNode, offset.as_union_value()), @@ -261,7 +261,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { }; let node = sg::Node::create(self.builder, &args); self.nodes.push(node); - (self.nodes.len() - 1) as u32 + NodeId::from_u32((self.nodes.len() - 1) as u32) } /// Return a graph builder for a subgraph. @@ -280,7 +280,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { pub fn add_constant( &mut self, input: TensorView, - ) -> u32 { + ) -> NodeId { let shape: Vec = input.shape().iter().map(|&x| x as u32).collect(); let shape_vec = self.builder.create_vector(&shape[..]); let dtype = ::dtype(); @@ -314,7 +314,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { } /// Add a value node to the model - pub fn add_value(&mut self, id: &str, shape: Option<&[Dimension]>) -> u32 { + pub fn add_value(&mut self, id: &str, shape: Option<&[Dimension]>) -> NodeId { let shape = shape.map(|shape| { let dim_vec: Vec<_> = shape .iter() @@ -349,9 +349,9 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { &mut self, id: &str, op_info: OpType, - inputs: &[Option], - outputs: &[u32], - ) -> u32 { + inputs: &[Option], + outputs: &[NodeId], + ) -> NodeId { // Generate an (op_type, attr_type, attrs) tuple for an operator with // no attributes. macro_rules! op { @@ -832,11 +832,11 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { let input_ids: Vec = inputs .iter() .map(|&id| match id { - Some(id) => id as i32, + Some(id) => id.as_u32() as i32, None => -1, }) .collect(); - let output_ids: Vec = outputs.iter().map(|&id| id as i32).collect(); + let output_ids: Vec = outputs.iter().map(|&id| id.as_u32() as i32).collect(); let input_vec = self.builder.create_vector(&input_ids); let output_vec = self.builder.create_vector(&output_ids); @@ -854,12 +854,12 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { } /// Mark a node in the graph as an input. - pub fn add_input(&mut self, node_id: u32) { + pub fn add_input(&mut self, node_id: NodeId) { self.input_ids.push(node_id); } /// Mark a node in the graph as an output. - pub fn add_output(&mut self, node_id: u32) { + pub fn add_output(&mut self, node_id: NodeId) { self.output_ids.push(node_id); } @@ -877,8 +877,11 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { /// Finish writing this graph to the FlatBuffers buffer. pub fn finish(self) -> WIPOffset> { - let inputs_vec = self.builder.create_vector(&self.input_ids[..]); - let outputs_vec = self.builder.create_vector(&self.output_ids[..]); + let input_ids: Vec<_> = self.input_ids.iter().map(|id| id.as_u32()).collect(); + let output_ids: Vec<_> = self.output_ids.iter().map(|id| id.as_u32()).collect(); + + let inputs_vec = self.builder.create_vector(&input_ids); + let outputs_vec = self.builder.create_vector(&output_ids); let nodes_vec = self.builder.create_vector(&self.nodes[..]); sg::Graph::create( diff --git a/src/optimize.rs b/src/optimize.rs index b0d8dbf4..1776b54b 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -611,7 +611,7 @@ mod tests { use super::{GraphOptimizer, OptimizeError}; use crate::constant_storage::{ArcSlice, ArcTensorView, ConstantStorage}; use crate::downcast::DowncastDyn; - use crate::graph::{CaptureEnv, Constant, Graph, Node}; + use crate::graph::{CaptureEnv, Constant, Graph, Node, NodeId}; use crate::ops::{ Add, Div, Erf, LayerNormalization, MatMul, Mul, Pow, ReduceMean, Sigmoid, Sqrt, Sub, Transpose, @@ -871,7 +871,7 @@ mod tests { fn test_optimize_error() { let mut graph = Graph::new(); let optimizer = GraphOptimizer::new(); - let invalid_id = 123; + let invalid_id = NodeId::from_u32(123); graph.set_input_ids(&[invalid_id]); graph.set_output_ids(&[invalid_id]); let result = optimizer.optimize(graph, None); diff --git a/src/wasm_api.rs b/src/wasm_api.rs index cad6d416..edbceafe 100644 --- a/src/wasm_api.rs +++ b/src/wasm_api.rs @@ -5,7 +5,7 @@ use rten_tensor::prelude::*; use rten_tensor::rng::XorShiftRng; use wasm_bindgen::prelude::*; -use crate::graph::Dimension; +use crate::graph::{Dimension, NodeId}; use crate::model; use crate::ops::{matmul, InputOrOutput, Output}; use crate::tensor_pool::TensorPool; @@ -26,35 +26,45 @@ impl Model { /// Find the ID of a node in the graph from its name. #[wasm_bindgen(js_name = findNode)] - pub fn find_node(&self, name: &str) -> Option { - self.model.find_node(name) + pub fn find_node(&self, name: &str) -> Option { + self.model.find_node(name).map(|id| id.as_u32()) } /// Get metadata about the node with a given ID. /// /// This is useful for getting the input tensor shape expected by the model. #[wasm_bindgen(js_name = nodeInfo)] - pub fn node_info(&self, id: usize) -> Option { - self.model.node_info(id).map(|ni| NodeInfo { - name: ni.name().map(|n| n.to_string()), - shape: ni.shape(), - }) + pub fn node_info(&self, id: u32) -> Option { + self.model + .node_info(NodeId::from_u32(id)) + .map(|ni| NodeInfo { + name: ni.name().map(|n| n.to_string()), + shape: ni.shape(), + }) } /// Return the IDs of input nodes. /// /// Additional details about the nodes can be obtained using `node_info`. #[wasm_bindgen(js_name = inputIds)] - pub fn input_ids(&self) -> Vec { - self.model.input_ids().into() + pub fn input_ids(&self) -> Vec { + self.model + .input_ids() + .iter() + .map(|id| id.as_u32()) + .collect() } /// Return the IDs of output nodes. /// /// Additional details about the nodes can be obtained using `node_info`. #[wasm_bindgen(js_name = outputIds)] - pub fn output_ids(&self) -> Vec { - self.model.output_ids().into() + pub fn output_ids(&self) -> Vec { + self.model + .output_ids() + .iter() + .map(|id| id.as_u32()) + .collect() } /// Execute the model, passing `input` as the tensor values for the node @@ -62,16 +72,18 @@ impl Model { /// specified by `output_ids`. pub fn run( &self, - input_ids: &[usize], + input_ids: &[u32], input: Vec, - output_ids: &[usize], + output_ids: &[u32], ) -> Result, String> { - let inputs: Vec<(usize, InputOrOutput)> = input_ids + let inputs: Vec<(NodeId, InputOrOutput)> = input_ids .iter() .copied() + .map(NodeId::from_u32) .zip(input.iter().map(|tensor| tensor.data.as_input().into())) .collect(); - let result = self.model.run(inputs, output_ids, None); + let output_ids: Vec = output_ids.iter().copied().map(NodeId::from_u32).collect(); + let result = self.model.run(inputs, &output_ids, None); match result { Ok(outputs) => { let mut list = Vec::new();