Skip to content

Commit

Permalink
Merge pull request #381 from robertknight/u32-node-id
Browse files Browse the repository at this point in the history
Convert `NodeId` from an alias for `usize` to u32-sized opaque type
  • Loading branch information
robertknight authored Oct 13, 2024
2 parents 2ee9990 + cecd0ed commit d7a3490
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 91 deletions.
15 changes: 11 additions & 4 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,8 +764,12 @@ mod tests {
/// Return a model with a given set of inputs and outputs.
fn with_inputs_and_outputs(inputs: &[NodeInfo], outputs: &[NodeInfo]) -> FakeModel {
let node_infos = [inputs, outputs].concat();
let input_ids = (0..inputs.len()).collect();
let output_ids = (inputs.len()..(inputs.len() + outputs.len())).collect();
let input_ids = (0..inputs.len())
.map(|id| NodeId::from_u32(id as u32))
.collect();
let output_ids = (inputs.len()..(inputs.len() + outputs.len()))
.map(|id| NodeId::from_u32(id as u32))
.collect();

FakeModel {
input_ids,
Expand Down Expand Up @@ -796,11 +800,14 @@ mod tests {

impl Model for FakeModel {
fn find_node(&self, name: &str) -> Option<NodeId> {
self.nodes.iter().position(|info| info.name() == name)
self.nodes
.iter()
.position(|info| info.name() == name)
.map(|pos| NodeId::from_u32(pos as u32))
}

fn node_info(&self, id: NodeId) -> Option<NodeInfo> {
self.nodes.get(id).cloned()
self.nodes.get(id.as_u32() as usize).cloned()
}

fn input_ids(&self) -> &[NodeId] {
Expand Down
113 changes: 81 additions & 32 deletions src/graph.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::num::NonZero;
use std::sync::{Arc, Mutex};
use std::time::Duration;

Expand Down Expand Up @@ -281,8 +282,45 @@ impl Node {
}
}

/// ID of a node in a [Model](crate::Model) graph.
pub type NodeId = usize;
/// ID of a node in a [`Model`](crate::Model) graph.
///
/// This is used to identify input and output values as well as internal nodes.
///
/// Node IDs are u32 values <= `i32::MAX`.
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct NodeId(NonZero<u32>);

impl NodeId {
/// Return the underlying u32 value of the ID.
pub fn as_u32(self) -> u32 {
self.0.get() - 1
}

/// Construct a node ID from a u32 value.
///
/// Panics if the value exceeds `i32::MAX`.
pub fn from_u32(value: u32) -> NodeId {
// Node IDs are limited to `i32::MAX` because the `OperatorNode` type
// in the FlatBuffers schema represents operator input and output IDs
// as `i32`. Negative values are used as a niche to represent missing
// optional inputs.
assert!(value <= i32::MAX as u32);

// Valid node IDs are in the range `[0, i32::MAX]`, so we store them as
// values in `[1, i32::MAX + 1]` internally and reserve 0 as a niche to
// make `Option<NodeId>` the same size as `NodeId`.
NodeId(unsafe {
// Safety: `value + 1` cannot be zero
NonZero::new_unchecked(value + 1)
})
}
}

impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.as_u32().fmt(f)
}
}

/// Reasons why a graph execution failed
#[derive(Eq, PartialEq, Debug)]
Expand Down Expand Up @@ -368,14 +406,14 @@ impl NodeRefCount {
/// Increment ref count of node. If the refcount reaches `u8::MAX` it
/// will become "sticky" and never decrement.
fn inc(&mut self, id: NodeId) {
let rc = &mut self.rc[id];
let rc = &mut self.rc[id.as_u32() as usize];
*rc = rc.saturating_add(1);
}

/// Decrement ref count of node and return new count, or `None` if the
/// ref count was already zero.
fn dec(&mut self, id: NodeId) -> Option<usize> {
let rc = &mut self.rc[id];
let rc = &mut self.rc[id.as_u32() as usize];

// If the refcount reaches the max value, it becomes sticky.
if *rc == u8::MAX {
Expand All @@ -389,7 +427,7 @@ impl NodeRefCount {
}

fn count(&self, id: NodeId) -> usize {
self.rc[id] as usize
self.rc[id.as_u32() as usize] as usize
}
}

Expand Down Expand Up @@ -674,10 +712,10 @@ impl Graph {
}

pub fn add_node(&mut self, node: Node) -> NodeId {
let node_id = NodeId::from_u32(self.nodes.len() as u32);
self.nodes.push(node);
let node_id = self.nodes.len() - 1;

if let Some(name) = self.nodes[node_id].name() {
if let Some(name) = self.nodes[node_id.as_u32() as usize].name() {
self.node_id_from_name.insert(name.to_string(), node_id);
}

Expand Down Expand Up @@ -775,7 +813,10 @@ impl Graph {

/// Return an iterator over nodes in the graph.
pub fn iter(&self) -> impl Iterator<Item = (NodeId, &Node)> {
self.nodes.iter().enumerate()
self.nodes
.iter()
.enumerate()
.map(|(i, node)| (NodeId::from_u32(i as u32), node))
}

/// Return the debug name for a node.
Expand All @@ -788,7 +829,7 @@ impl Graph {

/// Retrieve a node by ID
pub fn get_node(&self, id: NodeId) -> Option<&Node> {
self.nodes.get(id)
self.nodes.get(id.as_u32() as usize)
}

/// Look up a node ID given its unique name
Expand All @@ -808,7 +849,7 @@ impl Graph {

/// Retrieve a node by ID
pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
self.nodes.get_mut(id)
self.nodes.get_mut(id.as_u32() as usize)
}

/// Return the total number of parameters in all constant nodes in this
Expand Down Expand Up @@ -928,7 +969,7 @@ impl Graph {

let inputs_by_id: FxHashMap<NodeId, InputOrOutput> = inputs.iter().cloned().collect();
let get_value_from_constant_or_input = |node_id: NodeId| -> Option<Input> {
match self.nodes.get(node_id) {
match self.nodes.get(node_id.as_u32() as usize) {
Some(Node::Constant(constant)) => Some(constant.as_input()),
Some(Node::Value(_)) => inputs_by_id.get(&node_id).map(|input| input.as_input()),
_ => {
Expand All @@ -938,21 +979,24 @@ impl Graph {
};

let get_value_from_capture = |node_id: NodeId| -> Option<Input> {
let name = self.nodes.get(node_id).and_then(|n| n.name())?;
let name = self
.nodes
.get(node_id.as_u32() as usize)
.and_then(|n| n.name())?;
captures.as_ref().and_then(|cap| cap.get_input(name))
};

// Count how often each temporary output is used, so we can free them
// when no longer needed.
let mut temp_value_refcount = NodeRefCount::with_capacity(self.nodes.len());
for &op_node_id in plan.iter() {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id) else {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_u32() as usize) else {
return Err(RunError::PlanningError(
"operator node not found".to_string(),
));
};
for node_id in self.operator_dependencies(op_node) {
if let Some(Node::Value(_)) = self.nodes.get(node_id) {
if let Some(Node::Value(_)) = self.nodes.get(node_id.as_u32() as usize) {
temp_value_refcount.inc(node_id);
}
}
Expand Down Expand Up @@ -984,7 +1028,7 @@ impl Graph {
let mut op_start = Instant::now();

for (step, &op_node_id) in plan.iter().enumerate() {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id) else {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_u32() as usize) else {
return Err(RunError::PlanningError(
"operator node not found".to_string(),
));
Expand Down Expand Up @@ -1305,7 +1349,7 @@ impl Graph {
// Walk forwards through the plan and prune away steps that cannot be
// computed due to missing inputs.
for &node_id in plan {
let Some(Node::Operator(op_node)) = self.nodes.get(node_id) else {
let Some(Node::Operator(op_node)) = self.nodes.get(node_id.as_u32() as usize) else {
continue;
};

Expand Down Expand Up @@ -1354,12 +1398,11 @@ impl Graph {
inputs: I,
include_captures: bool,
) -> FxHashSet<NodeId> {
let mut resolved: FxHashSet<NodeId> =
inputs
.chain(self.nodes.iter().enumerate().filter_map(|(node_id, node)| {
matches!(node, Node::Constant(_)).then_some(node_id)
}))
.collect();
let mut resolved: FxHashSet<NodeId> = inputs
.chain(self.nodes.iter().enumerate().filter_map(|(node_id, node)| {
matches!(node, Node::Constant(_)).then_some(NodeId::from_u32(node_id as u32))
}))
.collect();

if include_captures {
resolved.extend(self.captures().iter().copied());
Expand Down Expand Up @@ -1514,7 +1557,7 @@ mod tests {
use smallvec::{smallvec, SmallVec};

use super::{CachedPlan, CaptureEnv};
use crate::graph::{Dimension, Graph, Node, RunError, RunOptions, TypedConstant};
use crate::graph::{Dimension, Graph, Node, NodeId, RunError, RunOptions, TypedConstant};
use crate::ops::{
Add, Concat, Conv, Identity, If, InputList, IntoOpResult, Mul, OpError, Operator, Output,
OutputList, Relu, Shape,
Expand Down Expand Up @@ -1943,7 +1986,7 @@ mod tests {
#[test]
fn test_err_if_invalid_output() {
let g = Graph::new();
let result = g.run(vec![], &[123], None);
let result = g.run(vec![], &[NodeId::from_u32(123)], None);
assert_eq!(
result.err(),
Some(RunError::PlanningError("Missing output 123".to_string()))
Expand All @@ -1953,7 +1996,7 @@ mod tests {
#[test]
fn test_err_if_missing_operator_input() {
let mut g = Graph::new();
let (_, output) = g.add_simple_op("op", Relu {}, &[42]);
let (_, output) = g.add_simple_op("op", Relu {}, &[NodeId::from_u32(42)]);
let result = g.run(vec![], &[output], None);
assert_eq!(
result.err(),
Expand Down Expand Up @@ -2268,21 +2311,27 @@ mod tests {

#[test]
fn test_cached_plan_matches() {
let input_ids = &[3, 1, 2];
let output_ids = &[6, 4, 5];
let op_ids = &[10, 11, 12];
let input_ids = &[3, 1, 2].map(NodeId::from_u32);
let output_ids = &[6, 4, 5].map(NodeId::from_u32);
let op_ids = &[10, 11, 12].map(NodeId::from_u32);

let plan = CachedPlan::new(input_ids, output_ids, op_ids.to_vec());

assert!(plan.matches(input_ids, output_ids));

// Same input and output IDs, different orders.
assert!(plan.matches(&[1, 2, 3], &[4, 5, 6]));
assert!(plan.matches(&[3, 2, 1], &[6, 5, 4]));
assert!(plan.matches(
&[1, 2, 3].map(NodeId::from_u32),
&[4, 5, 6].map(NodeId::from_u32)
));
assert!(plan.matches(
&[3, 2, 1].map(NodeId::from_u32),
&[6, 5, 4].map(NodeId::from_u32)
));

// Different input and output IDs
assert!(!plan.matches(&[20, 21, 22], output_ids));
assert!(!plan.matches(input_ids, &[20, 21, 22]));
assert!(!plan.matches(&[20, 21, 22].map(NodeId::from_u32), output_ids));
assert!(!plan.matches(input_ids, &[20, 21, 22].map(NodeId::from_u32)));
}

/// A trivial control flow operator which just forwards inputs to a subgraph
Expand Down
38 changes: 17 additions & 21 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,20 +356,20 @@ impl Model {

let input_ids: Vec<NodeId> = serialized_graph
.inputs()
.map(|ids| ids.iter().map(|id| id as NodeId).collect())
.map(|ids| ids.iter().map(NodeId::from_u32).collect())
.unwrap_or_default();

let output_ids: Vec<NodeId> = serialized_graph
.outputs()
.map(|ids| ids.iter().map(|id| id as NodeId).collect())
.map(|ids| ids.iter().map(NodeId::from_u32).collect())
.unwrap_or_default();

let mut graph = Graph::with_capacity(node_count);
graph.set_input_ids(&input_ids);
graph.set_output_ids(&output_ids);

if let Some(captures) = serialized_graph.captures() {
let captures: Vec<NodeId> = captures.iter().map(|id| id as NodeId).collect();
let captures: Vec<NodeId> = captures.iter().map(NodeId::from_u32).collect();
graph.set_captures(&captures);
}

Expand Down Expand Up @@ -839,7 +839,7 @@ mod tests {
use rten_tensor::prelude::*;
use rten_tensor::Tensor;

use crate::graph::{Dimension, RunError};
use crate::graph::{Dimension, NodeId, RunError};
use crate::model::{Model, ModelOptions};
use crate::model_builder::{
GraphBuilder, IfArgs, MetadataArgs, ModelBuilder, ModelFormat, OpType,
Expand Down Expand Up @@ -1147,7 +1147,7 @@ mod tests {
.load(buffer)
.unwrap();

let result = model.run(vec![], &[output_node as usize], None);
let result = model.run(vec![], &[output_node], None);

assert_eq!(
result.err(),
Expand Down Expand Up @@ -1181,7 +1181,7 @@ mod tests {
let mut op_outputs = Vec::new();

let mut add_operator =
|builder: &mut GraphBuilder, name: &str, op: OpType, input_nodes: &[Option<u32>]| {
|builder: &mut GraphBuilder, name: &str, op: OpType, input_nodes: &[Option<NodeId>]| {
let output_name = format!("{}_out", name);
let op_output_node = builder.add_value(&output_name, None);
builder.add_operator(name, op, input_nodes, &[op_output_node]);
Expand Down Expand Up @@ -1605,8 +1605,8 @@ mod tests {
let result = model
.run(
vec![
(input_node as usize, input.view().into()),
(input_bool as usize, input_bool_data.view().into()),
(input_node, input.view().into()),
(input_bool, input_bool_data.view().into()),
],
&[output_id],
None,
Expand All @@ -1629,11 +1629,7 @@ mod tests {
for output in outputs {
let output_id = model.find_node(output).unwrap();
let result = model
.run(
vec![(input_2d as usize, input.view().into())],
&[output_id],
None,
)
.run(vec![(input_2d, input.view().into())], &[output_id], None)
.unwrap();
assert_eq!(result.len(), 1);
}
Expand All @@ -1645,11 +1641,11 @@ mod tests {
let result = model
.run(
vec![
(range_start_node as usize, start.into()),
(range_limit_node as usize, limit.into()),
(range_delta_node as usize, delta.into()),
(range_start_node, start.into()),
(range_limit_node, limit.into()),
(range_delta_node, delta.into()),
],
&[range_out as usize],
&[range_out],
None,
)
.unwrap();
Expand All @@ -1662,11 +1658,11 @@ mod tests {
let result = model
.run(
vec![
(where_cond as usize, cond.into()),
(where_x as usize, x.into()),
(where_y as usize, y.into()),
(where_cond, cond.into()),
(where_x, x.into()),
(where_y, y.into()),
],
&[where_out as usize],
&[where_out],
None,
)
.unwrap();
Expand Down
Loading

0 comments on commit d7a3490

Please sign in to comment.