Skip to content

Commit

Permalink
refactor: NodeType constructors, adding new_auto (#635)
Browse files Browse the repository at this point in the history
* Rename NodeType::open_extensions to NodeType::new_open
* Rename NodeType::pure to NodeType::new_pure
* Add NodeType::new_auto, which uses Pure for module-ops and Open for
others
* Remove special-case in infer.rs solving some module-ops to empty set
* Switch builder/HugrMut methods from new_open to new_auto
  • Loading branch information
acl-cqc authored Nov 6, 2023
1 parent 1a07cd9 commit b3062e9
Show file tree
Hide file tree
Showing 14 changed files with 65 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ pub(crate) mod test {
/// inference. Using DFGBuilder will default to a root node with an open
/// extension variable
pub(crate) fn closed_dfg_root_hugr(signature: FunctionType) -> Hugr {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: signature.clone(),
}));
hugr.add_op_with_parent(
Expand Down
4 changes: 2 additions & 2 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pub trait Dataflow: Container {
op: impl Into<OpType>,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
self.add_dataflow_node(NodeType::open_extensions(op), input_wires)
self.add_dataflow_node(NodeType::new_auto(op), input_wires)
}

/// Add a dataflow [`NodeType`] to the sibling graph, wiring up the `input_wires` to the
Expand Down Expand Up @@ -628,7 +628,7 @@ fn add_op_with_wires<T: Dataflow + ?Sized>(
optype: impl Into<OpType>,
inputs: Vec<Wire>,
) -> Result<(Node, usize), BuildError> {
add_node_with_wires(data_builder, NodeType::open_extensions(optype), inputs)
add_node_with_wires(data_builder, NodeType::new_auto(optype), inputs)
}

fn add_node_with_wires<T: Dataflow + ?Sized>(
Expand Down
2 changes: 1 addition & 1 deletion src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl CFGBuilder<Hugr> {
signature: signature.clone(),
};

let base = Hugr::new(NodeType::open_extensions(cfg_op));
let base = Hugr::new(NodeType::new_open(cfg_op));
let cfg_node = base.root();
CFGBuilder::create(base, cfg_node, signature.input, signature.output)
}
Expand Down
4 changes: 2 additions & 2 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl ConditionalBuilder<Hugr> {
extension_delta,
};
// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(op));
let base = Hugr::new(NodeType::new_open(op));
let conditional_node = base.root();

Ok(ConditionalBuilder {
Expand All @@ -194,7 +194,7 @@ impl CaseBuilder<Hugr> {
let op = ops::Case {
signature: signature.clone(),
};
let base = Hugr::new(NodeType::open_extensions(op));
let base = Hugr::new(NodeType::new_open(op));
let root = base.root();
let dfg_builder = DFGBuilder::create_with_io(base, root, signature, None)?;

Expand Down
2 changes: 1 addition & 1 deletion src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl DFGBuilder<Hugr> {
let dfg_op = ops::DFG {
signature: signature.clone(),
};
let base = Hugr::new(NodeType::open_extensions(dfg_op));
let base = Hugr::new(NodeType::new_open(dfg_op));
let root = base.root();
DFGBuilder::create_with_io(base, root, signature, None)
}
Expand Down
2 changes: 1 addition & 1 deletion src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
};
self.hugr_mut().replace_op(
f_node,
NodeType::pure(ops::FuncDefn {
NodeType::new_pure(ops::FuncDefn {
name,
signature: signature.clone(),
}),
Expand Down
2 changes: 1 addition & 1 deletion src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl TailLoopBuilder<Hugr> {
rest: inputs_outputs.into(),
};
// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(tail_loop.clone()));
let base = Hugr::new(NodeType::new_open(tail_loop.clone()));
let root = base.root();
Self::create_with_io(base, root, &tail_loop)
}
Expand Down
24 changes: 9 additions & 15 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,6 @@ impl UnificationContext {
m_output,
node_type.op_signature().extension_reqs,
);
if matches!(
node_type.tag(),
OpTag::Alias | OpTag::Function | OpTag::FuncDefn
) {
self.add_solution(m_input, ExtensionSet::new());
}
}
// We have a solution for everything!
Some(sig) => {
Expand Down Expand Up @@ -723,7 +717,7 @@ mod test {
signature: main_sig,
};

let root_node = NodeType::open_extensions(op);
let root_node = NodeType::new_open(op);
let mut hugr = Hugr::new(root_node);

let input = ops::Input::new(type_row![NAT, NAT]);
Expand Down Expand Up @@ -833,21 +827,21 @@ mod test {
// This generates a solution that causes validation to fail
// because of a missing lift node
fn missing_lift_node() -> Result<(), Box<dyn Error>> {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A)),
}));

let input = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Input {
NodeType::new_pure(ops::Input {
types: type_row![NAT],
}),
)?;

let output = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Output {
NodeType::new_pure(ops::Output {
types: type_row![NAT],
}),
)?;
Expand Down Expand Up @@ -1049,7 +1043,7 @@ mod test {
extension_delta: rs.clone(),
};

let mut hugr = Hugr::new(NodeType::pure(op));
let mut hugr = Hugr::new(NodeType::new_pure(op));
let conditional_node = hugr.root();

let case_op = ops::Case {
Expand Down Expand Up @@ -1084,7 +1078,7 @@ mod test {
fn extension_adding_sequence() -> Result<(), Box<dyn Error>> {
let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::DFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::DFG {
signature: df_sig
.clone()
.with_extension_delta(&ExtensionSet::from_iter([A, B])),
Expand Down Expand Up @@ -1255,7 +1249,7 @@ mod test {
let b = ExtensionSet::singleton(&B);
let c = ExtensionSet::singleton(&C);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc),
}));

Expand Down Expand Up @@ -1353,7 +1347,7 @@ mod test {
/// +--------------------+
#[test]
fn multi_entry() -> Result<(), Box<dyn Error>> {
let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions?
}));
let cfg = hugr.root();
Expand Down Expand Up @@ -1436,7 +1430,7 @@ mod test {
) -> Result<Hugr, Box<dyn Error>> {
let hugr_delta = entry_ext.clone().union(&bb1_ext).union(&bb2_ext);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&hugr_delta),
}));
Expand Down
21 changes: 15 additions & 6 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl NodeType {
}

