From c15b2e1ee87dd8a70db2033120db70fe54ed59de Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Fri, 9 Aug 2024 18:43:08 +0100 Subject: [PATCH 01/26] `hugr-model` import, export and text format draft. --- Cargo.toml | 6 +- hugr-core/Cargo.toml | 9 +- hugr-core/src/builder.rs | 2 +- hugr-core/src/export.rs | 729 ++++++++++++++++++++ hugr-core/src/hugr/ident.rs | 22 + hugr-core/src/import.rs | 987 +++++++++++++++++++++++++++ hugr-core/src/lib.rs | 2 + hugr-core/src/types.rs | 4 +- hugr-core/tests/fixtures/model-1.edn | 38 ++ hugr-core/tests/model.rs | 13 + hugr-model/Cargo.toml | 23 + hugr-model/src/lib.rs | 5 + hugr-model/src/v0/mod.rs | 652 ++++++++++++++++++ hugr-model/src/v0/text/hugr.pest | 99 +++ hugr-model/src/v0/text/mod.rs | 6 + hugr-model/src/v0/text/parse.rs | 682 ++++++++++++++++++ hugr-model/src/v0/text/print.rs | 620 +++++++++++++++++ 17 files changed, 3895 insertions(+), 4 deletions(-) create mode 100644 hugr-core/src/export.rs create mode 100644 hugr-core/src/import.rs create mode 100644 hugr-core/tests/fixtures/model-1.edn create mode 100644 hugr-core/tests/model.rs create mode 100644 hugr-model/Cargo.toml create mode 100644 hugr-model/src/lib.rs create mode 100644 hugr-model/src/v0/mod.rs create mode 100644 hugr-model/src/v0/text/hugr.pest create mode 100644 hugr-model/src/v0/text/mod.rs create mode 100644 hugr-model/src/v0/text/parse.rs create mode 100644 hugr-model/src/v0/text/print.rs diff --git a/Cargo.toml b/Cargo.toml index 788d6df22..e16575caa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ lto = "thin" [workspace] resolver = "2" -members = ["hugr", "hugr-core", "hugr-passes", "hugr-cli"] +members = ["hugr", "hugr-core", "hugr-passes", "hugr-cli", "hugr-model"] [workspace.package] rust-version = "1.75" @@ -62,6 +62,10 @@ clap-verbosity-flag = "2.2.0" assert_cmd = "2.0.14" assert_fs = "1.1.1" predicates = "3.1.0" +tinyvec = { version = "1.8.0", features = ["alloc", "serde"] } +indexmap = "2.3.0" +fxhash = "0.2.1" +bumpalo = { version = "3.16.0", features = ["collections"] } [profile.dev.package] insta.opt-level = 3 diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 40d54b4a8..5732cf85b 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -32,7 +32,7 @@ serde = { workspace = true, features = ["derive", "rc"] } serde_yaml = { workspace = true, optional = true } typetag = { workspace = true } smol_str = { workspace = true, features = ["serde"] } -derive_more = { workspace = true, features=["display", "from"]} +derive_more = { workspace = true, features = ["display", "from"] } itertools = { workspace = true } html-escape = { workspace = true } bitvec = { workspace = true, features = ["serde"] } @@ -46,6 +46,12 @@ paste = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } semver = { version = "1.0.23", features = ["serde"] } +hugr-model = { path = "../hugr-model" } +indexmap.workspace = true +tinyvec.workspace = true +fxhash.workspace = true +ascent = "0.6.0" +bumpalo = { workspace = true } [dev-dependencies] rstest = { workspace = true } @@ -61,3 +67,4 @@ regex-syntax = { workspace = true } # Required for documentation examples hugr = { path = "../hugr" } +serde_yaml = "0.9.34" diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index 38a5334b3..f2743f8eb 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -258,7 +258,7 @@ pub(crate) mod test { dataflow_builder.finish_with_outputs(w) } - pub(super) fn build_main( + pub(crate) fn build_main( signature: PolyFuncType, f: impl FnOnce(FunctionBuilder<&mut Hugr>) -> Result>, BuildError>, ) -> Result { diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs new file mode 100644 index 000000000..d612d4270 --- /dev/null +++ b/hugr-core/src/export.rs @@ -0,0 +1,729 @@ +//! Exporting HUGR graphs to their `hugr-model` representation. +use crate::{ + extension::ExtensionSet, + ops::OpType, + types::{ + type_param::{TypeArgVariable, TypeParam}, + type_row::TypeRowBase, + CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, + TypeArg, TypeBase, TypeEnum, + }, + Direction, Hugr, HugrView, IncomingPort, Node, Port, PortIndex, +}; +use bumpalo::{collections::Vec as BumpVec, Bump}; +use hugr_model::v0::{self as model}; +use indexmap::IndexSet; +use smol_str::ToSmolStr; + +pub(crate) const OP_FUNC_CALL_INDIRECT: &'static str = "func.call-indirect"; +pub(crate) const OP_ADT_TAG: &'static str = "adt.make-tag"; + +const TERM_PARAM_TUPLE: &'static str = "param.tuple"; + +/// Export a [`Hugr`] graph to its representation in the model. +pub fn export_hugr<'a>(hugr: &'a Hugr, bump: &'a Bump) -> model::Module<'a> { + let mut ctx = Context::new(hugr, bump); + ctx.export_root(); + ctx.module +} + +/// State for converting a HUGR graph to its representation in the model. +struct Context<'a> { + /// The HUGR graph to convert. + hugr: &'a Hugr, + /// The module that is being built. + module: model::Module<'a>, + /// Mapping from ports to link indices. + /// This only includes the minimum port among groups of linked ports. + links: IndexSet<(Node, Port)>, + bump: &'a Bump, +} + +impl<'a> Context<'a> { + pub fn new(hugr: &'a Hugr, bump: &'a Bump) -> Self { + // let mut node_to_id = FxHashMap::default(); + // node_to_id.reserve(hugr.node_count()); + + let mut module = model::Module::default(); + module.nodes.reserve(hugr.node_count()); + + Self { + hugr, + module, + bump, + links: IndexSet::new(), + } + } + + pub fn export_root(&mut self) { + let hugr_children = self.hugr.children(self.hugr.root()); + let mut children = BumpVec::with_capacity_in(hugr_children.len(), self.bump); + + for child in self.hugr.children(self.hugr.root()) { + children.push(self.export_node(child)); + } + + let root = self.module.insert_region(model::Region { + kind: model::RegionKind::DataFlow, + sources: &[], + targets: &[], + children: children.into_bump_slice(), + meta: &[], + }); + + self.module.root = root; + } + + /// Returns the edge id for a given port, creating a new edge if necessary. + /// + /// Any two ports that are linked will be represented by the same link. + fn get_link_id(&mut self, node: Node, port: Port) -> model::LinkId { + // To ensure that linked ports are represented by the same edge, we take the minimum port + // among all the linked ports, including the one we started with. + let linked_ports = self.hugr.linked_ports(node, port); + let all_ports = std::iter::once((node, port)).chain(linked_ports); + let repr = all_ports.min().unwrap(); + let edge = self.links.insert_full(repr).0 as _; + model::LinkId(edge) + } + + pub fn make_ports(&mut self, node: Node, direction: Direction) -> &'a [model::Port<'a>] { + let ports = self.hugr.node_ports(node, direction); + let mut model_ports = BumpVec::with_capacity_in(ports.len(), self.bump); + + for port in ports { + if let Some(model_port) = self.make_port(node, port) { + model_ports.push(model_port); + } + } + + model_ports.into_bump_slice() + } + + pub fn make_port(&mut self, node: Node, port: impl Into) -> Option> { + let port: Port = port.into(); + let op_type = self.hugr.get_optype(node); + + let r#type = match op_type.port_kind(port)? { + EdgeKind::ControlFlow => { + // TODO: This should ideally be reported by the op itself + let types: Vec<_> = match (op_type, port.direction()) { + (OpType::DataflowBlock(block), Direction::Incoming) => { + block.inputs.iter().map(|t| self.export_type(t)).collect() + } + (OpType::DataflowBlock(block), Direction::Outgoing) => { + let mut types = Vec::new(); + types.extend( + (&block.sum_rows[port.index()]) + .iter() + .map(|t| self.export_type(t)), + ); + types.extend(block.other_outputs.iter().map(|t| self.export_type(t))); + types + } + (OpType::ExitBlock(block), Direction::Incoming) => block + .cfg_outputs + .iter() + .map(|t| self.export_type(t)) + .collect(), + (OpType::ExitBlock(_), Direction::Outgoing) => vec![], + _ => unreachable!("unexpected control flow port on non-control-flow op"), + }; + + let types = self.bump.alloc_slice_copy(&types); + let values = self.module.insert_term(model::Term::List { + items: types, + tail: None, + }); + self.module.insert_term(model::Term::Control { values }) + } + EdgeKind::Value(r#type) => self.export_type(&r#type), + EdgeKind::Const(_) => return None, + EdgeKind::Function(_) => return None, + EdgeKind::StateOrder => return None, + }; + + let link = model::LinkRef::Id(self.get_link_id(node, port)); + + Some(model::Port { + r#type: Some(r#type), + link, + meta: &[], + }) + } + + /// Get the node that declares or defines the function associated with the given + /// node via the static input. Returns `None` if the node is not connected to a function. + fn connected_function(&self, node: Node) -> Option { + let func_node = self.hugr.static_source(node)?; + + match self.hugr.get_optype(func_node) { + OpType::FuncDecl(_) => Some(func_node), + OpType::FuncDefn(_) => Some(func_node), + _ => None, + } + } + + /// Get the name of a function definition or declaration node. Returns `None` if not + /// one of those operations. + fn get_func_name(&self, func_node: Node) -> Option<&'a str> { + match self.hugr.get_optype(func_node) { + OpType::FuncDecl(func_decl) => Some(&func_decl.name), + OpType::FuncDefn(func_defn) => Some(&func_defn.name), + _ => None, + } + } + + pub fn export_node(&mut self, node: Node) -> model::NodeId { + let inputs = self.make_ports(node, Direction::Incoming); + let outputs = self.make_ports(node, Direction::Outgoing); + let mut params: &[_] = &[]; + let mut regions: &[_] = &[]; + + fn make_custom(name: &'static str) -> model::Operation { + model::Operation::Custom { + name: model::GlobalRef::Named(name), + } + } + + let operation = match self.hugr.get_optype(node) { + OpType::Module(_) => todo!("this should be an error"), + + OpType::Input(_) => { + panic!("input nodes should have been handled by the region export") + } + + OpType::Output(_) => { + panic!("output nodes should have been handled by the region export") + } + + OpType::DFG(_) => { + regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); + model::Operation::Dfg + } + + OpType::CFG(_) => { + regions = self.bump.alloc_slice_copy(&[self.export_cfg(node)]); + model::Operation::Cfg + } + + OpType::ExitBlock(_) => { + panic!("exit blocks should have been handled by the region export") + } + + OpType::Case(_) => { + todo!("case nodes should have been handled by the region export") + } + + OpType::DataflowBlock(_) => { + regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); + model::Operation::Block + } + + OpType::FuncDefn(func) => { + let name = self.get_func_name(node).unwrap(); + let (params, func) = self.export_poly_func_type(&func.signature); + let decl = self.bump.alloc(model::FuncDecl { name, params, func }); + regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); + model::Operation::DefineFunc { decl } + } + + OpType::FuncDecl(func) => { + let name = self.get_func_name(node).unwrap(); + let (params, func) = self.export_poly_func_type(&func.signature); + let decl = self.bump.alloc(model::FuncDecl { name, params, func }); + model::Operation::DeclareFunc { decl } + } + + OpType::AliasDecl(alias) => { + // TODO: We should support aliases with different types and with parameters + let r#type = self.module.insert_term(model::Term::Type); + let decl = self.bump.alloc(model::AliasDecl { + name: &alias.name, + params: &[], + r#type, + }); + model::Operation::DeclareAlias { decl } + } + + OpType::AliasDefn(alias) => { + let value = self.export_type(&alias.definition); + // TODO: We should support aliases with different types and with parameters + let r#type = self.module.insert_term(model::Term::Type); + let decl = self.bump.alloc(model::AliasDecl { + name: &alias.name, + params: &[], + r#type, + }); + model::Operation::DefineAlias { decl, value } + } + + OpType::Call(call) => { + // TODO: If the node is not connected to a function, we should do better than panic. + let node = self.connected_function(node).unwrap(); + let name = model::GlobalRef::Named(self.get_func_name(node).unwrap()); + + let mut args = BumpVec::new_in(self.bump); + args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg))); + let args = args.into_bump_slice(); + + let func = self + .module + .insert_term(model::Term::ApplyFull { name, args }); + model::Operation::CallFunc { func } + } + + OpType::LoadFunction(load) => { + // TODO: If the node is not connected to a function, we should do better than panic. + let node = self.connected_function(node).unwrap(); + let name = model::GlobalRef::Named(self.get_func_name(node).unwrap()); + + let mut args = BumpVec::new_in(self.bump); + args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg))); + let args = args.into_bump_slice(); + + let func = self + .module + .insert_term(model::Term::ApplyFull { name, args }); + model::Operation::LoadFunc { func } + } + + OpType::Const(_) => todo!("Export const nodes?"), + OpType::LoadConstant(_) => todo!("Export load constant?"), + + OpType::CallIndirect(_) => make_custom(OP_FUNC_CALL_INDIRECT), + + OpType::Tag(_) => make_custom(OP_ADT_TAG), + + OpType::TailLoop(op) => { + regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); + model::Operation::TailLoop { + inputs: self.export_type_row(&op.just_inputs), + outputs: self.export_type_row(&op.just_outputs), + rest: self.export_type_row(&op.rest), + extensions: self.export_ext_set(&op.extension_delta), + } + } + + OpType::Conditional(op) => { + let mut types = BumpVec::new_in(self.bump); + types.extend(op.sum_rows.iter().map(|l| self.export_type_row(l))); + let types = types.into_bump_slice(); + let sum_rows = model::Term::List { + items: &types, + tail: None, + }; + regions = self.export_conditional_regions(node); + model::Operation::Conditional { + cases: self.module.insert_term(sum_rows), + context: self.export_type_row(&op.other_inputs), + outputs: self.export_type_row(&op.outputs), + extensions: self.export_ext_set(&op.extension_delta), + } + } + + // Opaque/extension operations should in the future support having multiple different + // regions of potentially different kinds. At the moment, we check if the node has any + // children, in which case we create a dataflow region with those children. + OpType::ExtensionOp(op) => { + let name = + self.bump + .alloc_str(&format!("{}.{}", op.def().extension(), op.def().name())); + let name = model::GlobalRef::Named(name); + + params = self + .bump + .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); + + if let Some(region) = self.export_dfg_if_present(node) { + regions = self.bump.alloc_slice_copy(&[region]); + } + + model::Operation::Custom { name } + } + + OpType::OpaqueOp(op) => { + let name = self + .bump + .alloc_str(&format!("{}.{}", op.extension(), op.op_name())); + let name = model::GlobalRef::Named(name); + + params = self + .bump + .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); + + if let Some(region) = self.export_dfg_if_present(node) { + regions = self.bump.alloc_slice_copy(&[region]); + } + + model::Operation::Custom { name } + } + }; + + self.module.insert_node(model::Node { + operation, + inputs, + outputs, + params, + regions, + meta: &[], + }) + } + + /// Create a region from the given node's children, if it has any. + /// + /// See [`Self::export_dfg`]. + pub fn export_dfg_if_present(&mut self, node: Node) -> Option { + if self.hugr.children(node).next().is_none() { + None + } else { + Some(self.export_dfg(node)) + } + } + + /// Creates a data flow region from the given node's children. + /// + /// `Input` and `Output` nodes are used to determine the source and target ports of the region. + pub fn export_dfg(&mut self, node: Node) -> model::RegionId { + let mut children = self.hugr.children(node); + + // The first child is an `Input` node, which we use to determine the region's sources. + let input_node = children.next().unwrap(); + assert!(matches!(self.hugr.get_optype(input_node), OpType::Input(_))); + let sources = self.make_ports(input_node, Direction::Outgoing); + + // The second child is an `Output` node, which we use to determine the region's targets. + let output_node = children.next().unwrap(); + assert!(matches!( + self.hugr.get_optype(output_node), + OpType::Output(_) + )); + let targets = self.make_ports(output_node, Direction::Incoming); + + // Export the remaining children of the node. + let mut region_children = BumpVec::with_capacity_in(children.len(), self.bump); + + for child in children { + region_children.push(self.export_node(child)); + } + + self.module.insert_region(model::Region { + kind: model::RegionKind::DataFlow, + sources, + targets, + children: region_children.into_bump_slice(), + meta: &[], + }) + } + + /// Creates a control flow region from the given node's children. + pub fn export_cfg(&mut self, node: Node) -> model::RegionId { + let mut children = self.hugr.children(node); + + // The first child is the entry block. + // The entry block does have a dataflow subgraph, so we must still export it later. + // We create a source port on the control flow region and connect it to the + // first input port of the exported entry block. + let entry_block = children.next().unwrap(); + + assert!(matches!( + self.hugr.get_optype(entry_block), + OpType::DataflowBlock(_) + )); + + let source = self.make_port(entry_block, IncomingPort::from(0)).unwrap(); + + // The second child is the exit block. + // Contrary to the entry block, the exit block does not have a dataflow subgraph. + // We therefore do not export the block itself, but simply use its output ports + // as the target ports of the control flow region. + let exit_block = children.next().unwrap(); + + assert!(matches!( + self.hugr.get_optype(exit_block), + OpType::ExitBlock(_) + )); + + let targets = self.make_ports(exit_block, Direction::Incoming); + + // Now we export the child nodes, including the entry block. + let mut region_children = BumpVec::with_capacity_in(children.len() + 1, self.bump); + + region_children.push(self.export_node(entry_block)); + for child in children { + region_children.push(self.export_node(child)); + } + + self.module.insert_region(model::Region { + kind: model::RegionKind::DataFlow, + sources: self.bump.alloc_slice_copy(&[source]), + targets, + children: region_children.into_bump_slice(), + meta: &[], + }) + } + + /// Export the `Case` node children of a `Conditional` node as data flow regions. + pub fn export_conditional_regions(&mut self, node: Node) -> &'a [model::RegionId] { + let children = self.hugr.children(node); + let mut regions = BumpVec::with_capacity_in(children.len(), self.bump); + + for child in children { + assert!(matches!(self.hugr.get_optype(child), OpType::Case(_))); + regions.push(self.export_dfg(child)); + } + + regions.into_bump_slice() + } + + pub fn export_poly_func_type( + &mut self, + t: &PolyFuncTypeBase, + ) -> (&'a [model::Param<'a>], model::TermId) { + let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); + + for (i, param) in t.params().iter().enumerate() { + let name = self.bump.alloc_str(&i.to_string()); + let r#type = self.export_type_param(param); + let param = model::Param::Implicit { name, r#type }; + params.push(param) + } + + let body = self.export_func_type(t.body()); + + (params.into_bump_slice(), body) + } + + pub fn export_type(&mut self, t: &TypeBase) -> model::TermId { + self.export_type_enum(t.as_type_enum()) + } + + pub fn export_type_enum(&mut self, t: &TypeEnum) -> model::TermId { + match t { + TypeEnum::Extension(ext) => self.export_custom_type(ext), + TypeEnum::Alias(alias) => { + let name = model::GlobalRef::Named(self.bump.alloc_str(alias.name())); + let args = &[]; + self.module + .insert_term(model::Term::ApplyFull { name, args }) + } + TypeEnum::Function(func) => self.export_func_type(func), + TypeEnum::Variable(index, _) => { + // This ignores the type bound for now + self.module + .insert_term(model::Term::Var(model::LocalRef::Index(*index as _))) + } + TypeEnum::RowVar(rv) => self.export_row_var(rv.as_rv()), + TypeEnum::Sum(sum) => self.export_sum_type(sum), + } + } + + pub fn export_func_type(&mut self, t: &FuncTypeBase) -> model::TermId { + let inputs = self.export_type_row(t.input()); + let outputs = self.export_type_row(t.output()); + let extensions = self.export_ext_set(&t.extension_reqs); + self.module.insert_term(model::Term::FuncType { + inputs, + outputs, + extensions, + }) + } + + pub fn export_custom_type(&mut self, t: &CustomType) -> model::TermId { + let name = format!("{}.{}", t.extension(), t.name()); + let name = model::GlobalRef::Named(self.bump.alloc_str(&name)); + + let args = self + .bump + .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p))); + let term = model::Term::ApplyFull { name, args }; + self.module.insert_term(term) + } + + pub fn export_type_arg(&mut self, t: &TypeArg) -> model::TermId { + match t { + TypeArg::Type { ty } => self.export_type(ty), + TypeArg::BoundedNat { n } => self.module.insert_term(model::Term::Nat(*n)), + TypeArg::String { arg } => self.module.insert_term(model::Term::Str(arg.into())), + TypeArg::Sequence { elems } => { + // For now we assume that the sequence is meant to be a list. + let items = self + .bump + .alloc_slice_fill_iter(elems.iter().map(|elem| self.export_type_arg(elem))); + self.module + .insert_term(model::Term::List { items, tail: None }) + } + TypeArg::Extensions { es } => self.export_ext_set(es), + TypeArg::Variable { v } => self.export_type_arg_var(v), + } + } + + pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> model::TermId { + self.module + .insert_term(model::Term::Var(model::LocalRef::Index(var.index() as _))) + } + + pub fn export_row_var(&mut self, t: &RowVariable) -> model::TermId { + self.module + .insert_term(model::Term::Var(model::LocalRef::Index(t.0 as _))) + } + + pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId { + match t { + SumType::Unit { size } => { + let items = self.bump.alloc_slice_fill_iter((0..*size).map(|_| { + self.module.insert_term(model::Term::List { + items: &[], + tail: None, + }) + })); + let list = model::Term::List { items, tail: None }; + let variants = self.module.insert_term(list); + self.module.insert_term(model::Term::Adt { variants }) + } + SumType::General { rows } => { + let items = self + .bump + .alloc_slice_fill_iter(rows.iter().map(|row| self.export_type_row(row))); + let list = model::Term::List { items, tail: None }; + let variants = { self.module.insert_term(list) }; + self.module.insert_term(model::Term::Adt { variants }) + } + } + } + + pub fn export_type_row(&mut self, t: &TypeRowBase) -> model::TermId { + let mut items = BumpVec::with_capacity_in(t.len(), self.bump); + items.extend(t.iter().map(|row| self.export_type(row))); + let items = items.into_bump_slice(); + self.module + .insert_term(model::Term::List { items, tail: None }) + } + + pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId { + match t { + // This ignores the type bound for now. + TypeParam::Type { .. } => self.module.insert_term(model::Term::Type), + // This ignores the type bound for now. + TypeParam::BoundedNat { .. } => self.module.insert_term(model::Term::NatType), + TypeParam::String => self.module.insert_term(model::Term::StrType), + TypeParam::List { param } => { + let item_type = self.export_type_param(param); + self.module.insert_term(model::Term::ListType { item_type }) + } + TypeParam::Tuple { params } => { + let items = self.bump.alloc_slice_fill_iter( + params.iter().map(|param| self.export_type_param(param)), + ); + let types = self + .module + .insert_term(model::Term::List { items, tail: None }); + self.module.insert_term(model::Term::ApplyFull { + name: model::GlobalRef::Named(TERM_PARAM_TUPLE), + args: self.bump.alloc_slice_copy(&[types]), + }) + } + TypeParam::Extensions => { + let term = model::Term::ExtSetType; + self.module.insert_term(term) + } + } + } + + pub fn export_ext_set(&mut self, t: &ExtensionSet) -> model::TermId { + // Extension sets with variables are encoded using a hack: a variable in the + // extension set is represented by converting its index into a string. + // Until we have a better representation for extension sets, we therefore + // need to try and parse each extension as a number to determine if it is + // a variable or an extension. + let mut extensions = Vec::new(); + let mut variables = Vec::new(); + + for ext in t.iter() { + if let Ok(index) = ext.parse::() { + variables.push({ + self.module + .insert_term(model::Term::Var(model::LocalRef::Index(index as _))) + }); + } else { + extensions.push(ext.to_smolstr()); + } + } + + // Extension sets in the model support at most one variable. This is a + // deliberate limitation so that extension sets behave like polymorphic rows. + // The type theory of such rows and how to apply them to model (co)effects + // is well understood. + // + // Extension sets in `hugr-core` at this point have no such restriction. + // However, it appears that so far we never actually use extension sets with + // multiple variables, except for extension sets that are generated through + // property testing. + let rest = match variables.as_slice() { + [] => None, + [var] => Some(*var), + _ => { + // TODO: We won't need this anymore once we have a core representation + + // that ensures that extension sets have at most one variable. + panic!("Extension set with multiple variables") + } + }; + + let mut extensions = BumpVec::with_capacity_in(extensions.len(), self.bump); + extensions.extend(t.iter().map(|ext| self.bump.alloc_str(ext) as &str)); + let extensions = extensions.into_bump_slice(); + + self.module + .insert_term(model::Term::ExtSet { extensions, rest }) + } +} + +#[cfg(test)] +mod test { + use rstest::{fixture, rstest}; + + use crate::{ + builder::{Dataflow, DataflowSubContainer}, + extension::prelude::QB_T, + std_extensions::arithmetic::float_types, + type_row, + types::Signature, + utils::test_quantum_extension::{self, cx_gate, h_gate}, + Hugr, + }; + + #[fixture] + fn test_simple_circuit() -> Hugr { + crate::builder::test::build_main( + Signature::new_endo(type_row![QB_T, QB_T]) + .with_extension_delta(test_quantum_extension::EXTENSION_ID) + .with_extension_delta(float_types::EXTENSION_ID) + .into(), + |mut f_build| { + let wires: Vec<_> = f_build.input_wires().collect(); + let mut linear = f_build.as_circuit(wires); + + assert_eq!(linear.n_wires(), 2); + + linear + .append(h_gate(), [0])? + .append(cx_gate(), [0, 1])? + .append(cx_gate(), [1, 0])?; + + let outs = linear.finish(); + f_build.finish_with_outputs(outs) + }, + ) + .unwrap() + } + + #[rstest] + #[case(test_simple_circuit())] + fn test_export(#[case] hugr: Hugr) { + use bumpalo::Bump; + let bump = Bump::new(); + let _model = super::export_hugr(&hugr, &bump); + // TODO check the model + } +} diff --git a/hugr-core/src/hugr/ident.rs b/hugr-core/src/hugr/ident.rs index f818e2af6..105131d49 100644 --- a/hugr-core/src/hugr/ident.rs +++ b/hugr-core/src/hugr/ident.rs @@ -39,6 +39,28 @@ impl IdentList { } } + /// Split off the last component of the path, returning the prefix and suffix. + /// + /// # Example + /// + /// ``` + /// # use hugr_core::hugr::IdentList; + /// assert_eq!( + /// IdentList::new("foo.bar.baz").unwrap().split_last(), + /// Some((IdentList::new_unchecked("foo.bar"), "baz".into())) + /// ); + /// assert_eq!( + /// IdentList::new("foo").unwrap().split_last(), + /// None + /// ); + /// ``` + pub fn split_last(&self) -> Option<(IdentList, SmolStr)> { + let (prefix, suffix) = self.0.rsplit_once('.')?; + let prefix = Self::new_unchecked(prefix); + let suffix = suffix.into(); + Some((prefix, suffix)) + } + /// Create a new [IdentList] *without* doing the well-formedness check. /// This is a backdoor to be used sparingly, as we rely upon callers to /// validate names themselves. In tests, instead the [crate::const_extension_ids] diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs new file mode 100644 index 000000000..a09803b2d --- /dev/null +++ b/hugr-core/src/import.rs @@ -0,0 +1,987 @@ +//! Importing HUGR graphs from their `hugr-model` representation. +//! +//! **Warning**: This module is still under development and is expected to change. +//! It is included in the library to allow for early experimentation, and for +//! the core and model to converge incrementally. +use crate::{ + export::OP_FUNC_CALL_INDIRECT, + extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError}, + hugr::{HugrMut, IdentList}, + ops::{ + AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, FuncDecl, FuncDefn, Input, + LoadFunction, Module, OpType, OpaqueOp, Output, TailLoop, DFG, + }, + types::{ + type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, NoRV, + PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, TypeArg, TypeBase, TypeBound, + TypeRow, + }, + Direction, Hugr, HugrView, Node, Port, +}; +use fxhash::FxHashMap; +use hugr_model::v0::{self as model, GlobalRef}; +use indexmap::IndexMap; +use itertools::Either; +use smol_str::{SmolStr, ToSmolStr}; +use thiserror::Error; + +/// Error during import. +#[derive(Debug, Clone, Error)] +pub enum ImportError { + /// The model contains a feature that is not supported by the importer yet. + /// Errors of this kind are expected to be removed as the model format and + /// the core HUGR representation converge. + #[error("currently unsupported: {0}")] + Unsupported(String), + /// The model contains implicit information that has not yet been inferred. + /// This includes wildcards and application of functions with implicit parameters. + #[error("uninferred implicit: {0}")] + Uninferred(String), + /// A signature mismatch was detected during import. + #[error("signature error: {0}")] + Signature(#[from] SignatureError), + /// The model is not well-formed. + #[error("validate error: {0}")] + Model(#[from] model::ModelError), +} + +/// Helper macro to create an `ImportError::Unsupported` error with a formatted message. +macro_rules! error_unsupported { + ($($e:expr),*) => { ImportError::Unsupported(format!($($e),*)) } +} + +/// Helper macro to create an `ImportError::Uninferred` error with a formatted message. +macro_rules! error_uninferred { + ($($e:expr),*) => { ImportError::Uninferred(format!($($e),*)) } +} + +/// Import a `hugr` module from its model representation. +pub fn import_hugr( + module: &model::Module, + extensions: &ExtensionRegistry, +) -> Result { + let names = Names::new(module)?; + + // TODO: Module should know about the number of edges, so that we can use a vector here. + // For now we use a hashmap, which will be slower. + let edge_ports = FxHashMap::default(); + + let mut ctx = Context { + module, + names, + hugr: Hugr::new(OpType::Module(Module {})), + link_ports: edge_ports, + static_edges: Vec::new(), + extensions, + nodes: FxHashMap::default(), + local_variables: IndexMap::default(), + }; + + ctx.import_root()?; + ctx.link_ports()?; + ctx.link_static_ports()?; + + Ok(ctx.hugr) +} + +struct Context<'a> { + /// The module being imported. + module: &'a model::Module<'a>, + + names: Names<'a>, + + /// The HUGR graph being constructed. + hugr: Hugr, + + /// The ports that are part of each link. This is used to connect the ports at the end of the + /// import process. + link_ports: FxHashMap, Vec<(Node, Port)>>, + + /// Pairs of nodes that should be connected by a static edge. + /// These are collected during the import process and connected at the end. + static_edges: Vec<(model::NodeId, model::NodeId)>, + + // /// The `(Node, Port)` pairs for each `PortId` in the module. + // imported_ports: Vec>, + /// The ambient extension registry to use for importing. + extensions: &'a ExtensionRegistry, + + /// A map from `NodeId` to the imported `Node`. + nodes: FxHashMap, + + /// The types of the local variables that are currently in scope. + local_variables: IndexMap<&'a str, model::TermId>, +} + +impl<'a> Context<'a> { + fn get_port_types(&mut self, ports: &[model::Port]) -> Result { + let types = ports + .iter() + .map(|port| match port.r#type { + Some(r#type) => self.import_type(r#type), + None => return Err(error_uninferred!("port type")), + }) + .collect::, _>>()?; + + Ok(types.into()) + } + + /// Get the signature of the node with the given `NodeId`, using the type information + /// attached to the node's ports in the module. + fn get_node_signature(&mut self, node: model::NodeId) -> Result { + let node = self.get_node(node)?; + let inputs = self.get_port_types(node.inputs)?; + let outputs = self.get_port_types(node.outputs)?; + // This creates a signature with empty extension set. + Ok(Signature::new(inputs, outputs)) + } + + /// Get the node with the given `NodeId`, or return an error if it does not exist. + #[inline] + fn get_node(&self, node_id: model::NodeId) -> Result<&'a model::Node<'a>, ImportError> { + self.module + .get_node(node_id) + .ok_or_else(|| model::ModelError::NodeNotFound(node_id).into()) + } + + /// Get the term with the given `TermId`, or return an error if it does not exist. + #[inline] + fn get_term(&self, term_id: model::TermId) -> Result<&'a model::Term<'a>, ImportError> { + self.module + .get_term(term_id) + .ok_or_else(|| model::ModelError::TermNotFound(term_id).into()) + } + + /// Get the region with the given `RegionId`, or return an error if it does not exist. + #[inline] + fn get_region(&self, region_id: model::RegionId) -> Result<&'a model::Region<'a>, ImportError> { + self.module + .get_region(region_id) + .ok_or_else(|| model::ModelError::RegionNotFound(region_id).into()) + } + + /// Looks up a [`LocalRef`] within the current scope and returns its index and type. + fn resolve_local_ref( + &self, + local_ref: &model::LocalRef, + ) -> Result<(usize, model::TermId), ImportError> { + let term = match local_ref { + model::LocalRef::Index(index) => self + .local_variables + .get_index(*index as usize) + .map(|(_, term)| (*index as usize, *term)), + model::LocalRef::Named(name) => self + .local_variables + .get_full(name) + .map(|(index, _, term)| (index, *term)), + }; + + term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into()) + } + + fn make_node( + &mut self, + node_id: model::NodeId, + op: OpType, + parent: Node, + ) -> Result { + let node = self.hugr.add_node_with_parent(parent, op); + self.nodes.insert(node_id, node); + + let node_data = self.get_node(node_id)?; + self.record_links(node, Direction::Incoming, node_data.inputs); + self.record_links(node, Direction::Outgoing, node_data.outputs); + Ok(node) + } + + /// Associate links with the ports of the given node in the given direction. + fn record_links(&mut self, node: Node, direction: Direction, ports: &'a [model::Port]) { + let optype = self.hugr.get_optype(node); + let port_count = optype.port_count(direction); + assert!(ports.len() <= port_count); + + for (model_port, port) in ports.iter().zip(self.hugr.node_ports(node, direction)) { + self.link_ports + .entry(model_port.link.clone()) + .or_default() + .push((node, port)); + } + } + + fn make_input_node( + &mut self, + parent: Node, + ports: &'a [model::Port], + ) -> Result { + let types = self.get_port_types(ports)?; + let node = self + .hugr + .add_node_with_parent(parent, OpType::Input(Input { types })); + self.record_links(node, Direction::Outgoing, ports); + Ok(node) + } + + fn make_output_node( + &mut self, + parent: Node, + ports: &'a [model::Port], + ) -> Result { + let types = self.get_port_types(ports)?; + let node = self + .hugr + .add_node_with_parent(parent, OpType::Output(Output { types })); + self.record_links(node, Direction::Incoming, ports); + Ok(node) + } + + /// Link up the ports in the hugr graph, according to the connectivity information that + /// has been gathered in the `link_ports` map. + fn link_ports(&mut self) -> Result<(), ImportError> { + // For each edge, we group the ports by their direction. We reuse the `inputs` and + // `outputs` vectors to avoid unnecessary allocations. + let mut inputs = Vec::new(); + let mut outputs = Vec::new(); + + for (link_id, link_ports) in std::mem::take(&mut self.link_ports) { + // Skip the edge if it doesn't have any ports. + if link_ports.is_empty() { + continue; + } + + for (node, port) in link_ports { + match port.as_directed() { + Either::Left(input) => inputs.push((node, input)), + Either::Right(output) => outputs.push((node, output)), + } + } + + if inputs.is_empty() || outputs.is_empty() { + return Err(error_unsupported!( + "link {:?} is missing either an input or an output port", + link_id + )); + } + + // We connect the first output to all the inputs, and the first input to all the outputs + // (except the first one, which we already connected to the first input). This should + // result in the hugr having a (hyper)edge that connects all the ports. + // There should be a better way to do this. + for (node, port) in inputs.iter() { + self.hugr.connect(outputs[0].0, outputs[0].1, *node, *port); + } + + for (node, port) in outputs.iter().skip(1) { + self.hugr.connect(*node, *port, inputs[0].0, inputs[0].1); + } + + inputs.clear(); + outputs.clear(); + } + + Ok(()) + } + + fn link_static_ports(&mut self) -> Result<(), ImportError> { + for (src_id, dst_id) in std::mem::take(&mut self.static_edges) { + // None of these lookups should fail given how we constructed `static_edges`. + let src = self.nodes[&src_id]; + let dst = self.nodes[&dst_id]; + let src_port = self.hugr.get_optype(src).static_output_port().unwrap(); + let dst_port = self.hugr.get_optype(dst).static_input_port().unwrap(); + self.hugr.connect(src, src_port, dst, dst_port); + } + + Ok(()) + } + + fn with_local_socpe( + &mut self, + f: impl FnOnce(&mut Self) -> Result, + ) -> Result { + let previous = std::mem::take(&mut self.local_variables); + let result = f(self); + self.local_variables = previous; + result + } + + fn resolve_global_ref( + &self, + global_ref: &model::GlobalRef, + ) -> Result { + match global_ref { + model::GlobalRef::Direct(node_id) => Ok(*node_id), + model::GlobalRef::Named(name) => { + let item = self + .names + .items + .get(name) + .ok_or_else(|| model::ModelError::InvalidGlobal(global_ref.to_string()))?; + + match item { + NamedItem::FuncDecl(node) => Ok(*node), + NamedItem::FuncDefn(node) => Ok(*node), + } + } + } + } + + fn get_global_name(&self, global_ref: model::GlobalRef<'a>) -> Result<&'a str, ImportError> { + match global_ref { + model::GlobalRef::Direct(node_id) => { + let node_data = self.get_node(node_id)?; + + let name = match node_data.operation { + model::Operation::DefineFunc { decl } => decl.name, + model::Operation::DeclareFunc { decl } => decl.name, + model::Operation::DefineAlias { decl, .. } => decl.name, + model::Operation::DeclareAlias { decl } => decl.name, + _ => { + return Err(model::ModelError::InvalidGlobal(global_ref.to_string()).into()); + } + }; + + Ok(name) + } + model::GlobalRef::Named(name) => Ok(name), + } + } + + fn get_func_signature( + &mut self, + func_node: model::NodeId, + ) -> Result { + let decl = match self.get_node(func_node)?.operation { + model::Operation::DefineFunc { decl } => decl, + model::Operation::DeclareFunc { decl } => decl, + _ => return Err(model::ModelError::UnexpectedOperation(func_node).into()), + }; + + self.import_poly_func_type(*decl, |_, signature| Ok(signature)) + } + + /// Import the root region of the module. + fn import_root(&mut self) -> Result<(), ImportError> { + let region_data = self.get_region(self.module.root)?; + + for node in region_data.children { + self.import_node(*node, self.hugr.root())?; + } + + Ok(()) + } + + fn import_node(&mut self, node_id: model::NodeId, parent: Node) -> Result { + let node_data = self.get_node(node_id)?; + + match node_data.operation { + model::Operation::Dfg => { + let signature = self.get_node_signature(node_id)?; + let optype = OpType::DFG(DFG { signature }); + let node = self.make_node(node_id, optype, parent)?; + + let [region] = node_data.regions else { + return Err(model::ModelError::InvalidRegions(node_id).into()); + }; + + self.import_dfg_region(node_id, *region, node)?; + + Ok(node) + } + + // TODO: Implement support for importing control flow graphs. + model::Operation::Cfg => Err(error_unsupported!("`cfg` nodes")), + + model::Operation::Block => Err(model::ModelError::UnexpectedOperation(node_id).into()), + + model::Operation::DefineFunc { decl } => { + self.import_poly_func_type(*decl, |ctx, signature| { + let optype = OpType::FuncDefn(FuncDefn { + name: decl.name.to_string(), + signature, + }); + + let node = ctx.make_node(node_id, optype, parent)?; + + let [region] = node_data.regions else { + return Err(model::ModelError::InvalidRegions(node_id).into()); + }; + + ctx.import_dfg_region(node_id, *region, node)?; + + Ok(node) + }) + } + + model::Operation::DeclareFunc { decl } => { + self.import_poly_func_type(*decl, |ctx, signature| { + let optype = OpType::FuncDecl(FuncDecl { + name: decl.name.to_string(), + signature, + }); + + let node = ctx.make_node(node_id, optype, parent)?; + + Ok(node) + }) + } + + model::Operation::CallFunc { func } => { + let model::Term::ApplyFull { name, args } = self.get_term(func)? else { + return Err(model::ModelError::TypeError(func).into()); + }; + + let func_node = self.resolve_global_ref(name)?; + let func_sig = self.get_func_signature(func_node)?; + + let type_args = args + .iter() + .map(|term| self.import_type_arg(*term)) + .collect::, _>>()?; + + self.static_edges.push((func_node, node_id)); + let optype = OpType::Call(Call::try_new(func_sig, type_args, &self.extensions)?); + + self.make_node(node_id, optype, parent) + } + + model::Operation::LoadFunc { func } => { + let model::Term::ApplyFull { name, args } = self.get_term(func)? else { + return Err(model::ModelError::TypeError(func).into()); + }; + + let func_node = self.resolve_global_ref(name)?; + let func_sig = self.get_func_signature(func_node)?; + + let type_args = args + .iter() + .map(|term| self.import_type_arg(*term)) + .collect::, _>>()?; + + self.static_edges.push((func_node, node_id)); + + let optype = OpType::LoadFunction(LoadFunction::try_new( + func_sig, + type_args, + &self.extensions, + )?); + + self.make_node(node_id, optype, parent) + } + + model::Operation::TailLoop { + inputs, + outputs, + rest, + extensions, + } => { + let just_inputs = self.import_type_row(inputs)?; + let just_outputs = self.import_type_row(outputs)?; + let rest = self.import_type_row(rest)?; + let extension_delta = self.import_extension_set(extensions)?; + + let optype = OpType::TailLoop(TailLoop { + just_inputs, + just_outputs, + rest, + extension_delta, + }); + + let node = self.make_node(node_id, optype, parent)?; + + let [region] = node_data.regions else { + return Err(model::ModelError::InvalidRegions(node_id).into()); + }; + + self.import_dfg_region(node_id, *region, node)?; + Ok(node) + } + + model::Operation::Conditional { + cases, + context, + outputs, + extensions, + } => { + let sum_rows: Vec<_> = self.import_type_rows(cases)?; + let other_inputs = self.import_type_row(context)?; + let outputs = self.import_type_row(outputs)?; + let extension_delta = self.import_extension_set(extensions)?; + + let optype = OpType::Conditional(Conditional { + sum_rows, + other_inputs, + outputs, + extension_delta, + }); + + let node = self.make_node(node_id, optype, parent)?; + + for region in node_data.regions { + let region_data = self.get_region(*region)?; + + let source_types = self.get_port_types(region_data.sources)?; + let target_types = self.get_port_types(region_data.targets)?; + let signature = FuncTypeBase::new(source_types, target_types); + + let case_node = self + .hugr + .add_node_with_parent(node, OpType::Case(Case { signature })); + + self.import_dfg_region(node_id, *region, case_node)?; + } + + Ok(node) + } + + model::Operation::CustomFull { + name: GlobalRef::Named(name), + } if name == OP_FUNC_CALL_INDIRECT => { + let signature = self.get_node_signature(node_id)?; + let optype = OpType::CallIndirect(CallIndirect { signature }); + self.make_node(node_id, optype, parent) + } + + model::Operation::CustomFull { name } => { + let signature = self.get_node_signature(node_id)?; + let args = node_data + .params + .iter() + .map(|param| self.import_type_arg(*param)) + .collect::, _>>()?; + + let name = match name { + GlobalRef::Direct(_) => { + return Err(error_unsupported!( + "custom operation with direct reference to declaring node" + )) + } + GlobalRef::Named(name) => name, + }; + + let (extension, name) = self.import_custom_name(name)?; + + let optype = OpType::OpaqueOp(OpaqueOp::new( + extension, + name, + String::default(), + args, + signature, + )); + + let node = self.make_node(node_id, optype, parent)?; + + match node_data.regions { + [] => {} + [region] => self.import_dfg_region(node_id, *region, node)?, + _ => return Err(error_unsupported!("multiple regions in custom operation")), + } + + Ok(node) + } + + model::Operation::Custom { .. } => Err(error_unsupported!( + "custom operation with implicit parameters" + )), + + model::Operation::DefineAlias { decl, value } => self.with_local_socpe(|ctx| { + if !decl.params.is_empty() { + return Err(error_unsupported!( + "parameters or constraints in alias definition" + )); + } + + let optype = OpType::AliasDefn(AliasDefn { + name: decl.name.to_smolstr(), + definition: ctx.import_type(value)?, + }); + + ctx.make_node(node_id, optype, parent) + }), + + model::Operation::DeclareAlias { decl } => self.with_local_socpe(|ctx| { + if !decl.params.is_empty() { + return Err(error_unsupported!( + "parameters or constraints in alias declaration" + )); + } + + let optype = OpType::AliasDecl(AliasDecl { + name: decl.name.to_smolstr(), + bound: TypeBound::Copyable, + }); + + ctx.make_node(node_id, optype, parent) + }), + } + } + + fn import_dfg_region( + &mut self, + node_id: model::NodeId, + region: model::RegionId, + node: Node, + ) -> Result<(), ImportError> { + let region_data = self.get_region(region)?; + + if !matches!(region_data.kind, model::RegionKind::DataFlow) { + return Err(model::ModelError::InvalidRegions(node_id).into()); + } + + self.make_input_node(node, region_data.sources)?; + self.make_output_node(node, region_data.targets)?; + + for child in region_data.children { + self.import_node(*child, node)?; + } + + Ok(()) + } + + fn import_poly_func_type( + &mut self, + decl: model::FuncDecl<'a>, + in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, + ) -> Result { + self.with_local_socpe(|ctx| { + let mut imported_params = Vec::with_capacity(decl.params.len()); + + for param in decl.params { + // TODO: `PolyFuncType` should be able to handle constraints + // and distinguish between implicit and explicit parameters. + match param { + model::Param::Implicit { name, r#type } => { + imported_params.push(ctx.import_type_param(*r#type)?); + ctx.local_variables.insert(name, *r#type); + } + model::Param::Explicit { name, r#type } => { + imported_params.push(ctx.import_type_param(*r#type)?); + ctx.local_variables.insert(name, *r#type); + } + model::Param::Constraint { constraint: _ } => { + return Err(error_unsupported!("constraints")); + } + } + } + + let body = ctx.import_func_type::(decl.func)?; + in_scope(ctx, PolyFuncTypeBase::new(imported_params, body)) + }) + } + + /// Import a [`TypeParam`] from a term that represents a static type. + fn import_type_param(&mut self, term_id: model::TermId) -> Result { + match self.get_term(term_id)? { + model::Term::Wildcard => Err(error_uninferred!("wildcard")), + + model::Term::Type => { + // As part of the migration from `TypeBound`s to constraints, we pretend that all + // `TypeBound`s are copyable. + Ok(TypeParam::Type { + b: TypeBound::Copyable, + }) + } + + model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), + model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), + model::Term::Var(_) => Err(error_unsupported!("type variable as `TypeParam`")), + model::Term::Apply { .. } => Err(error_unsupported!("custom type as `TypeParam`")), + model::Term::ApplyFull { .. } => Err(error_unsupported!("custom type as `TypeParam`")), + + model::Term::Quote { .. } => Err(error_unsupported!("`(quote ...)` as `TypeParam`")), + model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), + + model::Term::ListType { item_type } => { + let param = Box::new(self.import_type_param(*item_type)?); + Ok(TypeParam::List { param }) + } + + model::Term::StrType => Ok(TypeParam::String), + model::Term::ExtSetType => Ok(TypeParam::Extensions), + + // TODO: What do we do about the bounds on naturals? + model::Term::NatType => todo!(), + + model::Term::Nat(_) + | model::Term::Str(_) + | model::Term::List { .. } + | model::Term::ExtSet { .. } + | model::Term::Adt { .. } + | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + + model::Term::ControlType => { + Err(error_unsupported!("type of control types as `TypeArg`")) + } + } + } + + /// Import a `TypeArg` froma term that represents a static type or value. + fn import_type_arg(&mut self, term_id: model::TermId) -> Result { + match self.get_term(term_id)? { + model::Term::Wildcard => Err(error_uninferred!("wildcard")), + model::Term::Apply { .. } => { + Err(error_uninferred!("application with implicit parameters")) + } + + model::Term::Var(var) => { + let (index, var_type) = self.resolve_local_ref(var)?; + let decl = self.import_type_param(var_type)?; + Ok(TypeArg::new_var_use(index, decl)) + } + + model::Term::List { .. } => { + let elems = self + .import_closed_list(term_id)? + .iter() + .map(|item| self.import_type_arg(*item)) + .collect::>()?; + + Ok(TypeArg::Sequence { elems }) + } + + model::Term::Str(value) => Ok(TypeArg::String { + arg: value.clone().into(), + }), + + model::Term::Quote { .. } => Ok(TypeArg::Type { + ty: self.import_type(term_id)?, + }), + model::Term::Nat(value) => Ok(TypeArg::BoundedNat { n: *value }), + model::Term::ExtSet { .. } => Ok(TypeArg::Extensions { + es: self.import_extension_set(term_id)?, + }), + + model::Term::StrType => Err(error_unsupported!("`str` as `TypeArg`")), + model::Term::NatType => Err(error_unsupported!("`nat` as `TypeArg`")), + model::Term::ListType { .. } => Err(error_unsupported!("`(list ...)` as `TypeArg`")), + model::Term::ExtSetType => Err(error_unsupported!("`ext-set` as `TypeArg`")), + model::Term::Type => Err(error_unsupported!("`type` as `TypeArg`")), + model::Term::ApplyFull { .. } => Err(error_unsupported!("custom types as `TypeArg`")), + model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")), + model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")), + model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")), + + model::Term::FuncType { .. } + | model::Term::Adt { .. } + | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + } + } + + fn import_extension_set( + &mut self, + term_id: model::TermId, + ) -> Result { + match self.get_term(term_id)? { + model::Term::Wildcard => Err(error_uninferred!("wildcard")), + + model::Term::Var(var) => { + let mut es = ExtensionSet::new(); + let (index, _) = self.resolve_local_ref(var)?; + es.insert_type_var(index); + Ok(es) + } + + model::Term::ExtSet { extensions, rest } => { + let mut es = match rest { + Some(rest) => self.import_extension_set(*rest)?, + None => ExtensionSet::new(), + }; + + for ext in extensions.iter() { + let ext_ident = IdentList::new(*ext) + .map_err(|_| model::ModelError::MalformedName(ext.to_smolstr()))?; + es.insert(&ext_ident); + } + + Ok(es) + } + _ => Err(model::ModelError::TypeError(term_id).into()), + } + } + + /// Import a `Type` from a term that represents a runtime type. + fn import_type( + &mut self, + term_id: model::TermId, + ) -> Result, ImportError> { + match self.get_term(term_id)? { + model::Term::Wildcard => Err(error_uninferred!("wildcard")), + model::Term::Apply { .. } => { + Err(error_uninferred!("application with implicit parameters")) + } + + model::Term::ApplyFull { name, args } => { + let args = args + .iter() + .map(|arg| self.import_type_arg(*arg)) + .collect::, _>>()?; + + let name = self.get_global_name(*name)?; + let (extension, id) = self.import_custom_name(&name)?; + + Ok(TypeBase::new_extension(CustomType::new( + id, + args, + extension, + // As part of the migration from `TypeBound`s to constraints, we pretend that all + // `TypeBound`s are copyable. + TypeBound::Copyable, + ))) + } + + model::Term::Var(var) => { + // We pretend that all `TypeBound`s are copyable. + let (index, _) = self.resolve_local_ref(var)?; + Ok(TypeBase::new_var_use(index, TypeBound::Copyable)) + } + + model::Term::FuncType { .. } => { + let func_type = self.import_func_type::(term_id)?; + Ok(TypeBase::new_function(func_type)) + } + + model::Term::Adt { variants } => { + let variants = self.import_closed_list(*variants)?; + let variants = variants + .iter() + .map(|variant| self.import_type_row::(*variant)) + .collect::, _>>()?; + Ok(TypeBase::new_sum(variants)) + } + + // The following terms are not runtime types, but the core `Type` only contains runtime types. + // We therefore report a type error here. + model::Term::ListType { .. } + | model::Term::StrType + | model::Term::NatType + | model::Term::ExtSetType + | model::Term::StaticType + | model::Term::Type + | model::Term::Constraint + | model::Term::Quote { .. } + | model::Term::Str(_) + | model::Term::ExtSet { .. } + | model::Term::List { .. } + | model::Term::Control { .. } + | model::Term::ControlType + | model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()), + } + } + + fn import_func_type( + &mut self, + term_id: model::TermId, + ) -> Result, ImportError> { + let term = self.get_term(term_id)?; + + let model::Term::FuncType { + inputs, + outputs, + extensions: _, + } = term + else { + return Err(model::ModelError::TypeError(term_id).into()); + }; + + let inputs = self.import_type_row::(*inputs)?; + let outputs = self.import_type_row::(*outputs)?; + // TODO: extensions + Ok(FuncTypeBase::new(inputs, outputs)) + } + + fn import_closed_list( + &mut self, + mut term_id: model::TermId, + ) -> Result, ImportError> { + let mut list_items = Vec::new(); + + loop { + match self.get_term(term_id)? { + model::Term::Var(_) => return Err(error_unsupported!("open lists")), + model::Term::List { items, tail } => { + list_items.extend(items.iter()); + + match tail { + Some(tail) => term_id = *tail, + None => break, + } + } + _ => return Err(model::ModelError::TypeError(term_id).into()), + } + } + + Ok(list_items) + } + + fn import_type_row( + &mut self, + term_id: model::TermId, + ) -> Result, ImportError> { + let items = self + .import_closed_list(term_id)? + .iter() + .map(|item| self.import_type(*item)) + .collect::, _>>()?; + + Ok(items.into()) + } + + fn import_type_rows( + &mut self, + term_id: model::TermId, + ) -> Result>, ImportError> { + let items = self + .import_closed_list(term_id)? + .iter() + .map(|item| self.import_type_row(*item)) + .collect::, _>>()?; + Ok(items) + } + + fn import_custom_name(&self, symbol: &'a str) -> Result<(ExtensionId, SmolStr), ImportError> { + let qualified_name = ExtensionId::new(symbol) + .map_err(|_| model::ModelError::MalformedName(symbol.to_smolstr()))?; + + let (extension, id) = qualified_name + .split_last() + .ok_or_else(|| model::ModelError::MalformedName(symbol.to_smolstr()))?; + + Ok((extension, id.into())) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +enum NamedItem { + FuncDecl(model::NodeId), + FuncDefn(model::NodeId), +} + +struct Names<'a> { + items: FxHashMap<&'a str, NamedItem>, +} + +impl<'a> Names<'a> { + pub fn new(module: &model::Module<'a>) -> Result { + let mut items = FxHashMap::default(); + + for (node_id, node_data) in module.nodes.iter().enumerate() { + let node_id = model::NodeId(node_id as _); + + let item = match node_data.operation { + model::Operation::DefineFunc { decl } => { + Some((decl.name, NamedItem::FuncDecl(node_id))) + } + model::Operation::DeclareFunc { decl } => { + Some((decl.name, NamedItem::FuncDefn(node_id))) + } + _ => None, + }; + + if let Some((name, item)) = item { + // TODO: Deal with duplicates + items.insert(name, item); + } + } + + Ok(Self { items }) + } +} diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index 6bd2a262d..1221c5572 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -10,8 +10,10 @@ pub mod builder; pub mod core; +pub mod export; pub mod extension; pub mod hugr; +pub mod import; pub mod macros; pub mod ops; pub mod std_extensions; diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 039c1b6e0..5afab2294 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -8,7 +8,7 @@ mod serialize; mod signature; pub mod type_param; pub mod type_row; -use row_var::MaybeRV; +pub(crate) use row_var::MaybeRV; pub use row_var::{NoRV, RowVariable}; pub use crate::ops::constant::{ConstTypeError, CustomCheckFailure}; @@ -16,7 +16,9 @@ use crate::types::type_param::check_type_arg; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; +pub(crate) use poly_func::PolyFuncTypeBase; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; +pub(crate) use signature::FuncTypeBase; pub use signature::{FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::TypeArg; diff --git a/hugr-core/tests/fixtures/model-1.edn b/hugr-core/tests/fixtures/model-1.edn new file mode 100644 index 000000000..b0ca37d8b --- /dev/null +++ b/hugr-core/tests/fixtures/model-1.edn @@ -0,0 +1,38 @@ +(hugr 0) + +; NOTE: The @ in front of the names indicates that their implicit arguments are +; explicitly given as well. This is necessary everywhere at the moment +; since we do not have inference for implicit arguments yet. + +; NOTE: Every port in this file has been annotated with its type. This is quite +; verbose, but it is necessary currently until we have inference. + +(define-alias local.int type (@ arithmetic.int.types.int)) + +(define-func example.add + [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + (dfg + [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))] + ((@ arithmetic.int.iadd) [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] [(%2 (@ arithmetic.int.types.int))]))) + +(declare-func example.callee + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + (meta doc.title "Callee") + (meta doc.description "This is a function declaration.")) + +(define-func example.caller + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + (meta doc.title "Caller") + (meta doc.description "This defines a function that calls the function which we declared earlier.") + (dfg + [(%3 (@ arithmetic.int.types.int))] + [(%4 (@ arithmetic.int.types.int))] + (call (@ example.callee) [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))]))) + +(define-func example.swap + ; The types of the values to be swapped are passed as implicit parameters. + (forall ?a type) + (forall ?b type) + [?a ?b] [?b ?a] (ext) + (dfg [(%a ?a) (%b ?b)] [(%b ?b) (%a ?a)])) diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs new file mode 100644 index 000000000..2cefca464 --- /dev/null +++ b/hugr-core/tests/model.rs @@ -0,0 +1,13 @@ +use hugr::std_extensions::std_reg; +use hugr_core::{export::export_hugr, import::import_hugr}; +use hugr_model::v0 as model; + +#[test] +pub fn test_import_export() { + let bump = bumpalo::Bump::new(); + let parsed_module = model::text::parse(include_str!("fixtures/model-1.edn"), &bump).unwrap(); + let extensions = std_reg(); + let hugr = import_hugr(&parsed_module.module, &extensions).unwrap(); + let roundtrip = export_hugr(&hugr, &bump); + panic!("{}:", model::text::print_to_string(&roundtrip, 80).unwrap()); +} diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml new file mode 100644 index 000000000..5c4f79a8c --- /dev/null +++ b/hugr-model/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "hugr-model" +version = "0.1.0" +rust-version.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true + +[dependencies] +beef = "0.5.2" +bumpalo = { workspace = true } +fxhash.workspace = true +indexmap.workspace = true +pest = "2.7.12" +pest_derive = "2.7.12" +pretty = "0.12.3" +smol_str = { workspace = true, features = ["serde"] } +thiserror.workspace = true +tinyvec.workspace = true + +[lints] +workspace = true diff --git a/hugr-model/src/lib.rs b/hugr-model/src/lib.rs new file mode 100644 index 000000000..6c161ca37 --- /dev/null +++ b/hugr-model/src/lib.rs @@ -0,0 +1,5 @@ +//! The data model of the HUGR intermediate representation. +//! This crate defines data structures that capture the structure of a HUGR graph and +//! all its associated information in a form that can be stored on disk. The data structures +//! are not designed for efficient traversal or modification, but for simplicity and serialization. +pub mod v0; diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs new file mode 100644 index 000000000..43bb4ade7 --- /dev/null +++ b/hugr-model/src/v0/mod.rs @@ -0,0 +1,652 @@ +//! Version 0 (unstable). +//! +//! **Warning**: This module is still under development and is expected to change. +//! It is included in the library to allow for early experimentation, and for +//! the core and model to converge incrementally. +//! +//! +//! # Terms +//! +//! Terms form a meta language that is used to describe types, parameters and metadata that +//! are known statically. To allow types to be parameterized by values, types and values +//! are treated uniformly as terms, enabling a restricted form of dependent typing. +//! The type system is extensible and can be used to declaratively encode the desired shape +//! of operation parameters and metadata. Type constraints can be used to express more complex +//! validation rules. +//! +//! # Tabling +//! +//! Instead of directly nesting structures, we store them in tables and refer to them +//! by their index in the table. This allows us to attach additional data to the structures +//! without changing the data structure itself. This can be used, for example, to keep track +//! of metadata that has been parsed from its generic representation as a term into a more +//! specific in-memory representation. +//! +//! The tabling is also used for deduplication of terms. In practice, many terms will share +//! the same subterms, and we can save memory and validation time by storing them only once. +//! However we allow non-deduplicated terms for cases in which terms carry additional identity +//! over just their structure. For instance, structurally identical terms could originate +//! from different locations in a text file and therefore should be treated differently when +//! locating type errors. +//! +//! # Plain Data +//! +//! All types in the hugr model are plain data. This means that they can be serialized and +//! deserialized without loss of information. This is important for the model to be able to +//! serve as a stable interchange format between different tools and versions of the library. +//! +//! # Arena Allocation +//! +//! Since we intend to use the model data structures as an intermediary to convert between +//! different representations (such as text, binary or in-memory), we use arena allocation +//! to efficiently allocate and free the parts of the data structure that isn't directly stored +//! in the tables. For that purpose, we use the `'a` lifetime parameter to indicate the +//! lifetime of the arena. +use smol_str::SmolStr; +use thiserror::Error; + +pub mod text; + +macro_rules! define_index { + ($(#[$meta:meta])* $vis:vis struct $name:ident;) => { + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] + #[repr(transparent)] + $(#[$meta])* + $vis struct $name(pub u32); + + impl $name { + /// Create a new index. + /// + /// # Panics + /// + /// Panics if the index is 2^32 or larger. + pub fn new(index: usize) -> Self { + assert!(index < u32::MAX as usize, "index out of bounds"); + Self(index as u32) + } + + /// Returns the index as a `usize` to conveniently use it as a slice index. + #[inline] + pub fn index(self) -> usize { + self.0 as usize + } + + /// Convert a slice of this index type into a slice of `u32`s. + pub fn unwrap_slice(slice: &[Self]) -> &[u32] { + // SAFETY: This type is just a newtype around `u32`. + unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u32, slice.len()) } + } + + /// Convert a slice of `u32`s into a slice of this index type. + pub fn wrap_slice(slice: &[u32]) -> &[Self] { + // SAFETY: This type is just a newtype around `u32`. + unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) } + } + } + }; +} + +define_index! { + /// Index of a node in a hugr graph. + pub struct NodeId; +} + +define_index! { + /// Index of a link in a hugr graph. + pub struct LinkId; +} + +define_index! { + /// Index of a region in a hugr graph. + pub struct RegionId; +} + +define_index! { + /// Index of a term in a hugr graph. + pub struct TermId; +} + +/// A module consisting of a hugr graph together with terms. +#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)] +pub struct Module<'a> { + /// The id of the root region. + pub root: RegionId, + /// Table of [`Node`]s. + pub nodes: Vec>, + /// Table of [`Region`]s. + pub region: Vec>, + /// Table of [`Term`]s. + pub terms: Vec>, +} + +impl<'a> Module<'a> { + /// Return the node data for a given node id. + #[inline] + pub fn get_node(&self, node_id: NodeId) -> Option<&Node<'a>> { + self.nodes.get(node_id.0 as usize) + } + + /// Insert a new node into the module and return its id. + pub fn insert_node(&mut self, node: Node<'a>) -> NodeId { + let id = NodeId(self.nodes.len() as u32); + self.nodes.push(node); + id + } + + /// Return the term data for a given term id. + #[inline] + pub fn get_term(&self, term_id: TermId) -> Option<&Term<'a>> { + self.terms.get(term_id.0 as usize) + } + + /// Insert a new term into the module and return its id. + pub fn insert_term(&mut self, term: Term<'a>) -> TermId { + let id = TermId(self.terms.len() as u32); + self.terms.push(term); + id + } + + /// Return the region data for a given region id. + #[inline] + pub fn get_region(&self, region_id: RegionId) -> Option<&Region<'a>> { + self.region.get(region_id.0 as usize) + } + + /// Insert a new region into the module and return its id. + pub fn insert_region(&mut self, region: Region<'a>) -> RegionId { + let id = RegionId(self.region.len() as u32); + self.region.push(region); + id + } +} + +/// Nodes in the hugr graph. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Node<'a> { + /// The operation that the node performs. + pub operation: Operation<'a>, + /// The input ports of the node. + pub inputs: &'a [Port<'a>], + /// The output ports of the node. + pub outputs: &'a [Port<'a>], + /// The parameters of the node. + pub params: &'a [TermId], + /// The regions of the node. + pub regions: &'a [RegionId], + /// The meta information attached to the node. + pub meta: &'a [MetaItem<'a>], +} + +/// Operations that nodes can perform. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Operation<'a> { + /// Data flow graphs. + Dfg, + /// Control flow graphs. + Cfg, + /// Basic blocks. + Block, + /// Function definitions. + DefineFunc { + /// The declaration of the function to be defined. + decl: &'a FuncDecl<'a>, + }, + /// Function declarations. + DeclareFunc { + /// The function to be declared. + decl: &'a FuncDecl<'a>, + }, + /// Function calls. + CallFunc { + /// The function to be called. + func: TermId, + }, + /// Function constants. + LoadFunc { + /// The function to be loaded. + func: TermId, + }, + /// Custom operation. + /// + /// The implicit parameters of the operation are left out. + Custom { + /// The name of the custom operation. + name: GlobalRef<'a>, + }, + /// Custom operation. + /// + /// The implicit parameters of the operation are included. + CustomFull { + /// The name of the custom operation. + name: GlobalRef<'a>, + }, + /// Alias definitions. + DefineAlias { + /// The declaration of the alias to be defined. + decl: &'a AliasDecl<'a>, + /// The value of the alias. + value: TermId, + }, + + /// Alias declarations. + DeclareAlias { + /// The alias to be declared. + decl: &'a AliasDecl<'a>, + }, + + /// Tail controlled loop. + /// Nodes with this operation contain a dataflow graph that is executed in a loop. + /// The loop body is executed at least once, producing a result that indicates whether + /// to continue the loop or return the result. + /// + /// # Port Types + /// + /// - **Inputs**: `inputs` + `rest` + /// - **Outputs**: `outputs` + `rest` + /// - **Sources**: `inputs` + `rest` + /// - **Targets**: `(adt [inputs outputs])` + `rest` + TailLoop { + // TODO: These can be determined by the port types? + /// Types of the values that are passed as inputs to the loop, and are returned + /// by the loop body when the loop is continued. + /// + /// **Type**: `(list type)` + inputs: TermId, + /// Types of the values that are produced at the end of the loop body when the loop + /// should be ended. + /// + /// **Type**: `(list type)` + outputs: TermId, + /// Types of the values that are passed as inputs to the loop, to each iteration and + /// are then returned at the end of the loop. + /// + /// **Type**: `(list type)` + rest: TermId, + /// + /// + /// **Type**: `ext-set` + extensions: TermId, + }, + + /// Conditional operation. + /// + /// # Port types + /// + /// - **Inputs**: `[(adt inputs)]` + `context` + /// - **Outputs**: `outputs` + Conditional { + /// Port types for each case of the conditional. + /// + /// **Type**: `(list (list type))` + cases: TermId, + /// Port types for additional inputs to the conditional. + /// + /// **Type**: `(list type)` + context: TermId, + /// Port types for the outputs of each case. + /// + /// **Type**: `(list type)` + outputs: TermId, + /// + /// + /// **Type**: `ext-set` + extensions: TermId, + }, +} + +/// A region in the hugr. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Region<'a> { + /// The kind of the region. See [`RegionKind`] for details. + pub kind: RegionKind, + /// The source ports of the region. + pub sources: &'a [Port<'a>], + /// The target ports of the region. + pub targets: &'a [Port<'a>], + /// The nodes in the region. The order of the nodes is not significant. + pub children: &'a [NodeId], + /// The metadata attached to the region. + pub meta: &'a [MetaItem<'a>], +} + +/// The kind of a region. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum RegionKind { + /// Data flow region. + DataFlow, + /// Control flow region. + ControlFlow, +} + +/// A port attached to a [`Node`] or [`Region`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Port<'a> { + /// The link that the port is connected to. + pub link: LinkRef<'a>, + /// The type of the port. + pub r#type: Option, + /// Metadata attached to the port. + pub meta: &'a [MetaItem<'a>], +} + +/// A function declaration. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct FuncDecl<'a> { + /// The name of the function to be declared. + pub name: &'a str, + /// The static parameters of the function. + pub params: &'a [Param<'a>], + /// The type of the function. + pub func: TermId, +} + +/// An alias declaration. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct AliasDecl<'a> { + /// The name of the alias to be declared. + pub name: &'a str, + /// The static parameters of the alias. + pub params: &'a [Param<'a>], + /// The type of the alias. + pub r#type: TermId, +} + +/// A metadata item. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct MetaItem<'a> { + /// Name of the metadata item. + pub name: &'a str, + /// Value of the metadata item. + pub value: TermId, +} + +/// A reference to a global variable. +/// +/// Global variables are defined in nodes. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum GlobalRef<'a> { + /// Reference to the global that is defined by the given node. + Direct(NodeId), + /// Reference to the global with the given name. + Named(&'a str), +} + +impl std::fmt::Display for GlobalRef<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GlobalRef::Direct(id) => write!(f, ":{}", id.index()), + GlobalRef::Named(name) => write!(f, "{}", name), + } + } +} + +/// A reference to a local variable. +/// +/// Local variables are defined as parameters to nodes. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum LocalRef<'a> { + /// Reference to the local variable by its parameter index. + Index(u16), + /// Reference to the local variable by its name. + Named(&'a str), +} + +impl std::fmt::Display for LocalRef<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LocalRef::Index(index) => write!(f, "?:{}", index), + LocalRef::Named(name) => write!(f, "?{}", name), + } + } +} + +/// A reference to a link. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum LinkRef<'a> { + /// Reference to the link by its id. + Id(LinkId), + /// Reference to the link by its name. + Named(&'a str), +} + +impl std::fmt::Display for LinkRef<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LinkRef::Id(id) => write!(f, "%:{})", id.index()), + LinkRef::Named(name) => write!(f, "%{}", name), + } + } +} + +/// A term in the compile time meta language. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Term<'a> { + /// Standin for any term. + Wildcard, + + /// The type of runtime types. + /// + /// `type : static` + Type, + + /// The type of static types. + /// + /// `static : static` + StaticType, + + /// The type of constraints. + /// + /// `constraint : static` + Constraint, + + /// A local variable. + Var(LocalRef<'a>), + + /// A symbolic function application. + /// + /// `(GLOBAL ARG-0 ... ARG-n)` + Apply { + // TODO: Should the name be replaced with the id of the node that defines + // the function to be applied? This could be a type, alias or function. + /// The name of the term. + name: GlobalRef<'a>, + /// Arguments to the function, covering only the explicit parameters. + args: &'a [TermId], + }, + + /// A symbolic function application with all arguments applied. + /// + /// `(@GLOBAL ARG-0 ... ARG-n)` + ApplyFull { + /// The name of the function to apply. + name: GlobalRef<'a>, + /// Arguments to the function, covering both implicit and explicit parameters. + args: &'a [TermId], + }, + + /// Quote a runtime type as a static type. + /// + /// `(quote T) : static` where `T : type`. + Quote { + /// The runtime type to be quoted. + /// + /// **Type:** `type` + r#type: TermId, + }, + + /// A list, with an optional tail. + /// + /// - `[ITEM-0 ... ITEM-n] : (list T)` where `T : static`, `ITEM-i : T`. + /// - `[ITEM-0 ... ITEM-n . TAIL] : (list item-type)` where `T : static`, `ITEM-i : T`, `TAIL : (list T)`. + List { + /// The items in the list. + /// + /// `item-i : item-type` + items: &'a [TermId], + /// The tail of the list. + /// + /// `tail : (list item-type)` + tail: Option, + }, + + /// The type of lists, given a type for the items. + /// + /// `(list T) : static` where `T : static`. + ListType { + /// The type of the items in the list. + /// + /// `item_type : static` + item_type: TermId, + }, + + /// A literal string. + /// + /// `"STRING" : str` + Str(SmolStr), + + /// The type of literal strings. + /// + /// `str : static` + StrType, + + /// A literal natural number. + /// + /// `N : nat` + Nat(u64), + + /// The type of literal natural numbers. + /// + /// `nat : static` + NatType, + + /// Extension set. + /// + /// - `(ext EXT-0 ... EXT-n) : ext-set` + /// - `(ext EXT-0 ... EXT-n . REST) : ext-set` where `REST : ext-set`. + ExtSet { + /// The items in the extension set. + extensions: &'a [&'a str], + /// The rest of the extension set. + rest: Option, + }, + + /// The type of extension sets. + /// + /// `ext-set : static` + ExtSetType, + + /// An algebraic data type. + /// + /// `(adt VARIANTS) : type` where `VARIANTS : (list (list type))`. + Adt { + /// List of variants in the algrebaic data type. + /// Each of the variants is itself a list of runtime types. + variants: TermId, + }, + + /// The type of functions, given lists of input and output types and an extension set. + FuncType { + /// The input types of the function, given as a list of runtime types. + /// + /// `inputs : (list type)` + inputs: TermId, + /// The output types of the function, given as a list of runtime types. + /// + /// `outputs : (list type)` + outputs: TermId, + /// The set of extensions that the function requires to be present in + /// order to be called. + /// + /// `extensions : ext-set` + extensions: TermId, + }, + + /// Control flow. + /// + /// `(ctrl VALUES) : ctrl` where `VALUES : (list type)`. + Control { + /// List of values. + values: TermId, + }, + + /// Type of control flow edges. + /// + /// `ctrl : static` + ControlType, +} + +impl<'a> Default for Term<'a> { + fn default() -> Self { + Self::Wildcard + } +} + +/// A parameter to a function or alias. +/// +/// Parameter names must be unique within a parameter list. +/// Implicit and explicit parameters share a namespace. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Param<'a> { + /// An implicit parameter that should be inferred. + Implicit { + /// The name of the parameter. + name: &'a str, + /// The type of the parameter. + /// + /// This must be a term of type `static`. + r#type: TermId, + }, + /// An explicit parameter that should always be provided. + Explicit { + /// The name of the parameter. + name: &'a str, + /// The type of the parameter. + /// + /// This must be a term of type `static`. + r#type: TermId, + }, + /// A constraint that should be satisfied by other parameters in a parameter list. + Constraint { + /// The constraint to be satisfied. + /// + /// This must be a term of type `constraint`. + constraint: TermId, + }, +} + +/// Errors that can occur when traversing and interpreting the model. +#[derive(Debug, Clone, Error)] +pub enum ModelError { + /// There is a reference to a node that does not exist. + #[error("node not found: {0:?}")] + NodeNotFound(NodeId), + /// There is a reference to a term that does not exist. + #[error("term not found: {0:?}")] + TermNotFound(TermId), + /// There is a reference to a region that does not exist. + #[error("region not found: {0:?}")] + RegionNotFound(RegionId), + /// There is a local reference that does not resolve. + #[error("local variable invalid: {0:?}")] + InvalidLocal(String), + /// There is a global reference that does not resolve to a node + /// that defines a global variable. + #[error("global variable invalid: {0:?}")] + InvalidGlobal(String), + /// The model contains an operation in a place where it is not allowed. + #[error("unexpected operation on node: {0:?}")] + UnexpectedOperation(NodeId), + /// There is a term that is not well-typed. + #[error("type error in term: {0:?}")] + TypeError(TermId), + /// There is a node whose regions are not well-formed according to the node's operation. + #[error("node has invalid regions: {0:?}")] + InvalidRegions(NodeId), + /// There is a name that is not well-formed. + #[error("malformed name: {0}")] + MalformedName(SmolStr), + /// There is a condition node that lacks a case for a tag or + /// defines two cases for the same tag. + #[error("condition node is malformed: {0:?}")] + MalformedCondition(NodeId), +} diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest new file mode 100644 index 000000000..212a0d053 --- /dev/null +++ b/hugr-model/src/v0/text/hugr.pest @@ -0,0 +1,99 @@ +WHITESPACE = _{ " " | "\t" | "\r" | "\n" } +COMMENT = _{ ";" ~ (!("\n") ~ ANY)* ~ "\n" } +identifier = @{ (ASCII_ALPHA | "_" | "-") ~ (ASCII_ALPHANUMERIC | "_" | "-")* } +symbol = @{ identifier ~ ("." ~ identifier)+ } + +string = @{ "\"" ~ (!("\"") ~ ANY)* ~ "\"" } +list_tail = { "." } + +module = { "(" ~ "hugr" ~ "0" ~ ")" ~ meta* ~ node* ~ EOI } + +meta = { "(" ~ "meta" ~ symbol ~ term ~ ")" } + +edge_name = @{ "%" ~ (ASCII_ALPHANUMERIC | "_" | "-")* } +port = { edge_name | ("(" ~ edge_name ~ term ~ meta* ~ ")") } +port_list = { "[" ~ port* ~ "]" } +port_lists = _{ port_list ~ port_list } + +node = { + node_dfg + | node_cfg + | node_block + | node_define_func + | node_declare_func + | node_call_func + | node_define_alias + | node_declare_alias + | node_tail_loop + | node_cond + | node_custom +} + +node_dfg = { "(" ~ "dfg" ~ port_lists? ~ meta* ~ region* ~ ")" } +node_cfg = { "(" ~ "cfg" ~ port_lists? ~ meta* ~ region* ~ ")" } +node_block = { "(" ~ "block" ~ port_lists? ~ meta* ~ region* ~ ")" } +node_define_func = { "(" ~ "define-func" ~ func_header ~ meta* ~ region* ~ ")" } +node_declare_func = { "(" ~ "declare-func" ~ func_header ~ meta* ~ ")" } +node_call_func = { "(" ~ "call" ~ term ~ port_lists? ~ meta* ~ ")" } +node_define_alias = { "(" ~ "define-alias" ~ alias_header ~ term ~ meta* ~ ")" } +node_declare_alias = { "(" ~ "declare-alias" ~ alias_header ~ meta* ~ ")" } +node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ meta* ~ region* ~ ")" } +node_cond = { "(" ~ "cond" ~ port_lists? ~ meta* ~ region* ~ ")" } +node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ meta* ~ region* ~ ")" } + +func_header = { symbol ~ param* ~ term ~ term ~ term } +alias_header = { symbol ~ param* ~ term } + +param = { param_implicit | param_explicit | param_constraint } + +param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } +param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } +param_constraint = { "(" ~ "where" ~ term ~ ")" } + +region = { region_dfg | region_cfg } +region_dfg = { "(" ~ "dfg" ~ port_lists? ~ meta* ~ node* ~ ")" } +region_cfg = { "(" ~ "cfg" ~ port_lists? ~ meta* ~ node* ~ ")" } + +term = { + term_wildcard + | term_type + | term_static + | term_constraint + | term_var + | term_quote + | term_list + | term_list_type + | term_str + | term_str_type + | term_nat + | term_nat_type + | term_ext_set + | term_ext_set_type + | term_adt + | term_func_type + | term_ctrl + | term_ctrl_type + | term_apply_full + | term_apply +} + +term_wildcard = { "_" } +term_type = { "type" } +term_static = { "static" } +term_constraint = { "constraint" } +term_var = { "?" ~ identifier } +term_apply_full = { ("(" ~ "@" ~ symbol ~ term* ~ ")") } +term_apply = { symbol | ("(" ~ symbol ~ term* ~ ")") } +term_quote = { "(" ~ "quote" ~ term ~ ")" } +term_list = { "[" ~ term* ~ (list_tail ~ term)? ~ "]" } +term_list_type = { "(" ~ "list" ~ term ~ ")" } +term_str = { string } +term_str_type = { "str" } +term_nat = { (ASCII_DIGIT)+ } +term_nat_type = { "nat" } +term_ext_set = { "(" ~ "ext" ~ identifier* ~ (list_tail ~ term)? ~ ")" } +term_ext_set_type = { "ext-set" } +term_adt = { "(" ~ "adt" ~ term ~ ")" } +term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } +term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } +term_ctrl_type = { "ctrl" } diff --git a/hugr-model/src/v0/text/mod.rs b/hugr-model/src/v0/text/mod.rs new file mode 100644 index 000000000..a070bf7e5 --- /dev/null +++ b/hugr-model/src/v0/text/mod.rs @@ -0,0 +1,6 @@ +//! The HUGR text representation. +mod parse; +mod print; + +pub use parse::{parse, ParseError, ParsedModule}; +pub use print::print_to_string; diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs new file mode 100644 index 000000000..98bfe1689 --- /dev/null +++ b/hugr-model/src/v0/text/parse.rs @@ -0,0 +1,682 @@ +use bumpalo::Bump; +use indexmap::IndexSet; +use pest::{ + iterators::{Pair, Pairs}, + Parser, RuleType, +}; +use smol_str::{SmolStr, ToSmolStr}; +use thiserror::Error; + +use crate::v0::{ + AliasDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, NodeId, Operation, + Param, Port, Region, RegionId, RegionKind, Term, TermId, +}; + +mod pest_parser { + use pest_derive::Parser; + + // NOTE: The pest derive macro generates a `Rule` enum. We do not want this to be + // part of the public API, and so we hide it within this private module. + + #[derive(Parser)] + #[grammar = "v0/text/hugr.pest"] + pub struct HugrParser; +} + +use pest_parser::{HugrParser, Rule}; + +/// A parsed HUGR module. +/// +/// This consists of the module itself, together with additional information that was +/// extracted from the text format. +#[derive(Debug, Clone)] +pub struct ParsedModule<'a> { + /// The parsed module. + pub module: Module<'a>, + /// The names of the edges. + pub edges: Vec, + // TODO: Spans +} + +/// Parses a HUGR module from its text representation. +pub fn parse<'a>(input: &'a str, bump: &'a Bump) -> Result, ParseError> { + let mut context = ParseContext::new(bump); + let mut pairs = HugrParser::parse(Rule::module, input).map_err(ParseError)?; + context.parse_module(pairs.next().unwrap())?; + + Ok(ParsedModule { + module: context.module, + edges: context.edge_names.into_iter().collect(), + }) +} + +struct ParseContext<'a> { + module: Module<'a>, + edge_names: IndexSet, + bump: &'a Bump, +} + +impl<'a> ParseContext<'a> { + fn new(bump: &'a Bump) -> Self { + Self { + module: Module::default(), + edge_names: IndexSet::default(), + bump, + } + } + + fn parse_module(&mut self, pair: Pair<'a, Rule>) -> ParseResult<()> { + debug_assert!(matches!(pair.as_rule(), Rule::module)); + let mut inner = pair.into_inner(); + let meta = self.parse_meta(&mut inner)?; + + let children = self.parse_nodes(&mut inner)?; + + let root_region = self.module.insert_region(Region { + kind: RegionKind::DataFlow, + sources: &[], + targets: &[], + children: self.bump.alloc_slice_copy(&children), + meta, + }); + + self.module.root = root_region; + + // TODO: Root region metadata + // self.module + // .node_meta + // .extend(meta.into_iter().map(|meta| (root, meta))); + + Ok(()) + } + + fn parse_term(&mut self, pair: Pair<'a, Rule>) -> ParseResult { + debug_assert!(matches!(pair.as_rule(), Rule::term)); + let pair = pair.into_inner().next().unwrap(); + let rule = pair.as_rule(); + let mut inner = pair.into_inner(); + + let term = match rule { + Rule::term_wildcard => Term::Wildcard, + Rule::term_type => Term::Type, + Rule::term_static => Term::StaticType, + Rule::term_constraint => Term::Constraint, + Rule::term_str_type => Term::StrType, + Rule::term_nat_type => Term::NatType, + Rule::term_ctrl_type => Term::ControlType, + Rule::term_ext_set_type => Term::ExtSetType, + + Rule::term_var => { + let name_token = inner.next().unwrap(); + let name = name_token.as_str(); + Term::Var(LocalRef::Named(name)) + } + + Rule::term_apply => { + let name = GlobalRef::Named(self.parse_symbol(&mut inner)?); + let mut args = Vec::new(); + + for token in inner { + args.push(self.parse_term(token)?); + } + + Term::Apply { + name, + args: self.bump.alloc_slice_copy(&args), + } + } + + Rule::term_apply_full => { + let name = GlobalRef::Named(self.parse_symbol(&mut inner)?); + let mut args = Vec::new(); + + for token in inner { + args.push(self.parse_term(token)?); + } + + Term::ApplyFull { + name, + args: self.bump.alloc_slice_copy(&args), + } + } + + Rule::term_quote => { + let r#type = self.parse_term(inner.next().unwrap())?; + Term::Quote { r#type } + } + + Rule::term_list => { + let mut items = Vec::new(); + let mut tail = None; + + for token in filter_rule(&mut inner, Rule::term) { + items.push(self.parse_term(token)?); + } + + if inner.next().is_some() { + let token = inner.next().unwrap(); + tail = Some(self.parse_term(token)?); + } + + Term::List { + items: self.bump.alloc_slice_copy(&items), + tail, + } + } + + Rule::term_list_type => { + let item_type = self.parse_term(inner.next().unwrap())?; + Term::ListType { item_type } + } + + Rule::term_str => { + // TODO: Escaping? + let value = inner.next().unwrap().as_str().to_smolstr(); + Term::Str(value) + } + + Rule::term_nat => { + let value = inner.next().unwrap().as_str().parse().unwrap(); + Term::Nat(value) + } + + Rule::term_ext_set => { + let mut extensions = Vec::new(); + let mut rest = None; + + for token in filter_rule(&mut inner, Rule::identifier) { + extensions.push(token.as_str()); + } + + if inner.next().is_some() { + let token = inner.next().unwrap(); + rest = Some(self.parse_term(token)?); + } + + Term::ExtSet { + extensions: self.bump.alloc_slice_copy(&extensions), + rest, + } + } + + Rule::term_adt => { + let variants = self.parse_term(inner.next().unwrap())?; + Term::Adt { variants } + } + + Rule::term_func_type => { + let inputs = self.parse_term(inner.next().unwrap())?; + let outputs = self.parse_term(inner.next().unwrap())?; + let extensions = self.parse_term(inner.next().unwrap())?; + Term::FuncType { + inputs, + outputs, + extensions, + } + } + + Rule::term_ctrl => { + let values = self.parse_term(inner.next().unwrap())?; + Term::Control { values } + } + + r => unreachable!("term: {:?}", r), + }; + + Ok(self.module.insert_term(term)) + } + + fn parse_node(&mut self, pair: Pair<'a, Rule>) -> ParseResult { + debug_assert!(matches!(pair.as_rule(), Rule::node)); + let pair = pair.into_inner().next().unwrap(); + let rule = pair.as_rule(); + + let mut inner = pair.into_inner(); + + let node = match rule { + Rule::node_dfg => { + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + let regions = self.parse_regions(&mut inner)?; + Node { + operation: Operation::Dfg, + inputs, + outputs, + params: &[], + regions, + meta, + } + } + + Rule::node_cfg => { + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + let regions = self.parse_regions(&mut inner)?; + Node { + operation: Operation::Cfg, + inputs, + outputs, + params: &[], + regions, + meta, + } + } + + Rule::node_block => { + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + let regions = self.parse_regions(&mut inner)?; + Node { + operation: Operation::Block, + inputs, + outputs, + params: &[], + regions, + meta, + } + } + + Rule::node_define_func => { + let decl = self.parse_func_header(inner.next().unwrap())?; + let meta = self.parse_meta(&mut inner)?; + let regions = self.parse_regions(&mut inner)?; + Node { + operation: Operation::DefineFunc { decl }, + inputs: &[], + outputs: &[], + params: &[], + regions, + meta, + } + } + + Rule::node_declare_func => { + let decl = self.parse_func_header(inner.next().unwrap())?; + let meta = self.parse_meta(&mut inner)?; + Node { + operation: Operation::DeclareFunc { decl }, + inputs: &[], + outputs: &[], + params: &[], + regions: &[], + meta, + } + } + + Rule::node_call_func => { + let func = self.parse_term(inner.next().unwrap())?; + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + Node { + operation: Operation::CallFunc { func }, + inputs, + outputs, + params: &[], + regions: &[], + meta, + } + } + + Rule::node_define_alias => { + let decl = self.parse_alias_header(inner.next().unwrap())?; + let value = self.parse_term(inner.next().unwrap())?; + let meta = self.parse_meta(&mut inner)?; + Node { + operation: Operation::DefineAlias { decl, value }, + inputs: &[], + outputs: &[], + params: &[], + regions: &[], + meta, + } + } + + Rule::node_declare_alias => { + let decl = self.parse_alias_header(inner.next().unwrap())?; + let meta = self.parse_meta(&mut inner)?; + Node { + operation: Operation::DeclareAlias { decl }, + inputs: &[], + outputs: &[], + params: &[], + regions: &[], + meta, + } + } + + Rule::node_custom => { + let op = inner.next().unwrap(); + debug_assert!(matches!( + op.as_rule(), + Rule::term_apply | Rule::term_apply_full + )); + let op_rule = op.as_rule(); + let mut op_inner = op.into_inner(); + + let name = GlobalRef::Named(self.parse_symbol(&mut op_inner)?); + + let mut params = Vec::new(); + + for token in filter_rule(&mut inner, Rule::term) { + params.push(self.parse_term(token)?); + } + + let operation = match op_rule { + Rule::term_apply_full => Operation::CustomFull { name }, + Rule::term_apply => Operation::Custom { name }, + _ => unreachable!(), + }; + + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + let regions = self.parse_regions(&mut inner)?; + Node { + operation, + inputs, + outputs, + params: self.bump.alloc_slice_copy(¶ms), + regions, + meta, + } + } + + Rule::node_tail_loop => { + let inputs = self.module.insert_term(Term::Wildcard); + let outputs = self.module.insert_term(Term::Wildcard); + let rest = self.module.insert_term(Term::Wildcard); + let extensions = self.module.insert_term(Term::Wildcard); + let operation = Operation::TailLoop { + inputs, + outputs, + rest, + extensions, + }; + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + let regions = self.parse_regions(&mut inner)?; + Node { + operation, + inputs, + outputs, + params: &[], + regions, + meta, + } + } + + Rule::node_cond => { + let cases = self.module.insert_term(Term::Wildcard); + let context = self.module.insert_term(Term::Wildcard); + let outputs = self.module.insert_term(Term::Wildcard); + let extensions = self.module.insert_term(Term::Wildcard); + let operation = Operation::Conditional { + cases, + context, + outputs, + extensions, + }; + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + let regions = self.parse_regions(&mut inner)?; + Node { + operation, + inputs, + outputs, + params: &[], + regions, + meta, + } + } + _ => unreachable!(), + }; + + let node_id = self.module.insert_node(node); + + Ok(node_id) + } + + fn parse_regions(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [RegionId]> { + let mut regions = Vec::new(); + for pair in filter_rule(pairs, Rule::region) { + regions.push(self.parse_region(pair)?); + } + Ok(self.bump.alloc_slice_copy(®ions)) + } + + fn parse_region(&mut self, pair: Pair<'a, Rule>) -> ParseResult { + debug_assert!(matches!(pair.as_rule(), Rule::region)); + let pair = pair.into_inner().next().unwrap(); + let rule = pair.as_rule(); + let mut inner = pair.into_inner(); + + let kind = match rule { + Rule::region_cfg => RegionKind::ControlFlow, + Rule::region_dfg => RegionKind::DataFlow, + _ => unreachable!(), + }; + + let sources = self.parse_port_list(&mut inner)?; + let targets = self.parse_port_list(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + let children = self.parse_nodes(&mut inner)?; + + Ok(self.module.insert_region(Region { + kind, + sources, + targets, + children, + meta, + })) + + // TODO: Attach region meta + } + + fn parse_nodes(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [NodeId]> { + let mut nodes = Vec::new(); + + for pair in filter_rule(pairs, Rule::node) { + nodes.push(self.parse_node(pair)?); + } + + Ok(self.bump.alloc_slice_copy(&nodes)) + } + + fn parse_func_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a FuncDecl<'a>> { + debug_assert!(matches!(pair.as_rule(), Rule::func_header)); + + let mut inner = pair.into_inner(); + let name = self.parse_symbol(&mut inner)?; + let params = self.parse_params(&mut inner)?; + + let inputs = self.parse_term(inner.next().unwrap())?; + let outputs = self.parse_term(inner.next().unwrap())?; + + // TODO: This is subtly broken: + let extensions = match inner.peek().map(|p| p.as_rule()) { + Some(Rule::term_ext_set) => self.parse_term(inner.next().unwrap())?, + _ => self.module.insert_term(Term::ExtSet { + extensions: &[], + rest: None, + }), + }; + + // Assemble the inputs, outputs and extensions into a function type. + let func = self.module.insert_term(Term::FuncType { + inputs, + outputs, + extensions, + }); + + Ok(self.bump.alloc(FuncDecl { name, params, func })) + } + + fn parse_alias_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a AliasDecl<'a>> { + debug_assert!(matches!(pair.as_rule(), Rule::alias_header)); + + let mut inner = pair.into_inner(); + let name = self.parse_symbol(&mut inner)?; + let params = self.parse_params(&mut inner)?; + let r#type = self.parse_term(inner.next().unwrap())?; + + Ok(self.bump.alloc(AliasDecl { + name, + params, + r#type, + })) + } + + fn parse_params(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [Param<'a>]> { + let mut params = Vec::new(); + + for pair in filter_rule(pairs, Rule::param) { + let param = pair.into_inner().next().unwrap(); + + let param = match param.as_rule() { + Rule::param_implicit => { + let mut inner = param.into_inner(); + let name = &inner.next().unwrap().as_str()[1..]; + let r#type = self.parse_term(inner.next().unwrap())?; + Param::Implicit { name, r#type } + } + Rule::param_explicit => { + let mut inner = param.into_inner(); + let name = &inner.next().unwrap().as_str()[1..]; + let r#type = self.parse_term(inner.next().unwrap())?; + Param::Explicit { name, r#type } + } + Rule::param_constraint => { + let mut inner = param.into_inner(); + let constraint = self.parse_term(inner.next().unwrap())?; + Param::Constraint { constraint } + } + _ => unreachable!(), + }; + + params.push(param); + } + + Ok(self.bump.alloc_slice_copy(¶ms)) + } + + fn parse_port_list(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [Port<'a>]> { + let Some(Rule::port_list) = pairs.peek().map(|p| p.as_rule()) else { + return Ok(&[]); + }; + + let pair = pairs.next().unwrap(); + let mut inner = pair.into_inner(); + let mut ports = Vec::new(); + + while let Some(token) = inner.next() { + let port = self.parse_port(token)?; + ports.push(port); + } + + Ok(self.bump.alloc_slice_copy(&ports)) + } + + fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult> { + debug_assert!(matches!(pair.as_rule(), Rule::port)); + + let mut inner = pair.into_inner(); + + let link = LinkRef::Named(inner.next().unwrap().as_str()); + + let mut r#type = None; + let mut meta = &[] as &[MetaItem<'a>]; + + if inner.peek().is_some() { + r#type = Some(self.parse_term(inner.next().unwrap())?); + meta = self.parse_meta(&mut inner)?; + } + + Ok(Port { link, r#type, meta }) + } + + fn parse_meta(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [MetaItem<'a>]> { + let mut items = Vec::new(); + + for meta in filter_rule(pairs, Rule::meta) { + let mut inner = meta.into_inner(); + let name = self.parse_symbol(&mut inner)?; + let value = self.parse_term(inner.next().unwrap())?; + items.push(MetaItem { name, value }) + } + + Ok(self.bump.alloc_slice_copy(&items)) + } + + fn parse_symbol(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a str> { + let pair = pairs.next().unwrap(); + if let Rule::symbol = pair.as_rule() { + Ok(pair.as_str()) + } else { + unreachable!("expected a symbol"); + } + } +} + +/// Draw from a pest pair iterator only the pairs that match a given rule. +/// +/// This is similar to a `take_while`, except that it does not take the iterator +/// by value and so lets us continue using it after the filter. +#[inline] +fn filter_rule<'a, 'i, R: RuleType>( + pairs: &'a mut Pairs<'i, R>, + rule: R, +) -> impl Iterator> + 'a { + std::iter::from_fn(move || { + let peek = pairs.peek()?; + if peek.as_rule() == rule { + Some(pairs.next().unwrap()) + } else { + None + } + }) +} + +/// An error that occurred during parsing. +#[derive(Debug, Clone, Error)] +#[error("{0}")] +pub struct ParseError(pest::error::Error); + +impl ParseError { + /// Line of the error in the input string. + pub fn line(&self) -> usize { + use pest::error::LineColLocation; + match self.0.line_col { + LineColLocation::Pos((line, _)) => line, + LineColLocation::Span((line, _), _) => line, + } + } + + /// Column of the error in the input string. + pub fn column(&self) -> usize { + use pest::error::LineColLocation; + match self.0.line_col { + LineColLocation::Pos((_, col)) => col, + LineColLocation::Span((_, col), _) => col, + } + } + + /// Location of the error in the input string in bytes. + pub fn location(&self) -> usize { + use pest::error::InputLocation; + match self.0.location { + InputLocation::Pos(offset) => offset, + InputLocation::Span((offset, _)) => offset, + } + } +} + +// NOTE: `ParseError` does not implement `From>` so that +// pest does not become part of the public API. + +type ParseResult = Result; diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs new file mode 100644 index 000000000..2721760b5 --- /dev/null +++ b/hugr-model/src/v0/text/print.rs @@ -0,0 +1,620 @@ +use pretty::{Arena, DocAllocator, RefDoc}; +use std::borrow::Cow; + +use crate::v0::{ + GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, Port, + RegionId, RegionKind, Term, TermId, +}; + +type PrintError = ModelError; +type PrintResult = Result; + +// TODO: Print tail-loop nodes +// TODO: Print conditional and case nodes + +/// Pretty-print a module to a string. +pub fn print_to_string(module: &Module, width: usize) -> PrintResult { + let arena = Arena::new(); + let doc = PrintContext::create_doc(&arena, module)?; + let mut out = String::new(); + let _ = doc.render_fmt(width, &mut out); + Ok(out) +} + +struct PrintContext<'p, 'a: 'p> { + /// The arena in which to allocate the pretty-printed documents. + arena: &'p Arena<'p>, + /// The module to be printed. + module: &'a Module<'a>, + /// Parts of the document to be concatenated. + docs: Vec>, + /// Stack of indices into `docs` denoting the current nesting. + docs_stack: Vec, + /// The names of local variables that are in scope. + locals: Vec<&'a str>, +} + +impl<'p, 'a: 'p> PrintContext<'p, 'a> { + fn create_doc(arena: &'p Arena<'p>, module: &'a Module) -> PrintResult> { + let mut this = Self { + arena, + module, + docs: Vec::new(), + docs_stack: Vec::new(), + locals: Vec::new(), + }; + + this.print_parens(|this| { + this.print_text("hugr"); + this.print_text("0"); + }); + + this.print_root()?; + Ok(this.finish()) + } + + fn finish(self) -> RefDoc<'p> { + let sep = self + .arena + .concat([self.arena.hardline(), self.arena.hardline()]); + self.arena.intersperse(self.docs, sep).into_doc() + } + + fn print_text(&mut self, text: impl Into>) { + self.docs.push(self.arena.text(text).into_doc()); + } + + /// Print a delimited sequence of elements. + /// + /// See [`print_group`], [`print_parens`], and [`print_brackets`]. + fn print_delimited( + &mut self, + start: &'static str, + end: &'static str, + nesting: isize, + f: impl FnOnce(&mut Self) -> T, + ) -> T { + self.docs_stack.push(self.docs.len()); + let result = f(self); + let docs = self.docs.drain(self.docs_stack.pop().unwrap()..); + let doc = self.arena.concat([ + self.arena.text(start), + self.arena + .intersperse(docs, self.arena.line()) + .nest(nesting) + .group(), + self.arena.text(end), + ]); + self.docs.push(doc.into_doc()); + result + } + + /// Print a sequence of elements that are preferrably laid out together in the same line. + #[inline] + fn print_group(&mut self, f: impl FnOnce(&mut Self) -> T) -> T { + self.print_delimited("", "", 0, f) + } + + /// Print a sequence of elements in a parenthesized list. + #[inline] + fn print_parens(&mut self, f: impl FnOnce(&mut Self) -> T) -> T { + self.print_delimited("(", ")", 2, f) + } + + /// Print a sequence of elements in a bracketed list. + #[inline] + fn print_brackets(&mut self, f: impl FnOnce(&mut Self) -> T) -> T { + self.print_delimited("[", "]", 1, f) + } + + fn print_root(&mut self) -> PrintResult<()> { + let root_id = self.module.root; + let root_data = self + .module + .get_region(root_id) + .ok_or_else(|| PrintError::RegionNotFound(root_id))?; + + self.print_meta(root_data.meta)?; + self.print_nodes(root_id)?; + Ok(()) + } + + fn with_local_scope( + &mut self, + params: &'a [Param<'a>], + f: impl FnOnce(&mut Self) -> PrintResult, + ) -> PrintResult { + let locals = std::mem::take(&mut self.locals); + + for param in params { + match param { + Param::Implicit { name, .. } => self.locals.push(name), + Param::Explicit { name, .. } => self.locals.push(name), + Param::Constraint { .. } => {} + } + } + + let result = f(self); + self.locals = locals; + result + } + + fn print_node(&mut self, node_id: NodeId) -> PrintResult<()> { + let node_data = self + .module + .get_node(node_id) + .ok_or_else(|| PrintError::NodeNotFound(node_id))?; + + self.print_parens(|this| match &node_data.operation { + Operation::Dfg => { + this.print_group(|this| { + this.print_text("dfg"); + this.print_port_list(&node_data.inputs)?; + this.print_port_list(&node_data.outputs) + })?; + this.print_meta(node_data.meta)?; + this.print_regions(&node_data.regions) + } + Operation::Cfg => { + this.print_group(|this| { + this.print_text("cfg"); + this.print_port_list(&node_data.inputs)?; + this.print_port_list(&node_data.outputs) + })?; + this.print_meta(node_data.meta)?; + this.print_regions(&node_data.regions) + } + Operation::Block => { + this.print_group(|this| { + this.print_text("block"); + this.print_port_list(&node_data.inputs)?; + this.print_port_list(&node_data.outputs) + })?; + this.print_meta(node_data.meta)?; + this.print_regions(&node_data.regions) + } + + Operation::DefineFunc { decl } => this.with_local_scope(decl.params, |this| { + this.print_group(|this| { + this.print_text("define-func"); + this.print_text(decl.name); + }); + + for param in decl.params { + this.print_param(*param)?; + } + + match self.module.get_term(decl.func) { + Some(Term::FuncType { + inputs, + outputs, + extensions, + }) => { + this.print_group(|this| { + this.print_term(*inputs)?; + this.print_term(*outputs)?; + this.print_term(*extensions) + })?; + } + Some(_) => return Err(PrintError::TypeError(decl.func)), + None => return Err(PrintError::TermNotFound(decl.func)), + } + + this.print_meta(node_data.meta)?; + this.print_regions(&node_data.regions) + }), + + Operation::DeclareFunc { decl } => this.with_local_scope(decl.params, |this| { + this.print_group(|this| { + this.print_text("declare-func"); + this.print_text(decl.name); + }); + + for param in decl.params { + this.print_param(*param)?; + } + + match self.module.get_term(decl.func) { + Some(Term::FuncType { + inputs, + outputs, + extensions, + }) => { + this.print_group(|this| { + this.print_term(*inputs)?; + this.print_term(*outputs)?; + this.print_term(*extensions) + })?; + } + Some(_) => return Err(PrintError::TypeError(decl.func)), + None => return Err(PrintError::TermNotFound(decl.func)), + } + + this.print_meta(node_data.meta)?; + Ok(()) + }), + + Operation::CallFunc { func } => { + this.print_group(|this| { + this.print_text("call"); + this.print_term(*func)?; + this.print_port_list(&node_data.inputs)?; + this.print_port_list(&node_data.outputs) + })?; + this.print_meta(node_data.meta)?; + Ok(()) + } + + Operation::LoadFunc { func } => { + this.print_group(|this| { + this.print_text("load-func"); + this.print_term(*func)?; + this.print_port_list(&node_data.inputs)?; + this.print_port_list(&node_data.outputs) + })?; + this.print_meta(node_data.meta)?; + Ok(()) + } + + Operation::Custom { name } => { + this.print_group(|this| { + if node_data.params.is_empty() { + this.print_global_ref(*name)?; + } else { + this.print_parens(|this| { + this.print_global_ref(*name)?; + + for param in node_data.params { + this.print_term(*param)?; + } + + Ok(()) + })?; + } + + this.print_port_list(&node_data.inputs)?; + this.print_port_list(&node_data.outputs) + })?; + this.print_meta(node_data.meta)?; + this.print_regions(&node_data.regions) + } + + Operation::CustomFull { name } => { + this.print_group(|this| { + this.print_parens(|this| { + this.print_text("@"); + this.print_global_ref(*name)?; + + for param in node_data.params { + this.print_term(*param)?; + } + + Ok(()) + })?; + + this.print_port_list(&node_data.inputs)?; + this.print_port_list(&node_data.outputs) + })?; + this.print_meta(node_data.meta)?; + this.print_regions(&node_data.regions) + } + + Operation::DefineAlias { decl, value } => this.with_local_scope(decl.params, |this| { + this.print_group(|this| { + this.print_text("define-alias"); + this.print_text(decl.name); + }); + + for param in decl.params { + this.print_param(*param)?; + } + + this.print_term(decl.r#type)?; + this.print_term(*value)?; + this.print_meta(node_data.meta)?; + Ok(()) + }), + Operation::DeclareAlias { decl } => this.with_local_scope(decl.params, |this| { + this.print_group(|this| { + this.print_text("declare-alias"); + this.print_text(decl.name); + }); + + for param in decl.params { + this.print_param(*param)?; + } + + this.print_term(decl.r#type)?; + this.print_meta(node_data.meta)?; + Ok(()) + }), + + Operation::TailLoop { + inputs, + outputs, + rest, + extensions, + } => todo!(), + Operation::Conditional { + cases, + context, + outputs, + extensions, + } => todo!(), + }) + } + + fn print_regions(&mut self, regions: &'a [RegionId]) -> PrintResult<()> { + for region in regions { + self.print_region(*region)?; + } + Ok(()) + } + + fn print_region(&mut self, region: RegionId) -> PrintResult<()> { + let region_data = self + .module + .get_region(region) + .ok_or(PrintError::RegionNotFound(region))?; + + self.print_parens(|this| { + match region_data.kind { + RegionKind::DataFlow => { + this.print_text("dfg"); + } + RegionKind::ControlFlow => { + this.print_text("cfg"); + } + }; + + if !region_data.sources.is_empty() || !region_data.targets.is_empty() { + this.print_port_list(®ion_data.sources)?; + this.print_port_list(®ion_data.targets)?; + } + + this.print_meta(region_data.meta)?; + this.print_nodes(region) + }) + } + + fn print_nodes(&mut self, region: RegionId) -> PrintResult<()> { + let region_data = self + .module + .get_region(region) + .ok_or(PrintError::RegionNotFound(region))?; + + for node_id in region_data.children { + self.print_node(*node_id)?; + } + + Ok(()) + } + + fn print_port_list(&mut self, ports: &'a [Port<'a>]) -> PrintResult<()> { + self.print_brackets(|this| { + for port in ports { + if port.r#type.is_some() || !port.meta.is_empty() { + this.print_parens(|this| { + this.print_link_ref(port.link); + + match port.r#type { + Some(r#type) => this.print_term(r#type)?, + None => this.print_text("_"), + }; + + this.print_meta(port.meta)?; + Ok(()) + })?; + } else { + this.print_link_ref(port.link); + } + } + + Ok(()) + }) + } + + fn print_link_ref(&mut self, link_ref: LinkRef<'a>) { + match link_ref { + LinkRef::Id(link_id) => self.print_text(format!("%{}", link_id.0)), + LinkRef::Named(name) => self.print_text(format!("%{}", name)), + } + } + + fn print_param(&mut self, param: Param<'a>) -> PrintResult<()> { + self.print_parens(|this| match param { + Param::Implicit { name, r#type } => { + this.print_text("forall"); + this.print_text(format!("?{}", name)); + this.print_term(r#type) + } + Param::Explicit { name, r#type } => { + this.print_text("param"); + this.print_text(format!("?{}", name)); + this.print_term(r#type) + } + Param::Constraint { constraint } => { + this.print_text("where"); + this.print_term(constraint) + } + }) + } + + fn print_term(&mut self, term_id: TermId) -> PrintResult<()> { + let term_data = self + .module + .get_term(term_id) + .ok_or_else(|| PrintError::TermNotFound(term_id))?; + + match term_data { + Term::Wildcard => { + self.print_text("_"); + Ok(()) + } + Term::Type => { + self.print_text("type"); + Ok(()) + } + Term::StaticType => { + self.print_text("static"); + Ok(()) + } + Term::Constraint => { + self.print_text("constraint"); + Ok(()) + } + Term::Var(local_ref) => self.print_local_ref(*local_ref), + Term::Apply { name, args } => { + if args.is_empty() { + self.print_global_ref(*name)?; + } else { + self.print_parens(|this| { + this.print_global_ref(*name)?; + for arg in args.iter() { + this.print_term(*arg)?; + } + Ok(()) + })?; + } + + Ok(()) + } + Term::ApplyFull { name, args } => self.print_parens(|this| { + this.print_text("@"); + this.print_global_ref(*name)?; + for arg in args.iter() { + this.print_term(*arg)?; + } + + Ok(()) + }), + Term::Quote { r#type } => self.print_parens(|this| { + this.print_text("quote"); + this.print_term(*r#type) + }), + Term::List { items, tail } => self.print_brackets(|this| { + for item in items.iter() { + this.print_term(*item)?; + } + if let Some(tail) = tail { + this.print_text("."); + this.print_term(*tail)?; + } + Ok(()) + }), + Term::ListType { item_type } => self.print_parens(|this| { + this.print_text("list"); + this.print_term(*item_type) + }), + Term::Str(str) => { + // TODO: escape + self.print_text("\""); + self.print_text(str.as_ref()); + self.print_text("\""); + Ok(()) + } + Term::StrType => { + self.print_text("str"); + Ok(()) + } + Term::Nat(n) => { + self.print_text(n.to_string()); + Ok(()) + } + Term::NatType => { + self.print_text("nat"); + Ok(()) + } + Term::ExtSet { extensions, rest } => self.print_parens(|this| { + this.print_text("ext"); + for extension in *extensions { + this.print_text(*extension); + } + if let Some(rest) = rest { + this.print_text("."); + this.print_term(*rest)?; + } + Ok(()) + }), + Term::ExtSetType => { + self.print_text("ext-set"); + Ok(()) + } + Term::Adt { variants } => self.print_parens(|this| { + this.print_text("adt"); + this.print_term(*variants) + }), + Term::FuncType { + inputs, + outputs, + extensions, + } => self.print_parens(|this| { + this.print_text("fn"); + this.print_term(*inputs)?; + this.print_term(*outputs)?; + this.print_term(*extensions) + }), + Term::Control { values } => self.print_parens(|this| { + this.print_text("ctrl"); + this.print_term(*values) + }), + Term::ControlType => { + self.print_text("ctrl"); + Ok(()) + } + } + } + + fn print_local_ref(&mut self, local_ref: LocalRef<'a>) -> PrintResult<()> { + let name = match local_ref { + LocalRef::Index(i) => { + let Some(name) = self.locals.get(i as usize) else { + return Err(PrintError::InvalidLocal(local_ref.to_string())); + }; + + name + } + LocalRef::Named(name) => name, + }; + + self.print_text(format!("?{}", name)); + Ok(()) + } + + fn print_global_ref(&mut self, global_ref: GlobalRef<'a>) -> PrintResult<()> { + match global_ref { + GlobalRef::Direct(node_id) => { + let node_data = self + .module + .get_node(node_id) + .ok_or_else(|| PrintError::NodeNotFound(node_id))?; + + let name = match &node_data.operation { + Operation::DefineFunc { decl } => decl.name, + Operation::DeclareFunc { decl } => decl.name, + Operation::DefineAlias { decl, .. } => decl.name, + Operation::DeclareAlias { decl } => decl.name, + _ => return Err(PrintError::UnexpectedOperation(node_id)), + }; + + self.print_text(name) + } + + GlobalRef::Named(symbol) => self.print_text(symbol), + } + + Ok(()) + } + + fn print_meta(&mut self, meta: &'a [MetaItem<'a>]) -> PrintResult<()> { + for item in meta { + self.print_parens(|this| { + this.print_text("meta"); + this.print_text(item.name); + this.print_term(item.value) + })?; + } + + Ok(()) + } +} From 8f3773f39c3cd1bc3e851cd50c13b5f2dcfb17a1 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 12:22:14 +0100 Subject: [PATCH 02/26] Import CFGs. --- hugr-core/src/import.rs | 177 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 169 insertions(+), 8 deletions(-) diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index a09803b2d..3099b6bbe 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -8,13 +8,14 @@ use crate::{ extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError}, hugr::{HugrMut, IdentList}, ops::{ - AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, FuncDecl, FuncDefn, Input, - LoadFunction, Module, OpType, OpaqueOp, Output, TailLoop, DFG, + AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, DataflowBlock, ExitBlock, + FuncDecl, FuncDefn, Input, LoadFunction, Module, OpType, OpaqueOp, Output, Tag, TailLoop, + CFG, DFG, }, types::{ type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, NoRV, - PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, TypeArg, TypeBase, TypeBound, - TypeRow, + PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, Type, TypeArg, TypeBase, TypeBound, + TypeEnum, TypeRow, TypeRowRV, }, Direction, Hugr, HugrView, Node, Port, }; @@ -384,15 +385,24 @@ impl<'a> Context<'a> { }; self.import_dfg_region(node_id, *region, node)?; - Ok(node) } - // TODO: Implement support for importing control flow graphs. - model::Operation::Cfg => Err(error_unsupported!("`cfg` nodes")), + model::Operation::Cfg => { + let signature = self.get_node_signature(node_id)?; + let optype = OpType::CFG(CFG { signature }); + let node = self.make_node(node_id, optype, parent)?; - model::Operation::Block => Err(model::ModelError::UnexpectedOperation(node_id).into()), + let [region] = node_data.regions else { + return Err(model::ModelError::InvalidRegions(node_id).into()); + }; + self.import_cfg_region(node_id, *region, node)?; + Ok(node) + } + + model::Operation::Block => self.import_cfg_block(node_id, parent), + // Err(model::ModelError::UnexpectedOperation(node_id).into())}, model::Operation::DefineFunc { decl } => { self.import_poly_func_type(*decl, |ctx, signature| { let optype = OpType::FuncDefn(FuncDefn { @@ -637,6 +647,157 @@ impl<'a> Context<'a> { Ok(()) } + /// Create the entry block for a control flow region. + /// + /// Since the core hugr does not have explicit entry blocks yet, we create a dataflow block + /// that simply forwards its inputs to its outputs. + fn make_entry_node( + &mut self, + parent: Node, + ports: &'a [model::Port<'a>], + ) -> Result { + let types = self.get_port_types(ports)?; + + let node = self.hugr.add_node_with_parent( + parent, + OpType::DataflowBlock(DataflowBlock { + inputs: types.clone(), + other_outputs: TypeRow::default(), + sum_rows: vec![types.clone()], + extension_delta: ExtensionSet::default(), + }), + ); + + let node_input = self.hugr.add_node_with_parent( + node, + OpType::Input(Input { + types: types.clone(), + }), + ); + + let node_output = self.hugr.add_node_with_parent( + node, + OpType::Output(Output { + types: vec![Type::new_sum([types.clone()])].into(), + }), + ); + + let node_tag = self.hugr.add_node_with_parent( + node, + OpType::Tag(Tag { + tag: 0, + variants: vec![types], + }), + ); + + // Connect the input node to the tag node + let input_outputs = self.hugr.node_outputs(node_input); + let tag_inputs = self.hugr.node_inputs(node_tag); + + for (a, b) in input_outputs.zip(tag_inputs) { + self.hugr.connect(node_input, a, node_tag, b); + } + + // Connect the tag node to the output node + let tag_outputs = self.hugr.node_outputs(node_tag); + let output_inputs = self.hugr.node_inputs(node_output); + + for (a, b) in tag_outputs.zip(output_inputs) { + self.hugr.connect(node_tag, a, node_output, b); + } + + Ok(node) + } + + fn make_exit_node( + &mut self, + parent: Node, + ports: &'a [model::Port<'a>], + ) -> Result { + let cfg_outputs = self.get_port_types(ports)?; + let node = self + .hugr + .add_node_with_parent(parent, OpType::ExitBlock(ExitBlock { cfg_outputs })); + self.record_links(node, Direction::Outgoing, ports); + Ok(node) + } + + fn import_cfg_region( + &mut self, + node_id: model::NodeId, + region: model::RegionId, + node: Node, + ) -> Result<(), ImportError> { + let region_data = self.get_region(region)?; + + if !matches!(region_data.kind, model::RegionKind::DataFlow) { + return Err(model::ModelError::InvalidRegions(node_id).into()); + } + + let node_entry = self.make_entry_node(node, region_data.sources)?; + + for child in region_data.children { + self.import_node(*child, node)?; + } + + let node_exit = self.make_exit_node(node, region_data.targets)?; + + let entry_outputs = self.hugr.node_outputs(node_entry); + let first_block = self.hugr.children(node).nth(1).unwrap(); + let first_block_inputs = self.hugr.node_inputs(first_block); + + for (a, b) in entry_outputs.zip(first_block_inputs) { + self.hugr.connect(node_entry, a, node_exit, b); + } + + Ok(()) + } + + fn import_cfg_block( + &mut self, + node_id: model::NodeId, + parent: Node, + ) -> Result { + let node_data = self.get_node(node_id)?; + assert!(matches!(node_data.operation, model::Operation::Block)); + + let [region] = node_data.regions else { + return Err(model::ModelError::InvalidRegions(node_id).into()); + }; + let region_data = self.get_region(*region)?; + + let inputs = self.get_port_types(region_data.sources)?; + + let Some((targets_first, targets_rest)) = region_data.targets.split_first() else { + return Err(model::ModelError::InvalidRegions(node_id).into()); + }; + + let sum_rows: Vec<_> = { + let Some(term) = targets_first.r#type else { + return Err(error_uninferred!("port type")); + }; + + let model::Term::Adt { variants } = self.get_term(term)? else { + return Err(model::ModelError::TypeError(term).into()); + }; + + self.import_type_rows(*variants)? + }; + + let other_outputs = self.get_port_types(targets_rest)?; + + let optype = OpType::DataflowBlock(DataflowBlock { + inputs, + other_outputs, + sum_rows, + extension_delta: ExtensionSet::new(), + }); + let node = self.make_node(node_id, optype, parent)?; + + self.import_dfg_region(node_id, *region, node)?; + Ok(node) + } + fn import_poly_func_type( &mut self, decl: model::FuncDecl<'a>, From 4e49b902eae9b6bed9c7d7b626a74c6c62ac47f7 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 12:26:50 +0100 Subject: [PATCH 03/26] Made the model import/export test into an insta test. --- devenv.nix | 1 + hugr-core/tests/model.rs | 3 +- .../tests/snapshots/model__import_export.snap | 37 +++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 hugr-core/tests/snapshots/model__import_export.snap diff --git a/devenv.nix b/devenv.nix index b7458a6a2..92c31d948 100644 --- a/devenv.nix +++ b/devenv.nix @@ -21,6 +21,7 @@ in # cargo-llvm-cov is currently marked broken on nixpkgs unstable pkgs-stable.cargo-llvm-cov pkgs.graphviz + pkgs.cargo-insta ] ++ lib.optionals pkgs.stdenv.isDarwin (with pkgs.darwin.apple_sdk; [ diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 2cefca464..658a60fd1 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -9,5 +9,6 @@ pub fn test_import_export() { let extensions = std_reg(); let hugr = import_hugr(&parsed_module.module, &extensions).unwrap(); let roundtrip = export_hugr(&hugr, &bump); - panic!("{}:", model::text::print_to_string(&roundtrip, 80).unwrap()); + let roundtrip_str = model::text::print_to_string(&roundtrip, 80).unwrap(); + insta::assert_snapshot!(roundtrip_str); } diff --git a/hugr-core/tests/snapshots/model__import_export.snap b/hugr-core/tests/snapshots/model__import_export.snap new file mode 100644 index 000000000..8bc728ecc --- /dev/null +++ b/hugr-core/tests/snapshots/model__import_export.snap @@ -0,0 +1,37 @@ +--- +source: hugr-core/tests/model.rs +expression: roundtrip_str +--- +(hugr 0) + +(define-alias local.int type (@ arithmetic.int.types.int)) + +(define-func example.add + [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext) + (dfg + [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))] + (arithmetic.int.iadd + [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))]))) + +(declare-func example.callee + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)) + +(define-func example.caller + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + (dfg + [(%3 (@ arithmetic.int.types.int))] + [(%4 (@ arithmetic.int.types.int))] + (call + (@ example.callee) + [(%3 (@ arithmetic.int.types.int))] + [(%4 (@ arithmetic.int.types.int))]))) + +(define-func example.swap + (forall ?0 type) + (forall ?1 type) + [?0 ?1] [?1 ?0] (ext) + (dfg [(%5 ?0) (%6 ?1)] [(%6 ?1) (%5 ?0)])) From b7305fc40047a4c1034323ea360c3637a8e4e9eb Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 12:55:08 +0100 Subject: [PATCH 04/26] Import tail-loop and cond by inferring types from port types. --- hugr-core/src/export.rs | 18 +-- hugr-core/src/import.rs | 197 ++++++++++++++++++-------------- hugr-model/src/v0/mod.rs | 44 +------ hugr-model/src/v0/text/parse.rs | 31 +---- hugr-model/src/v0/text/print.rs | 29 ++--- 5 files changed, 139 insertions(+), 180 deletions(-) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index d612d4270..1058cfe33 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -297,12 +297,7 @@ impl<'a> Context<'a> { OpType::TailLoop(op) => { regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); - model::Operation::TailLoop { - inputs: self.export_type_row(&op.just_inputs), - outputs: self.export_type_row(&op.just_outputs), - rest: self.export_type_row(&op.rest), - extensions: self.export_ext_set(&op.extension_delta), - } + model::Operation::TailLoop } OpType::Conditional(op) => { @@ -314,12 +309,7 @@ impl<'a> Context<'a> { tail: None, }; regions = self.export_conditional_regions(node); - model::Operation::Conditional { - cases: self.module.insert_term(sum_rows), - context: self.export_type_row(&op.other_inputs), - outputs: self.export_type_row(&op.outputs), - extensions: self.export_ext_set(&op.extension_delta), - } + model::Operation::Conditional } // Opaque/extension operations should in the future support having multiple different @@ -544,7 +534,9 @@ impl<'a> Context<'a> { match t { TypeArg::Type { ty } => self.export_type(ty), TypeArg::BoundedNat { n } => self.module.insert_term(model::Term::Nat(*n)), - TypeArg::String { arg } => self.module.insert_term(model::Term::Str(arg.into())), + TypeArg::String { arg } => self + .module + .insert_term(model::Term::Str(self.bump.alloc_str(arg))), TypeArg::Sequence { elems } => { // For now we assume that the sequence is meant to be a list. let items = self diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 3099b6bbe..243385555 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -15,7 +15,7 @@ use crate::{ types::{ type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, NoRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, Type, TypeArg, TypeBase, TypeBound, - TypeEnum, TypeRow, TypeRowRV, + TypeRow, }, Direction, Hugr, HugrView, Node, Port, }; @@ -402,7 +402,7 @@ impl<'a> Context<'a> { } model::Operation::Block => self.import_cfg_block(node_id, parent), - // Err(model::ModelError::UnexpectedOperation(node_id).into())}, + model::Operation::DefineFunc { decl } => { self.import_poly_func_type(*decl, |ctx, signature| { let optype = OpType::FuncDefn(FuncDefn { @@ -478,70 +478,8 @@ impl<'a> Context<'a> { self.make_node(node_id, optype, parent) } - model::Operation::TailLoop { - inputs, - outputs, - rest, - extensions, - } => { - let just_inputs = self.import_type_row(inputs)?; - let just_outputs = self.import_type_row(outputs)?; - let rest = self.import_type_row(rest)?; - let extension_delta = self.import_extension_set(extensions)?; - - let optype = OpType::TailLoop(TailLoop { - just_inputs, - just_outputs, - rest, - extension_delta, - }); - - let node = self.make_node(node_id, optype, parent)?; - - let [region] = node_data.regions else { - return Err(model::ModelError::InvalidRegions(node_id).into()); - }; - - self.import_dfg_region(node_id, *region, node)?; - Ok(node) - } - - model::Operation::Conditional { - cases, - context, - outputs, - extensions, - } => { - let sum_rows: Vec<_> = self.import_type_rows(cases)?; - let other_inputs = self.import_type_row(context)?; - let outputs = self.import_type_row(outputs)?; - let extension_delta = self.import_extension_set(extensions)?; - - let optype = OpType::Conditional(Conditional { - sum_rows, - other_inputs, - outputs, - extension_delta, - }); - - let node = self.make_node(node_id, optype, parent)?; - - for region in node_data.regions { - let region_data = self.get_region(*region)?; - - let source_types = self.get_port_types(region_data.sources)?; - let target_types = self.get_port_types(region_data.targets)?; - let signature = FuncTypeBase::new(source_types, target_types); - - let case_node = self - .hugr - .add_node_with_parent(node, OpType::Case(Case { signature })); - - self.import_dfg_region(node_id, *region, case_node)?; - } - - Ok(node) - } + model::Operation::TailLoop => self.import_tail_loop(node_id, parent), + model::Operation::Conditional => self.import_conditional(node_id, parent), model::Operation::CustomFull { name: GlobalRef::Named(name), @@ -647,6 +585,113 @@ impl<'a> Context<'a> { Ok(()) } + fn import_adt_and_rest( + &mut self, + node_id: model::NodeId, + ports: &'a [model::Port<'a>], + ) -> Result<(Vec, TypeRow), ImportError> { + let Some((first, rest)) = ports.split_first() else { + return Err(model::ModelError::InvalidRegions(node_id).into()); + }; + + let sum_rows: Vec<_> = { + let Some(term) = first.r#type else { + return Err(error_uninferred!("port type")); + }; + + let model::Term::Adt { variants } = self.get_term(term)? else { + return Err(model::ModelError::TypeError(term).into()); + }; + + self.import_type_rows(*variants)? + }; + + let rest = self.get_port_types(rest)?; + + Ok((sum_rows, rest)) + } + + fn import_tail_loop( + &mut self, + node_id: model::NodeId, + parent: Node, + ) -> Result { + let node_data = self.get_node(node_id)?; + assert!(matches!(node_data.operation, model::Operation::TailLoop)); + + let [region] = node_data.regions else { + return Err(model::ModelError::InvalidRegions(node_id).into()); + }; + let region_data = self.get_region(*region)?; + + let (sum_rows, rest) = self.import_adt_and_rest(node_id, region_data.targets)?; + + let (just_inputs, just_outputs) = { + let mut sum_rows = sum_rows.into_iter(); + + let term = region_data.targets[0].r#type.unwrap(); + + let Some(just_inputs) = sum_rows.next() else { + return Err(model::ModelError::TypeError(term).into()); + }; + + let Some(just_outputs) = sum_rows.next() else { + return Err(model::ModelError::TypeError(term).into()); + }; + + (just_inputs, just_outputs) + }; + + let optype = OpType::TailLoop(TailLoop { + just_inputs, + just_outputs, + rest, + extension_delta: ExtensionSet::new(), + }); + + let node = self.make_node(node_id, optype, parent)?; + + self.import_dfg_region(node_id, *region, node)?; + Ok(node) + } + + fn import_conditional( + &mut self, + node_id: model::NodeId, + parent: Node, + ) -> Result { + let node_data = self.get_node(node_id)?; + assert!(matches!(node_data.operation, model::Operation::Conditional)); + + let (sum_rows, other_inputs) = self.import_adt_and_rest(node_id, node_data.inputs)?; + let outputs = self.get_port_types(node_data.outputs)?; + + let optype = OpType::Conditional(Conditional { + sum_rows, + other_inputs, + outputs, + extension_delta: ExtensionSet::new(), + }); + + let node = self.make_node(node_id, optype, parent)?; + + for region in node_data.regions { + let region_data = self.get_region(*region)?; + + let source_types = self.get_port_types(region_data.sources)?; + let target_types = self.get_port_types(region_data.targets)?; + let signature = FuncTypeBase::new(source_types, target_types); + + let case_node = self + .hugr + .add_node_with_parent(node, OpType::Case(Case { signature })); + + self.import_dfg_region(node_id, *region, case_node)?; + } + + Ok(node) + } + /// Create the entry block for a control flow region. /// /// Since the core hugr does not have explicit entry blocks yet, we create a dataflow block @@ -765,26 +810,8 @@ impl<'a> Context<'a> { return Err(model::ModelError::InvalidRegions(node_id).into()); }; let region_data = self.get_region(*region)?; - let inputs = self.get_port_types(region_data.sources)?; - - let Some((targets_first, targets_rest)) = region_data.targets.split_first() else { - return Err(model::ModelError::InvalidRegions(node_id).into()); - }; - - let sum_rows: Vec<_> = { - let Some(term) = targets_first.r#type else { - return Err(error_uninferred!("port type")); - }; - - let model::Term::Adt { variants } = self.get_term(term)? else { - return Err(model::ModelError::TypeError(term).into()); - }; - - self.import_type_rows(*variants)? - }; - - let other_outputs = self.get_port_types(targets_rest)?; + let (sum_rows, other_outputs) = self.import_adt_and_rest(node_id, region_data.targets)?; let optype = OpType::DataflowBlock(DataflowBlock { inputs, diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 43bb4ade7..57bde9857 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -245,28 +245,7 @@ pub enum Operation<'a> { /// - **Outputs**: `outputs` + `rest` /// - **Sources**: `inputs` + `rest` /// - **Targets**: `(adt [inputs outputs])` + `rest` - TailLoop { - // TODO: These can be determined by the port types? - /// Types of the values that are passed as inputs to the loop, and are returned - /// by the loop body when the loop is continued. - /// - /// **Type**: `(list type)` - inputs: TermId, - /// Types of the values that are produced at the end of the loop body when the loop - /// should be ended. - /// - /// **Type**: `(list type)` - outputs: TermId, - /// Types of the values that are passed as inputs to the loop, to each iteration and - /// are then returned at the end of the loop. - /// - /// **Type**: `(list type)` - rest: TermId, - /// - /// - /// **Type**: `ext-set` - extensions: TermId, - }, + TailLoop, /// Conditional operation. /// @@ -274,24 +253,7 @@ pub enum Operation<'a> { /// /// - **Inputs**: `[(adt inputs)]` + `context` /// - **Outputs**: `outputs` - Conditional { - /// Port types for each case of the conditional. - /// - /// **Type**: `(list (list type))` - cases: TermId, - /// Port types for additional inputs to the conditional. - /// - /// **Type**: `(list type)` - context: TermId, - /// Port types for the outputs of each case. - /// - /// **Type**: `(list type)` - outputs: TermId, - /// - /// - /// **Type**: `ext-set` - extensions: TermId, - }, + Conditional, } /// A region in the hugr. @@ -502,7 +464,7 @@ pub enum Term<'a> { /// A literal string. /// /// `"STRING" : str` - Str(SmolStr), + Str(&'a str), /// The type of literal strings. /// diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 98bfe1689..38d8c0dcf 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -33,8 +33,6 @@ use pest_parser::{HugrParser, Rule}; pub struct ParsedModule<'a> { /// The parsed module. pub module: Module<'a>, - /// The names of the edges. - pub edges: Vec, // TODO: Spans } @@ -46,13 +44,11 @@ pub fn parse<'a>(input: &'a str, bump: &'a Bump) -> Result, Par Ok(ParsedModule { module: context.module, - edges: context.edge_names.into_iter().collect(), }) } struct ParseContext<'a> { module: Module<'a>, - edge_names: IndexSet, bump: &'a Bump, } @@ -60,7 +56,6 @@ impl<'a> ParseContext<'a> { fn new(bump: &'a Bump) -> Self { Self { module: Module::default(), - edge_names: IndexSet::default(), bump, } } @@ -171,7 +166,7 @@ impl<'a> ParseContext<'a> { Rule::term_str => { // TODO: Escaping? - let value = inner.next().unwrap().as_str().to_smolstr(); + let value = inner.next().unwrap().as_str(); Term::Str(value) } @@ -386,22 +381,12 @@ impl<'a> ParseContext<'a> { } Rule::node_tail_loop => { - let inputs = self.module.insert_term(Term::Wildcard); - let outputs = self.module.insert_term(Term::Wildcard); - let rest = self.module.insert_term(Term::Wildcard); - let extensions = self.module.insert_term(Term::Wildcard); - let operation = Operation::TailLoop { - inputs, - outputs, - rest, - extensions, - }; let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { - operation, + operation: Operation::TailLoop, inputs, outputs, params: &[], @@ -411,22 +396,12 @@ impl<'a> ParseContext<'a> { } Rule::node_cond => { - let cases = self.module.insert_term(Term::Wildcard); - let context = self.module.insert_term(Term::Wildcard); - let outputs = self.module.insert_term(Term::Wildcard); - let extensions = self.module.insert_term(Term::Wildcard); - let operation = Operation::Conditional { - cases, - context, - outputs, - extensions, - }; let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { - operation, + operation: Operation::Conditional, inputs, outputs, params: &[], diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 2721760b5..2830976c9 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -329,18 +329,21 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) }), - Operation::TailLoop { - inputs, - outputs, - rest, - extensions, - } => todo!(), - Operation::Conditional { - cases, - context, - outputs, - extensions, - } => todo!(), + Operation::TailLoop => { + this.print_text("tail-loop"); + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs)?; + this.print_meta(node_data.meta)?; + this.print_regions(node_data.regions) + } + + Operation::Conditional => { + this.print_text("cond"); + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs)?; + this.print_meta(node_data.meta)?; + this.print_regions(node_data.regions) + } }) } @@ -509,7 +512,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Term::Str(str) => { // TODO: escape self.print_text("\""); - self.print_text(str.as_ref()); + self.print_text(*str); self.print_text("\""); Ok(()) } From 55624b6a35bef52db2095dcc452baf39038de5cb Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 13:34:49 +0100 Subject: [PATCH 05/26] Fixed lints. --- hugr-core/Cargo.toml | 1 - hugr-core/src/export.rs | 21 ++++-------- hugr-core/src/import.rs | 26 +++++++++++---- hugr-model/Cargo.toml | 1 - hugr-model/src/v0/mod.rs | 6 ++++ hugr-model/src/v0/text/hugr.pest | 3 ++ hugr-model/src/v0/text/parse.rs | 34 +++++++++++++------- hugr-model/src/v0/text/print.rs | 55 +++++++++++++++++--------------- 8 files changed, 86 insertions(+), 61 deletions(-) diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 5732cf85b..0e7749983 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -50,7 +50,6 @@ hugr-model = { path = "../hugr-model" } indexmap.workspace = true tinyvec.workspace = true fxhash.workspace = true -ascent = "0.6.0" bumpalo = { workspace = true } [dev-dependencies] diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 1058cfe33..6d3783149 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -15,10 +15,8 @@ use hugr_model::v0::{self as model}; use indexmap::IndexSet; use smol_str::ToSmolStr; -pub(crate) const OP_FUNC_CALL_INDIRECT: &'static str = "func.call-indirect"; -pub(crate) const OP_ADT_TAG: &'static str = "adt.make-tag"; - -const TERM_PARAM_TUPLE: &'static str = "param.tuple"; +pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect"; +const TERM_PARAM_TUPLE: &str = "param.tuple"; /// Export a [`Hugr`] graph to its representation in the model. pub fn export_hugr<'a>(hugr: &'a Hugr, bump: &'a Bump) -> model::Module<'a> { @@ -114,7 +112,7 @@ impl<'a> Context<'a> { (OpType::DataflowBlock(block), Direction::Outgoing) => { let mut types = Vec::new(); types.extend( - (&block.sum_rows[port.index()]) + block.sum_rows[port.index()] .iter() .map(|t| self.export_type(t)), ); @@ -293,21 +291,14 @@ impl<'a> Context<'a> { OpType::CallIndirect(_) => make_custom(OP_FUNC_CALL_INDIRECT), - OpType::Tag(_) => make_custom(OP_ADT_TAG), + OpType::Tag(tag) => model::Operation::Tag { tag: tag.tag as _ }, - OpType::TailLoop(op) => { + OpType::TailLoop(_) => { regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); model::Operation::TailLoop } - OpType::Conditional(op) => { - let mut types = BumpVec::new_in(self.bump); - types.extend(op.sum_rows.iter().map(|l| self.export_type_row(l))); - let types = types.into_bump_slice(); - let sum_rows = model::Term::List { - items: &types, - tail: None, - }; + OpType::Conditional(_) => { regions = self.export_conditional_regions(node); model::Operation::Conditional } diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 243385555..8ebbfedc5 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -120,7 +120,7 @@ impl<'a> Context<'a> { .iter() .map(|port| match port.r#type { Some(r#type) => self.import_type(r#type), - None => return Err(error_uninferred!("port type")), + None => Err(error_uninferred!("port type")), }) .collect::, _>>()?; @@ -203,7 +203,7 @@ impl<'a> Context<'a> { for (model_port, port) in ports.iter().zip(self.hugr.node_ports(node, direction)) { self.link_ports - .entry(model_port.link.clone()) + .entry(model_port.link) .or_default() .push((node, port)); } @@ -449,7 +449,7 @@ impl<'a> Context<'a> { .collect::, _>>()?; self.static_edges.push((func_node, node_id)); - let optype = OpType::Call(Call::try_new(func_sig, type_args, &self.extensions)?); + let optype = OpType::Call(Call::try_new(func_sig, type_args, self.extensions)?); self.make_node(node_id, optype, parent) } @@ -472,7 +472,7 @@ impl<'a> Context<'a> { let optype = OpType::LoadFunction(LoadFunction::try_new( func_sig, type_args, - &self.extensions, + self.extensions, )?); self.make_node(node_id, optype, parent) @@ -560,6 +560,18 @@ impl<'a> Context<'a> { ctx.make_node(node_id, optype, parent) }), + + model::Operation::Tag { tag } => { + let (variants, _) = self.import_adt_and_rest(node_id, node_data.outputs)?; + self.make_node( + node_id, + OpType::Tag(Tag { + variants, + tag: tag as _, + }), + parent, + ) + } } } @@ -927,7 +939,7 @@ impl<'a> Context<'a> { } model::Term::Str(value) => Ok(TypeArg::String { - arg: value.clone().into(), + arg: value.to_string(), }), model::Term::Quote { .. } => Ok(TypeArg::Type { @@ -1004,7 +1016,7 @@ impl<'a> Context<'a> { .collect::, _>>()?; let name = self.get_global_name(*name)?; - let (extension, id) = self.import_custom_name(&name)?; + let (extension, id) = self.import_custom_name(name)?; Ok(TypeBase::new_extension(CustomType::new( id, @@ -1133,7 +1145,7 @@ impl<'a> Context<'a> { .split_last() .ok_or_else(|| model::ModelError::MalformedName(symbol.to_smolstr()))?; - Ok((extension, id.into())) + Ok((extension, id)) } } diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml index 5c4f79a8c..35a012251 100644 --- a/hugr-model/Cargo.toml +++ b/hugr-model/Cargo.toml @@ -8,7 +8,6 @@ repository.workspace = true license.workspace = true [dependencies] -beef = "0.5.2" bumpalo = { workspace = true } fxhash.workspace = true indexmap.workspace = true diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 57bde9857..e3b439edf 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -254,6 +254,12 @@ pub enum Operation<'a> { /// - **Inputs**: `[(adt inputs)]` + `context` /// - **Outputs**: `outputs` Conditional, + + /// Create an ADT value from a sequence of inputs. + Tag { + /// The tag of the ADT value. + tag: u16, + }, } /// A region in the hugr. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 212a0d053..cd107501d 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -2,6 +2,7 @@ WHITESPACE = _{ " " | "\t" | "\r" | "\n" } COMMENT = _{ ";" ~ (!("\n") ~ ANY)* ~ "\n" } identifier = @{ (ASCII_ALPHA | "_" | "-") ~ (ASCII_ALPHANUMERIC | "_" | "-")* } symbol = @{ identifier ~ ("." ~ identifier)+ } +tag = @{ ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* } string = @{ "\"" ~ (!("\"") ~ ANY)* ~ "\"" } list_tail = { "." } @@ -26,6 +27,7 @@ node = { | node_declare_alias | node_tail_loop | node_cond + | node_tag | node_custom } @@ -39,6 +41,7 @@ node_define_alias = { "(" ~ "define-alias" ~ alias_header ~ term ~ meta* ~ ")" node_declare_alias = { "(" ~ "declare-alias" ~ alias_header ~ meta* ~ ")" } node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ meta* ~ region* ~ ")" } node_cond = { "(" ~ "cond" ~ port_lists? ~ meta* ~ region* ~ ")" } +node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ meta* ~ region* ~ ")" } node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ meta* ~ region* ~ ")" } func_header = { symbol ~ param* ~ term ~ term ~ term } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 38d8c0dcf..50115bb1c 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -1,10 +1,8 @@ use bumpalo::Bump; -use indexmap::IndexSet; use pest::{ iterators::{Pair, Pairs}, Parser, RuleType, }; -use smol_str::{SmolStr, ToSmolStr}; use thiserror::Error; use crate::v0::{ @@ -39,7 +37,8 @@ pub struct ParsedModule<'a> { /// Parses a HUGR module from its text representation. pub fn parse<'a>(input: &'a str, bump: &'a Bump) -> Result, ParseError> { let mut context = ParseContext::new(bump); - let mut pairs = HugrParser::parse(Rule::module, input).map_err(ParseError)?; + let mut pairs = + HugrParser::parse(Rule::module, input).map_err(|err| ParseError(Box::new(err)))?; context.parse_module(pairs.next().unwrap())?; Ok(ParsedModule { @@ -71,17 +70,12 @@ impl<'a> ParseContext<'a> { kind: RegionKind::DataFlow, sources: &[], targets: &[], - children: self.bump.alloc_slice_copy(&children), + children, meta, }); self.module.root = root_region; - // TODO: Root region metadata - // self.module - // .node_meta - // .extend(meta.into_iter().map(|meta| (root, meta))); - Ok(()) } @@ -409,6 +403,22 @@ impl<'a> ParseContext<'a> { meta, } } + + Rule::node_tag => { + let tag = inner.next().unwrap().as_str().parse::().unwrap(); + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + Node { + operation: Operation::Tag { tag }, + inputs, + outputs, + params: &[], + regions: &[], + meta, + } + } + _ => unreachable!(), }; @@ -546,10 +556,10 @@ impl<'a> ParseContext<'a> { }; let pair = pairs.next().unwrap(); - let mut inner = pair.into_inner(); + let inner = pair.into_inner(); let mut ports = Vec::new(); - while let Some(token) = inner.next() { + for token in inner { let port = self.parse_port(token)?; ports.push(port); } @@ -620,7 +630,7 @@ fn filter_rule<'a, 'i, R: RuleType>( /// An error that occurred during parsing. #[derive(Debug, Clone, Error)] #[error("{0}")] -pub struct ParseError(pest::error::Error); +pub struct ParseError(Box>); impl ParseError { /// Line of the error in the input string. diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 2830976c9..3f05386a3 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -9,9 +9,6 @@ use crate::v0::{ type PrintError = ModelError; type PrintResult = Result; -// TODO: Print tail-loop nodes -// TODO: Print conditional and case nodes - /// Pretty-print a module to a string. pub fn print_to_string(module: &Module, width: usize) -> PrintResult { let arena = Arena::new(); @@ -149,29 +146,29 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Operation::Dfg => { this.print_group(|this| { this.print_text("dfg"); - this.print_port_list(&node_data.inputs)?; - this.print_port_list(&node_data.outputs) + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs) })?; this.print_meta(node_data.meta)?; - this.print_regions(&node_data.regions) + this.print_regions(node_data.regions) } Operation::Cfg => { this.print_group(|this| { this.print_text("cfg"); - this.print_port_list(&node_data.inputs)?; - this.print_port_list(&node_data.outputs) + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs) })?; this.print_meta(node_data.meta)?; - this.print_regions(&node_data.regions) + this.print_regions(node_data.regions) } Operation::Block => { this.print_group(|this| { this.print_text("block"); - this.print_port_list(&node_data.inputs)?; - this.print_port_list(&node_data.outputs) + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs) })?; this.print_meta(node_data.meta)?; - this.print_regions(&node_data.regions) + this.print_regions(node_data.regions) } Operation::DefineFunc { decl } => this.with_local_scope(decl.params, |this| { @@ -201,7 +198,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } this.print_meta(node_data.meta)?; - this.print_regions(&node_data.regions) + this.print_regions(node_data.regions) }), Operation::DeclareFunc { decl } => this.with_local_scope(decl.params, |this| { @@ -238,8 +235,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_group(|this| { this.print_text("call"); this.print_term(*func)?; - this.print_port_list(&node_data.inputs)?; - this.print_port_list(&node_data.outputs) + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs) })?; this.print_meta(node_data.meta)?; Ok(()) @@ -249,8 +246,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_group(|this| { this.print_text("load-func"); this.print_term(*func)?; - this.print_port_list(&node_data.inputs)?; - this.print_port_list(&node_data.outputs) + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs) })?; this.print_meta(node_data.meta)?; Ok(()) @@ -272,11 +269,11 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { })?; } - this.print_port_list(&node_data.inputs)?; - this.print_port_list(&node_data.outputs) + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs) })?; this.print_meta(node_data.meta)?; - this.print_regions(&node_data.regions) + this.print_regions(node_data.regions) } Operation::CustomFull { name } => { @@ -292,11 +289,11 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) })?; - this.print_port_list(&node_data.inputs)?; - this.print_port_list(&node_data.outputs) + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs) })?; this.print_meta(node_data.meta)?; - this.print_regions(&node_data.regions) + this.print_regions(node_data.regions) } Operation::DefineAlias { decl, value } => this.with_local_scope(decl.params, |this| { @@ -344,6 +341,14 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } + + Operation::Tag { tag } => { + this.print_text("tag"); + this.print_text(format!("{}", tag)); + this.print_port_list(node_data.inputs)?; + this.print_port_list(node_data.outputs)?; + this.print_meta(node_data.meta) + } }) } @@ -371,8 +376,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { }; if !region_data.sources.is_empty() || !region_data.targets.is_empty() { - this.print_port_list(®ion_data.sources)?; - this.print_port_list(®ion_data.targets)?; + this.print_port_list(region_data.sources)?; + this.print_port_list(region_data.targets)?; } this.print_meta(region_data.meta)?; From 2a8c34faa5a0d262c70488c1119119e7704b7878 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 13:40:21 +0100 Subject: [PATCH 06/26] Feature gate the model in `hugr-core`. --- hugr-core/Cargo.toml | 3 ++- hugr-core/src/lib.rs | 3 ++- hugr-core/tests/model.rs | 4 ++++ hugr-model/src/v0/mod.rs | 1 - 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 0e7749983..beb1c3634 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -19,6 +19,7 @@ workspace = true [features] extension_inference = [] declarative = ["serde_yaml"] +model = ["hugr-model"] [dependencies] portgraph = { workspace = true, features = ["serde", "petgraph"] } @@ -46,7 +47,7 @@ paste = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } semver = { version = "1.0.23", features = ["serde"] } -hugr-model = { path = "../hugr-model" } +hugr-model = { path = "../hugr-model", optional = true } indexmap.workspace = true tinyvec.workspace = true fxhash.workspace = true diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index 1221c5572..04ebae64d 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -7,12 +7,13 @@ // https://github.com/rust-lang/rust/issues/120363 // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] - pub mod builder; pub mod core; +#[cfg(feature = "model")] pub mod export; pub mod extension; pub mod hugr; +#[cfg(feature = "model")] pub mod import; pub mod macros; pub mod ops; diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 658a60fd1..6180bd69e 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -1,7 +1,11 @@ +#[cfg(feature = "model")] use hugr::std_extensions::std_reg; +#[cfg(feature = "model")] use hugr_core::{export::export_hugr, import::import_hugr}; +#[cfg(feature = "model")] use hugr_model::v0 as model; +#[cfg(feature = "model")] #[test] pub fn test_import_export() { let bump = bumpalo::Bump::new(); diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index e3b439edf..9ce06a456 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -4,7 +4,6 @@ //! It is included in the library to allow for early experimentation, and for //! the core and model to converge incrementally. //! -//! //! # Terms //! //! Terms form a meta language that is used to describe types, parameters and metadata that From 9eb6a6f7db48891695611b0f4a73ba15c2d7aa06 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 13:56:53 +0100 Subject: [PATCH 07/26] Added node and region types. --- hugr-core/Cargo.toml | 1 + hugr-core/src/export.rs | 16 +++++++++- hugr-model/src/v0/mod.rs | 6 ++-- hugr-model/src/v0/text/hugr.pest | 21 ++++++------- hugr-model/src/v0/text/parse.rs | 51 ++++++++++++++++++++++++++++++-- hugr-model/src/v0/text/print.rs | 3 +- 6 files changed, 80 insertions(+), 18 deletions(-) diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index beb1c3634..791dd2592 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -20,6 +20,7 @@ workspace = true extension_inference = [] declarative = ["serde_yaml"] model = ["hugr-model"] +default = ["model"] [dependencies] portgraph = { workspace = true, features = ["serde", "petgraph"] } diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 6d3783149..7a3019660 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -54,6 +54,8 @@ impl<'a> Context<'a> { } pub fn export_root(&mut self) { + let r#type = self.module.insert_term(model::Term::Wildcard); + let hugr_children = self.hugr.children(self.hugr.root()); let mut children = BumpVec::with_capacity_in(hugr_children.len(), self.bump); @@ -67,6 +69,7 @@ impl<'a> Context<'a> { targets: &[], children: children.into_bump_slice(), meta: &[], + r#type, }); self.module.root = root; @@ -146,7 +149,6 @@ impl<'a> Context<'a> { Some(model::Port { r#type: Some(r#type), link, - meta: &[], }) } @@ -177,6 +179,7 @@ impl<'a> Context<'a> { let outputs = self.make_ports(node, Direction::Outgoing); let mut params: &[_] = &[]; let mut regions: &[_] = &[]; + let mut r#type = None; fn make_custom(name: &'static str) -> model::Operation { model::Operation::Custom { @@ -341,6 +344,8 @@ impl<'a> Context<'a> { } }; + let r#type = r#type.unwrap_or_else(|| self.module.insert_term(model::Term::Wildcard)); + self.module.insert_node(model::Node { operation, inputs, @@ -348,6 +353,7 @@ impl<'a> Context<'a> { params, regions, meta: &[], + r#type, }) } @@ -388,12 +394,16 @@ impl<'a> Context<'a> { region_children.push(self.export_node(child)); } + // TODO: We can determine the type of the region + let r#type = self.module.insert_term(model::Term::Wildcard); + self.module.insert_region(model::Region { kind: model::RegionKind::DataFlow, sources, targets, children: region_children.into_bump_slice(), meta: &[], + r#type, }) } @@ -435,12 +445,16 @@ impl<'a> Context<'a> { region_children.push(self.export_node(child)); } + // TODO: We can determine the type of the region + let r#type = self.module.insert_term(model::Term::Wildcard); + self.module.insert_region(model::Region { kind: model::RegionKind::DataFlow, sources: self.bump.alloc_slice_copy(&[source]), targets, children: region_children.into_bump_slice(), meta: &[], + r#type, }) } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 9ce06a456..b1f586dab 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -174,6 +174,8 @@ pub struct Node<'a> { pub regions: &'a [RegionId], /// The meta information attached to the node. pub meta: &'a [MetaItem<'a>], + /// The type of the node. + pub r#type: TermId, } /// Operations that nodes can perform. @@ -274,6 +276,8 @@ pub struct Region<'a> { pub children: &'a [NodeId], /// The metadata attached to the region. pub meta: &'a [MetaItem<'a>], + /// The type of the region. + pub r#type: TermId, } /// The kind of a region. @@ -292,8 +296,6 @@ pub struct Port<'a> { pub link: LinkRef<'a>, /// The type of the port. pub r#type: Option, - /// Metadata attached to the port. - pub meta: &'a [MetaItem<'a>], } /// A function declaration. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index cd107501d..b5f8408f9 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -31,19 +31,20 @@ node = { | node_custom } -node_dfg = { "(" ~ "dfg" ~ port_lists? ~ meta* ~ region* ~ ")" } -node_cfg = { "(" ~ "cfg" ~ port_lists? ~ meta* ~ region* ~ ")" } -node_block = { "(" ~ "block" ~ port_lists? ~ meta* ~ region* ~ ")" } +node_dfg = { "(" ~ "dfg" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } +node_cfg = { "(" ~ "cfg" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } +node_block = { "(" ~ "block" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } node_define_func = { "(" ~ "define-func" ~ func_header ~ meta* ~ region* ~ ")" } node_declare_func = { "(" ~ "declare-func" ~ func_header ~ meta* ~ ")" } -node_call_func = { "(" ~ "call" ~ term ~ port_lists? ~ meta* ~ ")" } +node_call_func = { "(" ~ "call" ~ term ~ port_lists? ~ type_hint? ~ meta* ~ ")" } node_define_alias = { "(" ~ "define-alias" ~ alias_header ~ term ~ meta* ~ ")" } node_declare_alias = { "(" ~ "declare-alias" ~ alias_header ~ meta* ~ ")" } -node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ meta* ~ region* ~ ")" } -node_cond = { "(" ~ "cond" ~ port_lists? ~ meta* ~ region* ~ ")" } -node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ meta* ~ region* ~ ")" } -node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ meta* ~ region* ~ ")" } +node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } +node_cond = { "(" ~ "cond" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } +node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } +node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } +type_hint = { "(" ~ "type" ~ term ~ term ~ term ~ ")" } func_header = { symbol ~ param* ~ term ~ term ~ term } alias_header = { symbol ~ param* ~ term } @@ -54,8 +55,8 @@ param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } param_constraint = { "(" ~ "where" ~ term ~ ")" } region = { region_dfg | region_cfg } -region_dfg = { "(" ~ "dfg" ~ port_lists? ~ meta* ~ node* ~ ")" } -region_cfg = { "(" ~ "cfg" ~ port_lists? ~ meta* ~ node* ~ ")" } +region_dfg = { "(" ~ "dfg" ~ port_lists? ~ type_hint? ~ meta* ~ node* ~ ")" } +region_cfg = { "(" ~ "cfg" ~ port_lists? ~ type_hint? ~ meta* ~ node* ~ ")" } term = { term_wildcard diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 50115bb1c..7a98c3dbf 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -62,6 +62,7 @@ impl<'a> ParseContext<'a> { fn parse_module(&mut self, pair: Pair<'a, Rule>) -> ParseResult<()> { debug_assert!(matches!(pair.as_rule(), Rule::module)); let mut inner = pair.into_inner(); + let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; let children = self.parse_nodes(&mut inner)?; @@ -72,6 +73,7 @@ impl<'a> ParseContext<'a> { targets: &[], children, meta, + r#type, }); self.module.root = root_region; @@ -226,6 +228,7 @@ impl<'a> ParseContext<'a> { Rule::node_dfg => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -235,12 +238,14 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, + r#type, } } Rule::node_cfg => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -250,12 +255,14 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, + r#type, } } Rule::node_block => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -265,11 +272,13 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, + r#type, } } Rule::node_define_func => { let decl = self.parse_func_header(inner.next().unwrap())?; + let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -279,11 +288,13 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, + r#type, } } Rule::node_declare_func => { let decl = self.parse_func_header(inner.next().unwrap())?; + let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::DeclareFunc { decl }, @@ -292,6 +303,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, + r#type, } } @@ -299,6 +311,7 @@ impl<'a> ParseContext<'a> { let func = self.parse_term(inner.next().unwrap())?; let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::CallFunc { func }, @@ -307,12 +320,14 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, + r#type, } } Rule::node_define_alias => { let decl = self.parse_alias_header(inner.next().unwrap())?; let value = self.parse_term(inner.next().unwrap())?; + let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::DefineAlias { decl, value }, @@ -321,11 +336,13 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, + r#type, } } Rule::node_declare_alias => { let decl = self.parse_alias_header(inner.next().unwrap())?; + let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::DeclareAlias { decl }, @@ -334,6 +351,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, + r#type, } } @@ -362,6 +380,7 @@ impl<'a> ParseContext<'a> { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -371,12 +390,14 @@ impl<'a> ParseContext<'a> { params: self.bump.alloc_slice_copy(¶ms), regions, meta, + r#type, } } Rule::node_tail_loop => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -386,12 +407,14 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, + r#type, } } Rule::node_cond => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -401,6 +424,7 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, + r#type, } } @@ -408,6 +432,7 @@ impl<'a> ParseContext<'a> { let tag = inner.next().unwrap().as_str().parse::().unwrap(); let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::Tag { tag }, @@ -416,6 +441,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, + r#type, } } @@ -449,6 +475,7 @@ impl<'a> ParseContext<'a> { let sources = self.parse_port_list(&mut inner)?; let targets = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let children = self.parse_nodes(&mut inner)?; @@ -458,6 +485,7 @@ impl<'a> ParseContext<'a> { targets, children, meta, + r#type, })) // TODO: Attach region meta @@ -550,6 +578,25 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc_slice_copy(¶ms)) } + fn parse_type_hint(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult { + let Some(Rule::type_hint) = pairs.peek().map(|p| p.as_rule()) else { + return Ok(self.module.insert_term(Term::Wildcard)); + }; + + let pair = pairs.next().unwrap(); + let mut inner = pair.into_inner(); + + let inputs = self.parse_term(inner.next().unwrap())?; + let outputs = self.parse_term(inner.next().unwrap())?; + let extensions = self.parse_term(inner.next().unwrap())?; + + Ok(self.module.insert_term(Term::FuncType { + inputs, + outputs, + extensions, + })) + } + fn parse_port_list(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [Port<'a>]> { let Some(Rule::port_list) = pairs.peek().map(|p| p.as_rule()) else { return Ok(&[]); @@ -575,14 +622,12 @@ impl<'a> ParseContext<'a> { let link = LinkRef::Named(inner.next().unwrap().as_str()); let mut r#type = None; - let mut meta = &[] as &[MetaItem<'a>]; if inner.peek().is_some() { r#type = Some(self.parse_term(inner.next().unwrap())?); - meta = self.parse_meta(&mut inner)?; } - Ok(Port { link, r#type, meta }) + Ok(Port { link, r#type }) } fn parse_meta(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [MetaItem<'a>]> { diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 3f05386a3..bd48cc3f8 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -401,7 +401,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { fn print_port_list(&mut self, ports: &'a [Port<'a>]) -> PrintResult<()> { self.print_brackets(|this| { for port in ports { - if port.r#type.is_some() || !port.meta.is_empty() { + if port.r#type.is_some() { this.print_parens(|this| { this.print_link_ref(port.link); @@ -410,7 +410,6 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { None => this.print_text("_"), }; - this.print_meta(port.meta)?; Ok(()) })?; } else { From deaeea824b818c69cb353deaa34f92ee57cad0d7 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 14:13:15 +0100 Subject: [PATCH 08/26] Print type hint. --- hugr-core/Cargo.toml | 1 - hugr-model/src/v0/text/hugr.pest | 2 +- hugr-model/src/v0/text/parse.rs | 12 +----------- hugr-model/src/v0/text/print.rs | 22 ++++++++++++++++++++++ 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 791dd2592..beb1c3634 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -20,7 +20,6 @@ workspace = true extension_inference = [] declarative = ["serde_yaml"] model = ["hugr-model"] -default = ["model"] [dependencies] portgraph = { workspace = true, features = ["serde", "petgraph"] } diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index b5f8408f9..9cad5664c 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -44,7 +44,7 @@ node_cond = { "(" ~ "cond" ~ port_lists? ~ type_hint? ~ meta* ~ region* node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } -type_hint = { "(" ~ "type" ~ term ~ term ~ term ~ ")" } +type_hint = { "(" ~ "type" ~ term ~ ")" } func_header = { symbol ~ param* ~ term ~ term ~ term } alias_header = { symbol ~ param* ~ term } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 7a98c3dbf..ecd0b3450 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -584,17 +584,7 @@ impl<'a> ParseContext<'a> { }; let pair = pairs.next().unwrap(); - let mut inner = pair.into_inner(); - - let inputs = self.parse_term(inner.next().unwrap())?; - let outputs = self.parse_term(inner.next().unwrap())?; - let extensions = self.parse_term(inner.next().unwrap())?; - - Ok(self.module.insert_term(Term::FuncType { - inputs, - outputs, - extensions, - })) + self.parse_term(pair.into_inner().next().unwrap()) } fn parse_port_list(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [Port<'a>]> { diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index bd48cc3f8..30d4604f4 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -149,6 +149,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -158,6 +159,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -167,6 +169,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -238,6 +241,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta)?; Ok(()) } @@ -249,6 +253,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta)?; Ok(()) } @@ -272,6 +277,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -292,6 +298,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -330,6 +337,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text("tail-loop"); this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs)?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -338,6 +346,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text("cond"); this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs)?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -347,6 +356,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(format!("{}", tag)); this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs)?; + this.print_type_hint(node_data.r#type)?; this.print_meta(node_data.meta) } }) @@ -380,6 +390,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(region_data.targets)?; } + this.print_type_hint(region_data.r#type)?; this.print_meta(region_data.meta)?; this.print_nodes(region) }) @@ -624,4 +635,15 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } + + fn print_type_hint(&mut self, term: TermId) -> PrintResult<()> { + if let Some(Term::Wildcard) = self.module.get_term(term) { + return Ok(()); + } + + self.print_parens(|this| { + this.print_text("type-hint"); + this.print_term(term) + }) + } } From 1cf28df34c822fbb74af93fda0ffde7ae16b0d60 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 14:17:56 +0100 Subject: [PATCH 09/26] Add README to hugr-model --- hugr-core/README.md | 3 +++ hugr-model/CHANGELOG.md | 1 + hugr-model/README.md | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+) create mode 100644 hugr-model/CHANGELOG.md create mode 100644 hugr-model/README.md diff --git a/hugr-core/README.md b/hugr-core/README.md index d57d3cd70..6c09b05b1 100644 --- a/hugr-core/README.md +++ b/hugr-core/README.md @@ -21,6 +21,9 @@ Please read the [API documentation here][]. Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. +- `model` + Import and export from the representation defined in the `hugr-model` crate. + Not enabled by default. ## Recent Changes diff --git a/hugr-model/CHANGELOG.md b/hugr-model/CHANGELOG.md new file mode 100644 index 000000000..825c32f0d --- /dev/null +++ b/hugr-model/CHANGELOG.md @@ -0,0 +1 @@ +# Changelog diff --git a/hugr-model/README.md b/hugr-model/README.md new file mode 100644 index 000000000..7e6035710 --- /dev/null +++ b/hugr-model/README.md @@ -0,0 +1,37 @@ +![](/hugr/assets/hugr_logo.svg) + +hugr-model +=============== + +[![build_status][]](https://github.com/CQCL/hugr/actions) +[![crates][]](https://crates.io/crates/hugr-model) +[![msrv][]](https://github.com/CQCL/hugr) +[![codecov][]](https://codecov.io/gh/CQCL/hugr) + +Experimental data model for `hugr`. +Refer to the [main crate](http://crates.io/crates/hugr) for more information. + +Please read the [API documentation here][]. + +## Experimental Features + +## Recent Changes + +See [CHANGELOG][] for a list of changes. The minimum supported rust +version will only change on major releases. + +## Development + +See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for instructions on setting up the development environment. + +## License + +This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). + + [API documentation here]: https://docs.rs/hugr-core/ + [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main + [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [crates]: https://img.shields.io/crates/v/hugr-core + [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov + [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-core/CHANGELOG.md From 26441df2f460b1a9e44d83f9552f6debe0e4de07 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 14:36:43 +0100 Subject: [PATCH 10/26] Lints. --- hugr-core/src/export.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 7a3019660..40d98286a 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -179,7 +179,6 @@ impl<'a> Context<'a> { let outputs = self.make_ports(node, Direction::Outgoing); let mut params: &[_] = &[]; let mut regions: &[_] = &[]; - let mut r#type = None; fn make_custom(name: &'static str) -> model::Operation { model::Operation::Custom { @@ -344,7 +343,7 @@ impl<'a> Context<'a> { } }; - let r#type = r#type.unwrap_or_else(|| self.module.insert_term(model::Term::Wildcard)); + let r#type = self.module.insert_term(model::Term::Wildcard); self.module.insert_node(model::Node { operation, From ecaf64034c6bb0ef476a47dd853e2fa6ddc0024a Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 15:42:38 +0100 Subject: [PATCH 11/26] Added cfgs and tail-loops to tests and fixed bugs. --- hugr-core/Cargo.toml | 1 + hugr-core/src/export.rs | 20 +++---- hugr-core/src/import.rs | 57 ++++++++++++++----- hugr-core/tests/fixtures/model-1.edn | 18 ++++++ hugr-core/tests/model.rs | 2 + .../tests/snapshots/model__import_export.snap | 35 ++++++++++++ hugr-model/src/v0/text/hugr.pest | 2 +- 7 files changed, 108 insertions(+), 27 deletions(-) diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index beb1c3634..791dd2592 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -20,6 +20,7 @@ workspace = true extension_inference = [] declarative = ["serde_yaml"] model = ["hugr-model"] +default = ["model"] [dependencies] portgraph = { workspace = true, features = ["serde", "petgraph"] } diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 40d98286a..fa78e63b5 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -409,9 +409,9 @@ impl<'a> Context<'a> { /// Creates a control flow region from the given node's children. pub fn export_cfg(&mut self, node: Node) -> model::RegionId { let mut children = self.hugr.children(node); + let mut region_children = BumpVec::with_capacity_in(children.len() + 1, self.bump); // The first child is the entry block. - // The entry block does have a dataflow subgraph, so we must still export it later. // We create a source port on the control flow region and connect it to the // first input port of the exported entry block. let entry_block = children.next().unwrap(); @@ -422,8 +422,14 @@ impl<'a> Context<'a> { )); let source = self.make_port(entry_block, IncomingPort::from(0)).unwrap(); + region_children.push(self.export_node(entry_block)); + + // Export the remaining children of the node, except for the last one. + for _ in 0..children.len() - 1 { + region_children.push(self.export_node(children.next().unwrap())); + } - // The second child is the exit block. + // The last child is the exit block. // Contrary to the entry block, the exit block does not have a dataflow subgraph. // We therefore do not export the block itself, but simply use its output ports // as the target ports of the control flow region. @@ -436,19 +442,11 @@ impl<'a> Context<'a> { let targets = self.make_ports(exit_block, Direction::Incoming); - // Now we export the child nodes, including the entry block. - let mut region_children = BumpVec::with_capacity_in(children.len() + 1, self.bump); - - region_children.push(self.export_node(entry_block)); - for child in children { - region_children.push(self.export_node(child)); - } - // TODO: We can determine the type of the region let r#type = self.module.insert_term(model::Term::Wildcard); self.module.insert_region(model::Region { - kind: model::RegionKind::DataFlow, + kind: model::RegionKind::ControlFlow, sources: self.bump.alloc_slice_copy(&[source]), targets, children: region_children.into_bump_slice(), diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 8ebbfedc5..ec7a98864 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -711,9 +711,24 @@ impl<'a> Context<'a> { fn make_entry_node( &mut self, parent: Node, + parent_id: model::NodeId, ports: &'a [model::Port<'a>], ) -> Result { - let types = self.get_port_types(ports)?; + let types = { + let [port] = ports else { + return Err(model::ModelError::InvalidRegions(parent_id).into()); + }; + + let Some(port_type) = port.r#type else { + return Err(error_uninferred!("port type")); + }; + + let model::Term::Control { values: types } = self.get_term(port_type)? else { + return Err(model::ModelError::TypeError(port_type).into()); + }; + + self.import_type_row(*types)? + }; let node = self.hugr.add_node_with_parent( parent, @@ -725,6 +740,8 @@ impl<'a> Context<'a> { }), ); + self.record_links(node, Direction::Outgoing, ports); + let node_input = self.hugr.add_node_with_parent( node, OpType::Input(Input { @@ -769,13 +786,29 @@ impl<'a> Context<'a> { fn make_exit_node( &mut self, parent: Node, + parent_id: model::NodeId, ports: &'a [model::Port<'a>], ) -> Result { - let cfg_outputs = self.get_port_types(ports)?; + let cfg_outputs = { + let [port] = ports else { + return Err(model::ModelError::InvalidRegions(parent_id).into()); + }; + + let Some(port_type) = port.r#type else { + return Err(error_uninferred!("port type")); + }; + + let model::Term::Control { values: types } = self.get_term(port_type)? else { + return Err(model::ModelError::TypeError(port_type).into()); + }; + + self.import_type_row(*types)? + }; + let node = self .hugr .add_node_with_parent(parent, OpType::ExitBlock(ExitBlock { cfg_outputs })); - self.record_links(node, Direction::Outgoing, ports); + self.record_links(node, Direction::Incoming, ports); Ok(node) } @@ -787,25 +820,17 @@ impl<'a> Context<'a> { ) -> Result<(), ImportError> { let region_data = self.get_region(region)?; - if !matches!(region_data.kind, model::RegionKind::DataFlow) { + if !matches!(region_data.kind, model::RegionKind::ControlFlow) { return Err(model::ModelError::InvalidRegions(node_id).into()); } - let node_entry = self.make_entry_node(node, region_data.sources)?; + self.make_entry_node(node, node_id, region_data.sources)?; for child in region_data.children { self.import_node(*child, node)?; } - let node_exit = self.make_exit_node(node, region_data.targets)?; - - let entry_outputs = self.hugr.node_outputs(node_entry); - let first_block = self.hugr.children(node).nth(1).unwrap(); - let first_block_inputs = self.hugr.node_inputs(first_block); - - for (a, b) in entry_outputs.zip(first_block_inputs) { - self.hugr.connect(node_entry, a, node_exit, b); - } + self.make_exit_node(node, node_id, region_data.targets)?; Ok(()) } @@ -1105,7 +1130,9 @@ impl<'a> Context<'a> { None => break, } } - _ => return Err(model::ModelError::TypeError(term_id).into()), + _ => { + return Err(model::ModelError::TypeError(term_id).into()); + } } } diff --git a/hugr-core/tests/fixtures/model-1.edn b/hugr-core/tests/fixtures/model-1.edn index b0ca37d8b..9308003ec 100644 --- a/hugr-core/tests/fixtures/model-1.edn +++ b/hugr-core/tests/fixtures/model-1.edn @@ -36,3 +36,21 @@ (forall ?b type) [?a ?b] [?b ?a] (ext) (dfg [(%a ?a) (%b ?b)] [(%b ?b) (%a ?a)])) + +(define-func example.loop + (forall ?a type) + [?a] [?a] (ext) + (dfg [(%5 ?a)] [(%6 ?a)] + (tail-loop [(%5 ?a)] [(%6 ?a)] + (dfg [(%7 ?a)] [(%8 (adt [[?a] [?a]]))] + (tag 0 [(%7 ?a)] [(%8 (adt [[?a] [?a]]))]))))) + +(define-func example.cfg + (forall ?a type) + [?a] [?a] (ext) + (dfg [(%9 ?a)] [(%10 ?a)] + (cfg [(%9 ?a)] [(%10 ?a)] + (cfg [(%13 (ctrl [?a]))] [(%14 (ctrl [?a]))] + (block [(%13 (ctrl [?a]))] [(%14 (ctrl [?a]))] + (dfg [(%11 ?a)] [(%12 (adt [[?a]]))] + (tag 0 [(%11 ?a)] [(%12 (adt [[?a]]))]))))))) diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 6180bd69e..b2c38b46d 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -11,7 +11,9 @@ pub fn test_import_export() { let bump = bumpalo::Bump::new(); let parsed_module = model::text::parse(include_str!("fixtures/model-1.edn"), &bump).unwrap(); let extensions = std_reg(); + let hugr = import_hugr(&parsed_module.module, &extensions).unwrap(); + let roundtrip = export_hugr(&hugr, &bump); let roundtrip_str = model::text::print_to_string(&roundtrip, 80).unwrap(); insta::assert_snapshot!(roundtrip_str); diff --git a/hugr-core/tests/snapshots/model__import_export.snap b/hugr-core/tests/snapshots/model__import_export.snap index 8bc728ecc..a03b06732 100644 --- a/hugr-core/tests/snapshots/model__import_export.snap +++ b/hugr-core/tests/snapshots/model__import_export.snap @@ -35,3 +35,38 @@ expression: roundtrip_str (forall ?1 type) [?0 ?1] [?1 ?0] (ext) (dfg [(%5 ?0) (%6 ?1)] [(%6 ?1) (%5 ?0)])) + +(define-func example.loop + (forall ?0 type) + [?0] [?0] (ext) + (dfg + [(%7 ?0)] + [(%8 ?0)] + (tail-loop + [(%7 ?0)] + [(%8 ?0)] + (dfg + [(%9 ?0)] + [(%10 (adt [[?0] [?0]]))] + (tag 0 [(%9 ?0)] [(%10 (adt [[?0] [?0]]))]))))) + +(define-func example.cfg + (forall ?0 type) + [?0] [?0] (ext) + (dfg + [(%11 ?0)] + [(%12 ?0)] + (cfg [(%11 ?0)] [(%12 ?0)] + (cfg + [(%13 (ctrl [?0]))] + [(%17 (ctrl [?0]))] + (block [(%13 (ctrl [?0]))] [(%14 (ctrl [?0]))] + (dfg + [(%15 ?0)] + [(%16 (adt [[?0]]))] + (tag 0 [(%15 ?0)] [(%16 (adt [[?0]]))]))) + (block [(%14 (ctrl [?0]))] [(%17 (ctrl [?0]))] + (dfg + [(%18 ?0)] + [(%19 (adt [[?0]]))] + (tag 0 [(%18 ?0)] [(%19 (adt [[?0]]))]))))))) diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 9cad5664c..99e18684c 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -2,7 +2,7 @@ WHITESPACE = _{ " " | "\t" | "\r" | "\n" } COMMENT = _{ ";" ~ (!("\n") ~ ANY)* ~ "\n" } identifier = @{ (ASCII_ALPHA | "_" | "-") ~ (ASCII_ALPHANUMERIC | "_" | "-")* } symbol = @{ identifier ~ ("." ~ identifier)+ } -tag = @{ ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* } +tag = @{ (ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) | "0" } string = @{ "\"" ~ (!("\"") ~ ANY)* ~ "\"" } list_tail = { "." } From 6c3703ac2d1ae6b113e7a4be61248c6a19467b11 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 15:55:15 +0100 Subject: [PATCH 12/26] Added `cond` node to model import/export test. --- hugr-core/Cargo.toml | 1 - hugr-core/tests/fixtures/model-1.edn | 20 +++++++++++++++++++ .../tests/snapshots/model__import_export.snap | 20 +++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 791dd2592..beb1c3634 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -20,7 +20,6 @@ workspace = true extension_inference = [] declarative = ["serde_yaml"] model = ["hugr-model"] -default = ["model"] [dependencies] portgraph = { workspace = true, features = ["serde", "petgraph"] } diff --git a/hugr-core/tests/fixtures/model-1.edn b/hugr-core/tests/fixtures/model-1.edn index 9308003ec..c43bb4c54 100644 --- a/hugr-core/tests/fixtures/model-1.edn +++ b/hugr-core/tests/fixtures/model-1.edn @@ -54,3 +54,23 @@ (block [(%13 (ctrl [?a]))] [(%14 (ctrl [?a]))] (dfg [(%11 ?a)] [(%12 (adt [[?a]]))] (tag 0 [(%11 ?a)] [(%12 (adt [[?a]]))]))))))) + +(define-func example.cond + [(adt [[] []]) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext) + (dfg + [(%20 (adt [[] []])) (%21 (@ arithmetic.int.types.int))] + [(%22 (@ arithmetic.int.types.int))] + (cond + [(%20 (adt [[] []])) (%21 (@ arithmetic.int.types.int))] + [(%22 (@ arithmetic.int.types.int))] + (dfg + [(%23 (@ arithmetic.int.types.int))] + [(%23 (@ arithmetic.int.types.int))]) + (dfg + [(%24 (@ arithmetic.int.types.int))] + [(%25 (@ arithmetic.int.types.int))] + ((@ arithmetic.int.ineg) + [(%24 (@ arithmetic.int.types.int))] + [(%25 (@ arithmetic.int.types.int))]))))) diff --git a/hugr-core/tests/snapshots/model__import_export.snap b/hugr-core/tests/snapshots/model__import_export.snap index a03b06732..beb7ea0bd 100644 --- a/hugr-core/tests/snapshots/model__import_export.snap +++ b/hugr-core/tests/snapshots/model__import_export.snap @@ -70,3 +70,23 @@ expression: roundtrip_str [(%18 ?0)] [(%19 (adt [[?0]]))] (tag 0 [(%18 ?0)] [(%19 (adt [[?0]]))]))))))) + +(define-func example.cond + [(adt [[] []]) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext) + (dfg + [(%20 (adt [[] []])) (%21 (@ arithmetic.int.types.int))] + [(%22 (@ arithmetic.int.types.int))] + (cond + [(%20 (adt [[] []])) (%21 (@ arithmetic.int.types.int))] + [(%22 (@ arithmetic.int.types.int))] + (dfg + [(%23 (@ arithmetic.int.types.int))] + [(%23 (@ arithmetic.int.types.int))]) + (dfg + [(%24 (@ arithmetic.int.types.int))] + [(%25 (@ arithmetic.int.types.int))] + (arithmetic.int.ineg + [(%24 (@ arithmetic.int.types.int))] + [(%25 (@ arithmetic.int.types.int))]))))) From cff376aab7e51f1a0d84da3f29a6462b4537b62c Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 16:18:34 +0100 Subject: [PATCH 13/26] Split up the test model files into smaller files. --- hugr-core/tests/fixtures/model-1.edn | 76 --------------- hugr-core/tests/fixtures/model-add.edn | 10 ++ hugr-core/tests/fixtures/model-alias.edn | 5 + hugr-core/tests/fixtures/model-call.edn | 15 +++ hugr-core/tests/fixtures/model-cfg.edn | 11 +++ hugr-core/tests/fixtures/model-cond.edn | 21 +++++ hugr-core/tests/fixtures/model-loop.edn | 9 ++ hugr-core/tests/fixtures/model-params.edn | 8 ++ hugr-core/tests/model.rs | 53 +++++++++-- .../tests/snapshots/model__import_export.snap | 92 ------------------- .../tests/snapshots/model__roundtrip_add.snap | 16 ++++ .../snapshots/model__roundtrip_alias.snap | 9 ++ .../snapshots/model__roundtrip_call.snap | 18 ++++ .../tests/snapshots/model__roundtrip_cfg.snap | 26 ++++++ .../snapshots/model__roundtrip_cond.snap | 25 +++++ .../snapshots/model__roundtrip_loop.snap | 19 ++++ .../snapshots/model__roundtrip_params.snap | 11 +++ 17 files changed, 248 insertions(+), 176 deletions(-) delete mode 100644 hugr-core/tests/fixtures/model-1.edn create mode 100644 hugr-core/tests/fixtures/model-add.edn create mode 100644 hugr-core/tests/fixtures/model-alias.edn create mode 100644 hugr-core/tests/fixtures/model-call.edn create mode 100644 hugr-core/tests/fixtures/model-cfg.edn create mode 100644 hugr-core/tests/fixtures/model-cond.edn create mode 100644 hugr-core/tests/fixtures/model-loop.edn create mode 100644 hugr-core/tests/fixtures/model-params.edn delete mode 100644 hugr-core/tests/snapshots/model__import_export.snap create mode 100644 hugr-core/tests/snapshots/model__roundtrip_add.snap create mode 100644 hugr-core/tests/snapshots/model__roundtrip_alias.snap create mode 100644 hugr-core/tests/snapshots/model__roundtrip_call.snap create mode 100644 hugr-core/tests/snapshots/model__roundtrip_cfg.snap create mode 100644 hugr-core/tests/snapshots/model__roundtrip_cond.snap create mode 100644 hugr-core/tests/snapshots/model__roundtrip_loop.snap create mode 100644 hugr-core/tests/snapshots/model__roundtrip_params.snap diff --git a/hugr-core/tests/fixtures/model-1.edn b/hugr-core/tests/fixtures/model-1.edn deleted file mode 100644 index c43bb4c54..000000000 --- a/hugr-core/tests/fixtures/model-1.edn +++ /dev/null @@ -1,76 +0,0 @@ -(hugr 0) - -; NOTE: The @ in front of the names indicates that their implicit arguments are -; explicitly given as well. This is necessary everywhere at the moment -; since we do not have inference for implicit arguments yet. - -; NOTE: Every port in this file has been annotated with its type. This is quite -; verbose, but it is necessary currently until we have inference. - -(define-alias local.int type (@ arithmetic.int.types.int)) - -(define-func example.add - [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) - (dfg - [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))] - ((@ arithmetic.int.iadd) [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] [(%2 (@ arithmetic.int.types.int))]))) - -(declare-func example.callee - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) - (meta doc.title "Callee") - (meta doc.description "This is a function declaration.")) - -(define-func example.caller - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) - (meta doc.title "Caller") - (meta doc.description "This defines a function that calls the function which we declared earlier.") - (dfg - [(%3 (@ arithmetic.int.types.int))] - [(%4 (@ arithmetic.int.types.int))] - (call (@ example.callee) [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))]))) - -(define-func example.swap - ; The types of the values to be swapped are passed as implicit parameters. - (forall ?a type) - (forall ?b type) - [?a ?b] [?b ?a] (ext) - (dfg [(%a ?a) (%b ?b)] [(%b ?b) (%a ?a)])) - -(define-func example.loop - (forall ?a type) - [?a] [?a] (ext) - (dfg [(%5 ?a)] [(%6 ?a)] - (tail-loop [(%5 ?a)] [(%6 ?a)] - (dfg [(%7 ?a)] [(%8 (adt [[?a] [?a]]))] - (tag 0 [(%7 ?a)] [(%8 (adt [[?a] [?a]]))]))))) - -(define-func example.cfg - (forall ?a type) - [?a] [?a] (ext) - (dfg [(%9 ?a)] [(%10 ?a)] - (cfg [(%9 ?a)] [(%10 ?a)] - (cfg [(%13 (ctrl [?a]))] [(%14 (ctrl [?a]))] - (block [(%13 (ctrl [?a]))] [(%14 (ctrl [?a]))] - (dfg [(%11 ?a)] [(%12 (adt [[?a]]))] - (tag 0 [(%11 ?a)] [(%12 (adt [[?a]]))]))))))) - -(define-func example.cond - [(adt [[] []]) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext) - (dfg - [(%20 (adt [[] []])) (%21 (@ arithmetic.int.types.int))] - [(%22 (@ arithmetic.int.types.int))] - (cond - [(%20 (adt [[] []])) (%21 (@ arithmetic.int.types.int))] - [(%22 (@ arithmetic.int.types.int))] - (dfg - [(%23 (@ arithmetic.int.types.int))] - [(%23 (@ arithmetic.int.types.int))]) - (dfg - [(%24 (@ arithmetic.int.types.int))] - [(%25 (@ arithmetic.int.types.int))] - ((@ arithmetic.int.ineg) - [(%24 (@ arithmetic.int.types.int))] - [(%25 (@ arithmetic.int.types.int))]))))) diff --git a/hugr-core/tests/fixtures/model-add.edn b/hugr-core/tests/fixtures/model-add.edn new file mode 100644 index 000000000..4749dc0c7 --- /dev/null +++ b/hugr-core/tests/fixtures/model-add.edn @@ -0,0 +1,10 @@ +(hugr 0) + +(define-func example.add + [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext) + (dfg + [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))] + ((@ arithmetic.int.iadd) [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] [(%2 (@ arithmetic.int.types.int))]))) diff --git a/hugr-core/tests/fixtures/model-alias.edn b/hugr-core/tests/fixtures/model-alias.edn new file mode 100644 index 000000000..2a148c25d --- /dev/null +++ b/hugr-core/tests/fixtures/model-alias.edn @@ -0,0 +1,5 @@ +(hugr 0) + +(declare-alias local.float type) + +(define-alias local.int type (@ arithmetic.int.types.int)) diff --git a/hugr-core/tests/fixtures/model-call.edn b/hugr-core/tests/fixtures/model-call.edn new file mode 100644 index 000000000..ef55548b6 --- /dev/null +++ b/hugr-core/tests/fixtures/model-call.edn @@ -0,0 +1,15 @@ +(hugr 0) + +(declare-func example.callee + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + (meta doc.title "Callee") + (meta doc.description "This is a function declaration.")) + +(define-func example.caller + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + (meta doc.title "Caller") + (meta doc.description "This defines a function that calls the function which we declared earlier.") + (dfg + [(%3 (@ arithmetic.int.types.int))] + [(%4 (@ arithmetic.int.types.int))] + (call (@ example.callee) [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))]))) diff --git a/hugr-core/tests/fixtures/model-cfg.edn b/hugr-core/tests/fixtures/model-cfg.edn new file mode 100644 index 000000000..54e895a6b --- /dev/null +++ b/hugr-core/tests/fixtures/model-cfg.edn @@ -0,0 +1,11 @@ +(hugr 0) + +(define-func example.cfg + (forall ?a type) + [?a] [?a] (ext) + (dfg [(%0 ?a)] [(%1 ?a)] + (cfg [(%0 ?a)] [(%1 ?a)] + (cfg [(%2 (ctrl [?a]))] [(%4 (ctrl [?a]))] + (block [(%2 (ctrl [?a]))] [(%4 (ctrl [?a]))] + (dfg [(%5 ?a)] [(%6 (adt [[?a]]))] + (tag 0 [(%5 ?a)] [(%6 (adt [[?a]]))]))))))) diff --git a/hugr-core/tests/fixtures/model-cond.edn b/hugr-core/tests/fixtures/model-cond.edn new file mode 100644 index 000000000..04304d108 --- /dev/null +++ b/hugr-core/tests/fixtures/model-cond.edn @@ -0,0 +1,21 @@ +(hugr 0) + +(define-func example.cond + [(adt [[] []]) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext) + (dfg + [(%0 (adt [[] []])) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))] + (cond + [(%0 (adt [[] []])) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))] + (dfg + [(%3 (@ arithmetic.int.types.int))] + [(%3 (@ arithmetic.int.types.int))]) + (dfg + [(%4 (@ arithmetic.int.types.int))] + [(%5 (@ arithmetic.int.types.int))] + ((@ arithmetic.int.ineg) + [(%4 (@ arithmetic.int.types.int))] + [(%5 (@ arithmetic.int.types.int))]))))) diff --git a/hugr-core/tests/fixtures/model-loop.edn b/hugr-core/tests/fixtures/model-loop.edn new file mode 100644 index 000000000..d35e60578 --- /dev/null +++ b/hugr-core/tests/fixtures/model-loop.edn @@ -0,0 +1,9 @@ +(hugr 0) + +(define-func example.loop + (forall ?a type) + [?a] [?a] (ext) + (dfg [(%0 ?a)] [(%1 ?a)] + (tail-loop [(%0 ?a)] [(%1 ?a)] + (dfg [(%2 ?a)] [(%3 (adt [[?a] [?a]]))] + (tag 0 [(%2 ?a)] [(%3 (adt [[?a] [?a]]))]))))) diff --git a/hugr-core/tests/fixtures/model-params.edn b/hugr-core/tests/fixtures/model-params.edn new file mode 100644 index 000000000..c89dd158f --- /dev/null +++ b/hugr-core/tests/fixtures/model-params.edn @@ -0,0 +1,8 @@ +(hugr 0) + +(define-func example.swap + ; The types of the values to be swapped are passed as implicit parameters. + (forall ?a type) + (forall ?b type) + [?a ?b] [?b ?a] (ext) + (dfg [(%a ?a) (%b ?b)] [(%b ?b) (%a ?a)])) diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index b2c38b46d..6a046b0df 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -6,15 +6,52 @@ use hugr_core::{export::export_hugr, import::import_hugr}; use hugr_model::v0 as model; #[cfg(feature = "model")] -#[test] -pub fn test_import_export() { +fn roundtrip(source: &str) -> String { let bump = bumpalo::Bump::new(); - let parsed_module = model::text::parse(include_str!("fixtures/model-1.edn"), &bump).unwrap(); - let extensions = std_reg(); + let parsed_model = model::text::parse(source, &bump).unwrap(); + let imported_hugr = import_hugr(&parsed_model.module, &std_reg()).unwrap(); + let exported_model = export_hugr(&imported_hugr, &bump); + model::text::print_to_string(&exported_model, 80).unwrap() +} + +#[cfg(feature = "model")] +#[test] +pub fn test_roundtrip_add() { + insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-add.edn"))); +} + +#[cfg(feature = "model")] +#[test] +pub fn test_roundtrip_call() { + insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-call.edn"))); +} - let hugr = import_hugr(&parsed_module.module, &extensions).unwrap(); +#[cfg(feature = "model")] +#[test] +pub fn test_roundtrip_alias() { + insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-alias.edn"))); +} - let roundtrip = export_hugr(&hugr, &bump); - let roundtrip_str = model::text::print_to_string(&roundtrip, 80).unwrap(); - insta::assert_snapshot!(roundtrip_str); +#[cfg(feature = "model")] +#[test] +pub fn test_roundtrip_cfg() { + insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-cfg.edn"))); +} + +#[cfg(feature = "model")] +#[test] +pub fn test_roundtrip_cond() { + insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-cond.edn"))); +} + +#[cfg(feature = "model")] +#[test] +pub fn test_roundtrip_loop() { + insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-loop.edn"))); +} + +#[cfg(feature = "model")] +#[test] +pub fn test_roundtrip_params() { + insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-params.edn"))); } diff --git a/hugr-core/tests/snapshots/model__import_export.snap b/hugr-core/tests/snapshots/model__import_export.snap deleted file mode 100644 index beb7ea0bd..000000000 --- a/hugr-core/tests/snapshots/model__import_export.snap +++ /dev/null @@ -1,92 +0,0 @@ ---- -source: hugr-core/tests/model.rs -expression: roundtrip_str ---- -(hugr 0) - -(define-alias local.int type (@ arithmetic.int.types.int)) - -(define-func example.add - [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext) - (dfg - [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))] - (arithmetic.int.iadd - [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))]))) - -(declare-func example.callee - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)) - -(define-func example.caller - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) - (dfg - [(%3 (@ arithmetic.int.types.int))] - [(%4 (@ arithmetic.int.types.int))] - (call - (@ example.callee) - [(%3 (@ arithmetic.int.types.int))] - [(%4 (@ arithmetic.int.types.int))]))) - -(define-func example.swap - (forall ?0 type) - (forall ?1 type) - [?0 ?1] [?1 ?0] (ext) - (dfg [(%5 ?0) (%6 ?1)] [(%6 ?1) (%5 ?0)])) - -(define-func example.loop - (forall ?0 type) - [?0] [?0] (ext) - (dfg - [(%7 ?0)] - [(%8 ?0)] - (tail-loop - [(%7 ?0)] - [(%8 ?0)] - (dfg - [(%9 ?0)] - [(%10 (adt [[?0] [?0]]))] - (tag 0 [(%9 ?0)] [(%10 (adt [[?0] [?0]]))]))))) - -(define-func example.cfg - (forall ?0 type) - [?0] [?0] (ext) - (dfg - [(%11 ?0)] - [(%12 ?0)] - (cfg [(%11 ?0)] [(%12 ?0)] - (cfg - [(%13 (ctrl [?0]))] - [(%17 (ctrl [?0]))] - (block [(%13 (ctrl [?0]))] [(%14 (ctrl [?0]))] - (dfg - [(%15 ?0)] - [(%16 (adt [[?0]]))] - (tag 0 [(%15 ?0)] [(%16 (adt [[?0]]))]))) - (block [(%14 (ctrl [?0]))] [(%17 (ctrl [?0]))] - (dfg - [(%18 ?0)] - [(%19 (adt [[?0]]))] - (tag 0 [(%18 ?0)] [(%19 (adt [[?0]]))]))))))) - -(define-func example.cond - [(adt [[] []]) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext) - (dfg - [(%20 (adt [[] []])) (%21 (@ arithmetic.int.types.int))] - [(%22 (@ arithmetic.int.types.int))] - (cond - [(%20 (adt [[] []])) (%21 (@ arithmetic.int.types.int))] - [(%22 (@ arithmetic.int.types.int))] - (dfg - [(%23 (@ arithmetic.int.types.int))] - [(%23 (@ arithmetic.int.types.int))]) - (dfg - [(%24 (@ arithmetic.int.types.int))] - [(%25 (@ arithmetic.int.types.int))] - (arithmetic.int.ineg - [(%24 (@ arithmetic.int.types.int))] - [(%25 (@ arithmetic.int.types.int))]))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_add.snap b/hugr-core/tests/snapshots/model__roundtrip_add.snap new file mode 100644 index 000000000..262891119 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_add.snap @@ -0,0 +1,16 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"fixtures/model-add.edn\"))" +--- +(hugr 0) + +(define-func example.add + [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext) + (dfg + [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))] + (arithmetic.int.iadd + [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))]))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_alias.snap b/hugr-core/tests/snapshots/model__roundtrip_alias.snap new file mode 100644 index 000000000..467955539 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_alias.snap @@ -0,0 +1,9 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"fixtures/model-alias.edn\"))" +--- +(hugr 0) + +(declare-alias local.float type) + +(define-alias local.int type (@ arithmetic.int.types.int)) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap new file mode 100644 index 000000000..da88711c6 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -0,0 +1,18 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))" +--- +(hugr 0) + +(declare-func example.callee + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)) + +(define-func example.caller + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + (dfg + [(%0 (@ arithmetic.int.types.int))] + [(%1 (@ arithmetic.int.types.int))] + (call + (@ example.callee) + [(%0 (@ arithmetic.int.types.int))] + [(%1 (@ arithmetic.int.types.int))]))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap new file mode 100644 index 000000000..cbcbe3643 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -0,0 +1,26 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"fixtures/model-cfg.edn\"))" +--- +(hugr 0) + +(define-func example.cfg + (forall ?0 type) + [?0] [?0] (ext) + (dfg + [(%0 ?0)] + [(%1 ?0)] + (cfg [(%0 ?0)] [(%1 ?0)] + (cfg + [(%2 (ctrl [?0]))] + [(%6 (ctrl [?0]))] + (block [(%2 (ctrl [?0]))] [(%3 (ctrl [?0]))] + (dfg + [(%4 ?0)] + [(%5 (adt [[?0]]))] + (tag 0 [(%4 ?0)] [(%5 (adt [[?0]]))]))) + (block [(%3 (ctrl [?0]))] [(%6 (ctrl [?0]))] + (dfg + [(%7 ?0)] + [(%8 (adt [[?0]]))] + (tag 0 [(%7 ?0)] [(%8 (adt [[?0]]))]))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_cond.snap b/hugr-core/tests/snapshots/model__roundtrip_cond.snap new file mode 100644 index 000000000..0fdbc2f91 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_cond.snap @@ -0,0 +1,25 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"fixtures/model-cond.edn\"))" +--- +(hugr 0) + +(define-func example.cond + [(adt [[] []]) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext) + (dfg + [(%0 (adt [[] []])) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))] + (cond + [(%0 (adt [[] []])) (%1 (@ arithmetic.int.types.int))] + [(%2 (@ arithmetic.int.types.int))] + (dfg + [(%3 (@ arithmetic.int.types.int))] + [(%3 (@ arithmetic.int.types.int))]) + (dfg + [(%4 (@ arithmetic.int.types.int))] + [(%5 (@ arithmetic.int.types.int))] + (arithmetic.int.ineg + [(%4 (@ arithmetic.int.types.int))] + [(%5 (@ arithmetic.int.types.int))]))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_loop.snap b/hugr-core/tests/snapshots/model__roundtrip_loop.snap new file mode 100644 index 000000000..eb1debbfd --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_loop.snap @@ -0,0 +1,19 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"fixtures/model-loop.edn\"))" +--- +(hugr 0) + +(define-func example.loop + (forall ?0 type) + [?0] [?0] (ext) + (dfg + [(%0 ?0)] + [(%1 ?0)] + (tail-loop + [(%0 ?0)] + [(%1 ?0)] + (dfg + [(%2 ?0)] + [(%3 (adt [[?0] [?0]]))] + (tag 0 [(%2 ?0)] [(%3 (adt [[?0] [?0]]))]))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_params.snap b/hugr-core/tests/snapshots/model__roundtrip_params.snap new file mode 100644 index 000000000..b146bc706 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_params.snap @@ -0,0 +1,11 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"fixtures/model-params.edn\"))" +--- +(hugr 0) + +(define-func example.swap + (forall ?0 type) + (forall ?1 type) + [?0 ?1] [?1 ?0] (ext) + (dfg [(%0 ?0) (%1 ?1)] [(%1 ?1) (%0 ?0)])) From df958c9e89e8a3556cb0ea574b680651f48018e2 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 2 Oct 2024 19:09:00 +0100 Subject: [PATCH 14/26] `load-func` in model import/export test. --- hugr-core/tests/fixtures/model-alias.edn | 2 ++ hugr-core/tests/fixtures/model-call.edn | 7 +++++++ .../snapshots/model__roundtrip_alias.snap | 2 ++ .../snapshots/model__roundtrip_call.snap | 20 +++++++++++++++++++ hugr-model/src/v0/text/hugr.pest | 2 ++ hugr-model/src/v0/text/parse.rs | 17 ++++++++++++++++ 6 files changed, 50 insertions(+) diff --git a/hugr-core/tests/fixtures/model-alias.edn b/hugr-core/tests/fixtures/model-alias.edn index 2a148c25d..9783b3dbd 100644 --- a/hugr-core/tests/fixtures/model-alias.edn +++ b/hugr-core/tests/fixtures/model-alias.edn @@ -3,3 +3,5 @@ (declare-alias local.float type) (define-alias local.int type (@ arithmetic.int.types.int)) + +(define-alias local.endo type (fn [] [] (ext))) diff --git a/hugr-core/tests/fixtures/model-call.edn b/hugr-core/tests/fixtures/model-call.edn index ef55548b6..a65def91f 100644 --- a/hugr-core/tests/fixtures/model-call.edn +++ b/hugr-core/tests/fixtures/model-call.edn @@ -13,3 +13,10 @@ [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))] (call (@ example.callee) [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))]))) + +(define-func example.load + [] [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))] (ext) + (dfg + [] + [(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))] + (load-func (@ example.caller) [] [(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))]))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_alias.snap b/hugr-core/tests/snapshots/model__roundtrip_alias.snap index 467955539..c279c5d6a 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_alias.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_alias.snap @@ -7,3 +7,5 @@ expression: "roundtrip(include_str!(\"fixtures/model-alias.edn\"))" (declare-alias local.float type) (define-alias local.int type (@ arithmetic.int.types.int)) + +(define-alias local.endo type (fn [] [] (ext))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index da88711c6..7e98e6e68 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -16,3 +16,23 @@ expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))" (@ example.callee) [(%0 (@ arithmetic.int.types.int))] [(%1 (@ arithmetic.int.types.int))]))) + +(define-func example.load + [] + [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))] + (ext) + (dfg + [] + [(%2 + (fn + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext)))] + (load-func + (@ example.caller) + [] + [(%2 + (fn + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext)))]))) diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 99e18684c..f2b5ad38f 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -23,6 +23,7 @@ node = { | node_define_func | node_declare_func | node_call_func + | node_load_func | node_define_alias | node_declare_alias | node_tail_loop @@ -37,6 +38,7 @@ node_block = { "(" ~ "block" ~ port_lists? ~ type_hint? ~ meta* ~ region node_define_func = { "(" ~ "define-func" ~ func_header ~ meta* ~ region* ~ ")" } node_declare_func = { "(" ~ "declare-func" ~ func_header ~ meta* ~ ")" } node_call_func = { "(" ~ "call" ~ term ~ port_lists? ~ type_hint? ~ meta* ~ ")" } +node_load_func = { "(" ~ "load-func" ~ term ~ port_lists? ~ type_hint? ~ meta* ~ ")" } node_define_alias = { "(" ~ "define-alias" ~ alias_header ~ term ~ meta* ~ ")" } node_declare_alias = { "(" ~ "declare-alias" ~ alias_header ~ meta* ~ ")" } node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index ecd0b3450..5c21ecc1f 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -324,6 +324,23 @@ impl<'a> ParseContext<'a> { } } + Rule::node_load_func => { + let func = self.parse_term(inner.next().unwrap())?; + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let r#type = self.parse_type_hint(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + Node { + operation: Operation::LoadFunc { func }, + inputs, + outputs, + params: &[], + regions: &[], + meta, + r#type, + } + } + Rule::node_define_alias => { let decl = self.parse_alias_header(inner.next().unwrap())?; let value = self.parse_term(inner.next().unwrap())?; From ec1736fab533035082de1196f0bfac3d1a6160c6 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 3 Oct 2024 09:57:49 +0100 Subject: [PATCH 15/26] Removed resolved TODOs. --- hugr-core/src/export.rs | 5 ----- hugr-model/src/v0/text/parse.rs | 3 --- 2 files changed, 8 deletions(-) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index fa78e63b5..b647f7e2c 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -39,9 +39,6 @@ struct Context<'a> { impl<'a> Context<'a> { pub fn new(hugr: &'a Hugr, bump: &'a Bump) -> Self { - // let mut node_to_id = FxHashMap::default(); - // node_to_id.reserve(hugr.node_count()); - let mut module = model::Module::default(); module.nodes.reserve(hugr.node_count()); @@ -658,7 +655,6 @@ impl<'a> Context<'a> { [var] => Some(*var), _ => { // TODO: We won't need this anymore once we have a core representation - // that ensures that extension sets have at most one variable. panic!("Extension set with multiple variables") } @@ -718,6 +714,5 @@ mod test { use bumpalo::Bump; let bump = Bump::new(); let _model = super::export_hugr(&hugr, &bump); - // TODO check the model } } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 5c21ecc1f..b3c24a880 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -504,8 +504,6 @@ impl<'a> ParseContext<'a> { meta, r#type, })) - - // TODO: Attach region meta } fn parse_nodes(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [NodeId]> { @@ -528,7 +526,6 @@ impl<'a> ParseContext<'a> { let inputs = self.parse_term(inner.next().unwrap())?; let outputs = self.parse_term(inner.next().unwrap())?; - // TODO: This is subtly broken: let extensions = match inner.peek().map(|p| p.as_rule()) { Some(Rule::term_ext_set) => self.parse_term(inner.next().unwrap())?, _ => self.module.insert_term(Term::ExtSet { From e82b99c3cdf11b3afefdaa6dfeebecc715b262c9 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 3 Oct 2024 15:32:41 +0100 Subject: [PATCH 16/26] Feature gate on tests via Cargo.toml --- hugr-core/Cargo.toml | 4 ++++ hugr-core/tests/model.rs | 11 ----------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index beb1c3634..9d3214d86 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -21,6 +21,10 @@ extension_inference = [] declarative = ["serde_yaml"] model = ["hugr-model"] +[[test]] +name = "model" +required-features = ["model"] + [dependencies] portgraph = { workspace = true, features = ["serde", "petgraph"] } thiserror = { workspace = true } diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 6a046b0df..f59f463a8 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -1,11 +1,7 @@ -#[cfg(feature = "model")] use hugr::std_extensions::std_reg; -#[cfg(feature = "model")] use hugr_core::{export::export_hugr, import::import_hugr}; -#[cfg(feature = "model")] use hugr_model::v0 as model; -#[cfg(feature = "model")] fn roundtrip(source: &str) -> String { let bump = bumpalo::Bump::new(); let parsed_model = model::text::parse(source, &bump).unwrap(); @@ -14,43 +10,36 @@ fn roundtrip(source: &str) -> String { model::text::print_to_string(&exported_model, 80).unwrap() } -#[cfg(feature = "model")] #[test] pub fn test_roundtrip_add() { insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-add.edn"))); } -#[cfg(feature = "model")] #[test] pub fn test_roundtrip_call() { insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-call.edn"))); } -#[cfg(feature = "model")] #[test] pub fn test_roundtrip_alias() { insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-alias.edn"))); } -#[cfg(feature = "model")] #[test] pub fn test_roundtrip_cfg() { insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-cfg.edn"))); } -#[cfg(feature = "model")] #[test] pub fn test_roundtrip_cond() { insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-cond.edn"))); } -#[cfg(feature = "model")] #[test] pub fn test_roundtrip_loop() { insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-loop.edn"))); } -#[cfg(feature = "model")] #[test] pub fn test_roundtrip_params() { insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-params.edn"))); From 1d7a9b2958afb6a1f74430c623bd2020b79275a9 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 3 Oct 2024 15:39:42 +0100 Subject: [PATCH 17/26] Update hugr-model/Cargo.toml MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com> --- hugr-model/Cargo.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml index 35a012251..37fc36074 100644 --- a/hugr-model/Cargo.toml +++ b/hugr-model/Cargo.toml @@ -1,6 +1,11 @@ [package] name = "hugr-model" version = "0.1.0" +readme = "README.md" +documentation = "https://docs.rs/hugr-model/" +description = "Data model for Quantinuum's HUGR intermediate representation" +keywords = ["Quantum", "Quantinuum"] +categories = ["compilers"] rust-version.workspace = true edition.workspace = true homepage.workspace = true From 9c34183bc563a461485be6e9baca1f250201dcbe Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 3 Oct 2024 15:40:03 +0100 Subject: [PATCH 18/26] Update hugr-model/README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com> --- hugr-model/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-model/README.md b/hugr-model/README.md index 7e6035710..19e64e5ae 100644 --- a/hugr-model/README.md +++ b/hugr-model/README.md @@ -34,4 +34,4 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [crates]: https://img.shields.io/crates/v/hugr-core [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-core/CHANGELOG.md + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-model/CHANGELOG.md From 3573d44968a0d78dc76301a93e654aed433d6678 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 3 Oct 2024 15:40:22 +0100 Subject: [PATCH 19/26] Update hugr-model/README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com> --- hugr-model/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-model/README.md b/hugr-model/README.md index 19e64e5ae..0ea6fdf8f 100644 --- a/hugr-model/README.md +++ b/hugr-model/README.md @@ -28,7 +28,7 @@ See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). - [API documentation here]: https://docs.rs/hugr-core/ + [API documentation here]: https://docs.rs/hugr-model/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg [crates]: https://img.shields.io/crates/v/hugr-core From ea5f73422fe2be9e66d2b9cafd350a924dfb703b Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 3 Oct 2024 15:46:29 +0100 Subject: [PATCH 20/26] Moved dependency features to Cargo.toml's of the individual crates. Also added hugr model to change-filters.yaml --- .github/change-filters.yml | 1 + Cargo.toml | 3 +-- hugr-core/Cargo.toml | 3 +-- hugr-model/Cargo.toml | 3 +-- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/change-filters.yml b/.github/change-filters.yml index 9c20268ed..6d3e5dbf6 100644 --- a/.github/change-filters.yml +++ b/.github/change-filters.yml @@ -6,6 +6,7 @@ rust: &rust - "hugr-cli/**" - "hugr-core/**" - "hugr-passes/**" + - "hugr-model/**" - "Cargo.toml" - "specification/schema/**" - ".github/workflows/ci-rs.yml" diff --git a/Cargo.toml b/Cargo.toml index 35c3ccddd..19bcf0222 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,10 +62,9 @@ clap-verbosity-flag = "2.2.0" assert_cmd = "2.0.14" assert_fs = "1.1.1" predicates = "3.1.0" -tinyvec = { version = "1.8.0", features = ["alloc", "serde"] } indexmap = "2.3.0" fxhash = "0.2.1" -bumpalo = { version = "3.16.0", features = ["collections"] } +bumpalo = { version = "3.16.0" } [profile.dev.package] insta.opt-level = 3 diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 9d3214d86..9440a8505 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -53,9 +53,8 @@ strum_macros = { workspace = true } semver = { version = "1.0.23", features = ["serde"] } hugr-model = { path = "../hugr-model", optional = true } indexmap.workspace = true -tinyvec.workspace = true fxhash.workspace = true -bumpalo = { workspace = true } +bumpalo = { workspace = true, features = ["collections"] } [dev-dependencies] rstest = { workspace = true } diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml index 37fc36074..7f0396a24 100644 --- a/hugr-model/Cargo.toml +++ b/hugr-model/Cargo.toml @@ -13,7 +13,7 @@ repository.workspace = true license.workspace = true [dependencies] -bumpalo = { workspace = true } +bumpalo = { workspace = true, features = ["collections"] } fxhash.workspace = true indexmap.workspace = true pest = "2.7.12" @@ -21,7 +21,6 @@ pest_derive = "2.7.12" pretty = "0.12.3" smol_str = { workspace = true, features = ["serde"] } thiserror.workspace = true -tinyvec.workspace = true [lints] workspace = true From 39a835b306376704d8f643453cc9ea9b2a0f14c8 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 3 Oct 2024 16:02:51 +0100 Subject: [PATCH 21/26] Various naming improvements, more explicit define_index! macro. --- hugr-core/src/export.rs | 30 ++++++++++++++++--------- hugr-core/src/import.rs | 8 +++---- hugr-model/src/v0/mod.rs | 37 +++++++++++++++--------------- hugr-model/src/v0/text/parse.rs | 40 ++++++++++++++++++--------------- hugr-model/src/v0/text/print.rs | 38 +++++++++++++++---------------- 5 files changed, 83 insertions(+), 70 deletions(-) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index b647f7e2c..985511a47 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -66,7 +66,7 @@ impl<'a> Context<'a> { targets: &[], children: children.into_bump_slice(), meta: &[], - r#type, + signature: r#type, }); self.module.root = root; @@ -220,7 +220,11 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => { let name = self.get_func_name(node).unwrap(); let (params, func) = self.export_poly_func_type(&func.signature); - let decl = self.bump.alloc(model::FuncDecl { name, params, func }); + let decl = self.bump.alloc(model::FuncDecl { + name, + params, + signature: func, + }); regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); model::Operation::DefineFunc { decl } } @@ -228,7 +232,11 @@ impl<'a> Context<'a> { OpType::FuncDecl(func) => { let name = self.get_func_name(node).unwrap(); let (params, func) = self.export_poly_func_type(&func.signature); - let decl = self.bump.alloc(model::FuncDecl { name, params, func }); + let decl = self.bump.alloc(model::FuncDecl { + name, + params, + signature: func, + }); model::Operation::DeclareFunc { decl } } @@ -266,7 +274,7 @@ impl<'a> Context<'a> { let func = self .module - .insert_term(model::Term::ApplyFull { name, args }); + .insert_term(model::Term::ApplyFull { global: name, args }); model::Operation::CallFunc { func } } @@ -281,7 +289,7 @@ impl<'a> Context<'a> { let func = self .module - .insert_term(model::Term::ApplyFull { name, args }); + .insert_term(model::Term::ApplyFull { global: name, args }); model::Operation::LoadFunc { func } } @@ -349,7 +357,7 @@ impl<'a> Context<'a> { params, regions, meta: &[], - r#type, + signature: r#type, }) } @@ -399,7 +407,7 @@ impl<'a> Context<'a> { targets, children: region_children.into_bump_slice(), meta: &[], - r#type, + signature: r#type, }) } @@ -448,7 +456,7 @@ impl<'a> Context<'a> { targets, children: region_children.into_bump_slice(), meta: &[], - r#type, + signature: r#type, }) } @@ -494,7 +502,7 @@ impl<'a> Context<'a> { let name = model::GlobalRef::Named(self.bump.alloc_str(alias.name())); let args = &[]; self.module - .insert_term(model::Term::ApplyFull { name, args }) + .insert_term(model::Term::ApplyFull { global: name, args }) } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { @@ -525,7 +533,7 @@ impl<'a> Context<'a> { let args = self .bump .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p))); - let term = model::Term::ApplyFull { name, args }; + let term = model::Term::ApplyFull { global: name, args }; self.module.insert_term(term) } @@ -610,7 +618,7 @@ impl<'a> Context<'a> { .module .insert_term(model::Term::List { items, tail: None }); self.module.insert_term(model::Term::ApplyFull { - name: model::GlobalRef::Named(TERM_PARAM_TUPLE), + global: model::GlobalRef::Named(TERM_PARAM_TUPLE), args: self.bump.alloc_slice_copy(&[types]), }) } diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index ec7a98864..6ab46522c 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -436,7 +436,7 @@ impl<'a> Context<'a> { } model::Operation::CallFunc { func } => { - let model::Term::ApplyFull { name, args } = self.get_term(func)? else { + let model::Term::ApplyFull { global: name, args } = self.get_term(func)? else { return Err(model::ModelError::TypeError(func).into()); }; @@ -455,7 +455,7 @@ impl<'a> Context<'a> { } model::Operation::LoadFunc { func } => { - let model::Term::ApplyFull { name, args } = self.get_term(func)? else { + let model::Term::ApplyFull { global: name, args } = self.get_term(func)? else { return Err(model::ModelError::TypeError(func).into()); }; @@ -888,7 +888,7 @@ impl<'a> Context<'a> { } } - let body = ctx.import_func_type::(decl.func)?; + let body = ctx.import_func_type::(decl.signature)?; in_scope(ctx, PolyFuncTypeBase::new(imported_params, body)) }) } @@ -1034,7 +1034,7 @@ impl<'a> Context<'a> { Err(error_uninferred!("application with implicit parameters")) } - model::Term::ApplyFull { name, args } => { + model::Term::ApplyFull { global: name, args } => { let args = args .iter() .map(|arg| self.import_type_arg(*arg)) diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index b1f586dab..4ddd1c052 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -47,8 +47,7 @@ use thiserror::Error; pub mod text; macro_rules! define_index { - ($(#[$meta:meta])* $vis:vis struct $name:ident;) => { - #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] + ($(#[$meta:meta])* $vis:vis struct $name:ident(pub u32);) => { #[repr(transparent)] $(#[$meta])* $vis struct $name(pub u32); @@ -87,22 +86,26 @@ macro_rules! define_index { define_index! { /// Index of a node in a hugr graph. - pub struct NodeId; + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] + pub struct NodeId(pub u32); } define_index! { /// Index of a link in a hugr graph. - pub struct LinkId; + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] + pub struct LinkId(pub u32); } define_index! { /// Index of a region in a hugr graph. - pub struct RegionId; + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] + pub struct RegionId(pub u32); } define_index! { /// Index of a term in a hugr graph. - pub struct TermId; + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] + pub struct TermId(pub u32); } /// A module consisting of a hugr graph together with terms. @@ -174,8 +177,8 @@ pub struct Node<'a> { pub regions: &'a [RegionId], /// The meta information attached to the node. pub meta: &'a [MetaItem<'a>], - /// The type of the node. - pub r#type: TermId, + /// The signature of the node. + pub signature: TermId, } /// Operations that nodes can perform. @@ -276,8 +279,8 @@ pub struct Region<'a> { pub children: &'a [NodeId], /// The metadata attached to the region. pub meta: &'a [MetaItem<'a>], - /// The type of the region. - pub r#type: TermId, + /// The signature of the region. + pub signature: TermId, } /// The kind of a region. @@ -305,8 +308,8 @@ pub struct FuncDecl<'a> { pub name: &'a str, /// The static parameters of the function. pub params: &'a [Param<'a>], - /// The type of the function. - pub func: TermId, + /// The signature of the function. + pub signature: TermId, } /// An alias declaration. @@ -415,10 +418,8 @@ pub enum Term<'a> { /// /// `(GLOBAL ARG-0 ... ARG-n)` Apply { - // TODO: Should the name be replaced with the id of the node that defines - // the function to be applied? This could be a type, alias or function. - /// The name of the term. - name: GlobalRef<'a>, + /// Reference to the global declaration to apply. + global: GlobalRef<'a>, /// Arguments to the function, covering only the explicit parameters. args: &'a [TermId], }, @@ -427,8 +428,8 @@ pub enum Term<'a> { /// /// `(@GLOBAL ARG-0 ... ARG-n)` ApplyFull { - /// The name of the function to apply. - name: GlobalRef<'a>, + /// Reference to the global declaration to apply. + global: GlobalRef<'a>, /// Arguments to the function, covering both implicit and explicit parameters. args: &'a [TermId], }, diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index b3c24a880..86ce89e56 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -73,7 +73,7 @@ impl<'a> ParseContext<'a> { targets: &[], children, meta, - r#type, + signature: r#type, }); self.module.root = root_region; @@ -112,7 +112,7 @@ impl<'a> ParseContext<'a> { } Term::Apply { - name, + global: name, args: self.bump.alloc_slice_copy(&args), } } @@ -126,7 +126,7 @@ impl<'a> ParseContext<'a> { } Term::ApplyFull { - name, + global: name, args: self.bump.alloc_slice_copy(&args), } } @@ -238,7 +238,7 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - r#type, + signature: r#type, } } @@ -255,7 +255,7 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - r#type, + signature: r#type, } } @@ -272,7 +272,7 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - r#type, + signature: r#type, } } @@ -288,7 +288,7 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - r#type, + signature: r#type, } } @@ -303,7 +303,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - r#type, + signature: r#type, } } @@ -320,7 +320,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - r#type, + signature: r#type, } } @@ -337,7 +337,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - r#type, + signature: r#type, } } @@ -353,7 +353,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - r#type, + signature: r#type, } } @@ -368,7 +368,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - r#type, + signature: r#type, } } @@ -407,7 +407,7 @@ impl<'a> ParseContext<'a> { params: self.bump.alloc_slice_copy(¶ms), regions, meta, - r#type, + signature: r#type, } } @@ -424,7 +424,7 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - r#type, + signature: r#type, } } @@ -441,7 +441,7 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - r#type, + signature: r#type, } } @@ -458,7 +458,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - r#type, + signature: r#type, } } @@ -502,7 +502,7 @@ impl<'a> ParseContext<'a> { targets, children, meta, - r#type, + signature: r#type, })) } @@ -541,7 +541,11 @@ impl<'a> ParseContext<'a> { extensions, }); - Ok(self.bump.alloc(FuncDecl { name, params, func })) + Ok(self.bump.alloc(FuncDecl { + name, + params, + signature: func, + })) } fn parse_alias_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a AliasDecl<'a>> { diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 30d4604f4..26a4e4f3a 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -149,7 +149,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -159,7 +159,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -169,7 +169,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -184,7 +184,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_param(*param)?; } - match self.module.get_term(decl.func) { + match self.module.get_term(decl.signature) { Some(Term::FuncType { inputs, outputs, @@ -196,8 +196,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_term(*extensions) })?; } - Some(_) => return Err(PrintError::TypeError(decl.func)), - None => return Err(PrintError::TermNotFound(decl.func)), + Some(_) => return Err(PrintError::TypeError(decl.signature)), + None => return Err(PrintError::TermNotFound(decl.signature)), } this.print_meta(node_data.meta)?; @@ -214,7 +214,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_param(*param)?; } - match self.module.get_term(decl.func) { + match self.module.get_term(decl.signature) { Some(Term::FuncType { inputs, outputs, @@ -226,8 +226,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_term(*extensions) })?; } - Some(_) => return Err(PrintError::TypeError(decl.func)), - None => return Err(PrintError::TermNotFound(decl.func)), + Some(_) => return Err(PrintError::TypeError(decl.signature)), + None => return Err(PrintError::TermNotFound(decl.signature)), } this.print_meta(node_data.meta)?; @@ -241,7 +241,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta)?; Ok(()) } @@ -253,7 +253,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta)?; Ok(()) } @@ -277,7 +277,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -298,7 +298,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs) })?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -337,7 +337,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text("tail-loop"); this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs)?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -346,7 +346,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text("cond"); this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs)?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -356,7 +356,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(format!("{}", tag)); this.print_port_list(node_data.inputs)?; this.print_port_list(node_data.outputs)?; - this.print_type_hint(node_data.r#type)?; + this.print_type_hint(node_data.signature)?; this.print_meta(node_data.meta) } }) @@ -390,7 +390,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_port_list(region_data.targets)?; } - this.print_type_hint(region_data.r#type)?; + this.print_type_hint(region_data.signature)?; this.print_meta(region_data.meta)?; this.print_nodes(region) }) @@ -482,7 +482,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } Term::Var(local_ref) => self.print_local_ref(*local_ref), - Term::Apply { name, args } => { + Term::Apply { global: name, args } => { if args.is_empty() { self.print_global_ref(*name)?; } else { @@ -497,7 +497,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - Term::ApplyFull { name, args } => self.print_parens(|this| { + Term::ApplyFull { global: name, args } => self.print_parens(|this| { this.print_text("@"); this.print_global_ref(*name)?; for arg in args.iter() { From 280ef39cd3f03e4c4e82c47f6cff6c7a1501306d Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 3 Oct 2024 16:21:55 +0100 Subject: [PATCH 22/26] More doc strings, feature marked as unstable, and some renamed fields. --- hugr-core/Cargo.toml | 4 ++-- hugr-core/README.md | 4 ++-- hugr-core/src/export.rs | 10 +++++----- hugr-core/src/import.rs | 6 +++--- hugr-core/src/lib.rs | 4 ++-- hugr-model/src/v0/mod.rs | 25 +++++++++++++++++++------ hugr-model/src/v0/text/parse.rs | 4 ++-- hugr-model/src/v0/text/print.rs | 10 +++++----- 8 files changed, 40 insertions(+), 27 deletions(-) diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 9440a8505..50aaf461a 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -19,11 +19,11 @@ workspace = true [features] extension_inference = [] declarative = ["serde_yaml"] -model = ["hugr-model"] +model_unstable = ["hugr-model"] [[test]] name = "model" -required-features = ["model"] +required-features = ["model_unstable"] [dependencies] portgraph = { workspace = true, features = ["serde", "petgraph"] } diff --git a/hugr-core/README.md b/hugr-core/README.md index 6c09b05b1..9bebc638f 100644 --- a/hugr-core/README.md +++ b/hugr-core/README.md @@ -21,9 +21,9 @@ Please read the [API documentation here][]. Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. -- `model` +- `model_unstable` Import and export from the representation defined in the `hugr-model` crate. - Not enabled by default. + Unstable and subject to change. Not enabled by default. ## Recent Changes diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 985511a47..0d49039b9 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -179,7 +179,7 @@ impl<'a> Context<'a> { fn make_custom(name: &'static str) -> model::Operation { model::Operation::Custom { - name: model::GlobalRef::Named(name), + operation: model::GlobalRef::Named(name), } } @@ -317,7 +317,7 @@ impl<'a> Context<'a> { let name = self.bump .alloc_str(&format!("{}.{}", op.def().extension(), op.def().name())); - let name = model::GlobalRef::Named(name); + let operation = model::GlobalRef::Named(name); params = self .bump @@ -327,14 +327,14 @@ impl<'a> Context<'a> { regions = self.bump.alloc_slice_copy(&[region]); } - model::Operation::Custom { name } + model::Operation::Custom { operation } } OpType::OpaqueOp(op) => { let name = self .bump .alloc_str(&format!("{}.{}", op.extension(), op.op_name())); - let name = model::GlobalRef::Named(name); + let operation = model::GlobalRef::Named(name); params = self .bump @@ -344,7 +344,7 @@ impl<'a> Context<'a> { regions = self.bump.alloc_slice_copy(&[region]); } - model::Operation::Custom { name } + model::Operation::Custom { operation } } }; diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 6ab46522c..b5e3c41e8 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -482,14 +482,14 @@ impl<'a> Context<'a> { model::Operation::Conditional => self.import_conditional(node_id, parent), model::Operation::CustomFull { - name: GlobalRef::Named(name), + operation: GlobalRef::Named(name), } if name == OP_FUNC_CALL_INDIRECT => { let signature = self.get_node_signature(node_id)?; let optype = OpType::CallIndirect(CallIndirect { signature }); self.make_node(node_id, optype, parent) } - model::Operation::CustomFull { name } => { + model::Operation::CustomFull { operation } => { let signature = self.get_node_signature(node_id)?; let args = node_data .params @@ -497,7 +497,7 @@ impl<'a> Context<'a> { .map(|param| self.import_type_arg(*param)) .collect::, _>>()?; - let name = match name { + let name = match operation { GlobalRef::Direct(_) => { return Err(error_unsupported!( "custom operation with direct reference to declaring node" diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index 04ebae64d..f58f8ef77 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -9,11 +9,11 @@ #![cfg_attr(test, allow(non_local_definitions))] pub mod builder; pub mod core; -#[cfg(feature = "model")] +#[cfg(feature = "model_unstable")] pub mod export; pub mod extension; pub mod hugr; -#[cfg(feature = "model")] +#[cfg(feature = "model_unstable")] pub mod import; pub mod macros; pub mod ops; diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 4ddd1c052..8e3d19325 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -212,17 +212,22 @@ pub enum Operation<'a> { }, /// Custom operation. /// - /// The implicit parameters of the operation are left out. + /// The node's parameters correspond to the explicit parameter of the custom operation, + /// leaving out the implicit parameters. Once the declaration of the custom operation + /// becomes known by resolving the reference, the node can be transformed into a [`Operation::CustomFull`] + /// by inferring terms for the implicit parameters or at least filling them in with a wildcard term. Custom { /// The name of the custom operation. - name: GlobalRef<'a>, + operation: GlobalRef<'a>, }, - /// Custom operation. + /// Custom operation with full parameters. /// - /// The implicit parameters of the operation are included. + /// The node's parameters correspond to both the explicit and implicit parameters of the custom operation. + /// Since this can be tedious to write, the [`Operation::Custom`] variant can be used to indicate that + /// the implicit parameters should be inferred. CustomFull { /// The name of the custom operation. - name: GlobalRef<'a>, + operation: GlobalRef<'a>, }, /// Alias definitions. DefineAlias { @@ -416,6 +421,10 @@ pub enum Term<'a> { /// A symbolic function application. /// + /// The arguments of this application cover only the explicit parameters of the referenced declaration, + /// leaving out the implicit parameters. Once the type of the declaration is known, the implicit parameters + /// can be inferred and the term replaced with [`Term::ApplyFull`]. + /// /// `(GLOBAL ARG-0 ... ARG-n)` Apply { /// Reference to the global declaration to apply. @@ -426,6 +435,9 @@ pub enum Term<'a> { /// A symbolic function application with all arguments applied. /// + /// The arguments to this application cover both the implicit and explicit parameters of the referenced declaration. + /// Since this can be tedious to write out, only the explicit parameters can be provided via [`Term::Apply`]. + /// /// `(@GLOBAL ARG-0 ... ARG-n)` ApplyFull { /// Reference to the global declaration to apply. @@ -557,7 +569,8 @@ impl<'a> Default for Term<'a> { /// Implicit and explicit parameters share a namespace. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Param<'a> { - /// An implicit parameter that should be inferred. + /// An implicit parameter that should be inferred, unless a full application form is used + /// (see [`Term::ApplyFull`] and [`Operation::CustomFull`]). Implicit { /// The name of the parameter. name: &'a str, diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 86ce89e56..cc1921ae8 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -390,8 +390,8 @@ impl<'a> ParseContext<'a> { } let operation = match op_rule { - Rule::term_apply_full => Operation::CustomFull { name }, - Rule::term_apply => Operation::Custom { name }, + Rule::term_apply_full => Operation::CustomFull { operation: name }, + Rule::term_apply => Operation::Custom { operation: name }, _ => unreachable!(), }; diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 26a4e4f3a..fe585a976 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -258,13 +258,13 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - Operation::Custom { name } => { + Operation::Custom { operation } => { this.print_group(|this| { if node_data.params.is_empty() { - this.print_global_ref(*name)?; + this.print_global_ref(*operation)?; } else { this.print_parens(|this| { - this.print_global_ref(*name)?; + this.print_global_ref(*operation)?; for param in node_data.params { this.print_term(*param)?; @@ -282,11 +282,11 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_regions(node_data.regions) } - Operation::CustomFull { name } => { + Operation::CustomFull { operation } => { this.print_group(|this| { this.print_parens(|this| { this.print_text("@"); - this.print_global_ref(*name)?; + this.print_global_ref(*operation)?; for param in node_data.params { this.print_term(*param)?; From e18066d543a9d97b306e7b887ee56e295fd0f844 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Fri, 4 Oct 2024 15:05:20 +0100 Subject: [PATCH 23/26] Update hugr-core/src/import.rs Co-authored-by: Seyon Sivarajah --- hugr-core/src/import.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index b5e3c41e8..2975e2bb2 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -934,7 +934,7 @@ impl<'a> Context<'a> { | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), model::Term::ControlType => { - Err(error_unsupported!("type of control types as `TypeArg`")) + Err(error_unsupported!("type of control types as `TypeParam`")) } } } From 8070fd1e350ec1580d4c6438154ebbad26c26a6a Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Fri, 4 Oct 2024 16:08:14 +0100 Subject: [PATCH 24/26] Fixed bugs with extension sets, and small improvements. --- hugr-core/src/export.rs | 59 ++++++++++--------- hugr-core/src/import.rs | 31 ++++++---- hugr-core/tests/fixtures/model-call.edn | 9 +-- .../snapshots/model__roundtrip_call.snap | 13 ++-- hugr-model/src/v0/text/hugr.pest | 3 +- hugr-model/src/v0/text/parse.rs | 11 +--- 6 files changed, 68 insertions(+), 58 deletions(-) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 0d49039b9..440095570 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -13,7 +13,6 @@ use crate::{ use bumpalo::{collections::Vec as BumpVec, Bump}; use hugr_model::v0::{self as model}; use indexmap::IndexSet; -use smol_str::ToSmolStr; pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect"; const TERM_PARAM_TUPLE: &str = "param.tuple"; @@ -635,41 +634,45 @@ impl<'a> Context<'a> { // Until we have a better representation for extension sets, we therefore // need to try and parse each extension as a number to determine if it is // a variable or an extension. - let mut extensions = Vec::new(); - let mut variables = Vec::new(); + println!("ext set: {:?}", t); + + // NOTE: This overprovisions the capacity since some of the entries of the row + // may be variables. Since we panic when there is more than one variable, this + // may at most waste one slot. That is way better than having to allocate + // a temporary vector. + // + // Also `ExtensionSet` has no way of reporting its size, so we have to count + // the elements by iterating over them... + let capacity = t.iter().count(); + let mut extensions = BumpVec::with_capacity_in(capacity, self.bump); + let mut rest = None; for ext in t.iter() { if let Ok(index) = ext.parse::() { - variables.push({ + // Extension sets in the model support at most one variable. This is a + // deliberate limitation so that extension sets behave like polymorphic rows. + // The type theory of such rows and how to apply them to model (co)effects + // is well understood. + // + // Extension sets in `hugr-core` at this point have no such restriction. + // However, it appears that so far we never actually use extension sets with + // multiple variables, except for extension sets that are generated through + // property testing. + if rest.is_some() { + // TODO: We won't need this anymore once we have a core representation + // that ensures that extension sets have at most one variable. + panic!("Extension set with multiple variables") + } + + rest = Some( self.module - .insert_term(model::Term::Var(model::LocalRef::Index(index as _))) - }); + .insert_term(model::Term::Var(model::LocalRef::Index(index as _))), + ); } else { - extensions.push(ext.to_smolstr()); + extensions.push(self.bump.alloc_str(ext) as &str); } } - // Extension sets in the model support at most one variable. This is a - // deliberate limitation so that extension sets behave like polymorphic rows. - // The type theory of such rows and how to apply them to model (co)effects - // is well understood. - // - // Extension sets in `hugr-core` at this point have no such restriction. - // However, it appears that so far we never actually use extension sets with - // multiple variables, except for extension sets that are generated through - // property testing. - let rest = match variables.as_slice() { - [] => None, - [var] => Some(*var), - _ => { - // TODO: We won't need this anymore once we have a core representation - // that ensures that extension sets have at most one variable. - panic!("Extension set with multiple variables") - } - }; - - let mut extensions = BumpVec::with_capacity_in(extensions.len(), self.bump); - extensions.extend(t.iter().map(|ext| self.bump.alloc_str(ext) as &str)); let extensions = extensions.into_bump_slice(); self.module diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 2975e2bb2..6836eb1f5 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -115,6 +115,7 @@ struct Context<'a> { } impl<'a> Context<'a> { + /// Get the types of the given ports and assemble them into a `TypeRow`. fn get_port_types(&mut self, ports: &[model::Port]) -> Result { let types = ports .iter() @@ -198,8 +199,9 @@ impl<'a> Context<'a> { /// Associate links with the ports of the given node in the given direction. fn record_links(&mut self, node: Node, direction: Direction, ports: &'a [model::Port]) { let optype = self.hugr.get_optype(node); - let port_count = optype.port_count(direction); - assert!(ports.len() <= port_count); + + // NOTE: `OpType::port_count` copies the signature, which significantly slows down the import. + debug_assert!(ports.len() <= optype.port_count(direction)); for (model_port, port) in ports.iter().zip(self.hugr.node_ports(node, direction)) { self.link_ports @@ -508,6 +510,11 @@ impl<'a> Context<'a> { let (extension, name) = self.import_custom_name(name)?; + // TODO: Currently we do not have the description or any other metadata for + // the custom op. This will improve with declarative extensions being able + // to declare operations as a node, in which case the description will be attached + // to that node as metadata. + let optype = OpType::OpaqueOp(OpaqueOp::new( extension, name, @@ -583,7 +590,7 @@ impl<'a> Context<'a> { ) -> Result<(), ImportError> { let region_data = self.get_region(region)?; - if !matches!(region_data.kind, model::RegionKind::DataFlow) { + if region_data.kind != model::RegionKind::DataFlow { return Err(model::ModelError::InvalidRegions(node_id).into()); } @@ -629,7 +636,7 @@ impl<'a> Context<'a> { parent: Node, ) -> Result { let node_data = self.get_node(node_id)?; - assert!(matches!(node_data.operation, model::Operation::TailLoop)); + debug_assert_eq!(node_data.operation, model::Operation::TailLoop); let [region] = node_data.regions else { return Err(model::ModelError::InvalidRegions(node_id).into()); @@ -641,6 +648,7 @@ impl<'a> Context<'a> { let (just_inputs, just_outputs) = { let mut sum_rows = sum_rows.into_iter(); + // NOTE: This can not fail since else `import_adt_and_rest` would have failed before. let term = region_data.targets[0].r#type.unwrap(); let Some(just_inputs) = sum_rows.next() else { @@ -673,7 +681,7 @@ impl<'a> Context<'a> { parent: Node, ) -> Result { let node_data = self.get_node(node_id)?; - assert!(matches!(node_data.operation, model::Operation::Conditional)); + debug_assert_eq!(node_data.operation, model::Operation::Conditional); let (sum_rows, other_inputs) = self.import_adt_and_rest(node_id, node_data.inputs)?; let outputs = self.get_port_types(node_data.outputs)?; @@ -820,7 +828,7 @@ impl<'a> Context<'a> { ) -> Result<(), ImportError> { let region_data = self.get_region(region)?; - if !matches!(region_data.kind, model::RegionKind::ControlFlow) { + if region_data.kind != model::RegionKind::ControlFlow { return Err(model::ModelError::InvalidRegions(node_id).into()); } @@ -841,7 +849,7 @@ impl<'a> Context<'a> { parent: Node, ) -> Result { let node_data = self.get_node(node_id)?; - assert!(matches!(node_data.operation, model::Operation::Block)); + debug_assert_eq!(node_data.operation, model::Operation::Block); let [region] = node_data.regions else { return Err(model::ModelError::InvalidRegions(node_id).into()); @@ -923,8 +931,7 @@ impl<'a> Context<'a> { model::Term::StrType => Ok(TypeParam::String), model::Term::ExtSetType => Ok(TypeParam::Extensions), - // TODO: What do we do about the bounds on naturals? - model::Term::NatType => todo!(), + model::Term::NatType => Ok(TypeParam::max_nat()), model::Term::Nat(_) | model::Term::Str(_) @@ -1101,7 +1108,7 @@ impl<'a> Context<'a> { let model::Term::FuncType { inputs, outputs, - extensions: _, + extensions, } = term else { return Err(model::ModelError::TypeError(term_id).into()); @@ -1109,8 +1116,8 @@ impl<'a> Context<'a> { let inputs = self.import_type_row::(*inputs)?; let outputs = self.import_type_row::(*outputs)?; - // TODO: extensions - Ok(FuncTypeBase::new(inputs, outputs)) + let extensions = self.import_extension_set(*extensions)?; + Ok(FuncTypeBase::new(inputs, outputs).with_extension_delta(extensions)) } fn import_closed_list( diff --git a/hugr-core/tests/fixtures/model-call.edn b/hugr-core/tests/fixtures/model-call.edn index a65def91f..5918543bf 100644 --- a/hugr-core/tests/fixtures/model-call.edn +++ b/hugr-core/tests/fixtures/model-call.edn @@ -1,22 +1,23 @@ (hugr 0) (declare-func example.callee - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + (forall ?ext ext-set) + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int . ?ext) (meta doc.title "Callee") (meta doc.description "This is a function declaration.")) (define-func example.caller - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int) (meta doc.title "Caller") (meta doc.description "This defines a function that calls the function which we declared earlier.") (dfg [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))] - (call (@ example.callee) [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))]))) + (call (@ example.callee (ext)) [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))]))) (define-func example.load [] [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))] (ext) (dfg [] [(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))] - (load-func (@ example.caller) [] [(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))]))) + (load-func (@ example.caller) [] [(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int)))]))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index 7e98e6e68..40a8ac10b 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -5,15 +5,20 @@ expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))" (hugr 0) (declare-func example.callee - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)) + (forall ?0 ext-set) + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext arithmetic.int . ?0)) (define-func example.caller - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext arithmetic.int) (dfg [(%0 (@ arithmetic.int.types.int))] [(%1 (@ arithmetic.int.types.int))] (call - (@ example.callee) + (@ example.callee (ext)) [(%0 (@ arithmetic.int.types.int))] [(%1 (@ arithmetic.int.types.int))]))) @@ -35,4 +40,4 @@ expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))" (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] - (ext)))]))) + (ext arithmetic.int)))]))) diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index f2b5ad38f..37e9dad06 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -1,6 +1,7 @@ WHITESPACE = _{ " " | "\t" | "\r" | "\n" } COMMENT = _{ ";" ~ (!("\n") ~ ANY)* ~ "\n" } identifier = @{ (ASCII_ALPHA | "_" | "-") ~ (ASCII_ALPHANUMERIC | "_" | "-")* } +ext_name = @{ identifier ~ ("." ~ identifier)* } symbol = @{ identifier ~ ("." ~ identifier)+ } tag = @{ (ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) | "0" } @@ -97,7 +98,7 @@ term_str = { string } term_str_type = { "str" } term_nat = { (ASCII_DIGIT)+ } term_nat_type = { "nat" } -term_ext_set = { "(" ~ "ext" ~ identifier* ~ (list_tail ~ term)? ~ ")" } +term_ext_set = { "(" ~ "ext" ~ ext_name* ~ (list_tail ~ term)? ~ ")" } term_ext_set_type = { "ext-set" } term_adt = { "(" ~ "adt" ~ term ~ ")" } term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index cc1921ae8..fec67bee8 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -175,7 +175,7 @@ impl<'a> ParseContext<'a> { let mut extensions = Vec::new(); let mut rest = None; - for token in filter_rule(&mut inner, Rule::identifier) { + for token in filter_rule(&mut inner, Rule::ext_name) { extensions.push(token.as_str()); } @@ -525,14 +525,7 @@ impl<'a> ParseContext<'a> { let inputs = self.parse_term(inner.next().unwrap())?; let outputs = self.parse_term(inner.next().unwrap())?; - - let extensions = match inner.peek().map(|p| p.as_rule()) { - Some(Rule::term_ext_set) => self.parse_term(inner.next().unwrap())?, - _ => self.module.insert_term(Term::ExtSet { - extensions: &[], - rest: None, - }), - }; + let extensions = self.parse_term(inner.next().unwrap())?; // Assemble the inputs, outputs and extensions into a function type. let func = self.module.insert_term(Term::FuncType { From 14e0e1eb724cd3f0fabc1b36ec87131a9b501d7e Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Fri, 4 Oct 2024 16:33:48 +0100 Subject: [PATCH 25/26] Nits --- hugr-core/src/export.rs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 440095570..39f025e83 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -49,8 +49,9 @@ impl<'a> Context<'a> { } } + /// Exports the root module of the HUGR graph. pub fn export_root(&mut self) { - let r#type = self.module.insert_term(model::Term::Wildcard); + let signature = self.module.insert_term(model::Term::Wildcard); let hugr_children = self.hugr.children(self.hugr.root()); let mut children = BumpVec::with_capacity_in(hugr_children.len(), self.bump); @@ -64,8 +65,8 @@ impl<'a> Context<'a> { sources: &[], targets: &[], children: children.into_bump_slice(), - meta: &[], - signature: r#type, + meta: &[], // TODO: Export metadata + signature, }); self.module.root = root; @@ -347,7 +348,7 @@ impl<'a> Context<'a> { } }; - let r#type = self.module.insert_term(model::Term::Wildcard); + let signature = self.module.insert_term(model::Term::Wildcard); self.module.insert_node(model::Node { operation, @@ -355,8 +356,8 @@ impl<'a> Context<'a> { outputs, params, regions, - meta: &[], - signature: r#type, + meta: &[], // TODO: Export metadata + signature, }) } @@ -398,15 +399,15 @@ impl<'a> Context<'a> { } // TODO: We can determine the type of the region - let r#type = self.module.insert_term(model::Term::Wildcard); + let signature = self.module.insert_term(model::Term::Wildcard); self.module.insert_region(model::Region { kind: model::RegionKind::DataFlow, sources, targets, children: region_children.into_bump_slice(), - meta: &[], - signature: r#type, + meta: &[], // TODO: Export metadata + signature, }) } @@ -447,15 +448,15 @@ impl<'a> Context<'a> { let targets = self.make_ports(exit_block, Direction::Incoming); // TODO: We can determine the type of the region - let r#type = self.module.insert_term(model::Term::Wildcard); + let signature = self.module.insert_term(model::Term::Wildcard); self.module.insert_region(model::Region { kind: model::RegionKind::ControlFlow, sources: self.bump.alloc_slice_copy(&[source]), targets, children: region_children.into_bump_slice(), - meta: &[], - signature: r#type, + meta: &[], // TODO: Export metadata + signature, }) } From 706b7ac688519a3f01a22d7c67e04bafa0704ac8 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Fri, 4 Oct 2024 16:56:43 +0100 Subject: [PATCH 26/26] Remove println left in from debugging. --- hugr-core/src/export.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 39f025e83..5fd6932eb 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -635,7 +635,6 @@ impl<'a> Context<'a> { // Until we have a better representation for extension sets, we therefore // need to try and parse each extension as a number to determine if it is // a variable or an extension. - println!("ext set: {:?}", t); // NOTE: This overprovisions the capacity since some of the entries of the row // may be variables. Since we panic when there is more than one variable, this