/// Instantiate an OpType with no input extensions
pub fn pure(op: impl Into<OpType>) -> Self {
pub fn new_pure(op: impl Into<OpType>) -> Self {
NodeType {
op: op.into(),
input_extensions: Some(ExtensionSet::new()),
Expand All @@ -91,13 +91,24 @@ impl NodeType {

/// Instantiate an OpType with an unknown set of input extensions
/// (to be inferred later)
pub fn open_extensions(op: impl Into<OpType>) -> Self {
pub fn new_open(op: impl Into<OpType>) -> Self {
NodeType {
op: op.into(),
input_extensions: None,
}
}

/// Instantiate an [OpType] with the default set of input extensions
/// for that OpType.
pub fn new_auto(op: impl Into<OpType>) -> Self {
let op = op.into();
if OpTag::ModuleOp.is_superset(op.tag()) {
Self::new_pure(op)
} else {
Self::new_open(op)
}
}

/// Use the input extensions to calculate the concrete signature of the node
pub fn signature(&self) -> Option<Signature> {
self.input_extensions
Expand All @@ -119,9 +130,7 @@ impl NodeType {
pub fn input_extensions(&self) -> Option<&ExtensionSet> {
self.input_extensions.as_ref()
}
}

impl NodeType {
/// Gets the underlying [OpType] i.e. without any [input_extensions]
///
/// [input_extensions]: NodeType::input_extensions
Expand Down Expand Up @@ -153,7 +162,7 @@ impl OpType {

impl Default for Hugr {
fn default() -> Self {
Self::new(NodeType::pure(crate::ops::Module))
Self::new(NodeType::new_pure(crate::ops::Module))
}
}

Expand Down Expand Up @@ -239,7 +248,7 @@ impl Hugr {

/// Add a node to the graph, with the default conversion from OpType to NodeType
pub(crate) fn add_op(&mut self, op: impl Into<OpType>) -> Node {
self.add_node(NodeType::open_extensions(op))
self.add_node(NodeType::new_auto(op))
}

/// Add a node to the graph.
Expand Down
6 changes: 3 additions & 3 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub trait HugrMut: HugrMutInternals {
parent: Node,
op: impl Into<OpType>,
) -> Result<Node, HugrError> {
self.add_node_with_parent(parent, NodeType::open_extensions(op))
self.add_node_with_parent(parent, NodeType::new_auto(op))
}

/// Add a node to the graph with a parent in the hierarchy.
Expand Down Expand Up @@ -217,7 +217,7 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
}

fn add_op_before(&mut self, sibling: Node, op: impl Into<OpType>) -> Result<Node, HugrError> {
self.add_node_before(sibling, NodeType::open_extensions(op))
self.add_node_before(sibling, NodeType::new_auto(op))
}

fn add_node_before(&mut self, sibling: Node, nodetype: NodeType) -> Result<Node, HugrError> {
Expand Down Expand Up @@ -620,7 +620,7 @@ mod test {

{
let f_in = hugr
.add_node_with_parent(f, NodeType::pure(ops::Input::new(type_row![NAT])))
.add_node_with_parent(f, NodeType::new_pure(ops::Input::new(type_row![NAT])))
.unwrap();
let f_out = hugr
.add_op_with_parent(f, ops::Output::new(type_row![NAT, NAT]))
Expand Down
9 changes: 3 additions & 6 deletions src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,7 @@ impl TryFrom<SerHugrV0> for Hugr {
for node_ser in nodes {
hugr.add_node_with_parent(
node_ser.parent,
match node_ser.input_extensions {
None => NodeType::open_extensions(node_ser.op),
Some(rs) => NodeType::new(node_ser.op, rs),
},
NodeType::new(node_ser.op, node_ser.input_extensions),
)?;
}

Expand Down Expand Up @@ -332,11 +329,11 @@ pub mod test {
let mut h = Hierarchy::new();
let mut op_types = UnmanagedDenseMap::new();

op_types[root] = NodeType::open_extensions(gen_optype(&g, root));
op_types[root] = NodeType::new_open(gen_optype(&g, root));

for n in [a, b, c] {
h.push_child(n, root).unwrap();
op_types[n] = NodeType::pure(gen_optype(&g, n));
op_types[n] = NodeType::new_pure(gen_optype(&g, n));
}

let hg = Hugr {
Expand Down
Loading

0 comments on commit b3062e9

Please sign in to comment.