From 4db3f5d72d1661bd0ec099a33afdcb075740785a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Tue, 15 Apr 2025 13:52:31 +0100 Subject: [PATCH 01/18] ci: Run ci checks on PRs to any branch --- .github/workflows/ci-py.yml | 2 +- .github/workflows/ci-rs.yml | 2 +- .github/workflows/pr-title.yml | 2 +- .github/workflows/semver-checks.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-py.yml b/.github/workflows/ci-py.yml index c393c195f..6ef3edce5 100644 --- a/.github/workflows/ci-py.yml +++ b/.github/workflows/ci-py.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '*' + - '**' merge_group: types: [checks_requested] workflow_dispatch: {} diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index d7c94e3a3..824291a8b 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '*' + - '**' merge_group: types: [checks_requested] workflow_dispatch: {} diff --git a/.github/workflows/pr-title.yml b/.github/workflows/pr-title.yml index 333eec96f..f778c1363 100644 --- a/.github/workflows/pr-title.yml +++ b/.github/workflows/pr-title.yml @@ -2,7 +2,7 @@ name: Check Conventional Commits format on: pull_request_target: branches: - - main + - '**' types: - opened - edited diff --git a/.github/workflows/semver-checks.yml b/.github/workflows/semver-checks.yml index e884b2e36..2c410aa85 100644 --- a/.github/workflows/semver-checks.yml +++ b/.github/workflows/semver-checks.yml @@ -2,7 +2,7 @@ name: Rust Semver Checks on: pull_request_target: branches: - - main + - '**' jobs: # Check if changes were made to the relevant files. From 81447ecf83acdba6d50427a64637447899553054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 15:58:19 +0100 Subject: [PATCH 02/18] feat!: Allow generic Nodes in HugrMut insert operations (#2075) `insert_hugr`, `insert_from_view`, and `insert_subgraph` were written before we made `Node` a type generic, and incorrectly assumed the return type maps were always `hugr::Node`s. The methods were either unusable or incorrect when using generic `HugrView`s source/targets with non-base node types. This PR fixes that, and additionally allows us us to have `SiblingSubgraph::extract_subgraph` work for generic `HugrViews`. BREAKING CHANGE: Added Node type parameters to extraction operations in `HugrMut`. --- hugr-core/src/builder/build_traits.rs | 2 +- hugr-core/src/hugr/hugrmut.rs | 114 ++++++++++++------- hugr-core/src/hugr/views/sibling_subgraph.rs | 4 +- 3 files changed, 75 insertions(+), 45 deletions(-) diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index f1613895d..e17d172ca 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -119,7 +119,7 @@ pub trait Container { } /// Insert a copy of a HUGR as a child of the container. - fn add_hugr_view(&mut self, child: &impl HugrView) -> InsertionResult { + fn add_hugr_view(&mut self, child: &H) -> InsertionResult { let parent = self.container_node(); self.hugr_mut().insert_from_view(parent, child) } diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index f3ef094be..38eb59222 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -1,13 +1,15 @@ //! Low-level interface for modifying a HUGR. use core::panic; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, PortMut, PortView, SecondaryMap}; +use crate::core::HugrNode; use crate::extension::ExtensionRegistry; +use crate::hugr::internal::HugrInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; @@ -162,10 +164,10 @@ pub trait HugrMut: HugrMutInternals { /// correspondingly for `Dom` edges) fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { + ) -> BTreeMap { panic_invalid_node(self, root); panic_invalid_node(self, new_parent); self.hugr_mut().copy_descendants(root, new_parent, subst) @@ -225,7 +227,7 @@ pub trait HugrMut: HugrMutInternals { /// /// If the root node is not in the graph. #[inline] - fn insert_hugr(&mut self, root: Node, other: Hugr) -> InsertionResult { + fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult { panic_invalid_node(self, root); self.hugr_mut().insert_hugr(root, other) } @@ -236,7 +238,11 @@ pub trait HugrMut: HugrMutInternals { /// /// If the root node is not in the graph. #[inline] - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult { panic_invalid_node(self, root); self.hugr_mut().insert_from_view(root, other) } @@ -255,12 +261,12 @@ pub trait HugrMut: HugrMutInternals { // TODO: Try to preserve the order when possible? We cannot always ensure // it, since the subgraph may have arbitrary nodes without including their // parent. - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap { panic_invalid_node(self, root); self.hugr_mut().insert_subgraph(root, other, subgraph) } @@ -307,20 +313,32 @@ pub trait HugrMut: HugrMutInternals { /// Records the result of inserting a Hugr or view /// via [HugrMut::insert_hugr] or [HugrMut::insert_from_view]. -pub struct InsertionResult { +/// +/// Contains a map from the nodes in the source HUGR to the nodes in the +/// target HUGR, using their respective `Node` types. +pub struct InsertionResult { /// The node, after insertion, that was the root of the inserted Hugr. /// /// That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root] - pub new_root: Node, + pub new_root: TargetN, /// Map from nodes in the Hugr/view that was inserted, to their new /// positions in the Hugr into which said was inserted. - pub node_map: HashMap, + pub node_map: HashMap, } -fn translate_indices( +/// Translate a portgraph node index map into a map from nodes in the source +/// HUGR to nodes in the target HUGR. +/// +/// This is as a helper in `insert_hugr` and `insert_subgraph`, where the source +/// HUGR may be an arbitrary `HugrView` with generic node types. +fn translate_indices( + mut source_node: impl FnMut(portgraph::NodeIndex) -> N, + mut target_node: impl FnMut(portgraph::NodeIndex) -> Node, node_map: HashMap, -) -> impl Iterator { - node_map.into_iter().map(|(k, v)| (k.into(), v.into())) +) -> impl Iterator { + node_map + .into_iter() + .map(move |(k, v)| (source_node(k), target_node(v))) } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -406,7 +424,11 @@ impl + AsMut> HugrMut for T (src_port, dst_port) } - fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> InsertionResult { + fn insert_hugr( + &mut self, + root: Self::Node, + mut other: Hugr, + ) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other); // Update the optypes and metadata, taking them from the other graph. // @@ -423,11 +445,16 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) + .collect(), } } - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other); // Update the optypes and metadata, copying them from the other graph. // @@ -444,22 +471,28 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) + .collect(), } } - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap { // Create a portgraph view with the explicit list of nodes defined by the subgraph. - let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> = + let context: HashSet = subgraph + .nodes() + .iter() + .map(|&n| other.get_pg_index(n)) + .collect(); + let portgraph: NodeFiltered<_, NodeFilter>, _> = NodeFiltered::new_node_filtered( other.portgraph(), - |node, ctx| ctx.contains(&node.into()), - subgraph.nodes(), + |node, ctx| ctx.contains(&node), + context, ); let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph); // Update the optypes and metadata, copying them from the other graph. @@ -473,25 +506,24 @@ impl + AsMut> HugrMut for T self.use_extensions(exts); } } - translate_indices(node_map).collect() + translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map).collect() } fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { + ) -> BTreeMap { let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index()); let root2 = descendants.next(); debug_assert_eq!(root2, Some(root.pg_index())); let nodes = Vec::from_iter(descendants); - let node_map = translate_indices( - portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) - .copy_in_parent() - .expect("Is a MultiPortGraph"), - ) - .collect::>(); + let node_map = portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) + .copy_in_parent() + .expect("Is a MultiPortGraph"); + let node_map = translate_indices(|n| self.get_node(n), |n| self.get_node(n), node_map) + .collect::>(); for node in self.children(root).collect::>() { self.set_parent(*node_map.get(&node).unwrap(), new_parent); @@ -563,10 +595,10 @@ fn insert_hugr_internal( /// sibling order in the hierarchy. This is due to the subgraph not necessarily /// having a single root, so the logic for reconstructing the hierarchy is not /// able to just do a BFS. -fn insert_subgraph_internal( +fn insert_subgraph_internal( hugr: &mut Hugr, root: Node, - other: &impl HugrView, + other: &impl HugrView, portgraph: &impl portgraph::LinkView, ) -> HashMap { let node_map = hugr diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index a0bf1a3da..c681fafc9 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -446,16 +446,14 @@ impl SiblingSubgraph { nu_out, )) } -} -impl SiblingSubgraph { /// Create a new Hugr containing only the subgraph. /// /// The new Hugr will contain a [FuncDefn][crate::ops::FuncDefn] root /// with the same signature as the subgraph and the specified `name` pub fn extract_subgraph( &self, - hugr: &impl HugrView, + hugr: &impl HugrView, name: impl Into, ) -> Hugr { let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap(); From ef1cba0a85f803423e9f14450844ad4f7300f1fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 16:02:17 +0100 Subject: [PATCH 03/18] fix!: Don't expose `HugrMutInternals` (#2071) `HugrMutInternals` is part of the semi-private traits defined in `hugr-core`. While most things get re-exported in `hugr`, we `*Internal` traits require you to explicitly declare a dependency on the `-core` package (as we don't want most users to have to interact with them). For some reason there was a public re-export of the trait in a re-exported module, so it ended up appearing in `hugr` anyways. BREAKING CHANGE: Removed public re-export of `HugrMutInternals` from `hugr`. --- hugr-core/src/hugr/rewrite/simple_replace.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index cf7f2922a..b4ec37db1 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -4,7 +4,6 @@ use std::collections::HashMap; use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; -pub use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; From fac6c8b92bab8904f47eaa9ebf4581eb1e1f4095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:10:32 +0100 Subject: [PATCH 04/18] feat!: Mark all Error enums as non_exhaustive (#2056) #2027 ended up being breaking due to adding a new variant to an error enum missing the `non_exhaustive` marker. This (breaking) PR makes sure all error enums have the flag. BREAKING CHANGE: Marked all Error enums as `non_exhaustive` --- hugr-core/src/extension.rs | 1 + hugr-core/src/hugr/serialize/upgrade.rs | 1 + hugr-core/src/import.rs | 2 ++ hugr-model/src/v0/ast/resolve.rs | 1 + hugr-model/src/v0/table/mod.rs | 1 + hugr-passes/src/lower.rs | 1 + hugr-passes/src/non_local.rs | 1 + hugr-passes/src/replace_types/linearize.rs | 1 + hugr-passes/src/validation.rs | 1 + 9 files changed, 10 insertions(+) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 408c88e15..b6e059050 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -378,6 +378,7 @@ pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry { /// TODO: decide on failure modes #[derive(Debug, Clone, Error, PartialEq, Eq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum SignatureError { /// Name mismatch #[error("Definition name ({0}) and instantiation name ({1}) do not match.")] diff --git a/hugr-core/src/hugr/serialize/upgrade.rs b/hugr-core/src/hugr/serialize/upgrade.rs index 2741b6175..ac1ac1eea 100644 --- a/hugr-core/src/hugr/serialize/upgrade.rs +++ b/hugr-core/src/hugr/serialize/upgrade.rs @@ -1,6 +1,7 @@ use thiserror::Error; #[derive(Debug, Error)] +#[non_exhaustive] pub enum UpgradeError { #[error(transparent)] Deserialize(#[from] serde_json::Error), diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 642c84c41..899deb17d 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -35,6 +35,7 @@ use thiserror::Error; /// Error during import. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ImportError { /// The model contains a feature that is not supported by the importer yet. /// Errors of this kind are expected to be removed as the model format and @@ -75,6 +76,7 @@ pub enum ImportError { /// Import error caused by incorrect order hints. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum OrderHintError { /// Duplicate order hint key in the same region. #[error("duplicate order hint key {0}")] diff --git a/hugr-model/src/v0/ast/resolve.rs b/hugr-model/src/v0/ast/resolve.rs index 2f8a5ba6e..c9be8896b 100644 --- a/hugr-model/src/v0/ast/resolve.rs +++ b/hugr-model/src/v0/ast/resolve.rs @@ -362,6 +362,7 @@ impl<'a> Context<'a> { /// Error that may occur in [`Module::resolve`]. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ResolveError { /// Unknown variable. #[error("unknown var: {0}")] diff --git a/hugr-model/src/v0/table/mod.rs b/hugr-model/src/v0/table/mod.rs index 756a52c1e..55a4b9889 100644 --- a/hugr-model/src/v0/table/mod.rs +++ b/hugr-model/src/v0/table/mod.rs @@ -456,6 +456,7 @@ pub struct VarId(pub NodeId, pub VarIndex); /// Errors that can occur when traversing and interpreting the model. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ModelError { /// There is a reference to a node that does not exist. #[error("node not found: {0}")] diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 09e02c41d..8f8920967 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -35,6 +35,7 @@ pub fn replace_many_ops>( /// Errors produced by the [`lower_ops`] function. #[derive(Debug, Error)] #[error(transparent)] +#[non_exhaustive] pub enum LowerError { /// Invalid subgraph. #[error("Subgraph formed by node is invalid: {0}")] diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index fca74657b..180e9d6fc 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -23,6 +23,7 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator { #[error("Found {} nonlocal edges", .0.len())] Edges(Vec<(N, IncomingPort)>), diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 371798dce..b3fc20da9 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -127,6 +127,7 @@ pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); #[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum LinearizeError { #[error("Need copy/discard op for {_0}")] NeedCopyDiscard(Type), diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index 5f53f403c..6c3e61fb4 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -25,6 +25,7 @@ pub enum ValidationLevel { #[derive(Error, Debug, PartialEq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum ValidatePassError { #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] InputError { From baaca02359a307f8691ab3985313272339a8c494 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Apr 2025 11:11:13 +0100 Subject: [PATCH 05/18] feat!: Handle CallIndirect in Dataflow Analysis (#2059) * PartialValue now has a LoadedFunction variant, created by LoadFunction nodes (only, although other analyses are able to create PartialValues if they want) * This requires adding a type parameter to PartialValue for the type of Node, which gets everywhere :-(. * Use this to handle CallIndirects *with known targets* (it'll be a single known target or none at all) just like other Calls to the same function * deprecate (and ignore) `value_from_function` * Add a new trait `AsConcrete` for the result type of `PartialValue::try_into_concrete` and `PartialSum::try_into_sum` Note almost no change to constant folding (only to drop impl of `value_from_function`) BREAKING CHANGE: in dataflow framework, PartialValue now has additional variant; `try_into_concrete` requires the target type to implement AsConcrete. --- hugr-passes/src/const_fold.rs | 63 ++--- hugr-passes/src/const_fold/test.rs | 6 +- hugr-passes/src/const_fold/value_handle.rs | 23 +- hugr-passes/src/dataflow.rs | 17 +- hugr-passes/src/dataflow/datalog.rs | 171 +++++++++++-- hugr-passes/src/dataflow/partial_value.rs | 267 +++++++++++++-------- hugr-passes/src/dataflow/results.rs | 22 +- hugr-passes/src/dataflow/test.rs | 108 ++++++++- hugr-passes/src/dataflow/value_row.rs | 38 +-- 9 files changed, 492 insertions(+), 223 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7552ed36f..e73e3cd0e 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -7,15 +7,11 @@ use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, ExtractHugr, HierarchyView}, - }, + hugr::hugrmut::HugrMut, ops::{ - constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - OpType, Value, + constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, }, - types::{EdgeKind, TypeArg}, + types::EdgeKind, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, }; use value_handle::ValueHandle; @@ -102,7 +98,7 @@ impl ConstantFoldPass { n, in_vals.iter().map(|(p, v)| { let const_with_dummy_loc = partial_from_const( - &ConstFoldContext(hugr), + &ConstFoldContext, ConstLocation::Field(p.index(), &fresh_node.into()), v, ); @@ -112,7 +108,7 @@ impl ConstantFoldPass { .map_err(|opty| ConstFoldError::InvalidEntryPoint(n, opty))?; } - let results = m.run(ConstFoldContext(hugr), []); + let results = m.run(ConstFoldContext, []); let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i); let wires_to_break = hugr @@ -131,7 +127,7 @@ impl ConstantFoldPass { n, ip, results - .try_read_wire_concrete::(Wire::new(src, outp)) + .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, )) }) @@ -205,60 +201,35 @@ pub fn constant_fold_pass(h: &mut H) { c.run(h).unwrap() } -struct ConstFoldContext<'a, H>(&'a H); - -impl std::ops::Deref for ConstFoldContext<'_, H> { - type Target = H; - fn deref(&self) -> &H { - self.0 - } -} +struct ConstFoldContext; -impl> ConstLoader> for ConstFoldContext<'_, H> { - type Node = H::Node; +impl ConstLoader> for ConstFoldContext { + type Node = Node; fn value_from_opaque( &self, - loc: ConstLocation, + loc: ConstLocation, val: &OpaqueValue, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_opaque(loc, val.clone())) } fn value_from_const_hugr( &self, - loc: ConstLocation, + loc: ConstLocation, h: &hugr_core::Hugr, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) } - - fn value_from_function( - &self, - node: H::Node, - type_args: &[TypeArg], - ) -> Option> { - if !type_args.is_empty() { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(&**self, node).ok()?; - Some(ValueHandle::new_const_hugr( - ConstLocation::Node(node), - Box::new(func.extract_hugr()), - )) - } } -impl> DFContext> for ConstFoldContext<'_, H> { +impl DFContext> for ConstFoldContext { fn interpret_leaf_op( &mut self, - node: H::Node, + node: Node, op: &ExtensionOp, - ins: &[PartialValue>], - outs: &mut [PartialValue>], + ins: &[PartialValue>], + outs: &mut [PartialValue>], ) { let sig = op.signature(); let known_ins = sig diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index b84d65d7d..58e69c568 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -42,8 +42,7 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { let n = Node::from(portgraph::NodeIndex::new(7)); let st = SumType::new([vec![k.get_type()], vec![]]); let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let temp = Hugr::default(); - let ctx: ConstFoldContext = ConstFoldContext(&temp); + let ctx = ConstFoldContext; let v1 = partial_from_const(&ctx, n, &subject_val); let v1_subfield = { @@ -114,8 +113,7 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { v.get_custom_value::().unwrap().value() } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); - let temp = Hugr::default(); - let mut ctx = ConstFoldContext(&temp); + let mut ctx = ConstFoldContext; let v_a = partial_from_const(&ctx, n_a, &f2c(a)); let v_b = partial_from_const(&ctx, n_b, &f2c(b)); assert_eq!(unwrap_float(v_a.clone()), a); diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index bda7bffd2..e5c99a8e7 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -1,16 +1,18 @@ //! Total equality (and hence [AbstractValue] support for [Value]s //! (by adding a source-Node and part unhashable constants) use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::convert::Infallible; use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::core::HugrNode; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; +use hugr_core::types::ConstTypeError; use hugr_core::{Hugr, Node}; use itertools::Either; -use crate::dataflow::{AbstractValue, ConstLocation}; +use crate::dataflow::{AbstractValue, AsConcrete, ConstLocation, LoadedFunction, Sum}; /// A custom constant that has been successfully hashed via [TryHash](hugr_core::ops::constant::TryHash) #[derive(Clone, Debug)] @@ -153,9 +155,12 @@ impl Hash for ValueHandle { // Unfortunately we need From for Value to be able to pass // Value's into interpret_leaf_op. So that probably doesn't make sense... -impl From> for Value { - fn from(value: ValueHandle) -> Self { - match value { +impl AsConcrete, N> for Value { + type ValErr = Infallible; + type SumErr = ConstTypeError; + + fn from_value(value: ValueHandle) -> Result { + Ok(match value { ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable { leaf: Either::Left(val), @@ -169,7 +174,15 @@ impl From> for Value { } => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) .map_err(|e| e.to_string()) .unwrap(), - } + }) + } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 43caa9c94..1f7c1ae5a 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -9,7 +9,7 @@ mod results; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; -pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, AsConcrete, LoadedFunction, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; @@ -31,8 +31,8 @@ pub trait DFContext: ConstLoader { &mut self, _node: Self::Node, _e: &ExtensionOp, - _ins: &[PartialValue], - _outs: &mut [PartialValue], + _ins: &[PartialValue], + _outs: &mut [PartialValue], ) { } } @@ -55,8 +55,8 @@ impl From for ConstLocation<'_, N> { } /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. -/// Implementors will likely want to override some/all of [Self::value_from_opaque], -/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// Implementors will likely want to override either/both of [Self::value_from_opaque] +/// and [Self::value_from_const_hugr]: the defaults /// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { /// The type of nodes in the Hugr. @@ -81,6 +81,7 @@ pub trait ConstLoader { /// [FuncDefn]: hugr_core::ops::FuncDefn /// [FuncDecl]: hugr_core::ops::FuncDecl /// [LoadFunction]: hugr_core::ops::LoadFunction + #[deprecated(note = "Automatically handled by Datalog, implementation will be ignored")] fn value_from_function(&self, _node: Self::Node, _type_args: &[TypeArg]) -> Option { None } @@ -94,7 +95,7 @@ pub fn partial_from_const<'a, V, CL: ConstLoader>( cl: &CL, loc: impl Into>, cst: &Value, -) -> PartialValue +) -> PartialValue where CL::Node: 'a, { @@ -120,8 +121,8 @@ where /// A row of inputs to a node contains bottom (can't happen, the node /// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). -pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( - elements: impl IntoIterator>, +pub fn row_contains_bottom<'a, V: 'a, N: 'a>( + elements: impl IntoIterator>, ) -> bool { elements.into_iter().any(PartialValue::contains_bottom) } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 13e510daf..ad1a99345 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -3,19 +3,22 @@ use std::collections::HashMap; use ascent::lattice::BoundedLattice; +use ascent::Lattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowOpTrait, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; use super::{ partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, - PartialValue, + LoadedFunction, PartialValue, }; -type PV = PartialValue; +type PV = PartialValue; + +type NodeInputs = Vec<(IncomingPort, PV)>; /// Basic structure for performing an analysis. Usage: /// 1. Make a new instance via [Self::new()] @@ -25,10 +28,7 @@ type PV = PartialValue; /// [Self::prepopulate_inputs] can be used on each externally-callable /// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] -pub struct Machine( - H, - HashMap)>>, -); +pub struct Machine(H, HashMap>); impl Machine { /// Create a new Machine to analyse the given Hugr(View) @@ -40,7 +40,7 @@ impl Machine { impl Machine { /// Provide initial values for a wire - these will be `join`d with any computed /// or any value previously prepopulated for the same Wire. - pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { for (n, inp) in self.0.linked_inputs(w.node(), w.source()) { self.1.entry(n).or_default().push((inp, v.clone())); } @@ -54,7 +54,7 @@ impl Machine { pub fn prepopulate_inputs( &mut self, parent: H::Node, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> Result<(), OpType> { match self.0.get_optype(parent) { OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => { @@ -102,7 +102,7 @@ impl Machine { pub fn run( mut self, context: impl DFContext, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = self.0.root(); if self.0.get_optype(root).is_module() { @@ -135,10 +135,12 @@ impl Machine { } } +pub(super) type InWire = (N, IncomingPort, PartialValue); + pub(super) fn run_datalog( mut ctx: impl DFContext, hugr: H, - in_wire_value_proto: Vec<(H::Node, IncomingPort, PV)>, + in_wire_value_proto: Vec>, ) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. @@ -155,9 +157,9 @@ pub(super) fn run_datalog( relation parent_of_node(H::Node, H::Node); // is parent of relation input_child(H::Node, H::Node); // has 1st child that is its `Input` relation output_child(H::Node, H::Node); // has 2nd child that is its `Output` - lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value - lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value - lattice node_in_value_row(H::Node, ValueRow); // 's inputs are + lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(H::Node, ValueRow); // 's inputs are node(n) <-- for n in hugr.nodes(); @@ -322,6 +324,37 @@ pub(super) fn run_datalog( func_call(call, func), output_child(func, outp), in_wire_value(outp, p, v); + + // CallIndirect -------------------- + lattice indirect_call(H::Node, LatticeWrapper); // is an `IndirectCall` to `FuncDefn` + indirect_call(call, tgt) <-- + node(call), + if let OpType::CallIndirect(_) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + let tgt = load_func(v); + + out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + input_child(func, inp), + in_wire_value(call, p, v) + if p.index() > 0; + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + output_child(func, outp), + in_wire_value(outp, p, v); + + // Default out-value is Bottom, but if we can't determine the called function, + // assign everything to Top + out_wire_value(call, p, PV::Top) <-- + node(call), + if let OpType::CallIndirect(ci) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + // Second alternative below addresses function::Value's: + if matches!(v, PartialValue::Top | PartialValue::Value(_)), + for p in ci.signature().output_ports(); }; let out_wire_values = all_results .out_wire_value @@ -337,13 +370,58 @@ pub(super) fn run_datalog( } } +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd)] +enum LatticeWrapper { + Bottom, + Value(T), + Top, +} + +impl Lattice for LatticeWrapper { + fn meet_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + return false; + }; + if *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + *self = other; + return true; + }; + // Both are `Value`s and not equal + *self = LatticeWrapper::Bottom; + true + } + + fn join_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + return false; + }; + if *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + *self = other; + return true; + }; + // Both are `Value`s and are not equal + *self = LatticeWrapper::Top; + true + } +} + +fn load_func(v: &PV) -> LatticeWrapper { + match v { + PartialValue::Bottom | PartialValue::PartialSum(_) => LatticeWrapper::Bottom, + PartialValue::LoadedFunction(LoadedFunction { func_node, .. }) => { + LatticeWrapper::Value(*func_node) + } + PartialValue::Value(_) | PartialValue::Top => LatticeWrapper::Top, + } +} + fn propagate_leaf_op( ctx: &mut impl DFContext, hugr: &H, n: H::Node, - ins: &[PV], + ins: &[PV], num_outs: usize, -) -> Option> { +) -> Option> { match hugr.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. @@ -362,8 +440,7 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent - OpType::Call(_) => None, // handled via Input/Output of FuncDefn - OpType::Const(_) => None, // handled by LoadConstant: + OpType::Call(_) | OpType::CallIndirect(_) => None, // handled via Input/Output of FuncDefn OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant let const_node = hugr @@ -380,10 +457,10 @@ fn propagate_leaf_op( .unwrap() .0; // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself - Some(ValueRow::singleton( - ctx.value_from_function(func_node, &load_op.type_args) - .map_or(PV::Top, PV::Value), - )) + Some(ValueRow::singleton(PartialValue::new_load( + func_node, + load_op.type_args.clone(), + ))) } OpType::ExtensionOp(e) => { Some(ValueRow::from_iter(if row_contains_bottom(ins) { @@ -401,6 +478,54 @@ fn propagate_leaf_op( outs })) } - o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + // We only call propagate_leaf_op for dataflow op non-containers, + o => todo!("Unhandled: {:?}", o), // and OpType is non-exhaustive + } +} + +#[cfg(test)] +mod test { + use ascent::Lattice; + + use super::LatticeWrapper; + + #[test] + fn latwrap_join() { + for lv in [ + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + LatticeWrapper::Top, + ] { + let mut subject = LatticeWrapper::Bottom; + assert!(subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.join_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Top + ); + assert_eq!(subject, LatticeWrapper::Top); + } + } + + #[test] + fn latwrap_meet() { + for lv in [ + LatticeWrapper::Bottom, + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + ] { + let mut subject = LatticeWrapper::Top; + assert!(subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.meet_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Bottom + ); + assert_eq!(subject, LatticeWrapper::Bottom); + } } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f2a497806..240f4f2d6 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,7 +1,7 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::Node; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -51,15 +51,25 @@ pub struct Sum { pub st: SumType, } +/// The output of an [LoadFunction](hugr_core::ops::LoadFunction) - a "pointer" +/// to a function at a specific node, instantiated with the provided type-args. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct LoadedFunction { + /// The [FuncDefn](hugr_core::ops::FuncDefn) or `FuncDecl`` that was loaded + pub func_node: N, + /// The type arguments provided when loading + pub args: Vec, +} + /// A representation of a value of [SumType], that may have one or more possible tags, /// with a [PartialValue] representation of each element-value of each possible tag. #[derive(PartialEq, Clone, Eq)] -pub struct PartialSum(pub HashMap>>); +pub struct PartialSum(pub HashMap>>); -impl PartialSum { +impl PartialSum { /// New instance for a single known tag. /// (Multi-tag instances can be created via [Self::try_join_mut].) - pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -75,9 +85,21 @@ impl PartialSum { pv.assert_invariants(); } } + + /// Whether this sum might have the specified tag + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } } -impl PartialSum { +impl PartialSum { /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns /// whether `self` has changed. /// @@ -141,12 +163,33 @@ impl PartialSum { } Ok(changed) } +} - /// Whether this sum might have the specified tag - pub fn supports_tag(&self, tag: usize) -> bool { - self.0.contains_key(&tag) - } +/// Trait implemented by value types into which [PartialValue]s can be converted, +/// so long as the PV has no [Top](PartialValue::Top), [Bottom](PartialValue::Bottom) +/// or [PartialSum]s with more than one possible tag. See [PartialSum::try_into_sum] +/// and [PartialValue::try_into_concrete]. +/// +/// `V` is the type of [AbstractValue] from which `Self` can (fallibly) be constructed, +/// `N` is the type of [HugrNode](hugr_core::core::HugrNode) for function pointers +pub trait AsConcrete: Sized { + /// Kind of error raised when creating `Self` from a value `V`, see [Self::from_value] + type ValErr: std::error::Error; + /// Kind of error that may be raised when creating `Self` from a [Sum] of `Self`s, + /// see [Self::from_sum] + type SumErr: std::error::Error; + + /// Convert an abstract value into concrete + fn from_value(val: V) -> Result; + + /// Convert a sum (of concrete values, already recursively converted) into concrete + fn from_sum(sum: Sum) -> Result; + + /// Convert a function pointer into a concrete value + fn from_func(func: LoadedFunction) -> Result>; +} +impl PartialSum { /// Turns this instance into a [Sum] of some "concrete" value type `C`, /// *if* this PartialSum has exactly one possible tag. /// @@ -155,11 +198,11 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_concrete]. - pub fn try_into_sum(self, typ: &Type) -> Result, ExtractValueError> - where - V: TryInto, - Sum: TryInto, - { + #[allow(clippy::type_complexity)] // Since C is a parameter, can't declare type aliases + pub fn try_into_sum>( + self, + typ: &Type, + ) -> Result, ExtractValueError> { if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); } @@ -185,22 +228,15 @@ impl PartialSum { num_elements: v.len(), }) } - - /// Can this ever occur at runtime? See [PartialValue::contains_bottom] - pub fn contains_bottom(&self) -> bool { - self.0 - .iter() - .all(|(_tag, elements)| row_contains_bottom(elements)) - } } /// An error converting a [PartialValue] or [PartialSum] into a concrete value type /// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] -pub enum ExtractValueError { +pub enum ExtractValueError { #[error("PartialSum value had multiple possible tags: {0}")] - MultipleVariants(PartialSum), + MultipleVariants(PartialSum), #[error("Value contained `Bottom`")] ValueIsBottom, #[error("Value contained `Top`")] @@ -209,6 +245,8 @@ pub enum ExtractValueError { CouldNotConvert(V, #[source] VE), #[error("Could not build Sum from concrete element values")] CouldNotBuildSum(#[source] SE), + #[error("Could not convert into concrete function pointer {0}")] + CouldNotLoadFunction(LoadedFunction), #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] BadSumType { typ: Type, @@ -217,14 +255,14 @@ pub enum ExtractValueError { }, } -impl PartialSum { +impl PartialSum { /// If this Sum might have the specified `tag`, get the elements inside that tag. - pub fn variant_values(&self, variant: usize) -> Option>> { + pub fn variant_values(&self, variant: usize) -> Option>> { self.0.get(&variant).cloned() } } -impl PartialOrd for PartialSum { +impl PartialOrd for PartialSum { fn partial_cmp(&self, other: &Self) -> Option { let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); @@ -254,13 +292,13 @@ impl PartialOrd for PartialSum { } } -impl std::fmt::Debug for PartialSum { +impl std::fmt::Debug for PartialSum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl Hash for PartialSum { +impl Hash for PartialSum { fn hash(&self, state: &mut H) { for (k, v) in &self.0 { k.hash(state); @@ -273,30 +311,32 @@ impl Hash for PartialSum { /// for use in dataflow analysis, including that an instance may be a [PartialSum] /// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub enum PartialValue { /// No possibilities known (so far) Bottom, + /// The output of an [LoadFunction](hugr_core::ops::LoadFunction) + LoadedFunction(LoadedFunction), /// A single value (of the underlying representation) Value(V), /// Sum (with at least one, perhaps several, possible tags) of underlying values - PartialSum(PartialSum), + PartialSum(PartialSum), /// Might be more than one distinct value of the underlying type `V` Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { Self::Value(v) } } -impl From> for PartialValue { - fn from(v: PartialSum) -> Self { +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { Self::PartialSum(v) } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { if let Self::PartialSum(ps) = self { ps.assert_invariants(); @@ -312,33 +352,59 @@ impl PartialValue { pub fn new_unit() -> Self { Self::new_variant(0, []) } + + /// New instance of self for a [LoadFunction](hugr_core::ops::LoadFunction) + pub fn new_load(func_node: N, args: impl Into>) -> Self { + Self::LoadedFunction(LoadedFunction { + func_node, + args: args.into(), + }) + } + + /// Tells us whether this value might be a Sum with the specified `tag` + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + false + } + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } } -impl PartialValue { +impl PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { - PartialValue::Bottom | PartialValue::Value(_) => return None, + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + return None + } PartialValue::PartialSum(ps) => ps.variant_values(tag)?, PartialValue::Top => vec![PartialValue::Top; len], }; assert_eq!(vals.len(), len); Some(vals) } +} - /// Tells us whether this value might be a Sum with the specified `tag` - pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, - } - } - +impl PartialValue { /// Turns this instance into some "concrete" value type `C`, *if* it is a single value, /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by /// [PartialSum::try_into_sum]. @@ -348,47 +414,27 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_concrete(self, typ: &Type) -> Result> - where - V: TryInto, - Sum: TryInto, - { + pub fn try_into_concrete>( + self, + typ: &Type, + ) -> Result> { match self { - Self::Value(v) => v - .clone() - .try_into() - .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), - Self::PartialSum(ps) => ps - .try_into_sum(typ)? - .try_into() - .map_err(ExtractValueError::CouldNotBuildSum), + Self::Value(v) => { + C::from_value(v.clone()).map_err(|e| ExtractValueError::CouldNotConvert(v, e)) + } + Self::LoadedFunction(lf) => { + C::from_func(lf).map_err(ExtractValueError::CouldNotLoadFunction) + } + Self::PartialSum(ps) => { + C::from_sum(ps.try_into_sum(typ)?).map_err(ExtractValueError::CouldNotBuildSum) + } Self::Top => Err(ExtractValueError::ValueIsTop), Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } - - /// A value contains bottom means that it cannot occur during execution: - /// it may be an artefact during bootstrapping of the analysis, or else - /// the value depends upon a `panic` or a loop that - /// [never terminates](super::TailLoopTermination::NeverBreaks). - pub fn contains_bottom(&self) -> bool { - match self { - PartialValue::Bottom => true, - PartialValue::Top | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.contains_bottom(), - } - } } -impl TryFrom> for Value { - type Error = ConstTypeError; - - fn try_from(value: Sum) -> Result { - Self::sum(value.tag, value.values, value.st) - } -} - -impl Lattice for PartialValue { +impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); let mut old_self = Self::Top; @@ -400,13 +446,17 @@ impl Lattice for PartialValue { Some((h3, b)) => (Self::Value(h3), b), None => (Self::Top, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also join the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Top, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Top, true) - } + _ => (Self::Top, true), }; *self = res; ch @@ -423,20 +473,24 @@ impl Lattice for PartialValue { Some((h3, ch)) => (Self::Value(h3), ch), None => (Self::Bottom, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also meet the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Bottom, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Bottom, true) - } + _ => (Self::Bottom, true), }; *self = res; ch } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self::Top } @@ -446,7 +500,7 @@ impl BoundedLattice for PartialValue { } } -impl PartialOrd for PartialValue { +impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; match (self, other) { @@ -457,6 +511,9 @@ impl PartialOrd for PartialValue { (Self::Top, _) => Some(Ordering::Greater), (_, Self::Top) => Some(Ordering::Less), (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) => { + (lf1 == lf2).then_some(Ordering::Equal) + } (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } @@ -468,19 +525,20 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; + use hugr_core::NodeIndex; use itertools::{zip_eq, Itertools as _}; use prop::sample::subsequence; use proptest::prelude::*; use proptest_recurse::{StrategyExt, StrategySet}; - use super::{AbstractValue, PartialSum, PartialValue}; + use super::{AbstractValue, LoadedFunction, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { Branch(Vec>>), - /// None => unit, Some => TestValue <= this *usize* - Leaf(Option), + LeafVal(usize), // contains a TestValue <= this usize + LeafPtr(usize), // contains a LoadedFunction with node <= this *usize* } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -509,8 +567,11 @@ mod test { fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), - (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::LeafVal(max), PartialValue::Value(TestValue(val))) => val <= max, + ( + Self::LeafPtr(max), + PartialValue::LoadedFunction(LoadedFunction { func_node, args }), + ) => args.is_empty() && func_node.index() <= *max, (Self::Branch(sop), PartialValue::PartialSum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { @@ -537,8 +598,11 @@ mod test { fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { use proptest::collection::vec; - let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); - let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; + let leaf_strat = prop_oneof![ + (0..usize::MAX).prop_map(TestSumType::LeafVal), + // This is the maximum value accepted by portgraph::NodeIndex::new + (0..((2usize ^ 31) - 2)).prop_map(TestSumType::LeafPtr) + ]; leaf_strat.prop_mutually_recursive( params.depth as u32, params.desired_size as u32, @@ -605,11 +669,18 @@ mod test { ust: &TestSumType, ) -> impl Strategy> { match ust { - TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), - TestSumType::Leaf(Some(i)) => (0..*i) + TestSumType::LeafVal(i) => (0..=*i) .prop_map(TestValue) .prop_map(PartialValue::from) .boxed(), + TestSumType::LeafPtr(i) => (0..=*i) + .prop_map(|i| { + PartialValue::LoadedFunction(LoadedFunction { + func_node: portgraph::NodeIndex::new(i).into(), + args: vec![], + }) + }) + .boxed(), TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), } } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index c40f1d87f..c4a94a9e7 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -1,17 +1,19 @@ use std::collections::HashMap; -use hugr_core::{HugrView, IncomingPort, PortIndex, Wire}; +use hugr_core::{HugrView, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, PartialValue, Sum}; +use super::{ + datalog::InWire, partial_value::ExtractValueError, AbstractValue, AsConcrete, PartialValue, +}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { pub(super) hugr: H, - pub(super) in_wire_value: Vec<(H::Node, IncomingPort, PartialValue)>, + pub(super) in_wire_value: Vec>, pub(super) case_reachable: Vec<(H::Node, H::Node)>, pub(super) bb_reachable: Vec<(H::Node, H::Node)>, - pub(super) out_wire_values: HashMap, PartialValue>, + pub(super) out_wire_values: HashMap, PartialValue>, } impl AnalysisResults { @@ -21,7 +23,7 @@ impl AnalysisResults { } /// Gets the lattice value computed for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() } @@ -84,13 +86,11 @@ impl AnalysisResults { /// `None` if the analysis did not produce a result for that wire, or if /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` - pub fn try_read_wire_concrete( + #[allow(clippy::type_complexity)] + pub fn try_read_wire_concrete>( &self, w: Wire, - ) -> Result>> - where - V2: TryFrom + TryFrom, Error = SE>, - { + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr @@ -116,7 +116,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - fn from_control_value(v: &PartialValue) -> Self { + fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 3af0097f7..1c4b4e439 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,10 +1,12 @@ +use std::convert::Infallible; + use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; +use hugr_core::builder::{inout_sig, CFGBuilder, Container, DataflowHugr, ModuleBuilder}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; -use hugr_core::ops::TailLoop; -use hugr_core::types::TypeRow; +use hugr_core::ops::{CallIndirect, TailLoop}; +use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -19,7 +21,10 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{ + AbstractValue, AsConcrete, ConstLoader, DFContext, LoadedFunction, Machine, PartialValue, Sum, + TailLoopTermination, +}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -35,10 +40,22 @@ impl ConstLoader for TestContext { impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) -impl From for Value { - fn from(v: Void) -> Self { +impl AsConcrete for Value { + type ValErr = Infallible; + + type SumErr = ConstTypeError; + + fn from_value(v: Void) -> Result { match v {} } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) + } } fn pv_false() -> PartialValue { @@ -295,9 +312,7 @@ fn test_conditional() { let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results - .try_read_wire_concrete::(cond_o2) - .is_err()); + assert!(results.try_read_wire_concrete::(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); @@ -547,3 +562,78 @@ fn test_module() { ); } } + +#[rstest] +#[case(pv_false(), pv_false())] +#[case(pv_false(), pv_true())] +#[case(pv_true(), pv_false())] +#[case(pv_true(), pv_true())] +fn call_indirect(#[case] inp1: PartialValue, #[case] inp2: PartialValue) { + let b2b = || Signature::new_endo(bool_t()); + let mut dfb = DFGBuilder::new(inout_sig(vec![bool_t(); 3], vec![bool_t(); 2])).unwrap(); + + let [id1, id2] = ["id1", "[id2]"].map(|name| { + let fb = dfb.define_function(name, b2b()).unwrap(); + let [inp] = fb.input_wires_arr(); + fb.finish_with_outputs([inp]).unwrap() + }); + + let [inp_direct, which, inp_indirect] = dfb.input_wires_arr(); + let [res1] = dfb + .call(id1.handle(), &[], [inp_direct]) + .unwrap() + .outputs_arr(); + + // We'll unconditionally load both functions, to demonstrate that it's + // the CallIndirect that matters, not just which functions are loaded. + let lf1 = dfb.load_func(id1.handle(), &[]).unwrap(); + let lf2 = dfb.load_func(id2.handle(), &[]).unwrap(); + let bool_func = || Type::new_function(b2b()); + let mut cond = dfb + .conditional_builder( + (vec![type_row![]; 2], which), + [(bool_func(), lf1), (bool_func(), lf2)], + bool_func().into(), + ) + .unwrap(); + let case_false = cond.case_builder(0).unwrap(); + let [f0, _f1] = case_false.input_wires_arr(); + case_false.finish_with_outputs([f0]).unwrap(); + let case_true = cond.case_builder(1).unwrap(); + let [_f0, f1] = case_true.input_wires_arr(); + case_true.finish_with_outputs([f1]).unwrap(); + let [tgt] = cond.finish_sub_container().unwrap().outputs_arr(); + let [res2] = dfb + .add_dataflow_op(CallIndirect { signature: b2b() }, [tgt, inp_indirect]) + .unwrap() + .outputs_arr(); + let h = dfb.finish_hugr_with_outputs([res1, res2]).unwrap(); + + let run = |which| { + Machine::new(&h).run( + TestContext, + [ + (0.into(), inp1.clone()), + (1.into(), which), + (2.into(), inp2.clone()), + ], + ) + }; + let (w1, w2) = (Wire::new(h.root(), 0), Wire::new(h.root(), 1)); + + // 1. Test with `which` unknown -> second output unknown + let results = run(PartialValue::Top); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); + + // 2. Test with `which` selecting second function -> both passthrough + let results = run(pv_true()); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(inp2.clone())); + + //3. Test with `which` selecting first function -> alias + let results = run(pv_false()); + let out = Some(inp1.join(inp2)); + assert_eq!(results.read_out_wire(w1), out); + assert_eq!(results.read_out_wire(w2), out); +} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 50cf10318..43c842d91 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -5,25 +5,25 @@ use std::{ ops::{Index, IndexMut}, }; -use ascent::{lattice::BoundedLattice, Lattice}; +use ascent::Lattice; use itertools::zip_eq; use super::{AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] -pub(super) struct ValueRow(Vec>); +pub(super) struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) + Self(vec![PartialValue::Bottom; len]) } - pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { *self.0.get_mut(idx).unwrap() = v; self } - pub fn singleton(v: PartialValue) -> Self { + pub fn singleton(v: PartialValue) -> Self { Self(vec![v]) } @@ -34,25 +34,25 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option>> { + ) -> Option>> { let vals = self[0].variant_values(variant, len)?; Some(vals.into_iter().chain(self.0[1..].to_owned())) } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } } -impl Lattice for ValueRow { +impl Lattice for ValueRow { fn join_mut(&mut self, other: Self) -> bool { assert_eq!(self.0.len(), other.0.len()); let mut changed = false; @@ -72,30 +72,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PartialValue; +impl IntoIterator for ValueRow { + type Item = PartialValue; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec>: Index, + Vec>: Index, { - type Output = > as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec>: IndexMut, + Vec>: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) From 63477565de0dbfb8027736cf905f6f148e2ddcab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 16 Apr 2025 14:52:21 +0100 Subject: [PATCH 06/18] feat: Make NodeHandle generic (#2092) Adds a generic node type to the `NodeHandle` type. This is a required change for #2029. drive-by: Implement the "Link the NodeHandles to the OpType" TODO --- hugr-core/src/ops.rs | 16 +++++++- hugr-core/src/ops/handle.rs | 73 ++++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 0c7d3bb3f..ce0d44de0 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -9,6 +9,7 @@ pub mod module; pub mod sum; pub mod tag; pub mod validate; +use crate::core::HugrNode; use crate::extension::resolution::{ collect_op_extension, collect_op_types_extensions, ExtensionCollectionError, }; @@ -20,6 +21,7 @@ use crate::types::{EdgeKind, Signature, Substitution}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; use derive_more::Display; +use handle::NodeHandle; use paste::paste; use portgraph::NodeIndex; @@ -41,7 +43,6 @@ pub use tag::OpTag; #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] /// The concrete operation types for a node in the HUGR. -// TODO: Link the NodeHandles to the OpType. #[non_exhaustive] #[allow(missing_docs)] #[serde(tag = "op")] @@ -377,6 +378,19 @@ pub trait OpTrait: Sized + Clone { /// Tag identifying the operation. fn tag(&self) -> OpTag; + /// Tries to create a specific [`NodeHandle`] for a node with this operation + /// type. + /// + /// Fails if the operation's [`OpTrait::tag`] does not match the + /// [`NodeHandle::TAG`] of the requested handle. + fn try_node_handle(&self, node: N) -> Option + where + N: HugrNode, + H: NodeHandle + From, + { + H::TAG.is_superset(self.tag()).then(|| node.into()) + } + /// The signature of the operation. /// /// Only dataflow operations have a signature, otherwise returns None. diff --git a/hugr-core/src/ops/handle.rs b/hugr-core/src/ops/handle.rs index d7fe16419..a5a3c294a 100644 --- a/hugr-core/src/ops/handle.rs +++ b/hugr-core/src/ops/handle.rs @@ -1,4 +1,5 @@ //! Handles to nodes in HUGR. +use crate::core::HugrNode; use crate::types::{Type, TypeBound}; use crate::Node; @@ -9,12 +10,12 @@ use super::{AliasDecl, OpTag}; /// Common trait for handles to a node. /// Typically wrappers around [`Node`]. -pub trait NodeHandle: Clone { +pub trait NodeHandle: Clone { /// The most specific operation tag associated with the handle. const TAG: OpTag; /// Index of underlying node. - fn node(&self) -> Node; + fn node(&self) -> N; /// Operation tag for the handle. #[inline] @@ -23,7 +24,7 @@ pub trait NodeHandle: Clone { } /// Cast the handle to a different more general tag. - fn try_cast>(&self) -> Option { + fn try_cast + From>(&self) -> Option { T::TAG.is_superset(Self::TAG).then(|| self.node().into()) } @@ -36,30 +37,30 @@ pub trait NodeHandle: Clone { /// Trait for handles that contain children. /// /// The allowed children handles are defined by the associated type. -pub trait ContainerHandle: NodeHandle { +pub trait ContainerHandle: NodeHandle { /// Handle type for the children of this node. - type ChildrenHandle: NodeHandle; + type ChildrenHandle: NodeHandle; } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowOp](crate::ops::dataflow). -pub struct DataflowOpID(Node); +pub struct DataflowOpID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DFG](crate::ops::DFG) node. -pub struct DfgID(Node); +pub struct DfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [CFG](crate::ops::CFG) node. -pub struct CfgID(Node); +pub struct CfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a module [Module](crate::ops::Module) node. -pub struct ModuleRootID(Node); +pub struct ModuleRootID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [module op](crate::ops::module) node. -pub struct ModuleID(Node); +pub struct ModuleID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [def](crate::ops::OpType::FuncDefn) @@ -67,7 +68,7 @@ pub struct ModuleID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct FuncID(Node); +pub struct FuncID(N); #[derive(Debug, Clone, PartialEq, Eq)] /// Handle to an [AliasDefn](crate::ops::OpType::AliasDefn) @@ -75,15 +76,15 @@ pub struct FuncID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct AliasID { - node: Node, +pub struct AliasID { + node: N, name: SmolStr, bound: TypeBound, } -impl AliasID { +impl AliasID { /// Construct new AliasID - pub fn new(node: Node, name: SmolStr, bound: TypeBound) -> Self { + pub fn new(node: N, name: SmolStr, bound: TypeBound) -> Self { Self { node, name, bound } } @@ -99,27 +100,27 @@ impl AliasID { #[derive(DerFrom, Debug, Clone, PartialEq, Eq)] /// Handle to a [Const](crate::ops::OpType::Const) node. -pub struct ConstID(Node); +pub struct ConstID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowBlock](crate::ops::DataflowBlock) or [Exit](crate::ops::ExitBlock) node. -pub struct BasicBlockID(Node); +pub struct BasicBlockID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Case](crate::ops::Case) node. -pub struct CaseID(Node); +pub struct CaseID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [TailLoop](crate::ops::TailLoop) node. -pub struct TailLoopID(Node); +pub struct TailLoopID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Conditional](crate::ops::Conditional) node. -pub struct ConditionalID(Node); +pub struct ConditionalID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a dataflow container node. -pub struct DataflowParentID(Node); +pub struct DataflowParentID(N); /// Implements the `NodeHandle` trait for a tuple struct that contains just a /// NodeIndex. Takes the name of the struct, and the corresponding OpTag. @@ -131,11 +132,11 @@ macro_rules! impl_nodehandle { impl_nodehandle!($name, $tag, 0); }; ($name:ident, $tag:expr, $node_attr:tt) => { - impl NodeHandle for $name { + impl NodeHandle for $name { const TAG: OpTag = $tag; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.$node_attr } } @@ -156,35 +157,35 @@ impl_nodehandle!(ConstID, OpTag::Const); impl_nodehandle!(BasicBlockID, OpTag::DataflowBlock); -impl NodeHandle for FuncID { +impl NodeHandle for FuncID { const TAG: OpTag = OpTag::Function; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.0 } } -impl NodeHandle for AliasID { +impl NodeHandle for AliasID { const TAG: OpTag = OpTag::Alias; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.node } } -impl NodeHandle for Node { +impl NodeHandle for N { const TAG: OpTag = OpTag::Any; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { *self } } /// Implements the `ContainerHandle` trait, with the given child handle type. macro_rules! impl_containerHandle { - ($name:path, $children:ident) => { - impl ContainerHandle for $name { - type ChildrenHandle = $children; + ($name:ident, $children:ident) => { + impl ContainerHandle for $name { + type ChildrenHandle = $children; } }; } @@ -197,5 +198,9 @@ impl_containerHandle!(CaseID, DataflowOpID); impl_containerHandle!(ModuleRootID, ModuleID); impl_containerHandle!(CfgID, BasicBlockID); impl_containerHandle!(BasicBlockID, DataflowOpID); -impl_containerHandle!(FuncID, DataflowOpID); -impl_containerHandle!(AliasID, DataflowOpID); +impl ContainerHandle for FuncID { + type ChildrenHandle = DataflowOpID; +} +impl ContainerHandle for AliasID { + type ChildrenHandle = DataflowOpID; +} From 5b43c0d351720a2d1ba66053467d64573bbbb9c6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 17 Apr 2025 10:38:12 +0100 Subject: [PATCH 07/18] feat!: remove ExtensionValue (#2093) Closes #1595 BREAKING CHANGE: `values` field in `Extension` and `ExtensionValue` struct/class removed in rust and python. Use 0-input ops that return constant values. --- hugr-core/src/extension.rs | 64 +------------------ .../src/extension/resolution/extension.rs | 11 +--- hugr-core/src/hugr/validate/test.rs | 8 +-- hugr-core/src/std_extensions/logic.rs | 26 +------- hugr-py/src/hugr/_serialization/extension.py | 21 ------ hugr-py/src/hugr/ext.py | 42 +----------- .../_json_defs/arithmetic/conversions.json | 1 - .../hugr/std/_json_defs/arithmetic/float.json | 1 - .../_json_defs/arithmetic/float/types.json | 1 - .../hugr/std/_json_defs/arithmetic/int.json | 1 - .../std/_json_defs/arithmetic/int/types.json | 1 - .../std/_json_defs/collections/array.json | 1 - .../hugr/std/_json_defs/collections/list.json | 1 - .../_json_defs/collections/static_array.json | 1 - hugr-py/src/hugr/std/_json_defs/logic.json | 28 -------- hugr-py/src/hugr/std/_json_defs/prelude.json | 1 - hugr-py/src/hugr/std/_json_defs/ptr.json | 1 - specification/schema/hugr_schema_live.json | 30 --------- .../schema/hugr_schema_strict_live.json | 30 --------- .../schema/testing_hugr_schema_live.json | 30 --------- .../testing_hugr_schema_strict_live.json | 30 --------- .../arithmetic/conversions.json | 1 - .../std_extensions/arithmetic/float.json | 1 - .../arithmetic/float/types.json | 1 - .../std_extensions/arithmetic/int.json | 1 - .../std_extensions/arithmetic/int/types.json | 1 - .../std_extensions/collections/array.json | 1 - .../std_extensions/collections/list.json | 1 - .../collections/static_array.json | 1 - specification/std_extensions/logic.json | 28 -------- specification/std_extensions/prelude.json | 1 - specification/std_extensions/ptr.json | 1 - 32 files changed, 7 insertions(+), 361 deletions(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index b6e059050..23238ccfd 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -19,9 +19,8 @@ use derive_more::Display; use thiserror::Error; use crate::hugr::IdentList; -use crate::ops::constant::{ValueName, ValueNameRef}; use crate::ops::custom::{ExtensionOp, OpaqueOp}; -use crate::ops::{self, OpName, OpNameRef}; +use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::RowVariable; use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; @@ -497,37 +496,6 @@ impl CustomConcrete for CustomType { } } -/// A constant value provided by a extension. -/// Must be an instance of a type available to the extension. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct ExtensionValue { - extension: ExtensionId, - name: ValueName, - typed_value: ops::Value, -} - -impl ExtensionValue { - /// Returns a reference to the typed value of this [`ExtensionValue`]. - pub fn typed_value(&self) -> &ops::Value { - &self.typed_value - } - - /// Returns a mutable reference to the typed value of this [`ExtensionValue`]. - pub(super) fn typed_value_mut(&mut self) -> &mut ops::Value { - &mut self.typed_value - } - - /// Returns a reference to the name of this [`ExtensionValue`]. - pub fn name(&self) -> &str { - self.name.as_str() - } - - /// Returns a reference to the extension this [`ExtensionValue`] belongs to. - pub fn extension(&self) -> &ExtensionId { - &self.extension - } -} - /// A unique identifier for a extension. /// /// The actual [`Extension`] is stored externally. @@ -583,8 +551,6 @@ pub struct Extension { pub runtime_reqs: ExtensionSet, /// Types defined by this extension. types: BTreeMap, - /// Static values defined by this extension. - values: BTreeMap, /// Operation declarations with serializable definitions. // Note: serde will serialize this because we configure with `features=["rc"]`. // That will clone anything that has multiple references, but each @@ -608,7 +574,6 @@ impl Extension { version, runtime_reqs: Default::default(), types: Default::default(), - values: Default::default(), operations: Default::default(), } } @@ -680,11 +645,6 @@ impl Extension { self.types.get(type_name) } - /// Allows read-only access to the values in this Extension - pub fn get_value(&self, value_name: &ValueNameRef) -> Option<&ExtensionValue> { - self.values.get(value_name) - } - /// Returns the name of the extension. pub fn name(&self) -> &ExtensionId { &self.name @@ -705,25 +665,6 @@ impl Extension { self.types.iter() } - /// Add a named static value to the extension. - pub fn add_value( - &mut self, - name: impl Into, - typed_value: ops::Value, - ) -> Result<&mut ExtensionValue, ExtensionBuildError> { - let extension_value = ExtensionValue { - extension: self.name.clone(), - name: name.into(), - typed_value, - }; - match self.values.entry(extension_value.name.clone()) { - btree_map::Entry::Occupied(_) => { - Err(ExtensionBuildError::ValueExists(extension_value.name)) - } - btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)), - } - } - /// Instantiate an [`ExtensionOp`] which references an [`OpDef`] in this extension. pub fn instantiate_extension_op( &self, @@ -784,9 +725,6 @@ pub enum ExtensionBuildError { /// Existing [`TypeDef`] #[error("Extension already has an type called {0}.")] TypeDefExists(TypeName), - /// Existing [`ExtensionValue`] - #[error("Extension already has an extension value called {0}.")] - ValueExists(ValueName), } /// A set of extensions identified by their unique [`ExtensionId`]. diff --git a/hugr-core/src/extension/resolution/extension.rs b/hugr-core/src/extension/resolution/extension.rs index 61adc1dea..05c0faf69 100644 --- a/hugr-core/src/extension/resolution/extension.rs +++ b/hugr-core/src/extension/resolution/extension.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef}; -use super::types_mut::{resolve_signature_exts, resolve_value_exts}; +use super::types_mut::resolve_signature_exts; use super::{ExtensionResolutionError, WeakExtensionRegistry}; impl ExtensionRegistry { @@ -59,14 +59,7 @@ impl Extension { for type_def in self.types.values_mut() { resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?; } - for val in self.values.values_mut() { - resolve_value_exts( - None, - val.typed_value_mut(), - extensions, - &mut used_extensions, - )?; - } + let ops = mem::take(&mut self.operations); for (op_id, mut op_def) in ops { // TODO: We should be able to clone the definition if needed by using `make_mut`, diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index ecb417ec5..37157020d 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -20,7 +20,6 @@ use crate::ops::handle::NodeHandle; use crate::ops::{self, OpType, Value}; use crate::std_extensions::logic::test::{and_op, or_op}; use crate::std_extensions::logic::LogicOp; -use crate::std_extensions::logic::{self}; use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, @@ -307,12 +306,7 @@ fn test_local_const() { port_kind: EdgeKind::Value(bool_t()) }) ); - let const_op: ops::Const = logic::EXTENSION - .get_value(&logic::TRUE_NAME) - .unwrap() - .typed_value() - .clone() - .into(); + let const_op: ops::Const = ops::Value::from_bool(true).into(); // Second input of Xor from a constant let cst = h.add_node_with_parent(h.root(), const_op); let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: bool_t() }); diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index fcc8be9d3..20977cb51 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -124,13 +124,6 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); fn extension() -> Arc { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { LogicOp::load_all_ops(extension, extension_ref).unwrap(); - - extension - .add_value(FALSE_NAME, ops::Value::false_val()) - .unwrap(); - extension - .add_value(TRUE_NAME, ops::Value::true_val()) - .unwrap(); }) } @@ -172,12 +165,9 @@ fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option> { pub(crate) mod test { use std::sync::Arc; - use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME}; + use super::{extension, LogicOp}; use crate::{ - extension::{ - prelude::bool_t, - simple_op::{MakeOpDef, MakeRegisteredOp}, - }, + extension::simple_op::{MakeOpDef, MakeRegisteredOp}, ops::{NamedOp, Value}, Extension, }; @@ -207,18 +197,6 @@ pub(crate) mod test { } } - #[test] - fn test_values() { - let r: Arc = extension(); - let false_val = r.get_value(&FALSE_NAME).unwrap(); - let true_val = r.get_value(&TRUE_NAME).unwrap(); - - for v in [false_val, true_val] { - let simpl = v.typed_value().get_type(); - assert_eq!(simpl, bool_t()); - } - } - /// Generate a logic extension "and" operation over [`crate::prelude::bool_t()`] pub(crate) fn and_op() -> LogicOp { LogicOp::And diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 429bdd785..95e59754e 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -8,7 +8,6 @@ from hugr.hugr.base import Hugr from hugr.utils import deser_it -from .ops import Value from .serial_hugr import SerialHugr, serialization_version from .tys import ( ConfiguredBaseModel, @@ -20,7 +19,6 @@ ) if TYPE_CHECKING: - from .ops import Value from .serial_hugr import SerialHugr @@ -62,20 +60,6 @@ def deserialize(self, extension: ext.Extension) -> ext.TypeDef: ) -class ExtensionValue(ConfiguredBaseModel): - extension: ExtensionId - name: str - typed_value: Value - - def deserialize(self, extension: ext.Extension) -> ext.ExtensionValue: - return extension.add_extension_value( - ext.ExtensionValue( - name=self.name, - val=self.typed_value.deserialize(), - ) - ) - - # -------------------------------------- # --------------- OpDef ---------------- # -------------------------------------- @@ -124,7 +108,6 @@ class Extension(ConfiguredBaseModel): name: ExtensionId runtime_reqs: set[ExtensionId] types: dict[str, TypeDef] - values: dict[str, ExtensionValue] operations: dict[str, OpDef] @classmethod @@ -146,10 +129,6 @@ def deserialize(self) -> ext.Extension: assert k == o.name, "Operation name must match key" e.add_op_def(o.deserialize(e)) - for k, v in self.values.items(): - assert k == v.name, "Value name must match key" - e.add_extension_value(v.deserialize(e)) - return e diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 494ea3c69..7bd02f982 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -8,7 +8,7 @@ from semver import Version import hugr._serialization.extension as ext_s -from hugr import ops, tys, val +from hugr import ops, tys from hugr.utils import ser_it __all__ = [ @@ -18,7 +18,6 @@ "FixedHugr", "OpDefSig", "OpDef", - "ExtensionValue", "Extension", "Version", ] @@ -246,23 +245,6 @@ def instantiate( return ops.ExtOp(self, concrete_signature, list(args or [])) -@dataclass -class ExtensionValue(ExtensionObject): - """A value defined in an :class:`Extension`.""" - - #: The name of the value. - name: str - #: Value payload. - val: val.Value - - def _to_serial(self) -> ext_s.ExtensionValue: - return ext_s.ExtensionValue( - extension=self.get_extension().name, - name=self.name, - typed_value=self.val._to_serial_root(), - ) - - T = TypeVar("T", bound=ops.RegisteredOp) @@ -278,8 +260,6 @@ class Extension: runtime_reqs: set[ExtensionId] = field(default_factory=set) #: Type definitions in the extension. types: dict[str, TypeDef] = field(default_factory=dict) - #: Values defined in the extension. - values: dict[str, ExtensionValue] = field(default_factory=dict) #: Operation definitions in the extension. operations: dict[str, OpDef] = field(default_factory=dict) @@ -295,7 +275,6 @@ def _to_serial(self) -> ext_s.Extension: version=self.version, # type: ignore[arg-type] runtime_reqs=self.runtime_reqs, types={k: v._to_serial() for k, v in self.types.items()}, - values={k: v._to_serial() for k, v in self.values.items()}, operations={k: v._to_serial() for k, v in self.operations.items()}, ) @@ -347,19 +326,6 @@ def add_type_def(self, type_def: TypeDef) -> TypeDef: self.types[type_def.name] = type_def return self.types[type_def.name] - def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue: - """Add a value to the extension. - - Args: - extension_value: The value to add. - - Returns: - The added value, now associated with the extension. - """ - extension_value._extension = self - self.values[extension_value.name] = extension_value - return self.values[extension_value.name] - @dataclass class OperationNotFound(NotFound): """Operation not found in extension.""" @@ -406,12 +372,6 @@ def get_type(self, name: str) -> TypeDef: class ValueNotFound(NotFound): """Value not found in extension.""" - def get_value(self, name: str) -> ExtensionValue: - try: - return self.values[name] - except KeyError as e: - raise self.ValueNotFound(name) from e - T = TypeVar("T", bound=ops.RegisteredOp) def register_op( diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json index 9c0054354..1d310df25 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json @@ -6,7 +6,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json index 31ccaaa59..8da056772 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json index 56e35c50b..0c563c474 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json @@ -14,6 +14,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json index 62d0a6663..5b1a81250 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json index 60cf69f63..36df125a6 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json @@ -19,6 +19,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/collections/array.json b/hugr-py/src/hugr/std/_json_defs/collections/array.json index 21e405151..375e13c72 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/array.json @@ -25,7 +25,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/list.json b/hugr-py/src/hugr/std/_json_defs/collections/list.json index 0fbafc638..8a60d3544 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/list.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/list.json @@ -21,7 +21,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json index e4669f671..53b8e61c7 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", diff --git a/hugr-py/src/hugr/std/_json_defs/logic.json b/hugr-py/src/hugr/std/_json_defs/logic.json index ad9f02019..ff29d2c21 100644 --- a/hugr-py/src/hugr/std/_json_defs/logic.json +++ b/hugr-py/src/hugr/std/_json_defs/logic.json @@ -3,34 +3,6 @@ "name": "logic", "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index e11ba2388..ec392b155 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -44,7 +44,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", diff --git a/hugr-py/src/hugr/std/_json_defs/ptr.json b/hugr-py/src/hugr/std/_json_defs/ptr.json index 18b1f26b6..614b6aecf 100644 --- a/hugr-py/src/hugr/std/_json_defs/ptr.json +++ b/hugr-py/src/hugr/std/_json_defs/ptr.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index 9e7d8c40c..ea08dff5b 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": true, "properties": { diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 6f436f969..8b65bae94 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": false, "properties": { diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index bc067d40e..91b121da6 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": true, "properties": { diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index 47c9778d3..eae6a13a7 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": false, "properties": { diff --git a/specification/std_extensions/arithmetic/conversions.json b/specification/std_extensions/arithmetic/conversions.json index 9c0054354..1d310df25 100644 --- a/specification/std_extensions/arithmetic/conversions.json +++ b/specification/std_extensions/arithmetic/conversions.json @@ -6,7 +6,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", diff --git a/specification/std_extensions/arithmetic/float.json b/specification/std_extensions/arithmetic/float.json index 31ccaaa59..8da056772 100644 --- a/specification/std_extensions/arithmetic/float.json +++ b/specification/std_extensions/arithmetic/float.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", diff --git a/specification/std_extensions/arithmetic/float/types.json b/specification/std_extensions/arithmetic/float/types.json index 56e35c50b..0c563c474 100644 --- a/specification/std_extensions/arithmetic/float/types.json +++ b/specification/std_extensions/arithmetic/float/types.json @@ -14,6 +14,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/arithmetic/int.json b/specification/std_extensions/arithmetic/int.json index 62d0a6663..5b1a81250 100644 --- a/specification/std_extensions/arithmetic/int.json +++ b/specification/std_extensions/arithmetic/int.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", diff --git a/specification/std_extensions/arithmetic/int/types.json b/specification/std_extensions/arithmetic/int/types.json index 60cf69f63..36df125a6 100644 --- a/specification/std_extensions/arithmetic/int/types.json +++ b/specification/std_extensions/arithmetic/int/types.json @@ -19,6 +19,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/collections/array.json b/specification/std_extensions/collections/array.json index 21e405151..375e13c72 100644 --- a/specification/std_extensions/collections/array.json +++ b/specification/std_extensions/collections/array.json @@ -25,7 +25,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", diff --git a/specification/std_extensions/collections/list.json b/specification/std_extensions/collections/list.json index 0fbafc638..8a60d3544 100644 --- a/specification/std_extensions/collections/list.json +++ b/specification/std_extensions/collections/list.json @@ -21,7 +21,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", diff --git a/specification/std_extensions/collections/static_array.json b/specification/std_extensions/collections/static_array.json index e4669f671..53b8e61c7 100644 --- a/specification/std_extensions/collections/static_array.json +++ b/specification/std_extensions/collections/static_array.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index ad9f02019..ff29d2c21 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -3,34 +3,6 @@ "name": "logic", "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index e11ba2388..ec392b155 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -44,7 +44,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", diff --git a/specification/std_extensions/ptr.json b/specification/std_extensions/ptr.json index 18b1f26b6..614b6aecf 100644 --- a/specification/std_extensions/ptr.json +++ b/specification/std_extensions/ptr.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", From 89c2680912b47950ffd73a7c29a21386fdd0aee7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Apr 2025 17:16:31 +0100 Subject: [PATCH 08/18] feat!: ComposablePass trait allowing sequencing and validation (#1895) Currently We have several "passes": monomorphization, dead function removal, constant folding. Each has its own code to allow setting a validation level (before and after that pass). This PR adds the ability chain (sequence) passes;, and to add validation before+after any pass or sequence; and commons up validation code. The top-level `constant_fold_pass` (etc.) functions are left as wrappers that do a single pass with validation only in test. I've left ConstFoldPass as always including DCE, but an alternative could be to return a sequence of the two - ATM that means a tuple `(ConstFoldPass, DeadCodeElimPass)`. I also wondered about including a method `add_entry_point` in ComposablePass (e.g. for ConstFoldPass, that means `with_inputs` but no inputs, i.e. all Top). I feel this is not applicable to *all* passes, but near enough. This could be done in a later PR but `add_entry_point` would need a no-op default for that to be a non-breaking change. So if we wouldn't be happy with the no-op default then I could just add it here... Finally...docs are extremely minimal ATM (this is hugr-passes), I am hoping that most of this is reasonably obvious (it doesn't really do a lot!), but please flag anything you think is particularly in need of a doc comment! BREAKING CHANGE: quite a lot of calls to current pass routines will break, specific cases include (a) `with_validation_level` should be done by wrapping a ValidatingPass around the receiver; (b) XXXPass::run() requires `use ...ComposablePass` (however, such calls will cease to do any validation). closes #1832 --- hugr-passes/src/composable.rs | 361 +++++++++++++++++++++ hugr-passes/src/const_fold.rs | 45 +-- hugr-passes/src/const_fold/test.rs | 1 + hugr-passes/src/dead_code.rs | 50 ++- hugr-passes/src/dead_funcs.rs | 77 ++--- hugr-passes/src/lib.rs | 12 +- hugr-passes/src/monomorphize.rs | 92 ++---- hugr-passes/src/replace_types.rs | 105 +++--- hugr-passes/src/replace_types/linearize.rs | 2 +- hugr-passes/src/untuple.rs | 70 ++-- 10 files changed, 550 insertions(+), 265 deletions(-) create mode 100644 hugr-passes/src/composable.rs diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs new file mode 100644 index 000000000..fb3319155 --- /dev/null +++ b/hugr-passes/src/composable.rs @@ -0,0 +1,361 @@ +//! Compiler passes and utilities for composing them + +use std::{error::Error, marker::PhantomData}; + +use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; +use hugr_core::HugrView; +use itertools::Either; + +/// An optimization pass that can be sequenced with another and/or wrapped +/// e.g. by [ValidatingPass] +pub trait ComposablePass: Sized { + type Error: Error; + type Result; // Would like to default to () but currently unstable + + fn run(&self, hugr: &mut impl HugrMut) -> Result; + + fn map_err( + self, + f: impl Fn(Self::Error) -> E2, + ) -> impl ComposablePass { + ErrMapper::new(self, f) + } + + /// Returns a [ComposablePass] that does "`self` then `other`", so long as + /// `other::Err` can be combined with ours. + fn then>( + self, + other: P, + ) -> impl ComposablePass { + struct Sequence(P1, P2, PhantomData); + impl ComposablePass for Sequence + where + P1: ComposablePass, + P2: ComposablePass, + E: ErrorCombiner, + { + type Error = E; + + type Result = (P1::Result, P2::Result); + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res1 = self.0.run(hugr).map_err(E::from_first)?; + let res2 = self.1.run(hugr).map_err(E::from_second)?; + Ok((res1, res2)) + } + } + + Sequence(self, other, PhantomData) + } +} + +/// Trait for combining the error types from two different passes +/// into a single error. +pub trait ErrorCombiner: Error { + fn from_first(a: A) -> Self; + fn from_second(b: B) -> Self; +} + +impl> ErrorCombiner for A { + fn from_first(a: A) -> Self { + a + } + + fn from_second(b: B) -> Self { + b.into() + } +} + +impl ErrorCombiner for Either { + fn from_first(a: A) -> Self { + Either::Left(a) + } + + fn from_second(b: B) -> Self { + Either::Right(b) + } +} + +// Note: in the short term we could wish for two more impls: +// impl ErrorCombiner for E +// impl ErrorCombiner for E +// however, these aren't possible as they conflict with +// impl> ErrorCombiner for A +// when A=E=Infallible, boo :-(. +// However this will become possible, indeed automatic, when Infallible is replaced +// by ! (never_type) as (unlike Infallible) ! converts Into anything + +// ErrMapper ------------------------------ +struct ErrMapper(P, F, PhantomData); + +impl E> ErrMapper { + fn new(pass: P, err_fn: F) -> Self { + Self(pass, err_fn, PhantomData) + } +} + +impl E> ComposablePass for ErrMapper { + type Error = E; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.0.run(hugr).map_err(&self.1) + } +} + +// ValidatingPass ------------------------------ + +/// Error from a [ValidatingPass] +#[derive(thiserror::Error, Debug)] +pub enum ValidatePassError { + #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] + Input { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] + Output { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error(transparent)] + Underlying(#[from] E), +} + +/// Runs an underlying pass, but with validation of the Hugr +/// both before and afterwards. +pub struct ValidatingPass

(P, bool); + +impl ValidatingPass

{ + pub fn new_default(underlying: P) -> Self { + // Self(underlying, cfg!(feature = "extension_inference")) + // Sadly, many tests fail with extension inference, hence: + Self(underlying, false) + } + + pub fn new_validating_extensions(underlying: P) -> Self { + Self(underlying, true) + } + + pub fn new(underlying: P, validate_extensions: bool) -> Self { + Self(underlying, validate_extensions) + } + + fn validation_impl( + &self, + hugr: &impl HugrView, + mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, + ) -> Result<(), ValidatePassError> { + match self.1 { + false => hugr.validate_no_extensions(), + true => hugr.validate(), + } + .map_err(|err| mk_err(err, hugr.mermaid_string())) + } +} + +impl ComposablePass for ValidatingPass

{ + type Error = ValidatePassError; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { + err, + pretty_hugr, + })?; + let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?; + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output { + err, + pretty_hugr, + })?; + Ok(res) + } +} + +// IfThen ------------------------------ +/// [ComposablePass] that executes a first pass that returns a `bool` +/// result; and then, if-and-only-if that first result was true, +/// executes a second pass +pub struct IfThen(A, B, PhantomData); + +impl, B: ComposablePass, E: ErrorCombiner> + IfThen +{ + /// Make a new instance given the [ComposablePass] to run first + /// and (maybe) second + pub fn new(fst: A, opt_snd: B) -> Self { + Self(fst, opt_snd, PhantomData) + } +} + +impl, B: ComposablePass, E: ErrorCombiner> + ComposablePass for IfThen +{ + type Error = E; + + type Result = Option; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?; + res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second)) + .transpose() + } +} + +pub(crate) fn validate_if_test( + pass: P, + hugr: &mut impl HugrMut, +) -> Result> { + if cfg!(test) { + ValidatingPass::new_default(pass).run(hugr) + } else { + pass.run(hugr).map_err(ValidatePassError::Underlying) + } +} + +#[cfg(test)] +mod test { + use itertools::{Either, Itertools}; + use std::convert::Infallible; + + use hugr_core::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, + }; + use hugr_core::extension::prelude::{ + bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple, PRELUDE_ID, + }; + use hugr_core::hugr::hugrmut::HugrMut; + use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG}; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::types::{Signature, TypeRow}; + use hugr_core::{Hugr, HugrView, IncomingPort}; + + use crate::const_fold::{ConstFoldError, ConstantFoldPass}; + use crate::untuple::{UntupleRecursive, UntupleResult}; + use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass}; + + use super::{validate_if_test, ComposablePass, IfThen, ValidatePassError, ValidatingPass}; + + #[test] + fn test_then() { + let mut mb = ModuleBuilder::new(); + let id1 = mb + .define_function("id1", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id1.input_wires(); + let id1 = id1.finish_with_outputs(inps).unwrap(); + let id2 = mb + .define_function("id2", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id2.input_wires(); + let id2 = id2.finish_with_outputs(inps).unwrap(); + let hugr = mb.finish_hugr().unwrap(); + + let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]); + let cfold = + ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]); + + cfold.run(&mut hugr.clone()).unwrap(); + + let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE); + let r: Result<_, Either> = + dce.clone().then(cfold.clone()).run(&mut hugr.clone()); + assert_eq!(r, Err(Either::Right(exp_err.clone()))); + + let r = dce + .clone() + .map_err(|inf| match inf {}) + .then(cfold.clone()) + .run(&mut hugr.clone()); + assert_eq!(r, Err(exp_err)); + + let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone()); + r2.unwrap(); + } + + #[test] + fn test_validation() { + let mut h = Hugr::new(DFG { + signature: Signature::new(usize_t(), bool_t()), + }); + let inp = h.add_node_with_parent( + h.root(), + Input { + types: usize_t().into(), + }, + ); + let outp = h.add_node_with_parent( + h.root(), + Output { + types: bool_t().into(), + }, + ); + h.connect(inp, 0, outp, 0); + let backup = h.clone(); + let err = backup.validate().unwrap_err(); + + let no_inputs: [(IncomingPort, _); 0] = []; + let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs); + cfold.run(&mut h).unwrap(); + assert_eq!(h, backup); // Did nothing + + let r = ValidatingPass(cfold, false).run(&mut h); + assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); + } + + #[test] + fn test_if_then() { + let tr = TypeRow::from(vec![usize_t(); 2]); + + let h = { + let sig = Signature::new_endo(tr.clone()).with_extension_delta(PRELUDE_ID); + let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap(); + let [a, b] = fb.input_wires_arr(); + let tup = fb + .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b]) + .unwrap(); + let untup = fb + .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs()) + .unwrap(); + fb.finish_hugr_with_outputs(untup.outputs()).unwrap() + }; + + let untup = UntuplePass::new(UntupleRecursive::Recursive); + { + // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple + let mut repl = ReplaceTypes::default(); + let usize_custom_t = usize_t().as_extension().unwrap().clone(); + repl.replace_type(usize_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup.clone()); + + let mut h = h.clone(); + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!( + r, + Some(UntupleResult { + rewrites_applied: 1 + }) + ); + let [tuple_in, tuple_out] = h.children(h.root()).collect_array().unwrap(); + assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]); + } + + // Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple + let mut repl = ReplaceTypes::default(); + let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone(); + repl.replace_type(i32_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup); + let mut h = h; + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!(r, None); + assert_eq!(h.children(h.root()).count(), 4); + let mktup = h + .output_neighbours(h.first_child(h.root()).unwrap()) + .next() + .unwrap(); + assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr))); + } +} diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index e73e3cd0e..99ccc180c 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -21,12 +21,11 @@ use crate::dataflow::{ TailLoopTermination, }; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{composable::validate_if_test, ComposablePass}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. pub struct ConstantFoldPass { - validation: ValidationLevel, allow_increase_termination: bool, /// Each outer key Node must be either: /// - a FuncDefn child of the root, if the root is a module; or @@ -34,13 +33,10 @@ pub struct ConstantFoldPass { inputs: HashMap>, } -#[derive(Debug, Error)] +#[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] /// Errors produced by [ConstantFoldPass]. pub enum ConstFoldError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), /// Error raised when a Node is specified as an entry-point but /// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor /// a [Conditional](OpType::Conditional). @@ -49,12 +45,6 @@ pub enum ConstFoldError { } impl ConstantFoldPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their /// result (if/when they do terminate) is either known or not needed. /// @@ -86,9 +76,19 @@ impl ConstantFoldPass { .extend(inputs.into_iter().map(|(p, v)| (p.into(), v))); self } +} + +impl ComposablePass for ConstantFoldPass { + type Error = ConstFoldError; + type Result = (); /// Run the Constant Folding pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + /// + /// # Errors + /// + /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] + /// was of an invalid [OpType] + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -164,23 +164,10 @@ impl ConstantFoldPass { } }) }) - .run(hugr)?; + .run(hugr) + .map_err(|inf| match inf {})?; // TODO use into_ok when available Ok(()) } - - /// Run the pass using this configuration. - /// - /// # Errors - /// - /// [ConstFoldError::ValidationError] if the Hugr does not validate before/afnerwards - /// (if [Self::validation_level] is set, or in tests) - /// - /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] - /// was of an invalid OpType - pub fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } /// Exhaustively apply constant folding to a HUGR. @@ -198,7 +185,7 @@ pub fn constant_fold_pass(h: &mut H) { } else { c }; - c.run(h).unwrap() + validate_if_test(c, h).unwrap() } struct ConstFoldContext; diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 58e69c568..ff5cd93a5 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -32,6 +32,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::ComposablePass as _; use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index b714dd6fd..899e30243 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -1,13 +1,14 @@ //! Pass for removing dead code, i.e. that computes values that are then discarded use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, Hugr, HugrView, Node}; +use std::convert::Infallible; use std::fmt::{Debug, Formatter}; use std::{ collections::{HashMap, HashSet, VecDeque}, sync::Arc, }; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration for Dead Code Elimination pass #[derive(Clone)] @@ -18,7 +19,6 @@ pub struct DeadCodeElimPass { /// Callback identifying nodes that must be preserved even if their /// results are not used. Defaults to [PreserveNode::default_for]. preserve_callback: Arc, - validation: ValidationLevel, } impl Default for DeadCodeElimPass { @@ -26,7 +26,6 @@ impl Default for DeadCodeElimPass { Self { entry_points: Default::default(), preserve_callback: Arc::new(PreserveNode::default_for), - validation: ValidationLevel::default(), } } } @@ -39,13 +38,11 @@ impl Debug for DeadCodeElimPass { #[derive(Debug)] struct DCEDebug<'a> { entry_points: &'a Vec, - validation: ValidationLevel, } Debug::fmt( &DCEDebug { entry_points: &self.entry_points, - validation: self.validation, }, f, ) @@ -86,13 +83,6 @@ impl PreserveNode { } impl DeadCodeElimPass { - /// Sets the validation level used before and after the pass is run - #[allow(unused)] - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows setting a callback that determines whether a node must be preserved /// (even when its result is not used) pub fn set_preserve_callback(mut self, cb: Arc) -> Self { @@ -146,24 +136,6 @@ impl DeadCodeElimPass { needed } - pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - self.validation.run_validated_pass(hugr, |h, _| { - self.run_no_validate(h); - Ok(()) - }) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) { - let needed = self.find_needed_nodes(&*hugr); - let remove = hugr - .nodes() - .filter(|n| !needed.contains(n)) - .collect::>(); - for n in remove { - hugr.remove_node(n); - } - } - fn must_preserve( &self, h: &impl HugrView, @@ -185,6 +157,22 @@ impl DeadCodeElimPass { } } +impl ComposablePass for DeadCodeElimPass { + type Error = Infallible; + type Result = (); + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { + let needed = self.find_needed_nodes(&*hugr); + let remove = hugr + .nodes() + .filter(|n| !needed.contains(n)) + .collect::>(); + for n in remove { + hugr.remove_node(n); + } + Ok(()) + } +} #[cfg(test)] mod test { use std::sync::Arc; @@ -196,6 +184,8 @@ mod test { use hugr_core::{ops::Value, type_row, HugrView}; use itertools::Itertools; + use crate::ComposablePass; + use super::{DeadCodeElimPass, PreserveNode}; #[test] diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index b114a9e42..7071d5335 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -10,7 +10,10 @@ use hugr_core::{ }; use petgraph::visit::{Dfs, Walker}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{ + composable::{validate_if_test, ValidatePassError}, + ComposablePass, +}; use super::call_graph::{CallGraph, CallGraphNode}; @@ -26,9 +29,6 @@ pub enum RemoveDeadFuncsError { /// The invalid node. node: N, }, - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), } fn reachable_funcs<'a, H: HugrView>( @@ -64,17 +64,10 @@ fn reachable_funcs<'a, H: HugrView>( #[derive(Debug, Clone, Default)] /// A configuration for the Dead Function Removal pass. pub struct RemoveDeadFuncsPass { - validation: ValidationLevel, entry_points: Vec, } impl RemoveDeadFuncsPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Adds new entry points - these must be [FuncDefn] nodes /// that are children of the [Module] at the root of the Hugr. /// @@ -87,16 +80,32 @@ impl RemoveDeadFuncsPass { self.entry_points.extend(entry_points); self } +} - /// Runs the pass (see [remove_dead_funcs]) with this configuration - pub fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { - self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { - remove_dead_funcs(hugr, self.entry_points.iter().cloned()) - }) +impl ComposablePass for RemoveDeadFuncsPass { + type Error = RemoveDeadFuncsError; + type Result = (); + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { + let reachable = reachable_funcs( + &CallGraph::new(hugr), + hugr, + self.entry_points.iter().cloned(), + )? + .collect::>(); + let unreachable = hugr + .nodes() + .filter(|n| { + OpTag::Function.is_superset(hugr.get_optype(*n).tag()) && !reachable.contains(n) + }) + .collect::>(); + for n in unreachable { + hugr.remove_subtree(n); + } + Ok(()) } } -/// Delete from the Hugr any functions that are not used by either [Call] or +/// Deletes from the Hugr any functions that are not used by either [Call] or /// [LoadFunction] nodes in reachable parts. /// /// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points, @@ -118,16 +127,11 @@ impl RemoveDeadFuncsPass { pub fn remove_dead_funcs( h: &mut impl HugrMut, entry_points: impl IntoIterator, -) -> Result<(), RemoveDeadFuncsError> { - let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::>(); - let unreachable = h - .nodes() - .filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n)) - .collect::>(); - for n in unreachable { - h.remove_subtree(n); - } - Ok(()) +) -> Result<(), ValidatePassError> { + validate_if_test( + RemoveDeadFuncsPass::default().with_module_entry_points(entry_points), + h, + ) } #[cfg(test)] @@ -142,7 +146,7 @@ mod test { }; use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView}; - use super::RemoveDeadFuncsPass; + use super::remove_dead_funcs; #[rstest] #[case([], vec![])] // No entry_points removes everything! @@ -182,15 +186,14 @@ mod test { }) .collect::>(); - RemoveDeadFuncsPass::default() - .with_module_entry_points( - entry_points - .into_iter() - .map(|name| *avail_funcs.get(name).unwrap()) - .collect::>(), - ) - .run(&mut hugr) - .unwrap(); + remove_dead_funcs( + &mut hugr, + entry_points + .into_iter() + .map(|name| *avail_funcs.get(name).unwrap()) + .collect::>(), + ) + .unwrap(); let remaining_funcs = hugr .nodes() diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 961c4da47..83ff71b67 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,8 @@ //! Compilation passes acting on the HUGR program representation. pub mod call_graph; +pub mod composable; +pub use composable::ComposablePass; pub mod const_fold; pub mod dataflow; pub mod dead_code; @@ -21,19 +23,11 @@ pub mod untuple; )] #[allow(deprecated)] pub use monomorphize::remove_polyfuncs; -// TODO: Deprecated re-export. Remove on a breaking release. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -#[allow(deprecated)] -pub use monomorphize::monomorphize; -pub use monomorphize::{MonomorphizeError, MonomorphizePass}; +pub use monomorphize::{monomorphize, MonomorphizePass}; pub mod replace_types; pub use replace_types::ReplaceTypes; pub mod nest_cfgs; pub mod non_local; -pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 4f4e9bda2..875ee9355 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -1,5 +1,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, + convert::Infallible, fmt::Write, ops::Deref, }; @@ -12,7 +13,9 @@ use hugr_core::{ use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType}; use itertools::Itertools as _; -use thiserror::Error; + +use crate::composable::{validate_if_test, ValidatePassError}; +use crate::ComposablePass; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -30,26 +33,8 @@ use thiserror::Error; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`. -pub fn monomorphize(mut h: Hugr) -> Hugr { - monomorphize_ref(&mut h); - h -} - -fn monomorphize_ref(h: &mut impl HugrMut) { - let root = h.root(); - // If the root is a polymorphic function, then there are no external calls, so nothing to do - if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(h, root, None, &mut HashMap::new()); - if !h.get_optype(root).is_module() { - #[allow(deprecated)] // TODO remove in next breaking release and update docs - remove_polyfuncs_ref(h); - } - } +pub fn monomorphize(hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + validate_if_test(MonomorphizePass, hugr) } /// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have @@ -254,8 +239,6 @@ fn instantiate( mono_tgt } -use crate::validation::{ValidatePassError, ValidationLevel}; - /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. /// @@ -271,38 +254,25 @@ use crate::validation::{ValidatePassError, ValidationLevel}; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[derive(Debug, Clone, Default)] -pub struct MonomorphizePass { - validation: ValidationLevel, -} - -#[derive(Debug, Error)] -#[non_exhaustive] -/// Errors produced by [MonomorphizePass]. -pub enum MonomorphizeError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), -} - -impl MonomorphizePass { - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - - /// Run the Monomorphization pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> { - monomorphize_ref(hugr); +#[derive(Debug, Clone)] +pub struct MonomorphizePass; + +impl ComposablePass for MonomorphizePass { + type Error = Infallible; + type Result = (); + + fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + let root = h.root(); + // If the root is a polymorphic function, then there are no external calls, so nothing to do + if !is_polymorphic_funcdefn(h.get_optype(root)) { + mono_scan(h, root, None, &mut HashMap::new()); + if !h.get_optype(root).is_module() { + #[allow(deprecated)] // TODO remove in next breaking release and update docs + remove_polyfuncs_ref(h); + } + } Ok(()) } - - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result<(), MonomorphizeError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } struct TypeArgsList<'a>(&'a [TypeArg]); @@ -387,9 +357,9 @@ mod test { use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use crate::remove_dead_funcs; + use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_inner_func, mangle_name, MonomorphizePass}; + use super::{is_polymorphic, mangle_inner_func, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -410,7 +380,7 @@ mod test { let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); let mut hugr2 = hugr.clone(); - MonomorphizePass::default().run(&mut hugr2).unwrap(); + monomorphize(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -472,7 +442,7 @@ mod test { .count(), 3 ); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono = hugr; mono.validate()?; @@ -493,7 +463,7 @@ mod test { ["double", "main", "triple"] ); let mut mono2 = mono.clone(); - MonomorphizePass::default().run(&mut mono2)?; + monomorphize(&mut mono2)?; assert_eq!(mono2, mono); // Idempotent @@ -601,7 +571,7 @@ mod test { .outputs_arr(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); @@ -662,7 +632,7 @@ mod test { let mono = mono.finish_with_outputs([a, b]).unwrap(); let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono_hugr = hugr; let mut funcs = list_funcs(&mono_hugr); @@ -719,7 +689,7 @@ mod test { module_builder.finish_hugr().unwrap() }; - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); remove_dead_funcs(&mut hugr, []).unwrap(); let funcs = list_funcs(&hugr); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 3ed7337a9..e81a640e3 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -26,7 +26,7 @@ use hugr_core::types::{ }; use hugr_core::{Hugr, HugrView, Node, Wire}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; mod linearize; pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; @@ -143,7 +143,6 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, - validation: ValidationLevel, } impl Default for ReplaceTypes { @@ -184,8 +183,6 @@ pub enum ReplaceTypesError { #[error(transparent)] SignatureError(#[from] SignatureError), #[error(transparent)] - ValidationError(#[from] ValidatePassError), - #[error(transparent)] ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), @@ -203,16 +200,9 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), - validation: Default::default(), } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Configures this instance to replace occurrences of type `src` with `dest`. /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes /// precedence over [Self::replace_parametrized_type] where the `src`s overlap. Thus, this @@ -323,36 +313,6 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { - let mut changed = false; - for n in hugr.nodes().collect::>() { - changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = new_dfsig - .filter(|_| changed && n != hugr.root()) - .map(Cow::into_owned) - { - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } - } - } - } - } - Ok(changed) - } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) @@ -472,11 +432,40 @@ impl ReplaceTypes { false } }), - Value::Function { hugr } => self.run_no_validate(&mut **hugr), + Value::Function { hugr } => self.run(&mut **hugr), } } } +impl ComposablePass for ReplaceTypes { + type Error = ReplaceTypesError; + type Result = bool; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let mut changed = false; + for n in hugr.nodes().collect::>() { + changed |= self.change_node(hugr, n)?; + let new_dfsig = hugr.get_optype(n).dataflow_signature(); + if let Some(new_sig) = new_dfsig + .filter(|_| changed && n != hugr.root()) + .map(Cow::into_owned) + { + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } + } + } + Ok(changed) + } +} + pub mod handlers; #[derive(Clone, Hash, PartialEq, Eq)] @@ -532,29 +521,26 @@ mod test { use hugr_core::extension::prelude::{ bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, }; - use hugr_core::extension::simple_op::MakeExtensionOp; - use hugr_core::extension::{TypeDefBound, Version}; - - use hugr_core::ops::constant::OpaqueValue; - use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; - use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; + use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; + use hugr_core::hugr::{IdentList, ValidationError}; + use hugr_core::ops::{ + constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value, + }; + use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use hugr_core::std_extensions::collections::array::{ array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, }; use hugr_core::std_extensions::collections::list::{ list_type, list_type_def, ListOp, ListValue, }; - - use hugr_core::hugr::ValidationError; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; - use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use hugr_core::{type_row, Extension, HugrView}; use itertools::Itertools; use rstest::rstest; - use crate::validation::ValidatePassError; + use crate::ComposablePass; - use super::ReplaceTypesError; use super::{handlers::list_const, NodeTemplate, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; @@ -979,13 +965,16 @@ mod test { let cu = cst.value().downcast_ref::().unwrap(); Ok(ConstInt::new_u(6, cu.value())?.into()) }); + + let mut h = backup.clone(); + repl.run(&mut h).unwrap(); // No validation here assert!( - matches!(repl.run(&mut backup.clone()), Err(ReplaceTypesError::ValidationError(ValidatePassError::OutputError { - err: ValidationError::IncompatiblePorts {from, to, ..}, .. - })) if backup.get_optype(from).is_const() && to == c.node()) + matches!(h.validate(), Err(ValidationError::IncompatiblePorts {from, to, ..}) + if backup.get_optype(from).is_const() && to == c.node()) ); repl.replace_consts_parametrized(array_type_def(), array_const); let mut h = backup; - repl.run(&mut h).unwrap(); // Includes validation + repl.run(&mut h).unwrap(); + h.validate_no_extensions().unwrap(); } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5b4da7184..bc508bd53 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -377,7 +377,7 @@ mod test { use crate::replace_types::handlers::linearize_array; use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; - use crate::ReplaceTypes; + use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index dbe04edd1..874fd9ec3 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -10,19 +10,19 @@ use hugr_core::hugr::views::SiblingSubgraph; use hugr_core::hugr::SimpleReplacementError; use hugr_core::ops::{NamedOp, OpTrait, OpType}; use hugr_core::types::Type; -use hugr_core::{HugrView, SimpleReplacement}; +use hugr_core::{HugrView, Node, SimpleReplacement}; use itertools::Itertools; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration enum for the untuple rewrite pass. /// /// Indicates whether the pattern match should traverse the HUGR recursively. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum UntupleRecursive { - /// Traverse the HUGR recursively. + /// Traverse the HUGR recursively, i.e. consider the entire subtree Recursive, - /// Do not traverse the HUGR recursively. + /// Do not traverse the HUGR recursively, i.e. consider only the sibling subgraph #[default] NonRecursive, } @@ -48,22 +48,20 @@ pub enum UntupleRecursive { pub struct UntuplePass { /// Whether to traverse the HUGR recursively. recursive: UntupleRecursive, - /// The level of validation to perform on the rewrite. - validation: ValidationLevel, + /// Parent node under which to operate; None indicates the Hugr root + parent: Option, } #[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)] #[non_exhaustive] /// Errors produced by [UntuplePass]. pub enum UntupleError { - /// An error occurred while validating the rewrite. - ValidationError(ValidatePassError), /// Rewriting the circuit failed. RewriteError(SimpleReplacementError), } /// Result type for the untuple pass. -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy, Default, PartialEq)] pub struct UntupleResult { /// Number of `MakeTuple` rewrites applied. pub rewrites_applied: usize, @@ -71,16 +69,16 @@ pub struct UntupleResult { impl UntuplePass { /// Create a new untuple pass with the given configuration. - pub fn new(recursive: UntupleRecursive, validation: ValidationLevel) -> Self { + pub fn new(recursive: UntupleRecursive) -> Self { Self { recursive, - validation, + parent: None, } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; + /// Sets the parent node to optimize (overwrites any previous setting) + pub fn set_parent(mut self, parent: impl Into>) -> Self { + self.parent = parent.into(); self } @@ -90,31 +88,6 @@ impl UntuplePass { self } - /// Run the pass using specified configuration. - pub fn run( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr, parent)) - } - - /// Run the Monomorphization pass. - fn run_no_validate( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - let rewrites = self.find_rewrites(hugr, parent); - let rewrites_applied = rewrites.len(); - // The rewrites are independent, so we can always apply them all. - for rewrite in rewrites { - hugr.apply_rewrite(rewrite)?; - } - Ok(UntupleResult { rewrites_applied }) - } - /// Find tuple pack operations followed by tuple unpack operations /// and generate rewrites to remove them. /// @@ -148,6 +121,22 @@ impl UntuplePass { } } +impl ComposablePass for UntuplePass { + type Error = UntupleError; + + type Result = UntupleResult; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let rewrites = self.find_rewrites(hugr, self.parent.unwrap_or(hugr.root())); + let rewrites_applied = rewrites.len(); + // The rewrites are independent, so we can always apply them all. + for rewrite in rewrites { + hugr.apply_rewrite(rewrite)?; + } + Ok(UntupleResult { rewrites_applied }) + } +} + /// Returns true if the given optype is a MakeTuple operation. /// /// Boilerplate required due to https://github.com/CQCL/hugr/issues/1496 @@ -421,7 +410,8 @@ mod test { let parent = hugr.root(); let res = pass - .run(&mut hugr, parent) + .set_parent(parent) + .run(&mut hugr) .unwrap_or_else(|e| panic!("{e}")); assert_eq!(res.rewrites_applied, expected_rewrites); assert_eq!(hugr.children(parent).count(), remaining_nodes); From d8a5d6794526f22bc99d7a5489cbcc2d39e3c59a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Apr 2025 10:55:43 +0100 Subject: [PATCH 09/18] feat!: ReplaceTypes: allow lowering ops into a Call to a function already in the Hugr (#2094) There are two issues: * Errors. The previous NodeTemplates still always work, but the Call one can fail if the Hugr doesn't contain the target function node. ATM there is no channel for reporting that error so I've had to panic. Otherwise it's an even-more-breaking change to add an error type to `NodeTemplate::add()` and `NodeTemplate::add_hugr()`. Should we? (I note `HugrMut::connect` panics if the node isn't there, but could make the `NodeTemplate::add` builder method return a BuildError...and propagate that everywhere of course) * There's a big limitation in `linearize_array` that it'll break if the *element* says it should be copied/discarded via a NodeTemplate::Call, as `linearize_array` puts the elementwise copy/discard function into a *nested Hugr* (`Value::Function`) that won't contain the function. This could be fixed via lifting those to toplevel FuncDefns with name-mangling, but I'd rather leave that for #2086 .... BREAKING CHANGE: Add new variant NodeTemplate::Call; LinearizeError no longer derives Eq. --- hugr-passes/src/replace_types.rs | 234 ++++++++++++++++----- hugr-passes/src/replace_types/handlers.rs | 4 +- hugr-passes/src/replace_types/linearize.rs | 104 +++++++-- 3 files changed, 268 insertions(+), 74 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e81a640e3..df4c14075 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -15,16 +15,17 @@ use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::{OpaqueValue, Sum}; -use hugr_core::ops::handle::DataflowOpID; +use hugr_core::ops::handle::{DataflowOpID, FuncID}; use hugr_core::ops::{ AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; use hugr_core::types::{ - ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, + ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow, + TypeTransformer, }; -use hugr_core::{Hugr, HugrView, Node, Wire}; +use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire}; use crate::ComposablePass; @@ -45,21 +46,37 @@ pub enum NodeTemplate { /// Note this will be of limited use before [monomorphization](super::monomorphize()) /// because the new subtree will not be able to use type variables present in the /// parent Hugr or previous op. - // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s - // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), - // TODO allow also Call to a Node in the existing Hugr - // (can't see any other way to achieve multiple calls to the same decl. - // So client should add the functions before replacement, then remove unused ones afterwards.) + /// A Call to an existing function. + Call(Node, Vec), } impl NodeTemplate { /// Adds this instance to the specified [HugrMut] as a new node or subtree under a /// given parent, returning the unique new child (of that parent) thus created - pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { + /// + /// # Panics + /// + /// * If `parent` is not in the `hugr` + /// + /// # Errors + /// + /// * If `self` is a [Self::Call] and the target Node either + /// * is neither a [FuncDefn] nor a [FuncDecl] + /// * has a [`signature`] which the type-args of the [Self::Call] do not match + /// + /// [`signature`]: hugr_core::types::PolyFuncType + pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result { match self { - NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), - NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)), + NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root), + NodeTemplate::Call(target, type_args) => { + let c = call(hugr, target, type_args)?; + let tgt_port = c.called_function_port(); + let n = hugr.add_node_with_parent(parent, c); + hugr.connect(target, 0, n, tgt_port); + Ok(n) + } } } @@ -72,10 +89,15 @@ impl NodeTemplate { match self { NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), + // Really we should check whether func points at a FuncDecl or FuncDefn and create + // the appropriate variety of FuncID but it doesn't matter for the purpose of making a Call. + NodeTemplate::Call(func, type_args) => { + dfb.call(&FuncID::::from(func), &type_args, inputs) + } } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, @@ -88,19 +110,57 @@ impl NodeTemplate { } root_opty } + NodeTemplate::Call(func, type_args) => { + let c = call(hugr, func, type_args)?; + let static_inport = c.called_function_port(); + // insert an input for the Call static input + hugr.insert_ports(n, Direction::Incoming, static_inport.index(), 1); + // connect the function to (what will be) the call + hugr.connect(func, 0, n, static_inport); + c.into() + } }; *hugr.optype_mut(n) = new_optype; + Ok(()) } - fn signature(&self) -> Option> { - match self { + fn check_signature( + &self, + inputs: &TypeRow, + outputs: &TypeRow, + ) -> Result<(), Option> { + let sig = match self { NodeTemplate::SingleOp(op_type) => op_type, NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + NodeTemplate::Call(_, _) => return Ok(()), // no way to tell + } + .dataflow_signature(); + if sig.as_deref().map(Signature::io) == Some((inputs, outputs)) { + Ok(()) + } else { + Err(sig.map(Cow::into_owned)) } - .dataflow_signature() } } +fn call>( + h: &H, + func: Node, + type_args: Vec, +) -> Result { + let func_sig = match h.get_optype(func) { + OpType::FuncDecl(fd) => fd.signature.clone(), + OpType::FuncDefn(fd) => fd.signature.clone(), + _ => { + return Err(BuildError::UnexpectedType { + node: func, + op_desc: "func defn/decl", + }) + } + }; + Ok(Call::try_new(func_sig, type_args)?) +} + /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [Self::run]. /// @@ -186,6 +246,8 @@ pub enum ReplaceTypesError { ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), + #[error("Replacement op for {0} could not be added because {1}")] + AddTemplateError(Node, BuildError), } impl ReplaceTypes { @@ -370,8 +432,11 @@ impl ReplaceTypes { OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => Ok( + // Copy/discard insertion done by caller if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { - replacement.replace(hugr, n); // Copy/discard insertion done by caller + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { let def = ext_op.def_arc(); @@ -382,7 +447,9 @@ impl ReplaceTypes { .get(&def.as_ref().into()) .and_then(|rep_fn| rep_fn(&args)) { - replacement.replace(hugr, n); + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { if ch { @@ -515,24 +582,22 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{ - bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, + bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, PRELUDE_ID, }; - use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; + use hugr_core::extension::{simple_op::MakeExtensionOp, ExtensionSet, TypeDefBound, Version}; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::{IdentList, ValidationError}; - use hugr_core::ops::{ - constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value, - }; - use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::ops::constant::OpaqueValue; + use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; + use hugr_core::std_extensions::arithmetic::conversions::{self, ConvertOpDef}; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use hugr_core::std_extensions::collections::array::{ - array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, - }; - use hugr_core::std_extensions::collections::list::{ - list_type, list_type_def, ListOp, ListValue, + use hugr_core::std_extensions::collections::{ + array::{self, array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue}, + list::{list_type, list_type_def, ListOp, ListValue}, }; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; use hugr_core::{type_row, Extension, HugrView}; @@ -601,30 +666,37 @@ mod test { ) } - fn lowerer(ext: &Arc) -> ReplaceTypes { - fn lowered_read(args: &[TypeArg]) -> Option { - let ty = just_elem_type(args); - let mut dfb = DFGBuilder::new(inout_sig( - vec![array_type(64, ty.clone()), i64_t()], - ty.clone(), - )) + fn lowered_read( + elem_ty: Type, + new: impl Fn(Signature) -> Result, + ) -> T { + let mut dfb = new(Signature::new( + vec![array_type(64, elem_ty.clone()), i64_t()], + elem_ty.clone(), + ) + .with_extension_delta(ExtensionSet::from_iter([ + PRELUDE_ID, + array::EXTENSION_ID, + conversions::EXTENSION_ID, + ]))) + .unwrap(); + let [val, idx] = dfb.input_wires_arr(); + let [idx] = dfb + .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) + .unwrap() + .outputs_arr(); + let [opt] = dfb + .add_dataflow_op(ArrayOpDef::get.to_concrete(elem_ty.clone(), 64), [val, idx]) + .unwrap() + .outputs_arr(); + let [res] = dfb + .build_unwrap_sum(1, option_type(Type::from(elem_ty)), opt) .unwrap(); - let [val, idx] = dfb.input_wires_arr(); - let [idx] = dfb - .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) - .unwrap() - .outputs_arr(); - let [opt] = dfb - .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) - .unwrap() - .outputs_arr(); - let [res] = dfb - .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) - .unwrap(); - Some(NodeTemplate::CompoundOp(Box::new( - dfb.finish_hugr_with_outputs([res]).unwrap(), - ))) - } + dfb.set_outputs([res]).unwrap(); + dfb + } + + fn lowerer(ext: &Arc) -> ReplaceTypes { let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = ReplaceTypes::default(); lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); @@ -640,7 +712,13 @@ mod test { .into(), ), ); - lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read)); + lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { + Some(NodeTemplate::CompoundOp(Box::new( + lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) + .finish_hugr() + .unwrap(), + ))) + }); lw } @@ -977,4 +1055,52 @@ mod test { repl.run(&mut h).unwrap(); h.validate_no_extensions().unwrap(); } + + #[test] + fn op_to_call() { + let e = ext(); + let pv = e.get_type(PACKED_VEC).unwrap(); + let inner = pv.instantiate([usize_t().into()]).unwrap(); + let outer = pv + .instantiate([Type::new_extension(inner.clone()).into()]) + .unwrap(); + let mut dfb = DFGBuilder::new(inout_sig(vec![outer.into(), i64_t()], usize_t())).unwrap(); + let [outer, idx] = dfb.input_wires_arr(); + let [inner] = dfb + .add_dataflow_op(read_op(&e, inner.clone().into()), [outer, idx]) + .unwrap() + .outputs_arr(); + let res = dfb + .add_dataflow_op(read_op(&e, usize_t()), [inner, idx]) + .unwrap(); + let mut h = dfb.finish_hugr_with_outputs(res.outputs()).unwrap(); + let read_func = h + .insert_hugr( + h.root(), + lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| { + FunctionBuilder::new( + "lowered_read", + PolyFuncType::new([TypeBound::Copyable.into()], sig), + ) + }) + .finish_hugr() + .unwrap(), + ) + .new_root; + + let mut lw = lowerer(&e); + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| { + Some(NodeTemplate::Call(read_func, args.to_owned())) + }); + lw.run(&mut h).unwrap(); + + assert_eq!(h.output_neighbours(read_func).count(), 2); + let ext_op_names = h + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op()) + .map(|e| e.def().name()) + .sorted() + .collect_vec(); + assert_eq!(ext_op_names, ["get", "itousize", "panic",]); + } } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index e835a2d9b..b6e6e6780 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -92,7 +92,7 @@ pub fn linearize_array( let [to_discard] = dfb.input_wires_arr(); lin.copy_discard_op(ty, 0)? .add(&mut dfb, [to_discard]) - .unwrap(); + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; let ret = dfb.add_load_value(Value::unary_unit_sum()); dfb.finish_hugr_with_outputs([ret]).unwrap() }; @@ -162,7 +162,7 @@ pub fn linearize_array( let mut copies = lin .copy_discard_op(ty, num_outports)? .add(&mut dfb, [elem]) - .unwrap() + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))? .outputs(); let copy0 = copies.next().unwrap(); // We'll return this directly diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index bc508bd53..5c4a4a707 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,10 +1,9 @@ -use std::borrow::Cow; use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ - inout_sig, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, + inout_sig, BuildError, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, + DataflowSubContainer, HugrBuilder, }; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; @@ -76,9 +75,11 @@ pub trait Linearizer { tgt_parent, }); } + let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it let copy_discard_op = self - .copy_discard_op(typ, targets.len())? - .add_hugr(hugr, src_parent); + .copy_discard_op(&typ, targets.len())? + .add_hugr(hugr, src_parent) + .map_err(|e| LinearizeError::NestedTemplateError(typ, e))?; for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); } @@ -133,7 +134,7 @@ impl Default for DelegatingLinearizer { // rather than passing a &DelegatingLinearizer directly) pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); -#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +#[derive(Clone, Debug, thiserror::Error, PartialEq)] #[allow(missing_docs)] #[non_exhaustive] pub enum LinearizeError { @@ -163,6 +164,10 @@ pub enum LinearizeError { /// Neither does linearization make sense for copyable types #[error("Type {_0} is copyable")] CopyableType(Type), + /// Error may be returned by a callback for e.g. a container because it could + /// not generate a [NodeTemplate] because of a problem with an element + #[error("Could not generate NodeTemplate for contained type {0} because {1}")] + NestedTemplateError(Type, BuildError), } impl DelegatingLinearizer { @@ -185,8 +190,10 @@ impl DelegatingLinearizer { /// /// * [LinearizeError::CopyableType] If `typ` is /// [Copyable](hugr_core::types::TypeBound::Copyable) - /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the - /// expected inputs or outputs + /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the expected + /// inputs or outputs (for [NodeTemplate::SingleOp] and [NodeTemplate::CompoundOp] + /// only: the signature for a [NodeTemplate::Call] cannot be checked until it is used + /// in a Hugr). pub fn register_simple( &mut self, cty: CustomType, @@ -230,18 +237,12 @@ impl DelegatingLinearizer { } fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { - let sig = tmpl.signature(); - if sig.as_ref().is_some_and(|sig| { - sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) - }) { - Ok(()) - } else { - Err(LinearizeError::WrongSignature { + tmpl.check_signature(&typ.clone().into(), &vec![typ.clone(); num_outports].into()) + .map_err(|sig| LinearizeError::WrongSignature { typ: typ.clone(), num_outports, - sig: sig.map(Cow::into_owned), + sig, }) - } } impl Linearizer for DelegatingLinearizer { @@ -353,7 +354,10 @@ mod test { use std::iter::successors; use std::sync::Arc; - use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; + use hugr_core::builder::{ + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, + }; use hugr_core::extension::prelude::{option_type, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; @@ -768,4 +772,68 @@ mod test { )); assert_eq!(copy_sig.input[2..], copy_sig.output[1..]); } + + #[test] + fn call_ok_except_in_array() { + let (e, _) = ext_lowerer(); + let lin_ct = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + let lin_t: Type = lin_ct.clone().into(); + + // A simple Hugr that discards a usize_t, with a "drop" function + let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); + let discard_fn = { + let mut fb = dfb + .define_function( + "drop", + Signature::new(lin_t.clone(), type_row![]) + .with_extension_delta(e.name().clone()), + ) + .unwrap(); + let ins = fb.input_wires(); + fb.add_dataflow_op( + ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(), + ins, + ) + .unwrap(); + fb.finish_with_outputs([]).unwrap() + } + .node(); + let backup = dfb.finish_hugr().unwrap(); + + let mut lower_discard_to_call = ReplaceTypes::default(); + // The `copy_fn` here will break completely, but we don't use it + lower_discard_to_call + .linearizer() + .register_simple( + lin_ct.clone(), + NodeTemplate::Call(backup.root(), vec![]), + NodeTemplate::Call(discard_fn, vec![]), + ) + .unwrap(); + + // Ok to lower usize_t to lin_t and call that function + { + let mut lowerer = lower_discard_to_call.clone(); + lowerer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); + let mut h = backup.clone(); + lowerer.run(&mut h).unwrap(); + assert_eq!(h.output_neighbours(discard_fn).count(), 1); + } + + // But if we lower usize_t to array, the call will fail + lower_discard_to_call.replace_type( + usize_t().as_extension().unwrap().clone(), + array_type(4, lin_ct.into()), + ); + let r = lower_discard_to_call.run(&mut backup.clone()); + assert!(matches!( + r, + Err(ReplaceTypesError::LinearizeError( + LinearizeError::NestedTemplateError( + nested_t, + BuildError::UnexpectedType { node, .. } + ) + )) if nested_t == lin_t && node == discard_fn + )); + } } From b209709ecb59742275189d99d018562759b17e6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 23 Apr 2025 17:32:27 +0100 Subject: [PATCH 10/18] feat!: Hugrmut on generic nodes (#2111) - Allows `HugrMut` to be implemented for `HugrView`s with arbitrary node types - Removes `HugrMutInternals::hugr_mut(&mut self) -> &mut Hugr`, it can be implemented for more complex types. This is required for #1926, but I haven't touched the read-only side yet. - Added a `Node` associated type to `Rewrite`. All existing rewrites only implement `Rewrite` for now, expanding their type is left for a separate PR. drive-by: Fix a couple bugs in rewrite implementations that assumed that `SiblingMut` contained transitive children. BREAKING CHANGE: `HugrMut` is now implemented generically for any `HugrView::Node` type. BREAKING CHANGE: `SiblingMut` has a new type parameter for the wrapped hugr type. --- hugr-core/src/builder/build_traits.rs | 2 +- hugr-core/src/hugr/hugrmut.rs | 212 ++++-------- hugr-core/src/hugr/internal.rs | 320 +++++++----------- hugr-core/src/hugr/rewrite.rs | 21 +- hugr-core/src/hugr/rewrite/consts.rs | 6 +- hugr-core/src/hugr/rewrite/inline_call.rs | 3 +- hugr-core/src/hugr/rewrite/inline_dfg.rs | 6 +- hugr-core/src/hugr/rewrite/insert_identity.rs | 6 +- hugr-core/src/hugr/rewrite/outline_cfg.rs | 38 +-- hugr-core/src/hugr/rewrite/replace.rs | 5 +- hugr-core/src/hugr/rewrite/simple_replace.rs | 3 +- hugr-core/src/hugr/views.rs | 17 +- hugr-core/src/hugr/views/descendants.rs | 6 +- hugr-core/src/hugr/views/impls.rs | 266 ++++++++++++--- hugr-core/src/hugr/views/root_checked.rs | 73 ++-- hugr-core/src/hugr/views/sibling.rs | 134 +++++--- .../src/utils/inline_constant_functions.rs | 4 +- hugr-passes/src/composable.rs | 45 ++- hugr-passes/src/const_fold.rs | 5 +- hugr-passes/src/dataflow/partial_value.rs | 2 +- hugr-passes/src/dead_code.rs | 3 +- hugr-passes/src/dead_funcs.rs | 5 +- hugr-passes/src/force_order.rs | 4 +- hugr-passes/src/lower.rs | 4 +- hugr-passes/src/merge_bbs.rs | 17 +- hugr-passes/src/monomorphize.rs | 13 +- hugr-passes/src/nest_cfgs.rs | 8 +- hugr-passes/src/replace_types.rs | 17 +- hugr-passes/src/replace_types/linearize.rs | 2 +- hugr-passes/src/untuple.rs | 4 +- 30 files changed, 676 insertions(+), 575 deletions(-) diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index e17d172ca..58c15c54a 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -153,7 +153,7 @@ pub trait Container { where ExtensionRegistry: Extend, { - self.hugr_mut().extensions_mut().extend(registry); + self.hugr_mut().use_extensions(registry); } } diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 38eb59222..bf9a4cad0 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -1,7 +1,7 @@ //! Low-level interface for modifying a HUGR. use core::panic; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use std::sync::Arc; use portgraph::view::{NodeFilter, NodeFiltered}; @@ -11,14 +11,13 @@ use crate::core::HugrNode; use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrView, Node, OpType, RootTagged}; +use crate::hugr::{HugrView, Node, OpType}; use crate::hugr::{NodeMetadata, Rewrite}; use crate::ops::OpTrait; use crate::types::Substitution; use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; -use super::NodeMetadataMap; /// Functions for low-level building of a HUGR. pub trait HugrMut: HugrMutInternals { @@ -27,14 +26,9 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph. - fn get_metadata_mut(&mut self, node: Node, key: impl AsRef) -> &mut NodeMetadata { + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata { panic_invalid_node(self, node); - let node_meta = self - .hugr_mut() - .metadata - .get_mut(node.pg_index()) - .get_or_insert_with(Default::default); - node_meta + self.node_metadata_map_mut(node) .entry(key.as_ref()) .or_insert(serde_json::Value::Null) } @@ -46,7 +40,7 @@ pub trait HugrMut: HugrMutInternals { /// If the node is not in the graph. fn set_metadata( &mut self, - node: Node, + node: Self::Node, key: impl AsRef, metadata: impl Into, ) { @@ -59,30 +53,10 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph. - fn remove_metadata(&mut self, node: Node, key: impl AsRef) { + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef) { panic_invalid_node(self, node); - let node_meta = self.hugr_mut().metadata.get_mut(node.pg_index()); - if let Some(node_meta) = node_meta { - node_meta.remove(key.as_ref()); - } - } - - /// Retrieve the complete metadata map for a node. - fn take_node_metadata(&mut self, node: Self::Node) -> Option { - if !self.valid_node(node) { - return None; - } - self.hugr_mut().metadata.take(node.pg_index()) - } - - /// Overwrite the complete metadata map for a node. - /// - /// # Panics - /// - /// If the node is not in the graph. - fn overwrite_node_metadata(&mut self, node: Node, metadata: Option) { - panic_invalid_node(self, node); - self.hugr_mut().metadata.set(node.pg_index(), metadata); + let node_meta = self.node_metadata_map_mut(node); + node_meta.remove(key.as_ref()); } /// Add a node to the graph with a parent in the hierarchy. @@ -92,11 +66,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the parent is not in the graph. - #[inline] - fn add_node_with_parent(&mut self, parent: Node, op: impl Into) -> Node { - panic_invalid_node(self, parent); - self.hugr_mut().add_node_with_parent(parent, op) - } + fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into) -> Self::Node; /// Add a node to the graph as the previous sibling of another node. /// @@ -105,11 +75,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the sibling is not in the graph, or if the sibling is the root node. - #[inline] - fn add_node_before(&mut self, sibling: Node, nodetype: impl Into) -> Node { - panic_invalid_non_root(self, sibling); - self.hugr_mut().add_node_before(sibling, nodetype) - } + fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into) -> Self::Node; /// Add a node to the graph as the next sibling of another node. /// @@ -118,11 +84,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the sibling is not in the graph, or if the sibling is the root node. - #[inline] - fn add_node_after(&mut self, sibling: Node, op: impl Into) -> Node { - panic_invalid_non_root(self, sibling); - self.hugr_mut().add_node_after(sibling, op) - } + fn add_node_after(&mut self, sibling: Self::Node, op: impl Into) -> Self::Node; /// Remove a node from the graph and return the node weight. /// Note that if the node has children, they are not removed; this leaves @@ -131,24 +93,14 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph, or if the node is the root node. - #[inline] - fn remove_node(&mut self, node: Node) -> OpType { - panic_invalid_non_root(self, node); - self.hugr_mut().remove_node(node) - } + fn remove_node(&mut self, node: Self::Node) -> OpType; /// Remove a node from the graph, along with all its descendants in the hierarchy. /// /// # Panics /// /// If the node is not in the graph, or is the root (this would leave an empty Hugr). - fn remove_subtree(&mut self, node: Node) { - panic_invalid_non_root(self, node); - while let Some(ch) = self.first_child(node) { - self.remove_subtree(ch) - } - self.hugr_mut().remove_node(node); - } + fn remove_subtree(&mut self, node: Self::Node); /// Copies the strict descendants of `root` to under the `new_parent`, optionally applying a /// [Substitution] to the [OpType]s of the copied nodes. @@ -167,29 +119,20 @@ pub trait HugrMut: HugrMutInternals { root: Self::Node, new_parent: Self::Node, subst: Option, - ) -> BTreeMap { - panic_invalid_node(self, root); - panic_invalid_node(self, new_parent); - self.hugr_mut().copy_descendants(root, new_parent, subst) - } + ) -> BTreeMap; /// Connect two nodes at the given ports. /// /// # Panics /// /// If either node is not in the graph or if the ports are invalid. - #[inline] fn connect( &mut self, - src: Node, + src: Self::Node, src_port: impl Into, - dst: Node, + dst: Self::Node, dst_port: impl Into, - ) { - panic_invalid_node(self, src); - panic_invalid_node(self, dst); - self.hugr_mut().connect(src, src_port, dst, dst_port); - } + ); /// Disconnects all edges from the given port. /// @@ -198,11 +141,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph, or if the port is invalid. - #[inline] - fn disconnect(&mut self, node: Node, port: impl Into) { - panic_invalid_node(self, node); - self.hugr_mut().disconnect(node, port); - } + fn disconnect(&mut self, node: Self::Node, port: impl Into); /// Adds a non-dataflow edge between two nodes. The kind is given by the /// operation's [`OpTrait::other_input`] or [`OpTrait::other_output`]. @@ -215,37 +154,25 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph, or if the port is invalid. - fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) { - panic_invalid_node(self, src); - panic_invalid_node(self, dst); - self.hugr_mut().add_other_edge(src, dst) - } + fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (OutgoingPort, IncomingPort); /// Insert another hugr into this one, under a given root node. /// /// # Panics /// /// If the root node is not in the graph. - #[inline] - fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult { - panic_invalid_node(self, root); - self.hugr_mut().insert_hugr(root, other) - } + fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult; /// Copy another hugr into this one, under a given root node. /// /// # Panics /// /// If the root node is not in the graph. - #[inline] fn insert_from_view( &mut self, root: Self::Node, other: &H, - ) -> InsertionResult { - panic_invalid_node(self, root); - self.hugr_mut().insert_from_view(root, other) - } + ) -> InsertionResult; /// Copy a subgraph from another hugr into this one, under a given root node. /// @@ -266,13 +193,13 @@ pub trait HugrMut: HugrMutInternals { root: Self::Node, other: &H, subgraph: &SiblingSubgraph, - ) -> HashMap { - panic_invalid_node(self, root); - self.hugr_mut().insert_subgraph(root, other, subgraph) - } + ) -> HashMap; /// Applies a rewrite to the graph. - fn apply_rewrite(&mut self, rw: impl Rewrite) -> Result + fn apply_rewrite( + &mut self, + rw: impl Rewrite, + ) -> Result where Self: Sized, { @@ -286,7 +213,7 @@ pub trait HugrMut: HugrMutInternals { /// /// See [`ExtensionRegistry::register_updated`] for more information. fn use_extension(&mut self, extension: impl Into>) { - self.hugr_mut().extensions.register_updated(extension); + self.extensions_mut().register_updated(extension); } /// Extend the set of extensions used by the hugr with the extensions in the @@ -302,12 +229,7 @@ pub trait HugrMut: HugrMutInternals { where ExtensionRegistry: Extend, { - self.hugr_mut().extensions.extend(registry); - } - - /// Returns a mutable reference to the extension registry for this hugr. - fn extensions_mut(&mut self) -> &mut ExtensionRegistry { - &mut self.hugr_mut().extensions + self.extensions_mut().extend(registry); } } @@ -342,11 +264,10 @@ fn translate_indices( } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. -impl + AsMut> HugrMut for T { +impl HugrMut for Hugr { fn add_node_with_parent(&mut self, parent: Node, node: impl Into) -> Node { let node = self.as_mut().add_node(node.into()); - self.as_mut() - .hierarchy + self.hierarchy .push_child(node.pg_index(), parent.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node @@ -354,8 +275,7 @@ impl + AsMut> HugrMut for T fn add_node_before(&mut self, sibling: Node, nodetype: impl Into) -> Node { let node = self.as_mut().add_node(nodetype.into()); - self.as_mut() - .hierarchy + self.hierarchy .insert_before(node.pg_index(), sibling.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node @@ -363,8 +283,7 @@ impl + AsMut> HugrMut for T fn add_node_after(&mut self, sibling: Node, op: impl Into) -> Node { let node = self.as_mut().add_node(op.into()); - self.as_mut() - .hierarchy + self.hierarchy .insert_after(node.pg_index(), sibling.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node @@ -372,9 +291,19 @@ impl + AsMut> HugrMut for T fn remove_node(&mut self, node: Node) -> OpType { panic_invalid_non_root(self, node); - self.as_mut().hierarchy.remove(node.pg_index()); - self.as_mut().graph.remove_node(node.pg_index()); - self.as_mut().op_types.take(node.pg_index()) + self.hierarchy.remove(node.pg_index()); + self.graph.remove_node(node.pg_index()); + self.op_types.take(node.pg_index()) + } + + fn remove_subtree(&mut self, node: Node) { + panic_invalid_non_root(self, node); + let mut queue = VecDeque::new(); + queue.push_back(node); + while let Some(n) = queue.pop_front() { + queue.extend(self.children(n)); + self.remove_node(n); + } } fn connect( @@ -388,8 +317,7 @@ impl + AsMut> HugrMut for T let dst_port = dst_port.into(); panic_invalid_port(self, src, src_port); panic_invalid_port(self, dst, dst_port); - self.as_mut() - .graph + self.graph .link_nodes( src.pg_index(), src_port.index(), @@ -404,11 +332,10 @@ impl + AsMut> HugrMut for T let offset = port.pg_offset(); panic_invalid_port(self, node, port); let port = self - .as_mut() .graph .port_index(node.pg_index(), offset) .expect("The port should exist at this point."); - self.as_mut().graph.unlink_port(port); + self.graph.unlink_port(port); } fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) { @@ -429,15 +356,15 @@ impl + AsMut> HugrMut for T root: Self::Node, mut other: Hugr, ) -> InsertionResult { - let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other); + let (new_root, node_map) = insert_hugr_internal(self, root, &other); // Update the optypes and metadata, taking them from the other graph. // // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { let optype = other.op_types.take(node); - self.as_mut().op_types.set(new_node, optype); + self.op_types.set(new_node, optype); let meta = other.metadata.take(node); - self.as_mut().metadata.set(new_node, meta); + self.metadata.set(new_node, meta); } debug_assert_eq!( Some(&new_root.pg_index()), @@ -455,15 +382,15 @@ impl + AsMut> HugrMut for T root: Self::Node, other: &H, ) -> InsertionResult { - let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other); + let (new_root, node_map) = insert_hugr_internal(self, root, other); // Update the optypes and metadata, copying them from the other graph. // // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { let nodetype = other.get_optype(other.get_node(node)); - self.as_mut().op_types.set(new_node, nodetype.clone()); + self.op_types.set(new_node, nodetype.clone()); let meta = other.base_hugr().metadata.get(node); - self.as_mut().metadata.set(new_node, meta.clone()); + self.metadata.set(new_node, meta.clone()); } debug_assert_eq!( Some(&new_root.pg_index()), @@ -494,13 +421,13 @@ impl + AsMut> HugrMut for T |node, ctx| ctx.contains(&node), context, ); - let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph); + let node_map = insert_subgraph_internal(self, root, other, &portgraph); // Update the optypes and metadata, copying them from the other graph. for (&node, &new_node) in node_map.iter() { let nodetype = other.get_optype(other.get_node(node)); - self.as_mut().op_types.set(new_node, nodetype.clone()); + self.op_types.set(new_node, nodetype.clone()); let meta = other.base_hugr().metadata.get(node); - self.as_mut().metadata.set(new_node, meta.clone()); + self.metadata.set(new_node, meta.clone()); // Add the required extensions to the registry. if let Ok(exts) = nodetype.used_extensions() { self.use_extensions(exts); @@ -519,7 +446,7 @@ impl + AsMut> HugrMut for T let root2 = descendants.next(); debug_assert_eq!(root2, Some(root.pg_index())); let nodes = Vec::from_iter(descendants); - let node_map = portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) + let node_map = portgraph::view::Subgraph::with_nodes(&mut self.graph, nodes) .copy_in_parent() .expect("Is a MultiPortGraph"); let node_map = translate_indices(|n| self.get_node(n), |n| self.get_node(n), node_map) @@ -538,9 +465,9 @@ impl + AsMut> HugrMut for T (None, op) => op.clone(), (Some(subst), op) => op.substitute(subst), }; - self.as_mut().op_types.set(new_node.pg_index(), new_optype); + self.op_types.set(new_node.pg_index(), new_optype); let meta = self.base_hugr().metadata.get(node.pg_index()).clone(); - self.as_mut().metadata.set(new_node.pg_index(), meta); + self.metadata.set(new_node.pg_index(), meta); } node_map } @@ -624,22 +551,20 @@ fn insert_subgraph_internal( /// Panic if [`HugrView::valid_node`] fails. #[track_caller] pub(super) fn panic_invalid_node(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. if !hugr.valid_node(node) { - panic!( - "Received an invalid node {node} while mutating a HUGR:\n\n {}", - hugr.mermaid_string() - ); + panic!("Received an invalid node {node} while mutating a HUGR.",); } } /// Panic if [`HugrView::valid_non_root`] fails. #[track_caller] pub(super) fn panic_invalid_non_root(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. if !hugr.valid_non_root(node) { - panic!( - "Received an invalid non-root node {node} while mutating a HUGR:\n\n {}", - hugr.mermaid_string() - ); + panic!("Received an invalid non-root node {node} while mutating a HUGR.",); } } @@ -651,15 +576,14 @@ pub(super) fn panic_invalid_port( port: impl Into, ) { let port = port.into(); + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. if hugr .portgraph() .port_index(node.pg_index(), port.pg_offset()) .is_none() { - panic!( - "Received an invalid port {port} for node {node} while mutating a HUGR:\n\n {}", - hugr.mermaid_string() - ); + panic!("Received an invalid port {port} for node {node} while mutating a HUGR"); } } diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 6dab3adc0..8892c3b11 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -2,19 +2,17 @@ use std::borrow::Cow; use std::ops::Range; -use std::rc::Rc; -use std::sync::Arc; +use std::sync::OnceLock; -use delegate::delegate; use itertools::Itertools; use portgraph::{LinkMut, LinkView, MultiPortGraph, PortMut, PortOffset, PortView}; +use crate::extension::ExtensionRegistry; use crate::ops::handle::NodeHandle; -use crate::ops::{OpTag, OpTrait}; use crate::{Direction, Hugr, Node}; use super::hugrmut::{panic_invalid_node, panic_invalid_non_root}; -use super::{HugrError, OpType, RootTagged}; +use super::{HugrError, NodeMetadataMap, OpType, RootTagged}; /// Trait for accessing the internals of a Hugr(View). /// @@ -46,10 +44,17 @@ pub trait HugrInternals { fn root_node(&self) -> Self::Node; /// Convert a node to a portgraph node index. - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex; /// Convert a portgraph node index to a node. fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; + + /// Returns a metadata entry associated with a node. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap; } impl HugrInternals for Hugr { @@ -80,145 +85,41 @@ impl HugrInternals for Hugr { self.root.into() } - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex { - node.pg_index() + #[inline] + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + node.node().pg_index() } + #[inline] fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node { index.into() } -} - -impl HugrInternals for &T { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} -impl HugrInternals for &mut T { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} - -impl HugrInternals for Rc { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} - -impl HugrInternals for Arc { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} - -impl HugrInternals for Box { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { + static EMPTY: OnceLock = OnceLock::new(); + panic_invalid_node(self, node); + let map = self.metadata.get(node.pg_index()).as_ref(); + map.unwrap_or(EMPTY.get_or_init(Default::default)) } } -impl HugrInternals for Cow<'_, T> { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to self.as_ref() { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} /// Trait for accessing the mutable internals of a Hugr(Mut). /// /// Specifically, this trait lets you apply arbitrary modifications that may /// invalidate the HUGR. -pub trait HugrMutInternals: RootTagged { - /// Returns the Hugr at the base of a chain of views. - fn hugr_mut(&mut self) -> &mut Hugr; +pub trait HugrMutInternals: RootTagged { + /// Set root node of the HUGR. + /// + /// This should be an existing node in the HUGR. Most operations use the + /// root node as a starting point for traversal. + fn set_root(&mut self, root: Self::Node); /// Set the number of ports on a node. This may invalidate the node's `PortIndex`. /// /// # Panics /// /// If the node is not in the graph. - fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) { - panic_invalid_node(self, node); - self.hugr_mut().set_num_ports(node, incoming, outgoing) - } + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); /// Alter the number of ports on a node and returns a range with the new /// port offsets, if any. This may invalidate the node's `PortIndex`. @@ -231,10 +132,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If the node is not in the graph. - fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { - panic_invalid_node(self, node); - self.hugr_mut().add_ports(node, direction, amount) - } + fn add_ports(&mut self, node: Self::Node, direction: Direction, amount: isize) -> Range; /// Insert `amount` new ports for a node, starting at `index`. The /// `direction` parameter specifies whether to add ports to the incoming or @@ -247,14 +145,11 @@ pub trait HugrMutInternals: RootTagged { /// If the node is not in the graph. fn insert_ports( &mut self, - node: Node, + node: Self::Node, direction: Direction, index: usize, amount: usize, - ) -> Range { - panic_invalid_node(self, node); - self.hugr_mut().insert_ports(node, direction, index, amount) - } + ) -> Range; /// Sets the parent of a node. /// @@ -263,11 +158,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If either the node or the parent is not in the graph. - fn set_parent(&mut self, node: Node, parent: Node) { - panic_invalid_node(self, parent); - panic_invalid_non_root(self, node); - self.hugr_mut().set_parent(node, parent); - } + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); /// Move a node in the hierarchy to be the subsequent sibling of another /// node. @@ -279,11 +170,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If either node is not in the graph, or if it is a root. - fn move_after_sibling(&mut self, node: Node, after: Node) { - panic_invalid_non_root(self, node); - panic_invalid_non_root(self, after); - self.hugr_mut().move_after_sibling(node, after); - } + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); /// Move a node in the hierarchy to be the prior sibling of another node. /// @@ -294,11 +181,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If either node is not in the graph, or if it is a root. - fn move_before_sibling(&mut self, node: Node, before: Node) { - panic_invalid_non_root(self, node); - panic_invalid_non_root(self, before); - self.hugr_mut().move_before_sibling(node, before) - } + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); /// Replace the OpType at node and return the old OpType. /// In general this invalidates the ports, which may need to be resized to @@ -306,7 +189,8 @@ pub trait HugrMutInternals: RootTagged { /// /// Returns the old OpType. /// - /// TODO: Add a version which ignores input extensions + /// If the module root is set to a non-module operation the hugr will + /// become invalid. /// /// # Errors /// @@ -316,48 +200,68 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If the node is not in the graph. - fn replace_op(&mut self, node: Node, op: impl Into) -> Result { - panic_invalid_node(self, node); - let op = op.into(); - if node == self.root() && !Self::RootHandle::TAG.is_superset(op.tag()) { - return Err(HugrError::InvalidTag { - required: Self::RootHandle::TAG, - actual: op.tag(), - }); - } - self.hugr_mut().replace_op(node, op) - } + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> Result; /// Gets a mutable reference to the optype. /// /// Changing this may invalidate the ports, which may need to be resized to /// match the OpType signature. /// - /// Will panic for the root node unless [`Self::RootHandle`](RootTagged::RootHandle) - /// is [OpTag::Any], as mutation could invalidate the bound. - fn optype_mut(&mut self, node: Node) -> &mut OpType { - if Self::RootHandle::TAG.is_superset(OpTag::Any) { - panic_invalid_node(self, node); - } else { - panic_invalid_non_root(self, node); - } - self.hugr_mut().op_types.get_mut(node.pg_index()) - } + /// Mutating the root node operation may invalidate the root tag. + /// + /// Mutating the module root into a non-module operation will invalidate the hugr. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn optype_mut(&mut self, node: Self::Node) -> &mut OpType; + + /// Returns a metadata entry associated with a node. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap; + + /// Returns a mutable reference to the extension registry for this hugr, + /// containing all extensions required to define the operations and types in + /// the hugr. + fn extensions_mut(&mut self) -> &mut ExtensionRegistry; } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. -impl + AsMut> HugrMutInternals for T { - fn hugr_mut(&mut self) -> &mut Hugr { - self.as_mut() +impl HugrMutInternals for Hugr { + fn set_root(&mut self, root: Node) { + panic_invalid_node(self, root); + self.root = self.get_pg_index(root); } #[inline] fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) { - self.hugr_mut() - .graph + panic_invalid_node(self, node); + self.graph .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}) } + fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { + panic_invalid_node(self, node); + let mut incoming = self.graph.num_inputs(node.pg_index()); + let mut outgoing = self.graph.num_outputs(node.pg_index()); + let increment = |num: &mut usize| { + let new = num.saturating_add_signed(amount); + let range = *num..new; + *num = new; + range + }; + let range = match direction { + Direction::Incoming => increment(&mut incoming), + Direction::Outgoing => increment(&mut outgoing), + }; + self.graph + .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}); + range + } + fn insert_ports( &mut self, node: Node, @@ -365,6 +269,7 @@ impl + AsMut> HugrMutInterna index: usize, amount: usize, ) -> Range { + panic_invalid_node(self, node); let old_num_ports = self.base_hugr().graph.num_ports(node.pg_index(), direction); self.add_ports(node, direction, amount as isize); @@ -383,10 +288,9 @@ impl + AsMut> HugrMutInterna .port_links(from_port_index) .map(|(_, to_subport)| to_subport.port()) .collect_vec(); - self.hugr_mut().graph.unlink_port(from_port_index); + self.graph.unlink_port(from_port_index); for linked_port_index in linked_ports { let _ = self - .hugr_mut() .graph .link_ports(to_port_index, linked_port_index) .expect("Ports exist"); @@ -395,53 +299,55 @@ impl + AsMut> HugrMutInterna index..index + amount } - fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { - let mut incoming = self.hugr_mut().graph.num_inputs(node.pg_index()); - let mut outgoing = self.hugr_mut().graph.num_outputs(node.pg_index()); - let increment = |num: &mut usize| { - let new = num.saturating_add_signed(amount); - let range = *num..new; - *num = new; - range - }; - let range = match direction { - Direction::Incoming => increment(&mut incoming), - Direction::Outgoing => increment(&mut outgoing), - }; - self.hugr_mut() - .graph - .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}); - range - } - fn set_parent(&mut self, node: Node, parent: Node) { - self.hugr_mut().hierarchy.detach(node.pg_index()); - self.hugr_mut() - .hierarchy + panic_invalid_node(self, parent); + panic_invalid_node(self, node); + self.hierarchy.detach(node.pg_index()); + self.hierarchy .push_child(node.pg_index(), parent.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_after_sibling(&mut self, node: Node, after: Node) { - self.hugr_mut().hierarchy.detach(node.pg_index()); - self.hugr_mut() - .hierarchy + panic_invalid_non_root(self, node); + panic_invalid_non_root(self, after); + self.hierarchy.detach(node.pg_index()); + self.hierarchy .insert_after(node.pg_index(), after.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_before_sibling(&mut self, node: Node, before: Node) { - self.hugr_mut().hierarchy.detach(node.pg_index()); - self.hugr_mut() - .hierarchy + panic_invalid_non_root(self, node); + panic_invalid_non_root(self, before); + self.hierarchy.detach(node.pg_index()); + self.hierarchy .insert_before(node.pg_index(), before.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn replace_op(&mut self, node: Node, op: impl Into) -> Result { + panic_invalid_node(self, node); // We know RootHandle=Node here so no need to check Ok(std::mem::replace(self.optype_mut(node), op.into())) } + + fn optype_mut(&mut self, node: Self::Node) -> &mut OpType { + panic_invalid_node(self, node); + let node = self.get_pg_index(node); + self.op_types.get_mut(node) + } + + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap { + panic_invalid_node(self, node); + self.metadata + .get_mut(node.pg_index()) + .get_or_insert_with(Default::default) + } + + fn extensions_mut(&mut self) -> &mut ExtensionRegistry { + &mut self.extensions + } } #[cfg(test)] diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index 7c4374b65..d2b0fe14d 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -9,7 +9,8 @@ mod port_types; pub mod replace; pub mod simple_replace; -use crate::{Hugr, HugrView, Node}; +use crate::core::HugrNode; +use crate::{Hugr, HugrView}; pub use port_types::{BoundaryPort, HostPort, ReplacementPort}; pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; @@ -17,6 +18,8 @@ use super::HugrMut; /// An operation that can be applied to mutate a Hugr pub trait Rewrite { + /// The node type used by the target Hugr. + type Node: HugrNode; /// The type of Error with which this Rewrite may fail type Error: std::error::Error; /// The type returned on successful application of the rewrite. @@ -29,7 +32,7 @@ pub trait Rewrite { /// Checks whether the rewrite would succeed on the specified Hugr. /// If this call succeeds, [self.apply] should also succeed on the same `h` /// If this calls fails, [self.apply] would fail with the same error. - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; /// Mutate the specified Hugr, or fail with an error. /// Returns [`Self::ApplyResult`] if successful. @@ -39,14 +42,17 @@ pub trait Rewrite { /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is, /// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())` /// being preferred. - fn apply(self, h: &mut impl HugrMut) -> Result; + fn apply( + self, + h: &mut impl HugrMut, + ) -> Result; /// Returns a set of nodes referenced by the rewrite. Modifying any of these /// nodes will invalidate it. /// /// Two `impl Rewrite`s can be composed if their invalidation sets are /// disjoint. - fn invalidation_set(&self) -> impl Iterator; + fn invalidation_set(&self) -> impl Iterator; } /// Wraps any rewrite into a transaction (i.e. that has no effect upon failure) @@ -57,15 +63,16 @@ pub struct Transactional { // Note we might like to constrain R to Rewrite but this // is not yet supported, https://github.com/rust-lang/rust/issues/92827 impl Rewrite for Transactional { + type Node = R::Node; type Error = R::Error; type ApplyResult = R::ApplyResult; const UNCHANGED_ON_FAILURE: bool = true; - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { self.underlying.verify(h) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { if R::UNCHANGED_ON_FAILURE { return self.underlying.apply(h); } @@ -86,7 +93,7 @@ impl Rewrite for Transactional { } #[inline] - fn invalidation_set(&self) -> impl Iterator { + fn invalidation_set(&self) -> impl Iterator { self.underlying.invalidation_set() } } diff --git a/hugr-core/src/hugr/rewrite/consts.rs b/hugr-core/src/hugr/rewrite/consts.rs index c112dfc57..ac657bf91 100644 --- a/hugr-core/src/hugr/rewrite/consts.rs +++ b/hugr-core/src/hugr/rewrite/consts.rs @@ -25,6 +25,7 @@ pub enum RemoveError { } impl Rewrite for RemoveLoadConstant { + type Node = Node; type Error = RemoveError; // The Const node the LoadConstant was connected to. @@ -50,7 +51,7 @@ impl Rewrite for RemoveLoadConstant { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let source = h @@ -73,6 +74,7 @@ impl Rewrite for RemoveLoadConstant { pub struct RemoveConst(pub Node); impl Rewrite for RemoveConst { + type Node = Node; type Error = RemoveError; // The parent of the Const node. @@ -94,7 +96,7 @@ impl Rewrite for RemoveConst { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let parent = h diff --git a/hugr-core/src/hugr/rewrite/inline_call.rs b/hugr-core/src/hugr/rewrite/inline_call.rs index 9af9cd70a..6b1e7a958 100644 --- a/hugr-core/src/hugr/rewrite/inline_call.rs +++ b/hugr-core/src/hugr/rewrite/inline_call.rs @@ -33,6 +33,7 @@ impl InlineCall { } impl Rewrite for InlineCall { + type Node = Node; type ApplyResult = (); type Error = InlineCallError; fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { @@ -51,7 +52,7 @@ impl Rewrite for InlineCall { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { self.verify(h)?; // Now we know we have a Call to a FuncDefn. let orig_func = h.static_source(self.0).unwrap(); diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index a8a09e0cc..8988df170 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -23,6 +23,7 @@ pub enum InlineDFGError { impl Rewrite for InlineDFG { /// Returns the removed nodes: the DFG, and its Input and Output children. + type Node = Node; type ApplyResult = [Node; 3]; type Error = InlineDFGError; @@ -39,7 +40,10 @@ impl Rewrite for InlineDFG { Ok(()) } - fn apply(self, h: &mut impl crate::hugr::HugrMut) -> Result { + fn apply( + self, + h: &mut impl crate::hugr::HugrMut, + ) -> Result { self.verify(h)?; let n = self.0.node(); let (oth_in, oth_out) = { diff --git a/hugr-core/src/hugr/rewrite/insert_identity.rs b/hugr-core/src/hugr/rewrite/insert_identity.rs index 2114be8fd..bde43413b 100644 --- a/hugr-core/src/hugr/rewrite/insert_identity.rs +++ b/hugr-core/src/hugr/rewrite/insert_identity.rs @@ -48,6 +48,7 @@ pub enum IdentityInsertionError { } impl Rewrite for IdentityInsertion { + type Node = Node; type Error = IdentityInsertionError; /// The inserted node. type ApplyResult = Node; @@ -65,7 +66,10 @@ impl Rewrite for IdentityInsertion { unimplemented!() } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply( + self, + h: &mut impl HugrMut, + ) -> Result { let kind = h.get_optype(self.post_node).port_kind(self.post_port); let Some(EdgeKind::Value(ty)) = kind else { return Err(IdentityInsertionError::InvalidPortKind(kind)); diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/rewrite/outline_cfg.rs index 7294bfcad..a76dbc6ee 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/rewrite/outline_cfg.rs @@ -6,14 +6,12 @@ use thiserror::Error; use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; use crate::extension::ExtensionSet; -use crate::hugr::internal::HugrMutInternals; use crate::hugr::rewrite::Rewrite; -use crate::hugr::views::sibling::SiblingMut; use crate::hugr::{HugrMut, HugrView}; use crate::ops; use crate::ops::controlflow::BasicBlock; use crate::ops::dataflow::DataflowOpTrait; -use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; +use crate::ops::handle::NodeHandle; use crate::ops::{DataflowBlock, OpType}; use crate::PortIndex; use crate::{type_row, Node}; @@ -95,6 +93,7 @@ impl OutlineCfg { } impl Rewrite for OutlineCfg { + type Node = Node; type Error = OutlineCfgError; /// The newly-created basic block, and the [CFG] node inside it /// @@ -185,8 +184,19 @@ impl Rewrite for OutlineCfg { let inner_exit = { // These operations do not fit within any CSG/SiblingMut // so we need to access the Hugr directly. - let h = h.hugr_mut(); - let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); + // + // TODO: This is a temporary hack that won't be needed once Hugr Root Pointers get implemented. + // The commented line below are the correct ones, but they don't work yet. + // https://github.com/CQCL/hugr/issues/2029 + let hierarchy = h.hierarchy(); + let inner_exit = hierarchy + .children(h.get_pg_index(cfg_node)) + .exactly_one() + .ok() + .unwrap(); + let inner_exit = h.get_node(inner_exit); + //let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); + // Entry node must be first h.move_before_sibling(entry, inner_exit); // And remaining nodes @@ -200,12 +210,7 @@ impl Rewrite for OutlineCfg { }; // 4(b). Reconnect exit edge to the new exit node within the inner CFG - // Use nested SiblingMut's in case the outer `h` is only a SiblingMut itself. - let mut in_bb_view: SiblingMut<'_, BasicBlockID> = - SiblingMut::try_new(h, new_block).unwrap(); - let mut in_cfg_view: SiblingMut<'_, CfgID> = - SiblingMut::try_new(&mut in_bb_view, cfg_node).unwrap(); - in_cfg_view.connect(exit, exit_port, inner_exit, 0); + h.connect(exit, exit_port, inner_exit, 0); Ok((new_block, cfg_node)) } @@ -252,10 +257,9 @@ mod test { HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::usize_t; - use crate::hugr::views::sibling::SiblingMut; use crate::hugr::HugrMut; use crate::ops::constant::Value; - use crate::ops::handle::{BasicBlockID, CfgID, ConstID, NodeHandle}; + use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; use crate::types::Signature; use crate::{Hugr, HugrView, Node}; use cool_asserts::assert_matches; @@ -457,11 +461,7 @@ mod test { h.output_neighbours(tail).collect::>(), HashSet::from([head, exit_node]) ); - outline_cfg_check_parents( - &mut SiblingMut::<'_, CfgID>::try_new(&mut h, cfg).unwrap(), - cfg, - vec![head, tail], - ); + outline_cfg_check_parents(&mut h, cfg, vec![head, tail]); h.validate().unwrap(); } @@ -491,7 +491,7 @@ mod test { } fn outline_cfg_check_parents( - h: &mut impl HugrMut, + h: &mut impl HugrMut, cfg: Node, blocks: Vec, ) -> (Node, Node, Node) { diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 55c07d680..c2659cc5a 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -222,6 +222,7 @@ impl Replacement { } } impl Rewrite for Replacement { + type Node = Node; type Error = ReplaceError; /// Map from Node in replacement to corresponding Node in the result Hugr @@ -282,7 +283,7 @@ impl Rewrite for Replacement { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { let parent = self.check_parent(h)?; // Calculate removed nodes here. (Does not include transfers, so enumerates only // nodes we are going to remove, individually, anyway; so no *asymptotic* speed penalty) @@ -343,7 +344,7 @@ impl Rewrite for Replacement { } fn transfer_edges<'a>( - h: &mut impl HugrMut, + h: &mut impl HugrMut, edges: impl Iterator, trans_src: impl Fn(Node) -> Result, trans_tgt: impl Fn(Node) -> Result, diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index b4ec37db1..5d3716dc0 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -275,6 +275,7 @@ impl SimpleReplacement { } impl Rewrite for SimpleReplacement { + type Node = Node; type Error = SimpleReplacementError; type ApplyResult = Vec<(Node, OpType)>; const UNCHANGED_ON_FAILURE: bool = true; @@ -283,7 +284,7 @@ impl Rewrite for SimpleReplacement { self.is_valid_rewrite(h) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { self.is_valid_rewrite(h)?; let parent = self.subgraph.get_parent(h); diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 09805d1f8..eb8059577 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -485,7 +485,7 @@ pub trait RootTagged: HugrView { /// /// The handle is guaranteed to be able to contain the operation returned by /// [`HugrView::root_type`]. - type RootHandle: NodeHandle; + type RootHandle: NodeHandle; } /// A common trait for views of a HUGR hierarchical subgraph. @@ -515,7 +515,8 @@ pub trait ExtractHugr: HugrView + Sized { } } -fn check_tag( +/// Check that the node in a HUGR can be represented by the required tag. +fn check_tag, N>( hugr: &impl HugrView, node: N, ) -> Result<(), HugrError> { @@ -527,18 +528,6 @@ fn check_tag( Ok(()) } -impl RootTagged for Hugr { - type RootHandle = Node; -} - -impl RootTagged for &Hugr { - type RootHandle = Node; -} - -impl RootTagged for &mut Hugr { - type RootHandle = Node; -} - // Explicit implementation to avoid cloning the Hugr. impl ExtractHugr for Hugr { fn extract_hugr(self) -> Hugr { diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 6f87027ef..28a7d9f2d 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -179,7 +179,7 @@ where } #[inline] - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex { + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { self.hugr.get_pg_index(node) } @@ -187,6 +187,10 @@ where fn get_node(&self, index: portgraph::NodeIndex) -> Node { self.hugr.get_node(index) } + + fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap { + self.hugr.node_metadata_map(node) + } } #[cfg(test)] diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 2cfc70104..928acba20 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -1,119 +1,285 @@ +//! Implementation of the core hugr traits for different wrappers of a `Hugr`. + use std::{borrow::Cow, rc::Rc, sync::Arc}; -use delegate::delegate; -use itertools::Either; +use super::HugrView; +use super::RootTagged; +use crate::hugr::internal::{HugrInternals, HugrMutInternals}; +use crate::hugr::HugrMut; +use crate::Hugr; +use crate::Node; -use super::{render::RenderConfig, HugrView, RootChecked}; -use crate::{ - extension::ExtensionRegistry, - hugr::{NodeMetadata, NodeMetadataMap, ValidationError}, - ops::OpType, - types::{PolyFuncType, Signature, Type}, - Direction, Hugr, IncomingPort, OutgoingPort, Port, -}; +macro_rules! hugr_internal_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate::delegate! { + to ({let $arg=self; $e}) { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; + fn base_hugr(&self) -> &crate::Hugr; + fn root_node(&self) -> Self::Node; + fn get_pg_index(&self, node: impl crate::ops::handle::NodeHandle) -> portgraph::NodeIndex; + fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; + fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap; + } + } + }; +} +pub(crate) use hugr_internal_methods; macro_rules! hugr_view_methods { // The extra ident here is because invocations of the macro cannot pass `self` as argument ($arg:ident, $e:expr) => { - delegate! { + delegate::delegate! { to ({let $arg=self; $e}) { fn root(&self) -> Self::Node; - fn root_type(&self) -> &OpType; + fn root_type(&self) -> &crate::ops::OpType; fn contains_node(&self, node: Self::Node) -> bool; fn valid_node(&self, node: Self::Node) -> bool; fn valid_non_root(&self, node: Self::Node) -> bool; fn get_parent(&self, node: Self::Node) -> Option; - fn get_optype(&self, node: Self::Node) -> &OpType; - fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&NodeMetadata>; - fn get_node_metadata(&self, node: Self::Node) -> Option<&NodeMetadataMap>; + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; + fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&crate::hugr::NodeMetadata>; + fn get_node_metadata(&self, node: Self::Node) -> Option<&crate::hugr::NodeMetadataMap>; fn node_count(&self) -> usize; fn edge_count(&self) -> usize; fn nodes(&self) -> impl Iterator + Clone; - fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator + Clone; - fn node_outputs(&self, node: Self::Node) -> impl Iterator + Clone; - fn node_inputs(&self, node: Self::Node) -> impl Iterator + Clone; - fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone; + fn node_ports(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; + fn node_outputs(&self, node: Self::Node) -> impl Iterator + Clone; + fn node_inputs(&self, node: Self::Node) -> impl Iterator + Clone; + fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone; fn linked_ports( &self, node: Self::Node, - port: impl Into, - ) -> impl Iterator + Clone; + port: impl Into, + ) -> impl Iterator + Clone; fn all_linked_ports( &self, node: Self::Node, - dir: Direction, - ) -> Either< - impl Iterator, - impl Iterator, + dir: crate::Direction, + ) -> itertools::Either< + impl Iterator, + impl Iterator, >; - fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator; - fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator; - fn single_linked_port(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, Port)>; - fn single_linked_output(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, OutgoingPort)>; - fn single_linked_input(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, IncomingPort)>; - fn linked_outputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; - fn linked_inputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; - fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator + Clone; - fn is_linked(&self, node: Self::Node, port: impl Into) -> bool; - fn num_ports(&self, node: Self::Node, dir: Direction) -> usize; + fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator; + fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator; + fn single_linked_port(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::Port)>; + fn single_linked_output(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::OutgoingPort)>; + fn single_linked_input(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::IncomingPort)>; + fn linked_outputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; + fn linked_inputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; + fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator + Clone; + fn is_linked(&self, node: Self::Node, port: impl Into) -> bool; + fn num_ports(&self, node: Self::Node, dir: crate::Direction) -> usize; fn num_inputs(&self, node: Self::Node) -> usize; fn num_outputs(&self, node: Self::Node) -> usize; fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone; fn first_child(&self, node: Self::Node) -> Option; - fn neighbours(&self, node: Self::Node, dir: Direction) -> impl Iterator + Clone; + fn neighbours(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; fn input_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn output_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn all_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn get_io(&self, node: Self::Node) -> Option<[Self::Node; 2]>; - fn inner_function_type(&self) -> Option>; - fn poly_func_type(&self) -> Option; + fn inner_function_type(&self) -> Option>; + fn poly_func_type(&self) -> Option; // TODO: cannot use delegate here. `PetgraphWrapper` is a thin // wrapper around `Self`, so falling back to the default impl // should be harmless. // fn as_petgraph(&self) -> PetgraphWrapper<'_, Self>; fn mermaid_string(&self) -> String; - fn mermaid_string_with_config(&self, config: RenderConfig) -> String; + fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; fn dot_string(&self) -> String; fn static_source(&self, node: Self::Node) -> Option; - fn static_targets(&self, node: Self::Node) -> Option>; - fn signature(&self, node: Self::Node) -> Option>; - fn value_types(&self, node: Self::Node, dir: Direction) -> impl Iterator; - fn in_value_types(&self, node: Self::Node) -> impl Iterator; - fn out_value_types(&self, node: Self::Node) -> impl Iterator; - fn extensions(&self) -> &ExtensionRegistry; - fn validate(&self) -> Result<(), ValidationError>; - fn validate_no_extensions(&self) -> Result<(), ValidationError>; + fn static_targets(&self, node: Self::Node) -> Option>; + fn signature(&self, node: Self::Node) -> Option>; + fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator; + fn in_value_types(&self, node: Self::Node) -> impl Iterator; + fn out_value_types(&self, node: Self::Node) -> impl Iterator; + fn extensions(&self) -> &crate::extension::ExtensionRegistry; + fn validate(&self) -> Result<(), crate::hugr::ValidationError>; + fn validate_no_extensions(&self) -> Result<(), crate::hugr::ValidationError>; } } } } +pub(crate) use hugr_view_methods; + +macro_rules! hugr_mut_internal_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate::delegate! { + to ({let $arg=self; $e}) { + fn set_root(&mut self, root: Self::Node); + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); + fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; + fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> Result; + fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; + fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; + } + } + }; +} +pub(crate) use hugr_mut_internal_methods; + +macro_rules! hugr_mut_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate::delegate! { + to ({let $arg=self; $e}) { + fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into) -> Self::Node; + fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into) -> Self::Node; + fn add_node_after(&mut self, sibling: Self::Node, op: impl Into) -> Self::Node; + fn remove_node(&mut self, node: Self::Node) -> crate::ops::OpType; + fn remove_subtree(&mut self, node: Self::Node); + fn copy_descendants(&mut self, root: Self::Node, new_parent: Self::Node, subst: Option) -> std::collections::BTreeMap; + fn connect(&mut self, src: Self::Node, src_port: impl Into, dst: Self::Node, dst_port: impl Into); + fn disconnect(&mut self, node: Self::Node, port: impl Into); + fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (crate::OutgoingPort, crate::IncomingPort); + fn insert_hugr(&mut self, root: Self::Node, other: crate::Hugr) -> crate::hugr::hugrmut::InsertionResult; + fn insert_from_view(&mut self, root: Self::Node, other: &Other) -> crate::hugr::hugrmut::InsertionResult; + fn insert_subgraph(&mut self, root: Self::Node, other: &Other, subgraph: &crate::hugr::views::SiblingSubgraph) -> std::collections::HashMap; + } + } + }; +} +pub(crate) use hugr_mut_methods; + +// -------- Base Hugr implementation +impl RootTagged for Hugr { + type RootHandle = Node; +} + +// -------- Immutable borrow +impl HugrInternals for &T { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, *this} +} impl HugrView for &T { hugr_view_methods! {this, *this} } +impl RootTagged for &T { + type RootHandle = T::RootHandle; +} + +// -------- Mutable borrow +impl HugrInternals for &mut T { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, &**this} +} impl HugrView for &mut T { hugr_view_methods! {this, &**this} } +impl RootTagged for &mut T { + type RootHandle = T::RootHandle; +} +impl HugrMutInternals for &mut T { + hugr_mut_internal_methods! {this, &mut **this} +} +impl HugrMut for &mut T { + hugr_mut_methods! {this, &mut **this} +} + +// -------- Rc +impl HugrInternals for Rc { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Rc { hugr_view_methods! {this, this.as_ref()} } +impl RootTagged for Rc { + type RootHandle = T::RootHandle; +} + +// -------- Arc +impl HugrInternals for Arc { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Arc { hugr_view_methods! {this, this.as_ref()} } +impl RootTagged for Arc { + type RootHandle = T::RootHandle; +} + +// -------- Box +impl HugrInternals for Box { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Box { hugr_view_methods! {this, this.as_ref()} } +impl RootTagged for Box { + type RootHandle = T::RootHandle; +} +impl HugrMutInternals for Box { + hugr_mut_internal_methods! {this, this.as_mut()} +} +impl HugrMut for Box { + hugr_mut_methods! {this, this.as_mut()} +} +// -------- Cow +impl HugrInternals for Cow<'_, T> { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Cow<'_, T> { hugr_view_methods! {this, this.as_ref()} } - -impl, Root> HugrView for RootChecked { - hugr_view_methods! {this, this.as_ref()} +impl RootTagged for Cow<'_, T> { + type RootHandle = T::RootHandle; +} +impl HugrMutInternals for Cow<'_, T> +where + T: HugrMutInternals + ToOwned, + ::Owned: HugrMutInternals, +{ + hugr_mut_internal_methods! {this, this.to_mut()} +} +impl HugrMut for Cow<'_, T> +where + T: HugrMut + ToOwned, + ::Owned: HugrMut, +{ + hugr_mut_methods! {this, this.to_mut()} } #[cfg(test)] diff --git a/hugr-core/src/hugr/views/root_checked.rs b/hugr-core/src/hugr/views/root_checked.rs index ba214241a..e0dcf3eb7 100644 --- a/hugr-core/src/hugr/views/root_checked.rs +++ b/hugr-core/src/hugr/views/root_checked.rs @@ -1,22 +1,20 @@ use std::borrow::Cow; use std::marker::PhantomData; -use delegate::delegate; -use portgraph::MultiPortGraph; - use crate::hugr::internal::{HugrInternals, HugrMutInternals}; use crate::hugr::{HugrError, HugrMut}; use crate::ops::handle::NodeHandle; +use crate::ops::OpTrait; use crate::{Hugr, Node}; -use super::{check_tag, RootTagged}; +use super::{check_tag, HugrView, RootTagged}; /// A view of the whole Hugr. /// (Just provides static checking of the type of the root node) #[derive(Clone)] pub struct RootChecked(H, PhantomData); -impl, Root: NodeHandle> RootChecked { +impl> RootChecked { /// Create a hierarchical view of a whole HUGR /// /// # Errors @@ -49,26 +47,21 @@ impl RootChecked<&mut Hugr, Root> { } } -impl, Root> HugrInternals for RootChecked { +impl HugrInternals for RootChecked { type Portgraph<'p> - = &'p MultiPortGraph + = H::Portgraph<'p> where Self: 'p; - type Node = Node; - - delegate! { - to self.as_ref() { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Node; - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Node; - } - } + type Node = H::Node; + + super::impls::hugr_internal_methods! {this, &this.0} } -impl, Root: NodeHandle> RootTagged for RootChecked { +impl HugrView for RootChecked { + super::impls::hugr_view_methods! {this, &this.0} +} + +impl> RootTagged for RootChecked { type RootHandle = Root; } @@ -78,17 +71,41 @@ impl, Root> AsRef for RootChecked { } } -impl, Root> HugrMutInternals for RootChecked -where - Root: NodeHandle, -{ - #[inline(always)] - fn hugr_mut(&mut self) -> &mut Hugr { - self.0.hugr_mut() +impl> HugrMutInternals for RootChecked { + fn replace_op( + &mut self, + node: Self::Node, + op: impl Into, + ) -> Result { + let op = op.into(); + if node == self.root() && !Root::TAG.is_superset(op.tag()) { + return Err(HugrError::InvalidTag { + required: Root::TAG, + actual: op.tag(), + }); + } + self.0.replace_op(node, op) + } + + delegate::delegate! { + to (&mut self.0) { + fn set_root(&mut self, root: Self::Node); + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); + fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; + fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); + fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; + fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; + } } } -impl, Root: NodeHandle> HugrMut for RootChecked {} +impl> HugrMut for RootChecked { + super::impls::hugr_mut_methods! {this, &mut this.0} +} #[cfg(test)] mod test { diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index f93b14cb4..4d15a9c48 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -6,8 +6,9 @@ use itertools::{Either, Itertools}; use portgraph::{LinkView, MultiPortGraph, PortView}; use crate::hugr::internal::HugrMutInternals; -use crate::hugr::{HugrError, HugrMut}; +use crate::hugr::{HugrError, HugrMut, NodeMetadataMap}; use crate::ops::handle::NodeHandle; +use crate::ops::OpTrait; use crate::{Direction, Hugr, Node, Port}; use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; @@ -212,7 +213,7 @@ where } #[inline] - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex { + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { self.hugr.get_pg_index(node) } @@ -220,6 +221,11 @@ where fn get_node(&self, index: portgraph::NodeIndex) -> Node { self.hugr.get_node(index) } + + #[inline] + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { + self.hugr.node_metadata_map(node) + } } /// Mutable view onto a HUGR sibling graph. @@ -233,101 +239,113 @@ where /// [HugrView] methods may be slower than for an immutable [SiblingGraph] /// as the latter may cache information about the graph connectivity, /// whereas (in order to ease mutation) this does not. -pub struct SiblingMut<'g, Root = Node> { +pub struct SiblingMut<'g, H: HugrView, Root = Node> { /// The chosen root node. - root: Node, + root: H::Node, /// The rest of the HUGR. - hugr: &'g mut Hugr, + hugr: &'g mut H, /// The operation type of the root node. _phantom: std::marker::PhantomData, } -impl<'g, Root: NodeHandle> SiblingMut<'g, Root> { +impl<'g, H: HugrMut, Root: NodeHandle> SiblingMut<'g, H, Root> { /// Create a new SiblingMut from a base. /// Equivalent to [HierarchyView::try_new] but takes a *mutable* reference. - pub fn try_new(hugr: &'g mut Base, root: Node) -> Result { - if root == hugr.root() && !Base::RootHandle::TAG.is_superset(Root::TAG) { + pub fn try_new(hugr: &'g mut H, root: H::Node) -> Result { + if root == hugr.root() && !H::RootHandle::TAG.is_superset(Root::TAG) { return Err(HugrError::InvalidTag { - required: Base::RootHandle::TAG, + required: H::RootHandle::TAG, actual: Root::TAG, }); } check_tag::(hugr, root)?; Ok(Self { - hugr: hugr.hugr_mut(), + hugr, root, _phantom: std::marker::PhantomData, }) } } -impl ExtractHugr for SiblingMut<'_, Root> {} +impl> ExtractHugr for SiblingMut<'_, H, Root> {} -impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> { +impl<'g, H: HugrMut, Root: NodeHandle> HugrInternals for SiblingMut<'g, H, Root> { type Portgraph<'p> = FlatRegionGraph<'p> where 'g: 'p, Root: 'p; - type Node = Node; + type Node = H::Node; + #[inline] fn portgraph(&self) -> Self::Portgraph<'_> { FlatRegionGraph::new( &self.base_hugr().graph, &self.base_hugr().hierarchy, - self.root.pg_index(), + self.get_pg_index(self.root), ) } + #[inline] fn base_hugr(&self) -> &Hugr { - self.hugr + self.hugr.base_hugr() } - fn root_node(&self) -> Node { + #[inline] + fn root_node(&self) -> Self::Node { self.root } #[inline] - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex { + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { self.hugr.get_pg_index(node) } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Node { + fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node { self.hugr.get_node(index) } + + #[inline] + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { + self.hugr.node_metadata_map(node) + } } -impl HugrView for SiblingMut<'_, Root> { +impl> HugrView for SiblingMut<'_, H, Root> { impl_base_members! {} - fn contains_node(&self, node: Node) -> bool { + fn contains_node(&self, node: H::Node) -> bool { // Don't call self.get_parent(). That requires valid_node(node) // which infinitely-recurses back here. - node == self.root || self.base_hugr().get_parent(node) == Some(self.root) + node == self.root || self.hugr.get_parent(node) == Some(self.root) } - fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.base_hugr().node_ports(node, dir) + fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator + Clone { + self.hugr.node_ports(node, dir) } - fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { - self.base_hugr().all_node_ports(node) + fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone { + self.hugr.all_node_ports(node) } fn linked_ports( &self, - node: Node, + node: Self::Node, port: impl Into, - ) -> impl Iterator + Clone { + ) -> impl Iterator + Clone { self.hugr .linked_ports(node, port) .filter(|(n, _)| self.contains_node(*n)) } - fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { + fn node_connections( + &self, + node: Self::Node, + other: Self::Node, + ) -> impl Iterator + Clone { match self.contains_node(node) && self.contains_node(other) { // The nodes are not in the sibling graph false => Either::Left(iter::empty()), @@ -336,34 +354,66 @@ impl HugrView for SiblingMut<'_, Root> { } } - fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.base_hugr().num_ports(node, dir) + fn num_ports(&self, node: Self::Node, dir: Direction) -> usize { + self.hugr.num_ports(node, dir) } - fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { + fn neighbours( + &self, + node: Self::Node, + dir: Direction, + ) -> impl Iterator + Clone { self.hugr .neighbours(node, dir) .filter(|n| self.contains_node(*n)) } - fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { + fn all_neighbours(&self, node: Self::Node) -> impl Iterator + Clone { self.hugr .all_neighbours(node) .filter(|n| self.contains_node(*n)) } } -impl RootTagged for SiblingMut<'_, Root> { +impl> RootTagged for SiblingMut<'_, H, Root> { type RootHandle = Root; } -impl HugrMutInternals for SiblingMut<'_, Root> { - fn hugr_mut(&mut self) -> &mut Hugr { - self.hugr +impl> HugrMutInternals for SiblingMut<'_, H, Root> { + fn replace_op( + &mut self, + node: Self::Node, + op: impl Into, + ) -> Result { + let op = op.into(); + if node == self.root() && !Root::TAG.is_superset(op.tag()) { + return Err(HugrError::InvalidTag { + required: Root::TAG, + actual: op.tag(), + }); + } + self.hugr.replace_op(node, op) + } + + delegate::delegate! { + to (&mut *self.hugr) { + fn set_root(&mut self, root: Self::Node); + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); + fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; + fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); + fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; + fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; + } } } -impl HugrMut for SiblingMut<'_, Root> {} +impl> HugrMut for SiblingMut<'_, H, Root> { + super::impls::hugr_mut_methods! {this, &mut *this.hugr} +} #[cfg(test)] mod test { @@ -475,7 +525,7 @@ mod test { let mut def_region_hugr = hugr.clone(); let mut inner_region_hugr = hugr.clone(); - test_properties::( + test_properties::>( &hugr, def, inner, @@ -526,7 +576,7 @@ mod test { let root = simple_dfg_hugr.root(); let signature = simple_dfg_hugr.inner_function_type().unwrap().into_owned(); - let sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root); + let sib_mut = SiblingMut::<_, CfgID>::try_new(&mut simple_dfg_hugr, root); assert_eq!( sib_mut.err(), Some(HugrError::InvalidTag { @@ -535,7 +585,7 @@ mod test { }) ); - let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); + let mut sib_mut = SiblingMut::<_, DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap(); let bad_nodetype: OpType = crate::ops::CFG { signature }.into(); assert_eq!( sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()), @@ -560,7 +610,7 @@ mod test { .unwrap() .into_owned(), }; - let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); + let mut sib_mut = SiblingMut::<_, DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap(); // As expected, we cannot replace the root with a Case assert_eq!( sib_mut.replace_op(root, case_nodetype), @@ -570,7 +620,7 @@ mod test { }) ); - let nested_sib_mut = SiblingMut::::try_new(&mut sib_mut, root); + let nested_sib_mut = SiblingMut::<_, DataflowParentID>::try_new(&mut sib_mut, root); assert!(nested_sib_mut.is_err()); } diff --git a/hugr-llvm/src/utils/inline_constant_functions.rs b/hugr-llvm/src/utils/inline_constant_functions.rs index 28e664b97..a55072a99 100644 --- a/hugr-llvm/src/utils/inline_constant_functions.rs +++ b/hugr-llvm/src/utils/inline_constant_functions.rs @@ -11,12 +11,12 @@ fn const_fn_name(konst_n: Node) -> String { format!("const_fun_{}", konst_n.index()) } -pub fn inline_constant_functions(hugr: &mut impl HugrMut) -> Result<()> { +pub fn inline_constant_functions(hugr: &mut impl HugrMut) -> Result<()> { while inline_constant_functions_impl(hugr)? {} Ok(()) } -fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result { +fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result { let mut const_funs = vec![]; for n in hugr.nodes() { diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs index fb3319155..ad8ff1ec0 100644 --- a/hugr-passes/src/composable.rs +++ b/hugr-passes/src/composable.rs @@ -2,6 +2,7 @@ use std::{error::Error, marker::PhantomData}; +use hugr_core::core::HugrNode; use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; use hugr_core::HugrView; use itertools::Either; @@ -9,36 +10,40 @@ use itertools::Either; /// An optimization pass that can be sequenced with another and/or wrapped /// e.g. by [ValidatingPass] pub trait ComposablePass: Sized { + type Node: HugrNode; type Error: Error; type Result; // Would like to default to () but currently unstable - fn run(&self, hugr: &mut impl HugrMut) -> Result; + fn run(&self, hugr: &mut impl HugrMut) -> Result; fn map_err( self, f: impl Fn(Self::Error) -> E2, - ) -> impl ComposablePass { + ) -> impl ComposablePass { ErrMapper::new(self, f) } /// Returns a [ComposablePass] that does "`self` then `other`", so long as /// `other::Err` can be combined with ours. - fn then>( + fn then, E: ErrorCombiner>( self, other: P, - ) -> impl ComposablePass { + ) -> impl ComposablePass { struct Sequence(P1, P2, PhantomData); impl ComposablePass for Sequence where P1: ComposablePass, - P2: ComposablePass, + P2: ComposablePass, E: ErrorCombiner, { + type Node = P1::Node; type Error = E; - type Result = (P1::Result, P2::Result); - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run( + &self, + hugr: &mut impl HugrMut, + ) -> Result { let res1 = self.0.run(hugr).map_err(E::from_first)?; let res2 = self.1.run(hugr).map_err(E::from_second)?; Ok((res1, res2)) @@ -95,10 +100,11 @@ impl E> ErrMapper { } impl E> ComposablePass for ErrMapper { + type Node = P::Node; type Error = E; type Result = P::Result; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { self.0.run(hugr).map_err(&self.1) } } @@ -157,10 +163,11 @@ impl ValidatingPass

{ } impl ComposablePass for ValidatingPass

{ + type Node = P::Node; type Error = ValidatePassError; type Result = P::Result; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { err, pretty_hugr, @@ -180,8 +187,11 @@ impl ComposablePass for ValidatingPass

{ /// executes a second pass pub struct IfThen(A, B, PhantomData); -impl, B: ComposablePass, E: ErrorCombiner> - IfThen +impl< + A: ComposablePass, + B: ComposablePass, + E: ErrorCombiner, + > IfThen { /// Make a new instance given the [ComposablePass] to run first /// and (maybe) second @@ -190,14 +200,17 @@ impl, B: ComposablePass, E: ErrorCombiner, B: ComposablePass, E: ErrorCombiner> - ComposablePass for IfThen +impl< + A: ComposablePass, + B: ComposablePass, + E: ErrorCombiner, + > ComposablePass for IfThen { + type Node = A::Node; type Error = E; - type Result = Option; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?; res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second)) .transpose() @@ -206,7 +219,7 @@ impl, B: ComposablePass, E: ErrorCombiner( pass: P, - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, ) -> Result> { if cfg!(test) { ValidatingPass::new_default(pass).run(hugr) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 99ccc180c..b406ae894 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -79,6 +79,7 @@ impl ConstantFoldPass { } impl ComposablePass for ConstantFoldPass { + type Node = Node; type Error = ConstFoldError; type Result = (); @@ -88,7 +89,7 @@ impl ComposablePass for ConstantFoldPass { /// /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] /// was of an invalid [OpType] - fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -175,7 +176,7 @@ impl ComposablePass for ConstantFoldPass { /// /// [FuncDefn]: hugr_core::ops::OpType::FuncDefn /// [Module]: hugr_core::ops::OpType::Module -pub fn constant_fold_pass(h: &mut H) { +pub fn constant_fold_pass>(h: &mut H) { let c = ConstantFoldPass::default(); let c = if h.get_optype(h.root()).is_module() { let no_inputs: [(IncomingPort, _); 0] = []; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 240f4f2d6..f7b8a171c 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -52,7 +52,7 @@ pub struct Sum { } /// The output of an [LoadFunction](hugr_core::ops::LoadFunction) - a "pointer" -/// to a function at a specific node, instantiated with the provided type-args. +/// to a function at a specific node, instantiated with the provided type-args. #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct LoadedFunction { /// The [FuncDefn](hugr_core::ops::FuncDefn) or `FuncDecl`` that was loaded diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index 899e30243..d92fed134 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -158,10 +158,11 @@ impl DeadCodeElimPass { } impl ComposablePass for DeadCodeElimPass { + type Node = Node; type Error = Infallible; type Result = (); - fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { let needed = self.find_needed_nodes(&*hugr); let remove = hugr .nodes() diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index 7071d5335..d1714eac9 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -83,9 +83,10 @@ impl RemoveDeadFuncsPass { } impl ComposablePass for RemoveDeadFuncsPass { + type Node = Node; type Error = RemoveDeadFuncsError; type Result = (); - fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { let reachable = reachable_funcs( &CallGraph::new(hugr), hugr, @@ -125,7 +126,7 @@ impl ComposablePass for RemoveDeadFuncsPass { /// [LoadFunction]: hugr_core::ops::OpType::LoadFunction /// [Module]: hugr_core::ops::OpType::Module pub fn remove_dead_funcs( - h: &mut impl HugrMut, + h: &mut impl HugrMut, entry_points: impl IntoIterator, ) -> Result<(), ValidatePassError> { validate_if_test( diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index 689479b95..ad40e2164 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -36,7 +36,7 @@ use petgraph::{ /// there is no path from `n2` to `n1` (otherwise this would invalidate `hugr`). /// Nodes of equal rank will be ordered arbitrarily, although that arbitrary /// order is deterministic. -pub fn force_order( +pub fn force_order>( hugr: &mut H, root: Node, rank: impl Fn(&H, Node) -> i64, @@ -46,7 +46,7 @@ pub fn force_order( /// As [force_order], but allows a generic [Ord] choice for the result of the /// `rank` function. -pub fn force_order_by_key( +pub fn force_order_by_key, K: Ord>( hugr: &mut H, root: Node, rank: impl Fn(&H, Node) -> K, diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 8f8920967..8de6c00a2 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -15,7 +15,7 @@ use thiserror::Error; /// /// Returns a [`HugrError`] if any replacement fails. pub fn replace_many_ops>( - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, mapping: impl Fn(&OpType) -> Option, ) -> Result, HugrError> { let replacements = hugr @@ -54,7 +54,7 @@ pub enum LowerError { /// /// Returns a [`LowerError`] if the lowered HUGR is invalid or if any rewrite fails. pub fn lower_ops( - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, lowering: impl Fn(&OpType) -> Option, ) -> Result, LowerError> { let replacements = hugr diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index aeabc26ce..d1731107d 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -16,8 +16,8 @@ use hugr_core::{Hugr, HugrView, Node}; /// Merge any basic blocks that are direct children of the specified CFG /// i.e. where a basic block B has a single successor B' whose only predecessor /// is B, B and B' can be combined. -pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { - let mut worklist = cfg.nodes().collect::>(); +pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { + let mut worklist = cfg.children(cfg.root()).collect::>(); while let Some(n) = worklist.pop() { // Consider merging n with its successor let Ok(succ) = cfg.output_neighbours(n).exactly_one() else { @@ -33,13 +33,11 @@ pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { continue; }; let (rep, merge_bb, dfgs) = mk_rep(cfg, n, succ); - let node_map = cfg.hugr_mut().apply_rewrite(rep).unwrap(); + let node_map = cfg.apply_rewrite(rep).unwrap(); let merged_bb = *node_map.get(&merge_bb).unwrap(); for dfg_id in dfgs { let n_id = *node_map.get(&dfg_id).unwrap(); - cfg.hugr_mut() - .apply_rewrite(InlineDFG(n_id.into())) - .unwrap(); + cfg.apply_rewrite(InlineDFG(n_id.into())).unwrap(); } worklist.push(merged_bb); } @@ -160,12 +158,12 @@ mod test { use std::sync::Arc; use hugr_core::extension::prelude::PRELUDE_ID; + use hugr_core::hugr::views::RootChecked; use itertools::Itertools; use rstest::rstest; use hugr_core::builder::{endo_sig, inout_sig, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; use hugr_core::extension::prelude::{qb_t, usize_t, ConstUsize}; - use hugr_core::hugr::views::sibling::SiblingMut; use hugr_core::ops::constant::Value; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{LoadConstant, OpTrait, OpType}; @@ -254,7 +252,7 @@ mod test { let mut h = h.finish_hugr()?; let r = h.root(); - merge_basic_blocks(&mut SiblingMut::::try_new(&mut h, r)?); + merge_basic_blocks(&mut RootChecked::<_, CfgID>::try_new(&mut h).unwrap()); h.validate().unwrap(); assert_eq!(r, h.root()); assert!(matches!(h.get_optype(r), OpType::CFG(_))); @@ -348,8 +346,7 @@ mod test { h.branch(&bb3, 0, &h.exit_block())?; let mut h = h.finish_hugr()?; - let root = h.root(); - merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); + merge_basic_blocks(&mut RootChecked::<_, CfgID>::try_new(&mut h).unwrap()); h.validate()?; // Should only be one BB left diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 875ee9355..3164702d8 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -33,7 +33,9 @@ use crate::ComposablePass; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -pub fn monomorphize(hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { +pub fn monomorphize( + hugr: &mut impl HugrMut, +) -> Result<(), ValidatePassError> { validate_if_test(MonomorphizePass, hugr) } @@ -56,7 +58,7 @@ pub fn remove_polyfuncs(mut h: Hugr) -> Hugr { since = "0.14.1", note = "Use hugr_passes::RemoveDeadFuncsPass instead" )] -fn remove_polyfuncs_ref(h: &mut impl HugrMut) { +fn remove_polyfuncs_ref(h: &mut impl HugrMut) { let mut pfs_to_delete = Vec::new(); let mut to_scan = Vec::from_iter(h.children(h.root())); while let Some(n) = to_scan.pop() { @@ -92,7 +94,7 @@ type Instantiations = HashMap, Node>>; /// Optionally copies the subtree into a new location whilst applying a substitution. /// The subtree should be monomorphic after the substitution (if provided) has been applied. fn mono_scan( - h: &mut impl HugrMut, + h: &mut impl HugrMut, parent: Node, mut subst_into: Option<&mut Instantiating>, cache: &mut Instantiations, @@ -160,7 +162,7 @@ fn mono_scan( } fn instantiate( - h: &mut impl HugrMut, + h: &mut impl HugrMut, poly_func: Node, type_args: Vec, mono_sig: Signature, @@ -258,10 +260,11 @@ fn instantiate( pub struct MonomorphizePass; impl ComposablePass for MonomorphizePass { + type Node = Node; type Error = Infallible; type Result = (); - fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { let root = h.root(); // If the root is a polymorphic function, then there are no external calls, so nothing to do if !is_polymorphic_funcdefn(h.get_optype(root)) { diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 9baf250f9..1c4928e12 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -51,7 +51,7 @@ use hugr_core::hugr::{hugrmut::HugrMut, Rewrite, RootTagged}; use hugr_core::ops::handle::{BasicBlockID, CfgID}; use hugr_core::ops::OpTag; use hugr_core::ops::OpTrait; -use hugr_core::{Direction, Hugr}; +use hugr_core::{Direction, Hugr, Node}; /// A "view" of a CFG in a Hugr which allows basic blocks in the underlying CFG to be split into /// multiple blocks in the view (or merged together). @@ -155,7 +155,7 @@ pub fn transform_cfg_to_nested( pub fn transform_all_cfgs(h: &mut Hugr) { let mut node_stack = Vec::from([h.root()]); while let Some(n) = node_stack.pop() { - if let Ok(s) = SiblingMut::::try_new(h, n) { + if let Ok(s) = SiblingMut::<_, CfgID>::try_new(h, n) { transform_cfg_to_nested(&mut IdentityCfgMap::new(s)); } node_stack.extend(h.children(n)) @@ -246,7 +246,7 @@ impl CfgNodeMap for IdentityCfgMap { } } -impl CfgNester for IdentityCfgMap { +impl> CfgNester for IdentityCfgMap { fn nest_sese_region( &mut self, entry_edge: (H::Node, H::Node), @@ -760,7 +760,7 @@ pub(crate) mod test { // Again, there's no need for a view of a region here, but check that the // transformation still works when we can only directly mutate the top level let root = h.root(); - let m = SiblingMut::::try_new(&mut h, root).unwrap(); + let m = SiblingMut::<_, CfgID>::try_new(&mut h, root).unwrap(); transform_cfg_to_nested(&mut IdentityCfgMap::new(m)); h.validate().unwrap(); assert_eq!(1, depth(&h, entry)); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index df4c14075..d33234126 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -66,7 +66,11 @@ impl NodeTemplate { /// * has a [`signature`] which the type-args of the [Self::Call] do not match /// /// [`signature`]: hugr_core::types::PolyFuncType - pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result { + pub fn add_hugr( + self, + hugr: &mut impl HugrMut, + parent: Node, + ) -> Result { match self { NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)), NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root), @@ -97,7 +101,7 @@ impl NodeTemplate { } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, @@ -375,7 +379,11 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { + fn change_node( + &self, + hugr: &mut impl HugrMut, + n: Node, + ) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) | OpType::FuncDecl(FuncDecl { signature, .. }) => signature.body_mut().transform(self), @@ -505,10 +513,11 @@ impl ReplaceTypes { } impl ComposablePass for ReplaceTypes { + type Node = Node; type Error = ReplaceTypesError; type Result = bool; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { changed |= self.change_node(hugr, n)?; diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5c4a4a707..321ec194f 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -49,7 +49,7 @@ pub trait Linearizer { /// if `src` is not a valid Wire (does not identify a dataflow out-port) fn insert_copy_discard( &self, - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index 874fd9ec3..d074bed0f 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -122,11 +122,11 @@ impl UntuplePass { } impl ComposablePass for UntuplePass { + type Node = Node; type Error = UntupleError; - type Result = UntupleResult; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { let rewrites = self.find_rewrites(hugr, self.parent.unwrap_or(hugr.root())); let rewrites_applied = rewrites.len(); // The rewrites are independent, so we can always apply them all. From d91dbe6481a201d24be0c0814b70fcdd1d5dbc45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:55:33 +0100 Subject: [PATCH 11/18] feat!: Removed model_unstable feature flag (#2120) Re-created from #2113, targeting the release branch instead. BREAKING CHANGE: Downstream crates need to remove the model_unstable feature flag when referencing hugr or hugr-core. --------- Co-authored-by: Lukas Heidemann --- .github/workflows/ci-rs.yml | 43 +++++++++---------- hugr-core/Cargo.toml | 4 +- hugr-core/README.md | 20 ++++----- hugr-core/src/envelope.rs | 33 +------------- hugr-core/src/lib.rs | 2 - .../std_extensions/arithmetic/float_types.rs | 1 - .../std_extensions/arithmetic/int_types.rs | 1 - .../src/std_extensions/collections/array.rs | 1 - hugr/Cargo.toml | 3 +- hugr/benches/benchmarks/hugr.rs | 23 ++++------ release-plz.toml | 4 +- 11 files changed, 42 insertions(+), 93 deletions(-) diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index 824291a8b..56c093e8b 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '**' + - "**" merge_group: types: [checks_requested] workflow_dispatch: {} @@ -25,7 +25,6 @@ env: LLVM_VERSION: "14.0" LLVM_FEATURE_NAME: "14-0" - jobs: # Check if changes were made to the relevant files. # Always returns true if running on the default branch, to ensure all changes are thoroughly checked. @@ -43,25 +42,25 @@ jobs: model: ${{ steps.filter.outputs.model == 'true' || steps.override.outputs.out == 'true' }} llvm: ${{ steps.filter.outputs.llvm == 'true' || steps.override.outputs.out == 'true' }} steps: - - uses: actions/checkout@v4 - - name: Override label - id: override - run: | - echo "Label contains run-ci-checks: $OVERRIDE_LABEL" - if [ "$OVERRIDE_LABEL" == "true" ]; then - echo "Overriding due to label 'run-ci-checks'" - echo "out=true" >> $GITHUB_OUTPUT - elif [ "$DEFAULT_BRANCH" == "true" ]; then - echo "Overriding due to running on the default branch" - echo "out=true" >> $GITHUB_OUTPUT - fi - env: - OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} - DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} - - uses: dorny/paths-filter@v3 - id: filter - with: - filters: .github/change-filters.yml + - uses: actions/checkout@v4 + - name: Override label + id: override + run: | + echo "Label contains run-ci-checks: $OVERRIDE_LABEL" + if [ "$OVERRIDE_LABEL" == "true" ]; then + echo "Overriding due to label 'run-ci-checks'" + echo "out=true" >> $GITHUB_OUTPUT + elif [ "$DEFAULT_BRANCH" == "true" ]; then + echo "Overriding due to running on the default branch" + echo "out=true" >> $GITHUB_OUTPUT + fi + env: + OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} + DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: .github/change-filters.yml check: needs: changes @@ -109,7 +108,7 @@ jobs: - name: Override criterion with the CodSpeed harness run: cargo add --dev codspeed-criterion-compat --rename criterion --package hugr - name: Build benchmarks - run: cargo codspeed build --profile bench --features extension_inference,declarative,model_unstable,llvm,llvm-test + run: cargo codspeed build --profile bench --features extension_inference,declarative,llvm,llvm-test - name: Run benchmarks uses: CodSpeedHQ/action@v3 with: diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 22e5390fc..1e4fa392f 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -19,7 +19,6 @@ workspace = true [features] extension_inference = [] declarative = ["serde_yaml"] -model_unstable = ["hugr-model"] zstd = ["dep:zstd"] [lib] @@ -27,10 +26,9 @@ bench = false [[test]] name = "model" -required-features = ["model_unstable"] [dependencies] -hugr-model = { version = "0.19.0", path = "../hugr-model", optional = true } +hugr-model = { version = "0.19.0", path = "../hugr-model" } cgmath = { workspace = true, features = ["serde"] } delegate = { workspace = true } diff --git a/hugr-core/README.md b/hugr-core/README.md index 46cafe16f..379041a5b 100644 --- a/hugr-core/README.md +++ b/hugr-core/README.md @@ -1,7 +1,6 @@ ![](/hugr/assets/hugr_logo.svg) -hugr-core -=============== +# hugr-core [![build_status][]](https://github.com/CQCL/hugr/actions) [![crates][]](https://crates.io/crates/hugr-core) @@ -21,9 +20,6 @@ Please read the [API documentation here][]. Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. -- `model_unstable` - Import and export from the representation defined in the `hugr-model` crate. - Unstable and subject to change. Not enabled by default. ## Recent Changes @@ -38,10 +34,10 @@ See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). - [API documentation here]: https://docs.rs/hugr-core/ - [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg - [crates]: https://img.shields.io/crates/v/hugr-core - [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov - [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-core/CHANGELOG.md +[API documentation here]: https://docs.rs/hugr-core/ +[build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main +[msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg +[crates]: https://img.shields.io/crates/v/hugr-core +[codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov +[LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE +[CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-core/CHANGELOG.md diff --git a/hugr-core/src/envelope.rs b/hugr-core/src/envelope.rs index 35ea9c85f..24c348b78 100644 --- a/hugr-core/src/envelope.rs +++ b/hugr-core/src/envelope.rs @@ -55,7 +55,6 @@ use std::io::Write; #[allow(unused_imports)] use itertools::Itertools as _; -#[cfg(feature = "model_unstable")] use crate::import::ImportError; /// Read a HUGR envelope from a reader. @@ -197,19 +196,16 @@ pub enum EnvelopeError { source: PackageEncodingError, }, /// Error importing a HUGR from a hugr-model payload. - #[cfg(feature = "model_unstable")] ModelImport { /// The source error. source: ImportError, }, /// Error reading a HUGR model payload. - #[cfg(feature = "model_unstable")] ModelRead { /// The source error. source: hugr_model::v0::binary::ReadError, }, /// Error writing a HUGR model payload. - #[cfg(feature = "model_unstable")] ModelWrite { /// The source error. source: hugr_model::v0::binary::WriteError, @@ -225,17 +221,9 @@ fn read_impl( match header.format { #[allow(deprecated)] EnvelopeFormat::PackageJson => Ok(Package::from_json_reader(payload, registry)?), - #[cfg(feature = "model_unstable")] EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { decode_model(payload, registry, header.format) } - #[cfg(not(feature = "model_unstable"))] - EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { - Err(EnvelopeError::FormatUnsupported { - format: header.format, - feature: Some("model_unstable"), - }) - } } } @@ -246,7 +234,6 @@ fn read_impl( /// - `extension_registry`: An extension registry with additional extensions to use when /// decoding the HUGR, if they are not already included in the package. /// - `format`: The format of the payload. -#[cfg(feature = "model_unstable")] fn decode_model( mut stream: impl BufRead, extension_registry: &ExtensionRegistry, @@ -286,22 +273,13 @@ fn write_impl( match config.format { #[allow(deprecated)] EnvelopeFormat::PackageJson => package.to_json_writer(writer)?, - #[cfg(feature = "model_unstable")] EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { encode_model(writer, package, config.format)? } - #[cfg(not(feature = "model_unstable"))] - EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { - return Err(EnvelopeError::FormatUnsupported { - format: config.format, - feature: Some("model_unstable"), - }) - } } Ok(()) } -#[cfg(feature = "model_unstable")] fn encode_model( mut writer: impl Write, package: &Package, @@ -391,7 +369,6 @@ mod tests { //#[case::empty(Package::default())] // Not currently supported #[case::simple(simple_package())] //#[case::multi(multi_module_package())] // Not currently supported - #[cfg(feature = "model_unstable")] fn module_exts_roundtrip(#[case] package: Package) { let mut buffer = Vec::new(); let config = EnvelopeConfig { @@ -417,15 +394,7 @@ mod tests { format: EnvelopeFormat::Model, zstd: None, }; - let res = package.store(&mut buffer, config); - - match cfg!(feature = "model_unstable") { - true => res.unwrap(), - false => { - assert_matches!(res, Err(EnvelopeError::FormatUnsupported { .. })); - return; - } - } + package.store(&mut buffer, config).unwrap(); let (decoded_config, new_package) = read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap(); diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index e32b623f2..e5f57d2a8 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -12,11 +12,9 @@ pub mod builder; pub mod core; pub mod envelope; -#[cfg(feature = "model_unstable")] pub mod export; pub mod extension; pub mod hugr; -#[cfg(feature = "model_unstable")] pub mod import; pub mod macros; pub mod ops; diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 3122bf30f..200e9dcbf 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -65,7 +65,6 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Name of the constructor for creating constant 64bit floats. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.float.const_f64"; /// Create a new [`ConstF64`] diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index e5d625695..1342dd932 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -105,7 +105,6 @@ pub struct ConstInt { impl ConstInt { /// Name of the constructor for creating constant integers. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.int.const"; /// Create a new [`ConstInt`] with a given width and unsigned value diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index 0332ff351..fac12b1bf 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -45,7 +45,6 @@ pub struct ArrayValue { impl ArrayValue { /// Name of the constructor for creating constant arrays. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "collections.array.const"; /// Create a new [CustomConst] for an array of values of type `typ`. diff --git a/hugr/Cargo.toml b/hugr/Cargo.toml index c0439a960..3763366ae 100644 --- a/hugr/Cargo.toml +++ b/hugr/Cargo.toml @@ -26,13 +26,12 @@ default = ["zstd"] extension_inference = ["hugr-core/extension_inference"] declarative = ["hugr-core/declarative"] -model_unstable = ["hugr-core/model_unstable", "hugr-model"] llvm = ["hugr-llvm/llvm14-0"] llvm-test = ["hugr-llvm/llvm14-0", "hugr-llvm/test-utils"] zstd = ["hugr-core/zstd"] [dependencies] -hugr-model = { path = "../hugr-model", optional = true, version = "0.19.0" } +hugr-model = { path = "../hugr-model", version = "0.19.0" } hugr-core = { path = "../hugr-core", version = "0.15.3" } hugr-passes = { path = "../hugr-passes", version = "0.15.3" } hugr-llvm = { path = "../hugr-llvm", version = "0.15.3", optional = true } diff --git a/hugr/benches/benchmarks/hugr.rs b/hugr/benches/benchmarks/hugr.rs index 49d73d58e..3635c8d09 100644 --- a/hugr/benches/benchmarks/hugr.rs +++ b/hugr/benches/benchmarks/hugr.rs @@ -24,10 +24,8 @@ impl Serializer for JsonSer { } } -#[cfg(feature = "model_unstable")] struct CapnpSer; -#[cfg(feature = "model_unstable")] impl Serializer for CapnpSer { fn serialize(&self, hugr: &Hugr) -> Vec { let bump = bumpalo::Bump::new(); @@ -90,20 +88,17 @@ fn bench_serialization(c: &mut Criterion) { } group.finish(); - #[cfg(feature = "model_unstable")] - { - let mut group = c.benchmark_group("circuit_roundtrip/capnp"); - group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); - for size in [0, 1, 10, 100, 1000].iter() { - group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { - let h = circuit(size).0; - b.iter(|| { - black_box(roundtrip(&h, CapnpSer)); - }); + let mut group = c.benchmark_group("circuit_roundtrip/capnp"); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for size in [0, 1, 10, 100, 1000].iter() { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let h = circuit(size).0; + b.iter(|| { + black_box(roundtrip(&h, CapnpSer)); }); - } - group.finish(); + }); } + group.finish(); } criterion_group! { diff --git a/release-plz.toml b/release-plz.toml index 091ca3795..4bc9f7104 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -63,9 +63,7 @@ version_group = "hugr" [[package]] name = "hugr-model" release = true -# Use a separate version group while the dependency is `-unstable`, -# to avoid breaking releases of the main package. -version_group = "hugr-model" +version_group = "hugr" [[package]] name = "hugr-llvm" From 6ca12580c26d353cbe418f64d866db6600c045eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 28 Apr 2025 14:06:39 +0100 Subject: [PATCH 12/18] feat!: Remove `RootTagged` from the hugr view trait hierarchy (#2122) Closes #2077 BREAKING CHANGE: Removed `RootTagged` trait. Now `RootChecked` is a non-hugrview wrapper only used to verify inputs where appropriate. --- hugr-core/src/builder/dataflow.rs | 24 +-- hugr-core/src/builder/module.rs | 3 +- hugr-core/src/hugr.rs | 9 +- hugr-core/src/hugr/hugrmut.rs | 6 +- hugr-core/src/hugr/internal.rs | 26 +-- hugr-core/src/hugr/rewrite.rs | 3 +- hugr-core/src/hugr/rewrite/inline_call.rs | 9 +- hugr-core/src/hugr/rewrite/replace.rs | 3 +- hugr-core/src/hugr/validate/test.rs | 26 +-- hugr-core/src/hugr/views.rs | 27 +-- hugr-core/src/hugr/views/descendants.rs | 7 +- hugr-core/src/hugr/views/impls.rs | 28 +-- hugr-core/src/hugr/views/root_checked.rs | 194 +++++++----------- hugr-core/src/hugr/views/sibling.rs | 63 +----- hugr-core/src/hugr/views/sibling_subgraph.rs | 34 ++- hugr-core/src/package.rs | 3 +- .../src/utils/inline_constant_functions.rs | 2 +- hugr-passes/src/half_node.rs | 18 +- hugr-passes/src/lower.rs | 16 +- hugr-passes/src/merge_bbs.rs | 18 +- hugr-passes/src/monomorphize.rs | 4 +- hugr-passes/src/nest_cfgs.rs | 15 +- hugr/src/hugr.rs | 2 +- uv.lock | 2 +- 24 files changed, 191 insertions(+), 351 deletions(-) diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index ebad52085..64c5f5c84 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -174,9 +174,7 @@ impl FunctionBuilder { // Update the inner input node let types = new_optype.signature.body().input.clone(); - self.hugr_mut() - .replace_op(inp_node, Input { types }) - .unwrap(); + self.hugr_mut().replace_op(inp_node, Input { types }); let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1); let new_port = new_port.next().unwrap(); @@ -211,9 +209,7 @@ impl FunctionBuilder { // Update the inner input node let types = new_optype.signature.body().output.clone(); - self.hugr_mut() - .replace_op(out_node, Output { types }) - .unwrap(); + self.hugr_mut().replace_op(out_node, Output { types }); let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1); let new_port = new_port.next().unwrap(); @@ -250,15 +246,13 @@ impl FunctionBuilder { .expect("FunctionBuilder node must be a FuncDefn"); let signature = old_optype.inner_signature().into_owned(); let name = old_optype.name.clone(); - self.hugr_mut() - .replace_op( - parent, - ops::FuncDefn { - signature: f(signature).into(), - name, - }, - ) - .expect("Could not replace FunctionBuilder operation"); + self.hugr_mut().replace_op( + parent, + ops::FuncDefn { + signature: f(signature).into(), + name, + }, + ); self.hugr().get_optype(parent).as_func_defn().unwrap() } diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 18390926e..1387c1ec5 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -83,8 +83,7 @@ impl + AsRef> ModuleBuilder { .clone(); let body = signature.body().clone(); self.hugr_mut() - .replace_op(f_node, ops::FuncDefn { name, signature }) - .expect("Replacing a FuncDecl node with a FuncDefn should always be valid"); + .replace_op(f_node, ops::FuncDefn { name, signature }); let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; Ok(FunctionBuilder::from_dfg_builder(db)) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 2789ae056..d708f15cd 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -23,7 +23,7 @@ use portgraph::multiportgraph::MultiPortGraph; use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap}; use thiserror::Error; -pub use self::views::{HugrView, RootTagged}; +pub use self::views::HugrView; use crate::core::NodeIndex; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionResolutionError, @@ -367,13 +367,10 @@ pub struct ExtensionError { } /// Errors that can occur while manipulating a Hugr. -/// -/// TODO: Better descriptions, not just re-exporting portgraph errors. #[derive(Debug, Clone, PartialEq, Eq, Error)] #[non_exhaustive] pub enum HugrError { /// The node was not of the required [OpTag] - /// (e.g. to conform to the [RootTagged::RootHandle] of a [HugrView]) #[error("Invalid tag: required a tag in {required} but found {actual}")] #[allow(missing_docs)] InvalidTag { required: OpTag, actual: OpTag }, @@ -671,12 +668,12 @@ mod test { signature: Signature::new_endo(ty).with_extension_delta(result.clone()), }; let mut expected = backup; - expected.replace_op(p, expected_p).unwrap(); + expected.replace_op(p, expected_p); let expected_gp = ops::Conditional { extension_delta: result, ..root_ty }; - expected.replace_op(h.root(), expected_gp).unwrap(); + expected.replace_op(h.root(), expected_gp); assert_eq!(h, expected); } else { diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index bf9a4cad0..51e92f342 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -156,14 +156,14 @@ pub trait HugrMut: HugrMutInternals { /// If the node is not in the graph, or if the port is invalid. fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (OutgoingPort, IncomingPort); - /// Insert another hugr into this one, under a given root node. + /// Insert another hugr into this one, under a given parent node. /// /// # Panics /// /// If the root node is not in the graph. fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult; - /// Copy another hugr into this one, under a given root node. + /// Copy another hugr into this one, under a given parent node. /// /// # Panics /// @@ -174,7 +174,7 @@ pub trait HugrMut: HugrMutInternals { other: &H, ) -> InsertionResult; - /// Copy a subgraph from another hugr into this one, under a given root node. + /// Copy a subgraph from another hugr into this one, under a given parent node. /// /// Sibling order is not preserved. /// diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 8892c3b11..58ce066c0 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -12,7 +12,7 @@ use crate::ops::handle::NodeHandle; use crate::{Direction, Hugr, Node}; use super::hugrmut::{panic_invalid_node, panic_invalid_non_root}; -use super::{HugrError, NodeMetadataMap, OpType, RootTagged}; +use super::{HugrView, NodeMetadataMap, OpType}; /// Trait for accessing the internals of a Hugr(View). /// @@ -107,7 +107,7 @@ impl HugrInternals for Hugr { /// /// Specifically, this trait lets you apply arbitrary modifications that may /// invalidate the HUGR. -pub trait HugrMutInternals: RootTagged { +pub trait HugrMutInternals: HugrView { /// Set root node of the HUGR. /// /// This should be an existing node in the HUGR. Most operations use the @@ -189,18 +189,10 @@ pub trait HugrMutInternals: RootTagged { /// /// Returns the old OpType. /// - /// If the module root is set to a non-module operation the hugr will - /// become invalid. - /// - /// # Errors - /// - /// Returns a [`HugrError::InvalidTag`] if this would break the bound - /// (`Self::RootHandle`) on the root node's OpTag. - /// /// # Panics /// /// If the node is not in the graph. - fn replace_op(&mut self, node: Self::Node, op: impl Into) -> Result; + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> OpType; /// Gets a mutable reference to the optype. /// @@ -223,9 +215,10 @@ pub trait HugrMutInternals: RootTagged { /// If the node is not in the graph. fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap; - /// Returns a mutable reference to the extension registry for this hugr, - /// containing all extensions required to define the operations and types in - /// the hugr. + /// Returns a mutable reference to the extension registry for this HUGR. + /// + /// This set contains all extensions required to define the operations and + /// types in the HUGR. fn extensions_mut(&mut self) -> &mut ExtensionRegistry; } @@ -326,10 +319,9 @@ impl HugrMutInternals for Hugr { .expect("Inserting a newly-created node into the hierarchy should never fail."); } - fn replace_op(&mut self, node: Node, op: impl Into) -> Result { + fn replace_op(&mut self, node: Node, op: impl Into) -> OpType { panic_invalid_node(self, node); - // We know RootHandle=Node here so no need to check - Ok(std::mem::replace(self.optype_mut(node), op.into())) + std::mem::replace(self.optype_mut(node), op.into()) } fn optype_mut(&mut self, node: Self::Node) -> &mut OpType { diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index d2b0fe14d..e220864a7 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -82,8 +82,7 @@ impl Rewrite for Transactional { let r = self.underlying.apply(h); if r.is_err() { // Try to restore backup. - h.replace_op(h.root(), backup.root_type().clone()) - .expect("The root replacement should always match the old root type"); + h.replace_op(h.root(), backup.root_type().clone()); while let Some(child) = h.first_child(h.root()) { h.remove_node(child); } diff --git a/hugr-core/src/hugr/rewrite/inline_call.rs b/hugr-core/src/hugr/rewrite/inline_call.rs index 6b1e7a958..e32373507 100644 --- a/hugr-core/src/hugr/rewrite/inline_call.rs +++ b/hugr-core/src/hugr/rewrite/inline_call.rs @@ -76,7 +76,6 @@ impl Rewrite for InlineCall { let ty_args = h .replace_op(self.0, new_op) - .unwrap() .as_call() .unwrap() .type_args @@ -117,8 +116,7 @@ mod test { ModuleBuilder, }; use crate::extension::prelude::usize_t; - use crate::hugr::views::RootChecked; - use crate::ops::handle::{FuncID, ModuleRootID, NodeHandle}; + use crate::ops::handle::{FuncID, NodeHandle}; use crate::ops::{Input, OpType, Value}; use crate::std_extensions::arithmetic::{ int_ops::{self, IntOpDef}, @@ -179,10 +177,7 @@ mod test { .count(), 1 ); - RootChecked::<_, ModuleRootID>::try_new(&mut hugr) - .unwrap() - .apply_rewrite(InlineCall(call1.node())) - .unwrap(); + hugr.apply_rewrite(InlineCall(call1.node())).unwrap(); hugr.validate().unwrap(); assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]); assert_eq!(calls(&hugr), [call2]); diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index c2659cc5a..0316f9d5b 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -732,8 +732,7 @@ mod test { // Root node type needs to be that of common parent of the removed nodes: let mut rep2 = rep.clone(); rep2.replacement - .replace_op(rep2.replacement.root(), h.root_type().clone()) - .unwrap(); + .replace_op(rep2.replacement.root(), h.root_type().clone()); assert_eq!( check_same_errors(rep2), ReplaceError::WrongRootNodeTag { diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 37157020d..a66296c35 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -191,26 +191,23 @@ fn df_children_restrictions() { .unwrap(); // Replace the output operation of the df subgraph with a copy - b.replace_op(output, Noop(usize_t())).unwrap(); + b.replace_op(output, Noop(usize_t())); assert_matches!( b.validate(), Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def) ); // Revert it back to an output, but with the wrong number of ports - b.replace_op(output, ops::Output::new(vec![bool_t()])) - .unwrap(); + b.replace_op(output, ops::Output::new(vec![bool_t()])); assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} ); - b.replace_op(output, ops::Output::new(vec![bool_t(), bool_t()])) - .unwrap(); + b.replace_op(output, ops::Output::new(vec![bool_t(), bool_t()])); // After fixing the output back, replace the copy with an output op - b.replace_op(copy, ops::Output::new(vec![bool_t(), bool_t()])) - .unwrap(); + b.replace_op(copy, ops::Output::new(vec![bool_t(), bool_t()])); assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) @@ -806,8 +803,7 @@ fn cfg_children_restrictions() { ops::CFG { signature: Signature::new(vec![bool_t()], vec![bool_t()]), }, - ) - .unwrap(); + ); assert_matches!( b.validate(), Err(ValidationError::ContainerWithoutChildren { .. }) @@ -869,8 +865,7 @@ fn cfg_children_restrictions() { ops::CFG { signature: Signature::new(vec![qb_t()], vec![bool_t()]), }, - ) - .unwrap(); + ); b.replace_op( block, ops::DataflowBlock { @@ -879,18 +874,15 @@ fn cfg_children_restrictions() { other_outputs: vec![qb_t()].into(), extension_delta: ExtensionSet::new(), }, - ) - .unwrap(); + ); let mut block_children = b.hierarchy.children(block.pg_index()); let block_input = block_children.next().unwrap().into(); let block_output = block_children.next_back().unwrap().into(); - b.replace_op(block_input, ops::Input::new(vec![qb_t()])) - .unwrap(); + b.replace_op(block_input, ops::Input::new(vec![qb_t()])); b.replace_op( block_output, ops::Output::new(vec![Type::new_unit_sum(1), qb_t()]), - ) - .unwrap(); + ); assert_matches!( b.validate(), Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index eb8059577..a154a956f 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -16,7 +16,7 @@ use std::borrow::Cow; pub use self::petgraph::PetgraphWrapper; use self::render::RenderConfig; pub use descendants::DescendantsGraph; -pub use root_checked::RootChecked; +pub use root_checked::{check_tag, RootCheckable, RootChecked}; pub use sibling::SiblingGraph; pub use sibling_subgraph::SiblingSubgraph; @@ -29,7 +29,6 @@ use super::{ Hugr, HugrError, HugrMut, Node, NodeMetadata, NodeMetadataMap, ValidationError, DEFAULT_OPTYPE, }; use crate::extension::ExtensionRegistry; -use crate::ops::handle::NodeHandle; use crate::ops::{OpParent, OpTag, OpTrait, OpType}; use crate::types::{EdgeKind, PolyFuncType, Signature, Type}; @@ -479,17 +478,8 @@ pub trait HugrView: HugrInternals { } } -/// Trait for views that provides a guaranteed bound on the type of the root node. -pub trait RootTagged: HugrView { - /// The kind of handle that can be used to refer to the root node. - /// - /// The handle is guaranteed to be able to contain the operation returned by - /// [`HugrView::root_type`]. - type RootHandle: NodeHandle; -} - /// A common trait for views of a HUGR hierarchical subgraph. -pub trait HierarchyView<'a>: RootTagged + Sized { +pub trait HierarchyView<'a>: HugrView + Sized { /// Create a hierarchical view of a HUGR given a root node. /// /// # Errors @@ -515,19 +505,6 @@ pub trait ExtractHugr: HugrView + Sized { } } -/// Check that the node in a HUGR can be represented by the required tag. -fn check_tag, N>( - hugr: &impl HugrView, - node: N, -) -> Result<(), HugrError> { - let actual = hugr.get_optype(node).tag(); - let required = Required::TAG; - if !required.is_superset(actual) { - return Err(HugrError::InvalidTag { required, actual }); - } - Ok(()) -} - // Explicit implementation to avoid cloning the Hugr. impl ExtractHugr for Hugr { fn extract_hugr(self) -> Hugr { diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 28a7d9f2d..906dea3e4 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -8,7 +8,7 @@ use crate::hugr::HugrError; use crate::ops::handle::NodeHandle; use crate::{Direction, Hugr, Node, Port}; -use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; +use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView}; type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>; @@ -131,16 +131,13 @@ impl HugrView for DescendantsGraph<'_, Root> { .map(|index| self.get_node(index)) } } -impl RootTagged for DescendantsGraph<'_, Root> { - type RootHandle = Root; -} impl<'a, Root> HierarchyView<'a> for DescendantsGraph<'a, Root> where Root: NodeHandle, { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { - check_tag::(hugr, root)?; + check_tag::(hugr, root)?; let hugr = hugr.base_hugr(); Ok(Self { root, diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 928acba20..440df9480 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -3,11 +3,8 @@ use std::{borrow::Cow, rc::Rc, sync::Arc}; use super::HugrView; -use super::RootTagged; use crate::hugr::internal::{HugrInternals, HugrMutInternals}; use crate::hugr::HugrMut; -use crate::Hugr; -use crate::Node; macro_rules! hugr_internal_methods { // The extra ident here is because invocations of the macro cannot pass `self` as argument @@ -116,7 +113,7 @@ macro_rules! hugr_mut_internal_methods { fn set_parent(&mut self, node: Self::Node, parent: Self::Node); fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); - fn replace_op(&mut self, node: Self::Node, op: impl Into) -> Result; + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> crate::ops::OpType; fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; @@ -149,11 +146,6 @@ macro_rules! hugr_mut_methods { } pub(crate) use hugr_mut_methods; -// -------- Base Hugr implementation -impl RootTagged for Hugr { - type RootHandle = Node; -} - // -------- Immutable borrow impl HugrInternals for &T { type Portgraph<'p> @@ -167,9 +159,6 @@ impl HugrInternals for &T { impl HugrView for &T { hugr_view_methods! {this, *this} } -impl RootTagged for &T { - type RootHandle = T::RootHandle; -} // -------- Mutable borrow impl HugrInternals for &mut T { @@ -184,9 +173,6 @@ impl HugrInternals for &mut T { impl HugrView for &mut T { hugr_view_methods! {this, &**this} } -impl RootTagged for &mut T { - type RootHandle = T::RootHandle; -} impl HugrMutInternals for &mut T { hugr_mut_internal_methods! {this, &mut **this} } @@ -207,9 +193,6 @@ impl HugrInternals for Rc { impl HugrView for Rc { hugr_view_methods! {this, this.as_ref()} } -impl RootTagged for Rc { - type RootHandle = T::RootHandle; -} // -------- Arc impl HugrInternals for Arc { @@ -224,9 +207,6 @@ impl HugrInternals for Arc { impl HugrView for Arc { hugr_view_methods! {this, this.as_ref()} } -impl RootTagged for Arc { - type RootHandle = T::RootHandle; -} // -------- Box impl HugrInternals for Box { @@ -241,9 +221,6 @@ impl HugrInternals for Box { impl HugrView for Box { hugr_view_methods! {this, this.as_ref()} } -impl RootTagged for Box { - type RootHandle = T::RootHandle; -} impl HugrMutInternals for Box { hugr_mut_internal_methods! {this, this.as_mut()} } @@ -264,9 +241,6 @@ impl HugrInternals for Cow<'_, T> { impl HugrView for Cow<'_, T> { hugr_view_methods! {this, this.as_ref()} } -impl RootTagged for Cow<'_, T> { - type RootHandle = T::RootHandle; -} impl HugrMutInternals for Cow<'_, T> where T: HugrMutInternals + ToOwned, diff --git a/hugr-core/src/hugr/views/root_checked.rs b/hugr-core/src/hugr/views/root_checked.rs index e0dcf3eb7..50c9bcf44 100644 --- a/hugr-core/src/hugr/views/root_checked.rs +++ b/hugr-core/src/hugr/views/root_checked.rs @@ -1,20 +1,28 @@ -use std::borrow::Cow; use std::marker::PhantomData; -use crate::hugr::internal::{HugrInternals, HugrMutInternals}; -use crate::hugr::{HugrError, HugrMut}; +use crate::hugr::HugrError; use crate::ops::handle::NodeHandle; -use crate::ops::OpTrait; +use crate::ops::{OpTag, OpTrait}; use crate::{Hugr, Node}; -use super::{check_tag, HugrView, RootTagged}; +use super::HugrView; -/// A view of the whole Hugr. -/// (Just provides static checking of the type of the root node) +/// A wrapper over a Hugr that ensures the root node optype is of the required +/// [`OpTag`]. #[derive(Clone)] -pub struct RootChecked(H, PhantomData); +pub struct RootChecked(H, PhantomData); + +impl> RootChecked { + /// A tag that can contain the operation of the hugr root node. + const TAG: OpTag = Handle::TAG; + + /// Returns the most specific tag that can be applied to the root node. + pub fn tag(&self) -> OpTag { + let tag = self.0.get_optype(self.0.root()).tag(); + debug_assert!(Self::TAG.is_superset(tag)); + tag + } -impl> RootChecked { /// Create a hierarchical view of a whole HUGR /// /// # Errors @@ -22,101 +30,80 @@ impl> RootChecked { /// /// [`OpTag`]: crate::ops::OpTag pub fn try_new(hugr: H) -> Result { - if !H::RootHandle::TAG.is_superset(Root::TAG) { - return Err(HugrError::InvalidTag { - required: H::RootHandle::TAG, - actual: Root::TAG, - }); - } - check_tag::(&hugr, hugr.root())?; + Self::check(&hugr)?; Ok(Self(hugr, PhantomData)) } -} -impl RootChecked { - /// Extracts the underlying (owned) Hugr - pub fn into_hugr(self) -> Hugr { - self.0 + /// Check if a Hugr is valid for the given [`OpTag`]. + /// + /// To check arbitrary nodes, use [`check_tag`]. + pub fn check(hugr: &H) -> Result<(), HugrError> { + check_tag::(hugr, hugr.root())?; + Ok(()) } -} -impl RootChecked<&mut Hugr, Root> { - /// Allows immutably borrowing the underlying mutable reference - pub fn borrow(&self) -> RootChecked<&Hugr, Root> { - RootChecked(&*self.0, PhantomData) + /// Returns a reference to the underlying Hugr. + pub fn hugr(&self) -> &H { + &self.0 } -} -impl HugrInternals for RootChecked { - type Portgraph<'p> - = H::Portgraph<'p> - where - Self: 'p; - type Node = H::Node; - - super::impls::hugr_internal_methods! {this, &this.0} -} - -impl HugrView for RootChecked { - super::impls::hugr_view_methods! {this, &this.0} -} + /// Extracts the underlying Hugr + pub fn into_hugr(self) -> H { + self.0 + } -impl> RootTagged for RootChecked { - type RootHandle = Root; + /// Returns a wrapper over a reference to the underlying Hugr. + pub fn as_ref(&self) -> RootChecked<&H, Handle> { + RootChecked(&self.0, PhantomData) + } } -impl, Root> AsRef for RootChecked { +impl, Handle> AsRef for RootChecked { fn as_ref(&self) -> &Hugr { self.0.as_ref() } } -impl> HugrMutInternals for RootChecked { - fn replace_op( - &mut self, - node: Self::Node, - op: impl Into, - ) -> Result { - let op = op.into(); - if node == self.root() && !Root::TAG.is_superset(op.tag()) { - return Err(HugrError::InvalidTag { - required: Root::TAG, - actual: op.tag(), - }); - } - self.0.replace_op(node, op) +/// A trait for types that can be checked for a specific [`OpTag`] at their root node. +/// +/// This is used mainly specifying function inputs that may either be a [`HugrView`] or an already checked [`RootChecked`]. +pub trait RootCheckable>: Sized { + /// Wrap the Hugr in a [`RootChecked`] if it is valid for the required [`OpTag`]. + /// + /// If `Self` is already a [`RootChecked`], it is a no-op. + fn try_into_checked(self) -> Result, HugrError>; +} +impl> RootCheckable for H { + fn try_into_checked(self) -> Result, HugrError> { + RootChecked::try_new(self) } - - delegate::delegate! { - to (&mut self.0) { - fn set_root(&mut self, root: Self::Node); - fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); - fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; - fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; - fn set_parent(&mut self, node: Self::Node, parent: Self::Node); - fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); - fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); - fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; - fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; - fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; - } +} +impl> RootCheckable for RootChecked { + fn try_into_checked(self) -> Result, HugrError> { + Ok(self) } } -impl> HugrMut for RootChecked { - super::impls::hugr_mut_methods! {this, &mut this.0} +/// Check that the node in a HUGR can be represented by the required tag. +pub fn check_tag, N>( + hugr: &impl HugrView, + node: N, +) -> Result<(), HugrError> { + let actual = hugr.get_optype(node).tag(); + let required = Required::TAG; + if !required.is_superset(actual) { + return Err(HugrError::InvalidTag { required, actual }); + } + Ok(()) } #[cfg(test)] mod test { use super::RootChecked; - use crate::extension::prelude::MakeTuple; - use crate::extension::ExtensionSet; - use crate::hugr::internal::HugrMutInternals; - use crate::hugr::{HugrError, HugrMut}; - use crate::ops::handle::{BasicBlockID, CfgID, DataflowParentID, DfgID}; - use crate::ops::{DataflowBlock, OpTag, OpType}; - use crate::{ops, type_row, types::Signature, Hugr, HugrView}; + use crate::hugr::HugrError; + use crate::ops::handle::{CfgID, DfgID}; + use crate::ops::{OpTag, OpType}; + use crate::{ops, types::Signature, Hugr}; #[test] fn root_checked() { @@ -125,7 +112,7 @@ mod test { } .into(); let mut h = Hugr::new(root_type.clone()); - let cfg_v = RootChecked::<&Hugr, CfgID>::try_new(&h); + let cfg_v = RootChecked::<_, CfgID>::check(&h); assert_eq!( cfg_v.err(), Some(HugrError::InvalidTag { @@ -133,46 +120,9 @@ mod test { actual: OpTag::Dfg }) ); - let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); - // That is a HugrMutInternal, so we can try: - let root = dfg_v.root(); - let bb: OpType = DataflowBlock { - inputs: type_row![], - other_outputs: type_row![], - sum_rows: vec![type_row![]], - extension_delta: ExtensionSet::new(), - } - .into(); - let r = dfg_v.replace_op(root, bb.clone()); - assert_eq!( - r, - Err(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: ops::OpTag::DataflowBlock - }) - ); - // That didn't do anything: - assert_eq!(dfg_v.get_optype(root), &root_type); - - // Make a RootChecked that allows any DataflowParent - // We won't be able to do this by widening the bound: - assert_eq!( - RootChecked::<_, DataflowParentID>::try_new(dfg_v).err(), - Some(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: OpTag::DataflowParent - }) - ); - - let mut dfp_v = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut h).unwrap(); - let r = dfp_v.replace_op(root, bb.clone()); - assert_eq!(r, Ok(root_type)); - assert_eq!(dfp_v.get_optype(root), &bb); - // Just check we can create a nested instance (narrowing the bound) - let mut bb_v = RootChecked::<_, BasicBlockID>::try_new(dfp_v).unwrap(); - - // And it's a HugrMut: - let nodetype = MakeTuple(type_row![]); - bb_v.add_node_with_parent(bb_v.root(), nodetype); + // This should succeed + let dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); + assert!(OpTag::Dfg.is_superset(dfg_v.tag())); + assert_eq!(dfg_v.as_ref().tag(), dfg_v.tag()); } } diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 4d15a9c48..ac31d2695 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -11,7 +11,7 @@ use crate::ops::handle::NodeHandle; use crate::ops::OpTrait; use crate::{Direction, Hugr, Node, Port}; -use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; +use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView}; type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>; @@ -154,9 +154,6 @@ impl HugrView for SiblingGraph<'_, Root> { .map(|n| self.get_node(n)) } } -impl RootTagged for SiblingGraph<'_, Root> { - type RootHandle = Root; -} impl<'a, Root: NodeHandle> SiblingGraph<'a, Root> { fn new_unchecked(hugr: &'a impl HugrView, root: Node) -> Self { @@ -254,12 +251,6 @@ impl<'g, H: HugrMut, Root: NodeHandle> SiblingMut<'g, H, Root> { /// Create a new SiblingMut from a base. /// Equivalent to [HierarchyView::try_new] but takes a *mutable* reference. pub fn try_new(hugr: &'g mut H, root: H::Node) -> Result { - if root == hugr.root() && !H::RootHandle::TAG.is_superset(Root::TAG) { - return Err(HugrError::InvalidTag { - required: H::RootHandle::TAG, - actual: Root::TAG, - }); - } check_tag::(hugr, root)?; Ok(Self { hugr, @@ -375,22 +366,20 @@ impl> HugrView for SiblingMut<'_, H, Root> } } -impl> RootTagged for SiblingMut<'_, H, Root> { - type RootHandle = Root; -} - impl> HugrMutInternals for SiblingMut<'_, H, Root> { fn replace_op( &mut self, node: Self::Node, op: impl Into, - ) -> Result { + ) -> crate::ops::OpType { let op = op.into(); + // Note: `SiblingMut` will be removed in a subsequent PR, so we just panic here for now. if node == self.root() && !Root::TAG.is_superset(op.tag()) { - return Err(HugrError::InvalidTag { + let err = HugrError::InvalidTag { required: Root::TAG, actual: op.tag(), - }); + }; + panic!("{err}"); } self.hugr.replace_op(node, op) } @@ -424,9 +413,9 @@ mod test { use crate::builder::test::simple_dfg_hugr; use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; use crate::extension::prelude::{qb_t, usize_t}; - use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID}; + use crate::ops::handle::{CfgID, DfgID, FuncID}; + use crate::ops::OpType; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; - use crate::ops::{OpTrait, OpType}; use crate::types::Signature; use crate::utils::test_quantum_extension::EXTENSION_ID; use crate::IncomingPort; @@ -585,45 +574,13 @@ mod test { }) ); - let mut sib_mut = SiblingMut::<_, DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap(); let bad_nodetype: OpType = crate::ops::CFG { signature }.into(); - assert_eq!( - sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()), - Err(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: OpTag::Cfg - }) - ); - // In contrast, performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation - simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap(); + // Performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation + simple_dfg_hugr.replace_op(root, bad_nodetype); assert!(simple_dfg_hugr.validate().is_err()); } - #[rstest] - fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) { - let root = simple_dfg_hugr.root(); - let case_nodetype = crate::ops::Case { - signature: simple_dfg_hugr - .root_type() - .dataflow_signature() - .unwrap() - .into_owned(), - }; - let mut sib_mut = SiblingMut::<_, DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap(); - // As expected, we cannot replace the root with a Case - assert_eq!( - sib_mut.replace_op(root, case_nodetype), - Err(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: OpTag::Case - }) - ); - - let nested_sib_mut = SiblingMut::<_, DataflowParentID>::try_new(&mut sib_mut, root); - assert!(nested_sib_mut.is_err()); - } - #[rstest] fn extract_hugr() -> Result<(), Box> { let (hugr, _def, inner) = make_module_hgr()?; diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index c681fafc9..9502d9f6b 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -22,13 +22,15 @@ use thiserror::Error; use crate::builder::{Container, FunctionBuilder}; use crate::core::HugrNode; use crate::extension::ExtensionSet; -use crate::hugr::{HugrMut, HugrView, RootTagged}; +use crate::hugr::{HugrMut, HugrView}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{ContainerHandle, DataflowOpID}; use crate::ops::{NamedOp, OpTag, OpTrait, OpType}; use crate::types::{Signature, Type}; use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement}; +use super::root_checked::RootCheckable; + /// A non-empty convex subgraph of a HUGR sibling graph. /// /// A HUGR region in which all nodes share the same parent. Unlike @@ -95,11 +97,18 @@ impl SiblingSubgraph { /// /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the /// subgraph is empty. - pub fn try_new_dataflow_subgraph(dfg_graph: &H) -> Result> + pub fn try_new_dataflow_subgraph<'h, H, Root>( + dfg_graph: impl RootCheckable<&'h H, Root>, + ) -> Result> where - H: Clone + RootTagged, - Root: ContainerHandle, + H: 'h + Clone + HugrView, + Root: ContainerHandle, { + let Ok(dfg_graph) = dfg_graph.try_into_checked() else { + return Err(InvalidSubgraph::NonDataflowRegion); + }; + let dfg_graph = dfg_graph.into_hugr(); + let parent = dfg_graph.root(); let nodes = dfg_graph.children(parent).skip(2).collect_vec(); let (inputs, outputs) = get_input_output_ports(dfg_graph); @@ -798,6 +807,9 @@ pub enum InvalidSubgraph { /// An invalid boundary port was found. #[error("Invalid boundary port.")] InvalidBoundary(#[from] InvalidSubgraphBoundary), + /// The hugr region is not a dataflow graph. + #[error("SiblingSubgraphs may only be defined on dataflow regions.")] + NonDataflowRegion, } /// Errors that can occur while constructing a [`SiblingSubgraph`]. @@ -985,7 +997,7 @@ mod tests { fn construct_simple_replacement() -> Result<(), InvalidSubgraph> { let (mut hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; + let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; let empty_dfg = { let builder = @@ -1009,7 +1021,7 @@ mod tests { fn test_signature() -> Result<(), InvalidSubgraph> { let (hugr, dfg) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, dfg).unwrap(); - let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; + let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; assert_eq!( sub.signature(&func), Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).with_extension_delta( @@ -1046,7 +1058,7 @@ mod tests { let (hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); assert_eq!( - SiblingSubgraph::try_new_dataflow_subgraph(&func) + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func) .unwrap() .nodes() .len(), @@ -1162,7 +1174,8 @@ mod tests { let (hugr, func_root) = build_hugr_classical().unwrap(); let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); + let func = + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func_graph).unwrap(); let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap(); assert_eq!(func_defn.signature, func.signature(&func_graph).into()); } @@ -1172,7 +1185,8 @@ mod tests { let (hugr, func_root) = build_hugr().unwrap(); let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); + let subgraph = + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func_graph).unwrap(); let extracted = subgraph.extract_subgraph(&hugr, "region"); extracted.validate().unwrap(); @@ -1197,7 +1211,7 @@ mod tests { let outw = [outw1].into_iter().chain(outw2); let h = builder.finish_hugr_with_outputs(outw).unwrap(); let view = SiblingGraph::::try_new(&h, h.root()).unwrap(); - let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap(); + let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>(&view).unwrap(); assert_eq!(subg.nodes().len(), 2); } diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index 1b96f1ebd..5e1fecdb6 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -353,8 +353,7 @@ fn to_module_hugr(mut hugr: Hugr) -> Result { name: "main".to_string(), signature: signature.into_owned().into(), }, - ) - .expect("Hugr accepts any root node"); + ); // Wrap it in a module. let new_root = hugr.add_node(Module::new().into()); diff --git a/hugr-llvm/src/utils/inline_constant_functions.rs b/hugr-llvm/src/utils/inline_constant_functions.rs index a55072a99..1b0931bd2 100644 --- a/hugr-llvm/src/utils/inline_constant_functions.rs +++ b/hugr-llvm/src/utils/inline_constant_functions.rs @@ -69,7 +69,7 @@ fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Resul hugr.insert_hugr(func_node, func_hugr); for lcn in load_constant_ns { - hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?)?; + hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?); } any_changes = true; } diff --git a/hugr-passes/src/half_node.rs b/hugr-passes/src/half_node.rs index ca0d9880e..7f332209f 100644 --- a/hugr-passes/src/half_node.rs +++ b/hugr-passes/src/half_node.rs @@ -3,12 +3,10 @@ use std::hash::Hash; use super::nest_cfgs::CfgNodeMap; use hugr_core::hugr::internal::HugrInternals; -use hugr_core::hugr::RootTagged; - +use hugr_core::hugr::views::RootCheckable; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{OpTag, OpTrait}; - -use hugr_core::{Direction, Node}; +use hugr_core::{Direction, HugrView, Node}; /// We provide a view of a cfg where every node has at most one of /// (multiple predecessors, multiple successors). @@ -32,9 +30,12 @@ struct HalfNodeView { exit: H::Node, } -impl> HalfNodeView { +impl HalfNodeView { #[allow(unused)] - pub(crate) fn new(h: H) -> Self { + pub(crate) fn new(h: impl RootCheckable>) -> Self { + let checked = h.try_into_checked().expect("Hugr must be a CFG region"); + let h = checked.into_hugr(); + let (entry, exit) = { let mut children = h.children(h.root()); (children.next().unwrap(), children.next().unwrap()) @@ -64,7 +65,7 @@ impl> HalfNodeView { } } -impl> CfgNodeMap> for HalfNodeView { +impl CfgNodeMap> for HalfNodeView { fn entry_node(&self) -> HalfNode { HalfNode::N(self.entry) } @@ -98,7 +99,6 @@ mod test { use super::super::nest_cfgs::{test::*, EdgeClassifier}; use super::{HalfNode, HalfNodeView}; use hugr_core::builder::BuildError; - use hugr_core::hugr::views::RootChecked; use hugr_core::ops::handle::NodeHandle; use itertools::Itertools; @@ -118,7 +118,7 @@ mod test { // \---<---<---<---<---<---<---<---<---<---/ // Allowing to identify two nested regions (and fixing the problem with an IdentityCfgMap on the same example) - let v = HalfNodeView::new(RootChecked::try_new(&h).unwrap()); + let v = HalfNodeView::new(&h); let edge_classes = EdgeClassifier::get_edge_classes(&v); let HalfNodeView { h: _, entry, exit } = v; diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 8de6c00a2..3a3bd5e91 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -1,5 +1,5 @@ use hugr_core::{ - hugr::{hugrmut::HugrMut, views::SiblingSubgraph, HugrError}, + hugr::{hugrmut::HugrMut, views::SiblingSubgraph}, ops::OpType, Hugr, Node, }; @@ -10,14 +10,10 @@ use thiserror::Error; /// New operations must match the signature of the old operations. /// /// Returns a list of the replaced nodes and their old operations. -/// -/// # Errors -/// -/// Returns a [`HugrError`] if any replacement fails. pub fn replace_many_ops>( hugr: &mut impl HugrMut, mapping: impl Fn(&OpType) -> Option, -) -> Result, HugrError> { +) -> Vec<(Node, OpType)> { let replacements = hugr .nodes() .filter_map(|node| { @@ -28,7 +24,10 @@ pub fn replace_many_ops>( replacements .into_iter() - .map(|(node, new_op)| hugr.replace_op(node, new_op).map(|old_op| (node, old_op))) + .map(|(node, new_op)| { + let old_op = hugr.replace_op(node, new_op); + (node, old_op) + }) .collect() } @@ -117,8 +116,7 @@ mod test { } else { None } - }) - .unwrap(); + }); assert_eq!(replaced.len(), 1); let (n, op) = replaced.remove(0); diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index d1731107d..a5de5eb57 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -4,11 +4,11 @@ use std::collections::HashMap; use hugr_core::extension::prelude::UnpackTuple; use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::hugr::views::RootCheckable; use itertools::Itertools; use hugr_core::hugr::rewrite::inline_dfg::InlineDFG; use hugr_core::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; -use hugr_core::hugr::RootTagged; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; use hugr_core::{Hugr, HugrView, Node}; @@ -16,7 +16,13 @@ use hugr_core::{Hugr, HugrView, Node}; /// Merge any basic blocks that are direct children of the specified CFG /// i.e. where a basic block B has a single successor B' whose only predecessor /// is B, B and B' can be combined. -pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { +pub fn merge_basic_blocks<'h, H>(cfg: impl RootCheckable<&'h mut H, CfgID>) +where + H: 'h + HugrMut, +{ + let checked = cfg.try_into_checked().expect("Hugr must be a CFG region"); + let cfg = checked.into_hugr(); + let mut worklist = cfg.children(cfg.root()).collect::>(); while let Some(n) = worklist.pop() { // Consider merging n with its successor @@ -44,7 +50,7 @@ pub fn merge_basic_blocks(cfg: &mut impl HugrMut, + cfg: &impl HugrView, pred: Node, succ: Node, ) -> (Replacement, Node, [Node; 2]) { @@ -158,14 +164,12 @@ mod test { use std::sync::Arc; use hugr_core::extension::prelude::PRELUDE_ID; - use hugr_core::hugr::views::RootChecked; use itertools::Itertools; use rstest::rstest; use hugr_core::builder::{endo_sig, inout_sig, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; use hugr_core::extension::prelude::{qb_t, usize_t, ConstUsize}; use hugr_core::ops::constant::Value; - use hugr_core::ops::handle::CfgID; use hugr_core::ops::{LoadConstant, OpTrait, OpType}; use hugr_core::types::{Signature, Type, TypeRow}; use hugr_core::{const_extension_ids, type_row, Extension, Hugr, HugrView, Wire}; @@ -252,7 +256,7 @@ mod test { let mut h = h.finish_hugr()?; let r = h.root(); - merge_basic_blocks(&mut RootChecked::<_, CfgID>::try_new(&mut h).unwrap()); + merge_basic_blocks(&mut h); h.validate().unwrap(); assert_eq!(r, h.root()); assert!(matches!(h.get_optype(r), OpType::CFG(_))); @@ -346,7 +350,7 @@ mod test { h.branch(&bb3, 0, &h.exit_block())?; let mut h = h.finish_hugr()?; - merge_basic_blocks(&mut RootChecked::<_, CfgID>::try_new(&mut h).unwrap()); + merge_basic_blocks(&mut h); h.validate()?; // Should only be one BB left diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 3164702d8..3ac85a020 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -157,7 +157,7 @@ fn mono_scan( h.disconnect(ch, fn_inp); // No-op if copying+substituting h.connect(new_tgt, fn_out, ch, fn_inp); - h.replace_op(ch, new_op).unwrap(); + h.replace_op(ch, new_op); } } @@ -178,7 +178,7 @@ fn instantiate( name: mangle_inner_func(&outer_name, &fd.name), signature: fd.signature.clone(), }; - h.replace_op(n, fd).unwrap(); + h.replace_op(n, fd); h.move_after_sibling(n, poly_func); } else { to_scan.extend(h.children(n)) diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 1c4928e12..b98d4fb23 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -46,8 +46,8 @@ use thiserror::Error; use hugr_core::hugr::rewrite::outline_cfg::OutlineCfg; use hugr_core::hugr::views::sibling::SiblingMut; -use hugr_core::hugr::views::{HierarchyView, HugrView, SiblingGraph}; -use hugr_core::hugr::{hugrmut::HugrMut, Rewrite, RootTagged}; +use hugr_core::hugr::views::{HierarchyView, HugrView, RootCheckable, SiblingGraph}; +use hugr_core::hugr::{hugrmut::HugrMut, Rewrite}; use hugr_core::ops::handle::{BasicBlockID, CfgID}; use hugr_core::ops::OpTag; use hugr_core::ops::OpTrait; @@ -219,9 +219,12 @@ pub struct IdentityCfgMap { entry: H::Node, exit: H::Node, } -impl> IdentityCfgMap { +impl IdentityCfgMap { /// Creates an [IdentityCfgMap] for the specified CFG - pub fn new(h: H) -> Self { + pub fn new(h: impl RootCheckable>) -> Self { + let h = h.try_into_checked().expect("Hugr must be a CFG region"); + let h = h.into_hugr(); + // Panic if malformed enough not to have two children let (entry, exit) = h.children(h.root()).take(2).collect_tuple().unwrap(); debug_assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit); @@ -636,7 +639,7 @@ pub(crate) mod test { let rc = RootChecked::<_, CfgID>::try_new(&mut h).unwrap(); let (entry, exit) = (entry.node(), exit.node()); let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node()); - let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.borrow())); + let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.as_ref())); let [&left, &right] = edge_classes .keys() .filter(|(s, _)| *s == split) @@ -734,7 +737,7 @@ pub(crate) mod test { // There's no need to use a view of a region here but we do so just to check // that we *can* (as we'll need to for "real" module Hugr's) - let v = IdentityCfgMap::new(SiblingGraph::try_new(&h, h.root()).unwrap()); + let v = IdentityCfgMap::new(SiblingGraph::::try_new(&h, h.root()).unwrap()); let edge_classes = EdgeClassifier::get_edge_classes(&v); let IdentityCfgMap { h: _, entry, exit } = v; let [&left, &right] = edge_classes diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index 88c8c8df0..a66de8315 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -3,6 +3,6 @@ // Exports everything except the `internal` module. pub use hugr_core::hugr::{ hugrmut, rewrite, serialize, validate, views, Hugr, HugrError, HugrView, IdentList, - InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Rewrite, RootTagged, + InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Rewrite, SimpleReplacement, SimpleReplacementError, ValidationError, DEFAULT_OPTYPE, }; diff --git a/uv.lock b/uv.lock index 130657231..4f7d6012a 100644 --- a/uv.lock +++ b/uv.lock @@ -277,7 +277,7 @@ wheels = [ [[package]] name = "hugr" -version = "0.11.4" +version = "0.11.5" source = { editable = "hugr-py" } dependencies = [ { name = "graphviz" }, From db4b39f4d9ccb06dfa48c1dab0999141bb3eba98 Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Tue, 29 Apr 2025 15:15:13 +0200 Subject: [PATCH 13/18] feat!: Split Rewrite trait into VerifyPatch and ApplyPatch (#2070) This PR splits the `Rewrite` trait into two (three) traits: - a `VerifyPatch` trait that has the `fn verify` and `fn invalidation_set` functions - a `ApplyPatch` trait that has the `fn apply` function. This inherits `VerifyPatch` and is the "rewriting" trait that should be used in most scenarios. In addition, there is a third trait `ApplyPatchHugrMut` that can be implemented by any patches that can be applied to _any_ `HugrMut` (as opposed to a specific type `H`). This is strictly stronger than `ApplyPatch` and should be implemented instead of `ApplyPatch` where possible (see the docs of the traits). closes #588 closes #2052 BREAKING CHANGE: Replaced the `Rewrite` trait with `Patch`. `Rewrite::ApplyResult` is now `Patch::Outcome`. `Rewrite::verify` was split into a separate trait, and is now `PatchVerification::verify`. BREAKING CHANGE: Renamed `hugr.rewrite` module to `hugr.patch`. BREAKING CHANGE: Changed the type `OutlineCfg::ApplyResult` (now `OutlineCfg::Outcome`) from `(Node, Node)` to `[Node; 2]`. --------- Co-authored-by: Alan Lawrence Co-authored-by: Alan Lawrence --- hugr-core/src/hugr.rs | 4 +- hugr-core/src/hugr/hugrmut.rs | 9 +- hugr-core/src/hugr/patch.rs | 169 ++++++++ .../src/hugr/{rewrite => patch}/consts.rs | 86 ++-- .../hugr/{rewrite => patch}/inline_call.rs | 49 +-- .../src/hugr/{rewrite => patch}/inline_dfg.rs | 38 +- .../{rewrite => patch}/insert_identity.rs | 45 ++- .../hugr/{rewrite => patch}/outline_cfg.rs | 48 ++- .../src/hugr/{rewrite => patch}/port_types.rs | 0 .../src/hugr/{rewrite => patch}/replace.rs | 375 +++++++++++------- .../hugr/{rewrite => patch}/simple_replace.rs | 141 ++++--- hugr-core/src/hugr/views/sibling_subgraph.rs | 4 +- hugr-passes/src/lower.rs | 9 +- hugr-passes/src/merge_bbs.rs | 8 +- hugr-passes/src/nest_cfgs.rs | 10 +- hugr-passes/src/untuple.rs | 2 +- hugr/src/hugr.rs | 4 +- 17 files changed, 650 insertions(+), 351 deletions(-) create mode 100644 hugr-core/src/hugr/patch.rs rename hugr-core/src/hugr/{rewrite => patch}/consts.rs (74%) rename hugr-core/src/hugr/{rewrite => patch}/inline_call.rs (91%) rename hugr-core/src/hugr/{rewrite => patch}/inline_dfg.rs (96%) rename hugr-core/src/hugr/{rewrite => patch}/insert_identity.rs (84%) rename hugr-core/src/hugr/{rewrite => patch}/outline_cfg.rs (96%) rename hugr-core/src/hugr/{rewrite => patch}/port_types.rs (100%) rename hugr-core/src/hugr/{rewrite => patch}/replace.rs (73%) rename hugr-core/src/hugr/{rewrite => patch}/simple_replace.rs (92%) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index d708f15cd..7a74b4070 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -4,7 +4,7 @@ pub mod hugrmut; pub(crate) mod ident; pub mod internal; -pub mod rewrite; +pub mod patch; pub mod serialize; pub mod validate; pub mod views; @@ -17,7 +17,7 @@ pub(crate) use self::hugrmut::HugrMut; pub use self::validate::ValidationError; pub use ident::{IdentList, InvalidIdentifier}; -pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError}; +pub use patch::{Patch, SimpleReplacement, SimpleReplacementError}; use portgraph::multiportgraph::MultiPortGraph; use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap}; diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 51e92f342..c58ccbdbc 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -12,7 +12,7 @@ use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType}; -use crate::hugr::{NodeMetadata, Rewrite}; +use crate::hugr::{NodeMetadata, Patch}; use crate::ops::OpTrait; use crate::types::Substitution; use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; @@ -195,11 +195,8 @@ pub trait HugrMut: HugrMutInternals { subgraph: &SiblingSubgraph, ) -> HashMap; - /// Applies a rewrite to the graph. - fn apply_rewrite( - &mut self, - rw: impl Rewrite, - ) -> Result + /// Applies a patch to the graph. + fn apply_patch(&mut self, rw: impl Patch) -> Result where Self: Sized, { diff --git a/hugr-core/src/hugr/patch.rs b/hugr-core/src/hugr/patch.rs new file mode 100644 index 000000000..bc6195eba --- /dev/null +++ b/hugr-core/src/hugr/patch.rs @@ -0,0 +1,169 @@ +//! Rewrite operations on the HUGR - replacement, outlining, etc. + +pub mod consts; +pub mod inline_call; +pub mod inline_dfg; +pub mod insert_identity; +pub mod outline_cfg; +mod port_types; +pub mod replace; +pub mod simple_replace; + +use crate::{Hugr, HugrView}; +pub use port_types::{BoundaryPort, HostPort, ReplacementPort}; +pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; + +use super::HugrMut; + +/// Verify that a patch application would succeed. +pub trait PatchVerification { + /// The type of Error with which this Rewrite may fail + type Error: std::error::Error; + + /// The node type of the HugrView that this patch applies to. + type Node; + + /// Checks whether the rewrite would succeed on the specified Hugr. + /// If this call succeeds, [Patch::apply] should also succeed on the same + /// `h` If this calls fails, [Patch::apply] would fail with the same + /// error. + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; + + /// Returns a set of nodes referenced by the rewrite. Modifying any of these + /// nodes will invalidate it. + /// + /// Two `impl Rewrite`s can be composed if their invalidation sets are + /// disjoint. + fn invalidation_set(&self) -> impl Iterator; +} + +/// A patch that can be applied to a mutable Hugr of type `H`. +/// +/// ### When to use +/// +/// Use this trait whenever possible in bounds for the most generality. Note +/// that this will require specifying which type `H` the patch applies to. +/// +/// ### When to implement +/// +/// For patches that work on any `H: HugrMut`, prefer implementing [`PatchHugrMut`] instead. This +/// will automatically implement this trait. +pub trait Patch: PatchVerification { + /// The type returned on successful application of the rewrite. + type Outcome; + + /// If `true`, [Patch::apply]'s of this rewrite guarantee that they do not + /// mutate the Hugr when they return an Err. If `false`, there is no + /// guarantee; the Hugr should be assumed invalid when Err is returned. + const UNCHANGED_ON_FAILURE: bool; + + /// Mutate the specified Hugr, or fail with an error. + /// + /// Returns [`Self::Outcome`] if successful. + /// If [Patch::UNCHANGED_ON_FAILURE] is true, then `h` must be unchanged if + /// Err is returned. See also [PatchVerification::verify] + /// + /// # Panics + /// + /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that + /// is, implementations may begin with `assert!(h.validate())`, with + /// `debug_assert!(h.validate())` being preferred. + fn apply(self, h: &mut H) -> Result; +} + +/// A patch that can be applied to any [`HugrMut`]. +/// +/// This trait is a generalisation of [`Patch`] in that it guarantees that +/// the patch can be applied to any type implementing [`HugrMut`]. +/// +/// ### When to use +/// +/// Prefer using the more general [`Patch`] trait in bounds where the +/// type `H` is known. Resort to this trait if patches must be applicable to +/// any [`HugrMut`] instance. +/// +/// ### When to implement +/// +/// Always implement this trait when possible, to define how a patch is applied +/// to any type implementing [`HugrMut`]. A blanket implementation ensures that +/// any type implementing this trait also implements [`Patch`]. +pub trait PatchHugrMut: PatchVerification { + /// The type returned on successful application of the rewrite. + type Outcome; + + /// If `true`, [self.apply]'s of this rewrite guarantee that they do not + /// mutate the Hugr when they return an Err. If `false`, there is no + /// guarantee; the Hugr should be assumed invalid when Err is returned. + const UNCHANGED_ON_FAILURE: bool; + + /// Mutate the specified Hugr, or fail with an error. + /// + /// Returns [`Self::Outcome`] if successful. + /// If [self.unchanged_on_failure] is true, then `h` must be unchanged if + /// Err is returned. See also [self.verify] + /// # Panics + /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that + /// is, implementations may begin with `assert!(h.validate())`, with + /// `debug_assert!(h.validate())` being preferred. + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result; +} + +impl> Patch for R { + type Outcome = R::Outcome; + const UNCHANGED_ON_FAILURE: bool = R::UNCHANGED_ON_FAILURE; + + fn apply(self, h: &mut H) -> Result { + self.apply_hugr_mut(h) + } +} + +/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure) +pub struct Transactional { + underlying: R, +} + +impl PatchVerification for Transactional { + type Error = R::Error; + type Node = R::Node; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + self.underlying.verify(h) + } + + #[inline] + fn invalidation_set(&self) -> impl Iterator { + self.underlying.invalidation_set() + } +} + +// Note we might like to constrain R to Rewrite but +// this is not yet supported, https://github.com/rust-lang/rust/issues/92827 +impl PatchHugrMut for Transactional { + type Outcome = R::Outcome; + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result { + if R::UNCHANGED_ON_FAILURE { + return self.underlying.apply_hugr_mut(h); + } + // Try to backup just the contents of this HugrMut. + let mut backup = Hugr::new(h.root_type().clone()); + backup.insert_from_view(backup.root(), h); + let r = self.underlying.apply_hugr_mut(h); + if r.is_err() { + // Try to restore backup. + h.replace_op(h.root(), backup.root_type().clone()); + while let Some(child) = h.first_child(h.root()) { + h.remove_node(child); + } + h.insert_hugr(h.root(), backup); + } + r + } +} diff --git a/hugr-core/src/hugr/rewrite/consts.rs b/hugr-core/src/hugr/patch/consts.rs similarity index 74% rename from hugr-core/src/hugr/rewrite/consts.rs rename to hugr-core/src/hugr/patch/consts.rs index ac657bf91..6d0c011fe 100644 --- a/hugr-core/src/hugr/rewrite/consts.rs +++ b/hugr-core/src/hugr/patch/consts.rs @@ -2,11 +2,11 @@ use std::iter; -use crate::{hugr::HugrMut, HugrView, Node}; +use crate::{core::HugrNode, hugr::HugrMut, HugrView, Node}; use itertools::Itertools; use thiserror::Error; -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; /// Remove a [`crate::ops::LoadConstant`] node with no consumers. #[derive(Debug, Clone)] @@ -15,25 +15,20 @@ pub struct RemoveLoadConstant(pub N); /// Error from an [`RemoveConst`] or [`RemoveLoadConstant`] operation. #[derive(Debug, Clone, Error, PartialEq, Eq)] #[non_exhaustive] -pub enum RemoveError { +pub enum RemoveError { /// Invalid node. #[error("Node is invalid (either not in HUGR or not correct operation).")] - InvalidNode(Node), + InvalidNode(N), /// Node in use. #[error("Node: {0} has non-zero outgoing connections.")] - ValueUsed(Node), + ValueUsed(N), } -impl Rewrite for RemoveLoadConstant { - type Node = Node; - type Error = RemoveError; +impl PatchVerification for RemoveLoadConstant { + type Error = RemoveError; + type Node = N; - // The Const node the LoadConstant was connected to. - type ApplyResult = Node; - - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { @@ -51,7 +46,18 @@ impl Rewrite for RemoveLoadConstant { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + iter::once(self.0) + } +} + +impl PatchHugrMut for RemoveLoadConstant { + /// The [`Const`](crate::ops::Const) node the [`LoadConstant`](crate::ops::LoadConstant) was + /// connected to. + type Outcome = N; + + const UNCHANGED_ON_FAILURE: bool = true; + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let source = h @@ -63,26 +69,17 @@ impl Rewrite for RemoveLoadConstant { Ok(source) } - - fn invalidation_set(&self) -> impl Iterator { - iter::once(self.0) - } } /// Remove a [`crate::ops::Const`] node with no outputs. #[derive(Debug, Clone)] -pub struct RemoveConst(pub Node); - -impl Rewrite for RemoveConst { - type Node = Node; - type Error = RemoveError; +pub struct RemoveConst(pub N); - // The parent of the Const node. - type ApplyResult = Node; +impl PatchVerification for RemoveConst { + type Node = N; + type Error = RemoveError; - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { @@ -96,7 +93,18 @@ impl Rewrite for RemoveConst { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + iter::once(self.0) + } +} + +impl PatchHugrMut for RemoveConst { + // The parent of the Const node. + type Outcome = N; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let parent = h @@ -106,10 +114,6 @@ impl Rewrite for RemoveConst { Ok(parent) } - - fn invalidation_set(&self) -> impl Iterator { - iter::once(self.0) - } } #[cfg(test)] @@ -144,12 +148,12 @@ mod test { let tup_node = tup.node(); // can't remove invalid node assert_eq!( - h.apply_rewrite(RemoveConst(tup_node)), + h.apply_patch(RemoveConst(tup_node)), Err(RemoveError::InvalidNode(tup_node)) ); assert_eq!( - h.apply_rewrite(RemoveLoadConstant(tup_node)), + h.apply_patch(RemoveLoadConstant(tup_node)), Err(RemoveError::InvalidNode(tup_node)) ); let load_1_node = load_1.node(); @@ -172,7 +176,7 @@ mod test { // can't remove nodes in use assert_eq!( - h.apply_rewrite(remove_1.clone()), + h.apply_patch(remove_1.clone()), Err(RemoveError::ValueUsed(load_1_node)) ); @@ -180,20 +184,20 @@ mod test { h.remove_node(tup_node); // remove first load - let reported_con_node = h.apply_rewrite(remove_1)?; + let reported_con_node = h.apply_patch(remove_1)?; assert_eq!(reported_con_node, con_node); // still can't remove const, in use by second load assert_eq!( - h.apply_rewrite(remove_con.clone()), + h.apply_patch(remove_con.clone()), Err(RemoveError::ValueUsed(con_node)) ); // remove second use - let reported_con_node = h.apply_rewrite(remove_2)?; + let reported_con_node = h.apply_patch(remove_2)?; assert_eq!(reported_con_node, con_node); // remove const - assert_eq!(h.apply_rewrite(remove_con)?, h.root()); + assert_eq!(h.apply_patch(remove_con)?, h.root()); assert_eq!(h.node_count(), 4); assert!(h.validate().is_ok()); diff --git a/hugr-core/src/hugr/rewrite/inline_call.rs b/hugr-core/src/hugr/patch/inline_call.rs similarity index 91% rename from hugr-core/src/hugr/rewrite/inline_call.rs rename to hugr-core/src/hugr/patch/inline_call.rs index e32373507..0619d373e 100644 --- a/hugr-core/src/hugr/rewrite/inline_call.rs +++ b/hugr-core/src/hugr/patch/inline_call.rs @@ -2,41 +2,41 @@ //! into a DFG which replaces the Call node. use derive_more::{Display, Error}; +use crate::core::HugrNode; use crate::ops::{DataflowParent, OpType, DFG}; use crate::types::Substitution; use crate::{Direction, HugrView, Node}; -use super::{HugrMut, Rewrite}; +use super::{HugrMut, PatchHugrMut, PatchVerification}; /// Rewrite to inline a [Call](OpType::Call) to a known [FuncDefn](OpType::FuncDefn) -pub struct InlineCall(Node); +pub struct InlineCall(N); /// Error in performing [InlineCall] rewrite. #[derive(Clone, Debug, Display, Error, PartialEq)] #[non_exhaustive] -pub enum InlineCallError { +pub enum InlineCallError { /// The specified Node was not a [Call](OpType::Call) #[display("Node to inline {_0} expected to be a Call but actually {_1}")] - NotCallNode(Node, OpType), + NotCallNode(N, OpType), /// The node was a Call, but the target was not a [FuncDefn](OpType::FuncDefn) /// - presumably a [FuncDecl](OpType::FuncDecl), if the Hugr is valid. #[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")] - CallTargetNotFuncDefn(Node, OpType), + CallTargetNotFuncDefn(N, OpType), } -impl InlineCall { +impl InlineCall { /// Create a new instance that will inline the specified node /// (i.e. that should be a [Call](OpType::Call)) - pub fn new(node: Node) -> Self { + pub fn new(node: N) -> Self { Self(node) } } -impl Rewrite for InlineCall { - type Node = Node; - type ApplyResult = (); - type Error = InlineCallError; - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { +impl PatchVerification for InlineCall { + type Error = InlineCallError; + type Node = N; + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { let call_ty = h.get_optype(self.0); if !call_ty.is_call() { return Err(InlineCallError::NotCallNode(self.0, call_ty.clone())); @@ -52,7 +52,14 @@ impl Rewrite for InlineCall { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + fn invalidation_set(&self) -> impl Iterator { + Some(self.0).into_iter() + } +} + +impl PatchHugrMut for InlineCall { + type Outcome = (); + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { self.verify(h)?; // Now we know we have a Call to a FuncDefn. let orig_func = h.static_source(self.0).unwrap(); @@ -99,10 +106,6 @@ impl Rewrite for InlineCall { /// Failure only occurs if the node is not a Call, or the target not a FuncDefn. /// (Any later failure means an invalid Hugr and `panic`.) const UNCHANGED_ON_FAILURE: bool = true; - - fn invalidation_set(&self) -> impl Iterator { - Some(self.0).into_iter() - } } #[cfg(test)] @@ -177,7 +180,7 @@ mod test { .count(), 1 ); - hugr.apply_rewrite(InlineCall(call1.node())).unwrap(); + hugr.apply_patch(InlineCall(call1.node())).unwrap(); hugr.validate().unwrap(); assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]); assert_eq!(calls(&hugr), [call2]); @@ -190,7 +193,7 @@ mod test { .count(), 1 ); - hugr.apply_rewrite(InlineCall(call2.node())).unwrap(); + hugr.apply_patch(InlineCall(call2.node())).unwrap(); hugr.validate().unwrap(); assert_eq!(hugr.output_neighbours(func.node()).next(), None); assert_eq!(calls(&hugr), []); @@ -225,7 +228,7 @@ mod test { let func = func.node(); let mut call = call.node(); for i in 2..10 { - hugr.apply_rewrite(InlineCall(call))?; + hugr.apply_patch(InlineCall(call))?; hugr.validate().unwrap(); assert_eq!(extension_ops(&hugr).len(), i); let v = calls(&hugr); @@ -264,7 +267,7 @@ mod test { let h = modb.finish_hugr().unwrap(); let mut h2 = h.clone(); assert_eq!( - h2.apply_rewrite(InlineCall(call.node())), + h2.apply_patch(InlineCall(call.node())), Err(InlineCallError::CallTargetNotFuncDefn( decl.node(), h.get_optype(decl.node()).clone() @@ -277,7 +280,7 @@ mod test { .try_into() .unwrap(); assert_eq!( - h2.apply_rewrite(InlineCall(inp)), + h2.apply_patch(InlineCall(inp)), Err(InlineCallError::NotCallNode( inp, Input { @@ -314,7 +317,7 @@ mod test { hugr.output_neighbours(inner.node()).collect::>(), [call1.node(), call2.node()] ); - hugr.apply_rewrite(InlineCall::new(call1.node()))?; + hugr.apply_patch(InlineCall::new(call1.node()))?; assert_eq!( hugr.output_neighbours(inner.node()).collect::>(), diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/patch/inline_dfg.rs similarity index 96% rename from hugr-core/src/hugr/rewrite/inline_dfg.rs rename to hugr-core/src/hugr/patch/inline_dfg.rs index 8988df170..58fd51cbb 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/patch/inline_dfg.rs @@ -2,7 +2,7 @@ //! of the DFG except Input+Output into the DFG's parent, //! and deleting the DFG along with its Input + Output -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; use crate::ops::handle::{DfgID, NodeHandle}; use crate::{IncomingPort, Node, OutgoingPort, PortIndex}; @@ -21,13 +21,10 @@ pub enum InlineDFGError { NoParent, } -impl Rewrite for InlineDFG { - /// Returns the removed nodes: the DFG, and its Input and Output children. - type Node = Node; - type ApplyResult = [Node; 3]; +impl PatchVerification for InlineDFG { type Error = InlineDFGError; - const UNCHANGED_ON_FAILURE: bool = true; + type Node = Node; fn verify(&self, h: &impl crate::HugrView) -> Result<(), Self::Error> { let n = self.0.node(); @@ -40,10 +37,21 @@ impl Rewrite for InlineDFG { Ok(()) } - fn apply( + fn invalidation_set(&self) -> impl Iterator { + [self.0.node()].into_iter() + } +} + +impl PatchHugrMut for InlineDFG { + /// The removed nodes: the DFG, and its Input and Output children. + type Outcome = [Node; 3]; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut( self, h: &mut impl crate::hugr::HugrMut, - ) -> Result { + ) -> Result { self.verify(h)?; let n = self.0.node(); let (oth_in, oth_out) = { @@ -124,10 +132,6 @@ impl Rewrite for InlineDFG { h.remove_node(n); Ok([n, input, output]) } - - fn invalidation_set(&self) -> impl Iterator { - [self.0.node()].into_iter() - } } #[cfg(test)] @@ -142,7 +146,7 @@ mod test { }; use crate::extension::prelude::qb_t; use crate::extension::ExtensionSet; - use crate::hugr::rewrite::inline_dfg::InlineDFGError; + use crate::hugr::patch::inline_dfg::InlineDFGError; use crate::hugr::HugrMut; use crate::ops::handle::{DfgID, NodeHandle}; use crate::ops::{OpType, Value}; @@ -212,13 +216,13 @@ mod test { // Check we can't inline the outer DFG let mut h = outer.clone(); assert_eq!( - h.apply_rewrite(InlineDFG(DfgID::from(h.root()))), + h.apply_patch(InlineDFG(DfgID::from(h.root()))), Err(InlineDFGError::NoParent) ); assert_eq!(h, outer); // unchanged } - outer.apply_rewrite(InlineDFG(*inner.handle()))?; + outer.apply_patch(InlineDFG(*inner.handle()))?; outer.validate()?; assert_eq!(outer.nodes().count(), 7); assert_eq!(find_dfgs(&outer), vec![outer.root()]); @@ -274,7 +278,7 @@ mod test { ] ); - h.apply_rewrite(InlineDFG(*swap.handle()))?; + h.apply_patch(InlineDFG(*swap.handle()))?; assert_eq!(find_dfgs(&h), vec![h.root()]); assert_eq!(h.nodes().count(), 5); // Dfg+I+O let mut ops = extension_ops(&h); @@ -350,7 +354,7 @@ mod test { )?; let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?; - outer.apply_rewrite(InlineDFG(*inner.handle()))?; + outer.apply_patch(InlineDFG(*inner.handle()))?; outer.validate()?; let order_neighbours = |n, d| { let p = outer.get_optype(n).other_port(d).unwrap(); diff --git a/hugr-core/src/hugr/rewrite/insert_identity.rs b/hugr-core/src/hugr/patch/insert_identity.rs similarity index 84% rename from hugr-core/src/hugr/rewrite/insert_identity.rs rename to hugr-core/src/hugr/patch/insert_identity.rs index bde43413b..98ab0ff02 100644 --- a/hugr-core/src/hugr/rewrite/insert_identity.rs +++ b/hugr-core/src/hugr/patch/insert_identity.rs @@ -2,6 +2,7 @@ use std::iter; +use crate::core::HugrNode; use crate::extension::prelude::Noop; use crate::hugr::{HugrMut, Node}; use crate::ops::{OpTag, OpTrait}; @@ -9,22 +10,22 @@ use crate::ops::{OpTag, OpTrait}; use crate::types::EdgeKind; use crate::{HugrView, IncomingPort}; -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; use thiserror::Error; /// Specification of a identity-insertion operation. #[derive(Debug, Clone)] -pub struct IdentityInsertion { +pub struct IdentityInsertion { /// The node following the identity to be inserted. - pub post_node: Node, + pub post_node: N, /// The port following the identity to be inserted. pub post_port: IncomingPort, } -impl IdentityInsertion { +impl IdentityInsertion { /// Create a new [`IdentityInsertion`] specification. - pub fn new(post_node: Node, post_port: IncomingPort) -> Self { + pub fn new(post_node: N, post_port: IncomingPort) -> Self { Self { post_node, post_port, @@ -47,12 +48,10 @@ pub enum IdentityInsertionError { InvalidPortKind(Option), } -impl Rewrite for IdentityInsertion { - type Node = Node; +impl PatchVerification for IdentityInsertion { type Error = IdentityInsertionError; - /// The inserted node. - type ApplyResult = Node; - const UNCHANGED_ON_FAILURE: bool = true; + type Node = N; + fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> { /* Assumptions: @@ -66,10 +65,23 @@ impl Rewrite for IdentityInsertion { unimplemented!() } - fn apply( + + #[inline] + fn invalidation_set(&self) -> impl Iterator { + iter::once(self.post_node) + } +} + +impl PatchHugrMut for IdentityInsertion { + /// The inserted node. + type Outcome = N; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut( self, - h: &mut impl HugrMut, - ) -> Result { + h: &mut impl HugrMut, + ) -> Result { let kind = h.get_optype(self.post_node).port_kind(self.post_port); let Some(EdgeKind::Value(ty)) = kind else { return Err(IdentityInsertionError::InvalidPortKind(kind)); @@ -92,11 +104,6 @@ impl Rewrite for IdentityInsertion { h.connect(new_node, 0, self.post_node, self.post_port); Ok(new_node) } - - #[inline] - fn invalidation_set(&self) -> impl Iterator { - iter::once(self.post_node) - } } #[cfg(test)] @@ -122,7 +129,7 @@ mod tests { let rw = IdentityInsertion::new(final_node, final_node_port); - let noop_node = h.apply_rewrite(rw).unwrap(); + let noop_node = h.apply_patch(rw).unwrap(); assert_eq!(h.node_count(), 7); diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/patch/outline_cfg.rs similarity index 96% rename from hugr-core/src/hugr/rewrite/outline_cfg.rs rename to hugr-core/src/hugr/patch/outline_cfg.rs index a76dbc6ee..0f40615a9 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/patch/outline_cfg.rs @@ -1,4 +1,5 @@ -//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection of an existing CFG +//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection +//! of an existing CFG use std::collections::HashSet; use itertools::Itertools; @@ -6,7 +7,6 @@ use thiserror::Error; use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; use crate::extension::ExtensionSet; -use crate::hugr::rewrite::Rewrite; use crate::hugr::{HugrMut, HugrView}; use crate::ops; use crate::ops::controlflow::BasicBlock; @@ -16,6 +16,8 @@ use crate::ops::{DataflowBlock, OpType}; use crate::PortIndex; use crate::{type_row, Node}; +use super::{PatchHugrMut, PatchVerification}; + /// Moves part of a Control-flow Sibling Graph into a new CFG-node /// that is the only child of a new Basic Block in the original CSG. pub struct OutlineCfg { @@ -92,20 +94,30 @@ impl OutlineCfg { } } -impl Rewrite for OutlineCfg { - type Node = Node; +impl PatchVerification for OutlineCfg { type Error = OutlineCfgError; + type Node = Node; + fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> { + self.compute_entry_exit_outside_extensions(h)?; + Ok(()) + } + + fn invalidation_set(&self) -> impl Iterator { + self.blocks.iter().copied() + } +} + +impl PatchHugrMut for OutlineCfg { /// The newly-created basic block, and the [CFG] node inside it /// /// [CFG]: OpType::CFG - type ApplyResult = (Node, Node); + type Outcome = [Node; 2]; const UNCHANGED_ON_FAILURE: bool = true; - fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> { - self.compute_entry_exit_outside_extensions(h)?; - Ok(()) - } - fn apply(self, h: &mut impl HugrMut) -> Result<(Node, Node), OutlineCfgError> { + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result<[Node; 2], OutlineCfgError> { let (entry, exit, outside, extension_delta) = self.compute_entry_exit_outside_extensions(h)?; // 1. Compute signature @@ -212,11 +224,7 @@ impl Rewrite for OutlineCfg { // 4(b). Reconnect exit edge to the new exit node within the inner CFG h.connect(exit, exit_port, inner_exit, 0); - Ok((new_block, cfg_node)) - } - - fn invalidation_set(&self) -> impl Iterator { - self.blocks.iter().copied() + Ok([new_block, cfg_node]) } } @@ -361,22 +369,22 @@ mod test { } = cond_then_loop_cfg; let backup = h.clone(); - let r = h.apply_rewrite(OutlineCfg::new([tail])); + let r = h.apply_patch(OutlineCfg::new([tail])); assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _))); assert_eq!(h, backup); - let r = h.apply_rewrite(OutlineCfg::new([entry, left, right])); + let r = h.apply_patch(OutlineCfg::new([entry, left, right])); assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right]))); assert_eq!(h, backup); - let r = h.apply_rewrite(OutlineCfg::new([left, right, merge])); + let r = h.apply_patch(OutlineCfg::new([left, right, merge])); assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right]))); assert_eq!(h, backup); // The entry node implicitly has an extra incoming edge - let r = h.apply_rewrite(OutlineCfg::new([entry, left, right, merge, head])); + let r = h.apply_patch(OutlineCfg::new([entry, left, right, merge, head])); assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head]))); assert_eq!(h, backup); @@ -497,7 +505,7 @@ mod test { ) -> (Node, Node, Node) { let mut other_blocks = h.children(cfg).collect::>(); assert!(blocks.iter().all(|b| other_blocks.remove(b))); - let (new_block, new_cfg) = h.apply_rewrite(OutlineCfg::new(blocks.clone())).unwrap(); + let [new_block, new_cfg] = h.apply_patch(OutlineCfg::new(blocks.clone())).unwrap(); for n in other_blocks { assert_eq!(h.get_parent(n), Some(cfg)) diff --git a/hugr-core/src/hugr/rewrite/port_types.rs b/hugr-core/src/hugr/patch/port_types.rs similarity index 100% rename from hugr-core/src/hugr/rewrite/port_types.rs rename to hugr-core/src/hugr/patch/port_types.rs diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/patch/replace.rs similarity index 73% rename from hugr-core/src/hugr/rewrite/replace.rs rename to hugr-core/src/hugr/patch/replace.rs index 0316f9d5b..6f0b0ed65 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/patch/replace.rs @@ -5,29 +5,32 @@ use std::collections::{HashMap, HashSet, VecDeque}; use itertools::Itertools; use thiserror::Error; +use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; use crate::hugr::HugrMut; use crate::ops::{OpTag, OpTrait}; use crate::types::EdgeKind; use crate::{Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort}; -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; /// Specifies how to create a new edge. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct NewEdgeSpec { - /// The source of the new edge. For [Replacement::mu_inp] and [Replacement::mu_new], this is in the - /// existing Hugr; for edges in [Replacement::mu_out] this is in the [Replacement::replacement] - pub src: Node, - /// The target of the new edge. For [Replacement::mu_inp], this is in the [Replacement::replacement]; - /// for edges in [Replacement::mu_out] and [Replacement::mu_new], this is in the existing Hugr. - pub tgt: Node, +pub struct NewEdgeSpec { + /// The source of the new edge. For [Replacement::mu_inp] and + /// [Replacement::mu_new], this is in the existing Hugr; for edges in + /// [Replacement::mu_out] this is in the [Replacement::replacement] + pub src: SrcNode, + /// The target of the new edge. For [Replacement::mu_inp], this is in the + /// [Replacement::replacement]; for edges in [Replacement::mu_out] and + /// [Replacement::mu_new], this is in the existing Hugr. + pub tgt: TgtNode, /// The kind of edge to create, and any port specifiers required pub kind: NewEdgeKind, } /// Describes an edge that should be created between two nodes already given -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NewEdgeKind { /// An [EdgeKind::StateOrder] edge (between DFG nodes only) Order, @@ -54,40 +57,47 @@ pub enum NewEdgeKind { /// Specification of a `Replace` operation #[derive(Debug, Clone, PartialEq)] -pub struct Replacement { +pub struct Replacement { /// The nodes to remove from the existing Hugr (known as Gamma). - /// These must all have a common parent (i.e. be siblings). Called "S" in the spec. - /// Must be non-empty - otherwise there is no parent under which to place [Self::replacement], - /// and there would be no possible [Self::mu_inp], [Self::mu_out] or [Self::adoptions]. - pub removal: Vec, - /// A hugr (not necessarily valid, as it may be missing edges and/or nodes), whose root - /// is the same type as the root of [Self::replacement]. "G" in the spec. + /// These must all have a common parent (i.e. be siblings). Called "S" in + /// the spec. Must be non-empty - otherwise there is no parent under + /// which to place [Self::replacement], and there would be no possible + /// [Self::mu_inp], [Self::mu_out] or [Self::adoptions]. + pub removal: Vec, + /// A hugr (not necessarily valid, as it may be missing edges and/or nodes), + /// whose root is the same type as the root of [Self::replacement]. "G" + /// in the spec. pub replacement: Hugr, - /// Describes how parts of the Hugr that would otherwise be removed should instead be preserved but - /// with new parents amongst the newly-inserted nodes. This is a Map from container nodes in - /// [Self::replacement] that have no children, to container nodes that are descended from [Self::removal]. - /// The keys are the new parents for the children of the values. Note no value may be ancestor or - /// descendant of another. This is "B" in the spec; "R" is the set of descendants of [Self::removal] - /// that are not descendants of values here. - pub adoptions: HashMap, - /// Edges from nodes in the existing Hugr that are not removed ([NewEdgeSpec::src] in Gamma\R) - /// to inserted nodes ([NewEdgeSpec::tgt] in [Self::replacement]). - pub mu_inp: Vec, - /// Edges from inserted nodes ([NewEdgeSpec::src] in [Self::replacement]) to existing nodes not removed - /// ([NewEdgeSpec::tgt] in Gamma \ R). - pub mu_out: Vec, - /// Edges to add between existing nodes (both [NewEdgeSpec::src] and [NewEdgeSpec::tgt] in Gamma \ R). - /// For example, in cases where the source had an edge to a removed node, and the target had an - /// edge from a removed node, this would allow source to be directly connected to target. - pub mu_new: Vec, + /// Describes how parts of the Hugr that would otherwise be removed should + /// instead be preserved but with new parents amongst the newly-inserted + /// nodes. This is a Map from container nodes in [Self::replacement] + /// that have no children, to container nodes that are descended from + /// [Self::removal]. The keys are the new parents for the children of + /// the values. Note no value may be ancestor or descendant of another. + /// This is "B" in the spec; "R" is the set of descendants of + /// [Self::removal] that are not descendants of values here. + pub adoptions: HashMap, + /// Edges from nodes in the existing Hugr that are not removed + /// ([NewEdgeSpec::src] in Gamma\R) to inserted nodes + /// ([NewEdgeSpec::tgt] in [Self::replacement]). + pub mu_inp: Vec>, + /// Edges from inserted nodes ([NewEdgeSpec::src] in [Self::replacement]) to + /// existing nodes not removed ([NewEdgeSpec::tgt] in Gamma \ R). + pub mu_out: Vec>, + /// Edges to add between existing nodes (both [NewEdgeSpec::src] and + /// [NewEdgeSpec::tgt] in Gamma \ R). For example, in cases where the + /// source had an edge to a removed node, and the target had an + /// edge from a removed node, this would allow source to be directly + /// connected to target. + pub mu_new: Vec>, } -impl NewEdgeSpec { - fn check_src( +impl NewEdgeSpec { + fn check_src( &self, - h: &impl HugrView, - err_spec: &NewEdgeSpec, - ) -> Result<(), ReplaceError> { + h: &impl HugrView, + err_spec: impl Fn(Self) -> WhichEdgeSpec, + ) -> Result<(), ReplaceError> { let optype = h.get_optype(self.src); let ok = match self.kind { NewEdgeKind::Order => optype.other_output() == Some(EdgeKind::StateOrder), @@ -103,13 +113,14 @@ impl NewEdgeSpec { } }; ok.then_some(()) - .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec(self.clone()))) } - fn check_tgt( + + fn check_tgt( &self, - h: &impl HugrView, - err_spec: &NewEdgeSpec, - ) -> Result<(), ReplaceError> { + h: &impl HugrView, + err_spec: impl Fn(Self) -> WhichEdgeSpec, + ) -> Result<(), ReplaceError> { let optype = h.get_optype(self.tgt); let ok = match self.kind { NewEdgeKind::Order => optype.other_input() == Some(EdgeKind::StateOrder), @@ -126,18 +137,20 @@ impl NewEdgeSpec { ), }; ok.then_some(()) - .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec(self.clone()))) } +} +impl NewEdgeSpec { fn check_existing_edge( &self, - h: &impl HugrView, - legal_src_ancestors: &HashSet, - err_edge: impl Fn() -> NewEdgeSpec, - ) -> Result<(), ReplaceError> { + h: &impl HugrView, + legal_src_ancestors: &HashSet, + err_edge: impl Fn(Self) -> WhichEdgeSpec, + ) -> Result<(), ReplaceError> { if let NewEdgeKind::Static { tgt_pos, .. } | NewEdgeKind::Value { tgt_pos, .. } = self.kind { - let descends_from_legal = |mut descendant: Node| -> bool { + let descends_from_legal = |mut descendant: HostNode| -> bool { while !legal_src_ancestors.contains(&descendant) { let Some(p) = h.get_parent(descendant) else { return false; @@ -150,15 +163,18 @@ impl NewEdgeSpec { .single_linked_output(self.tgt, tgt_pos) .is_some_and(|(src_n, _)| descends_from_legal(src_n)); if !found_incoming { - return Err(ReplaceError::NoRemovedEdge(err_edge())); + return Err(ReplaceError::NoRemovedEdge(err_edge(self.clone()))); }; }; Ok(()) } } -impl Replacement { - fn check_parent(&self, h: &impl HugrView) -> Result { +impl Replacement { + fn check_parent( + &self, + h: &impl HugrView, + ) -> Result> { let parent = self .removal .iter() @@ -168,8 +184,9 @@ impl Replacement { .map_err(|ex_one| ReplaceError::MultipleParents(ex_one.flatten().collect()))? .ok_or(ReplaceError::CantReplaceRoot)?; // If no parent - // Check replacement parent is of same tag. Note we do not require exact equality - // of OpType/Signature, e.g. to ease changing of Input/Output node signatures too. + // Check replacement parent is of same tag. Note we do not require exact + // equality of OpType/Signature, e.g. to ease changing of Input/Output + // node signatures too. let removed = h.get_optype(parent).tag(); let replacement = self.replacement.root_type().tag(); if removed != replacement { @@ -183,8 +200,8 @@ impl Replacement { fn get_removed_nodes( &self, - h: &impl HugrView, - ) -> Result, ReplaceError> { + h: &impl HugrView, + ) -> Result, ReplaceError> { // Check the keys of the transfer map too, the values we'll use imminently self.adoptions.keys().try_for_each(|&n| { (self.replacement.contains_node(n) @@ -193,7 +210,7 @@ impl Replacement { .then_some(()) .ok_or(ReplaceError::InvalidAdoptingParent(n)) })?; - let mut transferred: HashSet = self.adoptions.values().copied().collect(); + let mut transferred: HashSet = self.adoptions.values().copied().collect(); if transferred.len() != self.adoptions.values().len() { return Err(ReplaceError::AdopteesNotSeparateDescendants( self.adoptions @@ -221,98 +238,149 @@ impl Replacement { Ok(removed) } } -impl Rewrite for Replacement { - type Node = Node; - type Error = ReplaceError; - - /// Map from Node in replacement to corresponding Node in the result Hugr - type ApplyResult = HashMap; - const UNCHANGED_ON_FAILURE: bool = false; +impl PatchVerification for Replacement { + type Error = ReplaceError; + type Node = HostNode; - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { self.check_parent(h)?; let removed = self.get_removed_nodes(h)?; // Edge sources... - for e in self.mu_inp.iter().chain(self.mu_new.iter()) { + for e in self.mu_inp.iter() { if !h.contains_node(e.src) || removed.contains(&e.src) { return Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, - WhichHugr::Retained, - e.clone(), + WhichEdgeSpec::HostToRepl(e.clone()), )); } - e.check_src(h, e)?; + e.check_src(h, WhichEdgeSpec::HostToRepl)?; + } + for e in self.mu_new.iter() { + if !h.contains_node(e.src) || removed.contains(&e.src) { + return Err(ReplaceError::BadEdgeSpec( + Direction::Outgoing, + WhichEdgeSpec::HostToHost(e.clone()), + )); + } + e.check_src(h, WhichEdgeSpec::HostToHost)?; } self.mu_out .iter() .try_for_each(|e| match self.replacement.valid_non_root(e.src) { - true => e.check_src(&self.replacement, e), + true => e.check_src(&self.replacement, WhichEdgeSpec::ReplToHost), false => Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, - WhichHugr::Replacement, - e.clone(), + WhichEdgeSpec::ReplToHost(e.clone()), )), })?; // Edge targets... self.mu_inp .iter() .try_for_each(|e| match self.replacement.valid_non_root(e.tgt) { - true => e.check_tgt(&self.replacement, e), + true => e.check_tgt(&self.replacement, WhichEdgeSpec::HostToRepl), false => Err(ReplaceError::BadEdgeSpec( Direction::Incoming, - WhichHugr::Replacement, - e.clone(), + WhichEdgeSpec::HostToRepl(e.clone()), )), })?; - for e in self.mu_out.iter().chain(self.mu_new.iter()) { + for e in self.mu_out.iter() { if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { return Err(ReplaceError::BadEdgeSpec( Direction::Incoming, - WhichHugr::Retained, - e.clone(), + WhichEdgeSpec::ReplToHost(e.clone()), )); } - e.check_tgt(h, e)?; + e.check_tgt(h, WhichEdgeSpec::ReplToHost)?; // The descendant check is to allow the case where the old edge is nonlocal // from a part of the Hugr being moved (which may require changing source, // depending on where the transplanted portion ends up). While this subsumes - // the first "removed.contains" check, we'll keep that as a common-case fast-path. - e.check_existing_edge(h, &removed, || e.clone())?; + // the first "removed.contains" check, we'll keep that as a common-case + // fast-path. + e.check_existing_edge(h, &removed, WhichEdgeSpec::ReplToHost)?; + } + for e in self.mu_new.iter() { + if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { + return Err(ReplaceError::BadEdgeSpec( + Direction::Incoming, + WhichEdgeSpec::HostToHost(e.clone()), + )); + } + e.check_tgt(h, WhichEdgeSpec::HostToHost)?; + // The descendant check is to allow the case where the old edge is nonlocal + // from a part of the Hugr being moved (which may require changing source, + // depending on where the transplanted portion ends up). While this subsumes + // the first "removed.contains" check, we'll keep that as a common-case + // fast-path. + e.check_existing_edge(h, &removed, WhichEdgeSpec::HostToHost)?; } Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + self.removal.iter().copied() + } +} + +impl PatchHugrMut for Replacement { + /// Map from Node in replacement to corresponding Node in the result Hugr + type Outcome = HashMap; + + const UNCHANGED_ON_FAILURE: bool = false; + + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result { let parent = self.check_parent(h)?; // Calculate removed nodes here. (Does not include transfers, so enumerates only - // nodes we are going to remove, individually, anyway; so no *asymptotic* speed penalty) + // nodes we are going to remove, individually, anyway; so no *asymptotic* speed + // penalty) let to_remove = self.get_removed_nodes(h)?; - // 1. Add all the new nodes. Note this includes replacement.root(), which we don't want. + // 1. Add all the new nodes. Note this includes replacement.root(), which we + // don't want. // TODO what would an error here mean? e.g. malformed self.replacement?? let InsertionResult { new_root, node_map } = h.insert_hugr(parent, self.replacement); // 2. Add new edges from existing to copied nodes according to mu_in - let translate_idx = |n| node_map.get(&n).copied().ok_or(WhichHugr::Replacement); - let kept = |n| { - let keep = !to_remove.contains(&n); - keep.then_some(n).ok_or(WhichHugr::Retained) - }; - transfer_edges(h, self.mu_inp.iter(), kept, translate_idx, None)?; + let translate_idx = |n| node_map.get(&n).copied(); + let kept = |n| (!to_remove.contains(&n)).then_some(n); + transfer_edges( + h, + self.mu_inp.iter(), + kept, + translate_idx, + WhichEdgeSpec::HostToRepl, + None, + )?; // 3. Add new edges from copied to existing nodes according to mu_out, // replacing existing value/static edges incoming to targets - transfer_edges(h, self.mu_out.iter(), translate_idx, kept, Some(&to_remove))?; + transfer_edges( + h, + self.mu_out.iter(), + translate_idx, + kept, + WhichEdgeSpec::ReplToHost, + Some(&to_remove), + )?; // 4. Add new edges between existing nodes according to mu_new, // replacing existing value/static edges incoming to targets. - transfer_edges(h, self.mu_new.iter(), kept, kept, Some(&to_remove))?; + transfer_edges( + h, + self.mu_new.iter(), + kept, + kept, + WhichEdgeSpec::HostToHost, + Some(&to_remove), + )?; // 5. Put newly-added copies into correct places in hierarchy // (these will be correct places after removing nodes) let mut remove_top_sibs = self.removal.iter(); - for new_node in h.children(new_root).collect::>().into_iter() { + for new_node in h.children(new_root).collect::>().into_iter() { if let Some(top_sib) = remove_top_sibs.next() { h.move_before_sibling(new_node, *top_sib); } else { @@ -337,51 +405,53 @@ impl Rewrite for Replacement { }); Ok(node_map) } - - fn invalidation_set(&self) -> impl Iterator { - self.removal.iter().copied() - } } -fn transfer_edges<'a>( - h: &mut impl HugrMut, - edges: impl Iterator, - trans_src: impl Fn(Node) -> Result, - trans_tgt: impl Fn(Node) -> Result, - legal_src_ancestors: Option<&HashSet>, -) -> Result<(), ReplaceError> { +fn transfer_edges<'a, SrcNode, TgtNode, HostNode>( + h: &mut impl HugrMut, + edges: impl Iterator>, + trans_src: impl Fn(SrcNode) -> Option, + trans_tgt: impl Fn(TgtNode) -> Option, + err_spec: impl Fn(NewEdgeSpec) -> WhichEdgeSpec, + legal_src_ancestors: Option<&HashSet>, +) -> Result<(), ReplaceError> +where + SrcNode: 'a + HugrNode, + TgtNode: 'a + HugrNode, + HostNode: 'a + HugrNode, +{ for oe in edges { + let err_spec = err_spec(oe.clone()); let e = NewEdgeSpec { // Translation can only fail for Nodes that are supposed to be in the replacement src: trans_src(oe.src) - .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Outgoing, h, oe.clone()))?, + .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Outgoing, err_spec.clone()))?, tgt: trans_tgt(oe.tgt) - .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Incoming, h, oe.clone()))?, - ..oe.clone() + .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Incoming, err_spec.clone()))?, + kind: oe.kind, }; if !h.valid_node(e.src) { return Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, - WhichHugr::Retained, - oe.clone(), + err_spec.clone(), )); } if !h.valid_node(e.tgt) { return Err(ReplaceError::BadEdgeSpec( Direction::Incoming, - WhichHugr::Retained, - oe.clone(), + err_spec.clone(), )); }; - e.check_src(h, oe)?; - e.check_tgt(h, oe)?; + let err_spec = |_| err_spec.clone(); + e.check_src(h, err_spec)?; + e.check_tgt(h, err_spec)?; match e.kind { NewEdgeKind::Order => { h.add_other_edge(e.src, e.tgt); } NewEdgeKind::Value { src_pos, tgt_pos } | NewEdgeKind::Static { src_pos, tgt_pos } => { if let Some(legal_src_ancestors) = legal_src_ancestors { - e.check_existing_edge(h, legal_src_ancestors, || oe.clone())?; + e.check_existing_edge(h, legal_src_ancestors, err_spec)?; h.disconnect(e.tgt, tgt_pos); } h.connect(e.src, src_pos, e.tgt, tgt_pos); @@ -395,14 +465,14 @@ fn transfer_edges<'a>( /// Error in a [`Replacement`] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[non_exhaustive] -pub enum ReplaceError { +pub enum ReplaceError { /// The node(s) to replace had no parent i.e. were root(s). // (Perhaps if there is only one node to replace we should be able to?) #[error("Cannot replace the root node of the Hugr")] CantReplaceRoot, /// The nodes to replace did not have a unique common parent #[error("Removed nodes had different parents {0:?}")] - MultipleParents(Vec), + MultipleParents(Vec), /// Replacement root node had different tag from parent of removed nodes #[error("Expected replacement root with tag {removed} but found {replacement}")] WrongRootNodeTag { @@ -411,40 +481,47 @@ pub enum ReplaceError { /// The tag of the root in the replacement Hugr replacement: OpTag, }, - /// Keys in [Replacement::adoptions] were not valid container nodes in [Replacement::replacement] + /// Keys in [Replacement::adoptions] were not valid container nodes in + /// [Replacement::replacement] #[error("Node {0} was not an empty container node in the replacement")] InvalidAdoptingParent(Node), - /// Some values in [Replacement::adoptions] were either descendants of other values, or not - /// descendants of the [Replacement::removal]. The nodes are indicated on a best-effort basis. + /// Some values in [Replacement::adoptions] were either descendants of other + /// values, or not descendants of the [Replacement::removal]. The nodes + /// are indicated on a best-effort basis. #[error("Nodes not free to be moved into new locations: {0:?}")] - AdopteesNotSeparateDescendants(Vec), + AdopteesNotSeparateDescendants(Vec), /// A node at one end of a [NewEdgeSpec] was not found - #[error("{0:?} end of edge {2:?} not found in {1}")] - BadEdgeSpec(Direction, WhichHugr, NewEdgeSpec), - /// The target of the edge was found, but there was no existing edge to replace + #[error("{0:?} end of edge {1:?} not found in {which_hugr}", which_hugr = .1.which_hugr(*.0))] + BadEdgeSpec(Direction, WhichEdgeSpec), + /// The target of the edge was found, but there was no existing edge to + /// replace #[error("Target of edge {0:?} did not have a corresponding incoming edge being removed")] - NoRemovedEdge(NewEdgeSpec), + NoRemovedEdge(WhichEdgeSpec), /// The [NewEdgeKind] was not applicable for the source/target node(s) #[error("The edge kind was not applicable to the {0:?} node: {1:?}")] - BadEdgeKind(Direction, NewEdgeSpec), + BadEdgeKind(Direction, WhichEdgeSpec), } -/// A Hugr or portion thereof that is part of the [Replacement] +/// The three kinds of [NewEdgeSpec] that may appear in a [ReplaceError] #[derive(Clone, Debug, PartialEq, Eq)] -pub enum WhichHugr { - /// The newly-inserted nodes, i.e. the [Replacement::replacement] - Replacement, - /// Nodes in the existing Hugr that are not [Replacement::removal] - /// (or are on the RHS of an entry in [Replacement::adoptions]) - Retained, +pub enum WhichEdgeSpec { + /// An edge from the host Hugr into the replacement, i.e. + /// [Replacement::mu_inp] + HostToRepl(NewEdgeSpec), + /// An edge from the replacement to the host, i.e. [Replacement::mu_out] + ReplToHost(NewEdgeSpec), + /// An edge between two nodes in the host (bypassing the replacement), + /// i.e. [Replacement::mu_new] + HostToHost(NewEdgeSpec), } -impl std::fmt::Display for WhichHugr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(match self { - Self::Replacement => "replacement Hugr", - Self::Retained => "retained portion of Hugr", - }) +impl WhichEdgeSpec { + fn which_hugr(&self, d: Direction) -> &str { + match (self, d) { + (Self::HostToRepl(_), Direction::Incoming) + | (Self::ReplToHost(_), Direction::Outgoing) => "replacement Hugr", + _ => "retained portion of Hugr", + } } } @@ -462,8 +539,8 @@ mod test { use crate::extension::prelude::{bool_t, usize_t}; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::hugr::internal::HugrMutInternals; - use crate::hugr::rewrite::replace::WhichHugr; - use crate::hugr::{HugrMut, Rewrite}; + use crate::hugr::patch::PatchVerification; + use crate::hugr::{HugrMut, Patch}; use crate::ops::custom::ExtensionOp; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; @@ -473,7 +550,7 @@ mod test { use crate::utils::{depth, test_quantum_extension}; use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort}; - use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement}; + use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement, WhichEdgeSpec}; #[test] #[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-' @@ -520,7 +597,8 @@ mod test { } // Replacement: one BB with two DFGs inside. - // Use Hugr rather than Builder because DFGs must be empty (not even Input/Output). + // Use Hugr rather than Builder because it must be empty (not even + // Input/Output). let mut replacement = Hugr::new(ops::CFG { signature: Signature::new_endo(just_list.clone()), }); @@ -569,7 +647,7 @@ mod test { replacement.connect(r_df2, 1, out, 1); } - h.apply_rewrite(Replacement { + h.apply_patch(Replacement { removal: vec![entry.node(), bb2.node()], replacement, adoptions: HashMap::from([(r_df1.node(), entry.node()), (r_df2.node(), bb2.node())]), @@ -788,7 +866,10 @@ mod test { mu_inp: vec![edge_from_removed.clone()], ..rep.clone() }), - ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Retained, edge_from_removed) + ReplaceError::BadEdgeSpec( + Direction::Outgoing, + WhichEdgeSpec::HostToRepl(edge_from_removed) + ) ); let bad_out_edge = NewEdgeSpec { src: h.nodes().max().unwrap(), // not valid in replacement @@ -800,7 +881,7 @@ mod test { mu_out: vec![bad_out_edge.clone()], ..rep.clone() }), - ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, bad_out_edge) + ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichEdgeSpec::ReplToHost(bad_out_edge),) ); let bad_order_edge = NewEdgeSpec { src: cond.node(), @@ -812,7 +893,7 @@ mod test { mu_new: vec![bad_order_edge.clone()], ..rep.clone() }), - ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, bad_order_edge) + ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, WhichEdgeSpec::HostToHost(bad_order_edge)) ); let op = OutgoingPort::from(0); let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap(); @@ -829,7 +910,7 @@ mod test { mu_out: vec![new_out_edge.clone()], ..rep.clone() }), - ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge) + ReplaceError::BadEdgeKind(Direction::Outgoing, WhichEdgeSpec::ReplToHost(new_out_edge)) ); } } diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs similarity index 92% rename from hugr-core/src/hugr/rewrite/simple_replace.rs rename to hugr-core/src/hugr/patch/simple_replace.rs index 5d3716dc0..e9283644d 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrMut, HugrView, Rewrite}; +use crate::hugr::{HugrMut, HugrView}; use crate::ops::{OpTag, OpTrait, OpType}; use crate::{Hugr, IncomingPort, Node, OutgoingPort}; @@ -14,7 +14,7 @@ use itertools::Itertools; use thiserror::Error; use super::inline_dfg::InlineDFGError; -use super::{BoundaryPort, HostPort, ReplacementPort}; +use super::{BoundaryPort, HostPort, PatchHugrMut, PatchVerification, ReplacementPort}; /// Specification of a simple replacement operation. /// @@ -28,7 +28,8 @@ pub struct SimpleReplacement { /// A hugr with DFG root (consisting of replacement nodes). replacement: Hugr, /// A map from (target ports of edges from the Input node of `replacement`) - /// to (target ports of edges from nodes not in `subgraph` to nodes in `subgraph`). + /// to (target ports of edges from nodes not in `subgraph` to nodes in + /// `subgraph`). nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>, /// A map from (target ports of edges from nodes in `subgraph` to nodes not /// in `subgraph`) to (input ports of the Output node of `replacement`). @@ -125,7 +126,8 @@ impl SimpleReplacement { }) .map( |(&(rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| { - // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) + // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, + // n_inp_port) let (rem_inp_pred_node, rem_inp_pred_port) = host .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); @@ -158,8 +160,9 @@ impl SimpleReplacement { > + 'a { let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG"); - // For each q = self.nu_out[p] such that the predecessor of q is not an Input port, - // there will be an edge from (the new copy of) the predecessor of q to p. + // For each q = self.nu_out[p] such that the predecessor of q is not an Input + // port, there will be an edge from (the new copy of) the predecessor of + // q to p. self.nu_out .iter() .filter_map(move |(&(rem_out_node, rem_out_port), rep_out_port)| { @@ -196,8 +199,8 @@ impl SimpleReplacement { > + 'a { let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG"); - // For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 - // to p1. + // For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the + // predecessor of p0 to p1. self.nu_out .iter() .filter_map(move |(&(rem_out_node, rem_out_port), &rep_out_port)| { @@ -245,8 +248,9 @@ impl SimpleReplacement { /// Get all edges that the replacement would add between `host` and /// `self.replacement`. /// - /// This is equivalent to chaining the results of [`Self::incoming_boundary`], - /// [`Self::outgoing_boundary`], and [`Self::host_to_host_boundary`]. + /// This is equivalent to chaining the results of + /// [`Self::incoming_boundary`], [`Self::outgoing_boundary`], and + /// [`Self::host_to_host_boundary`]. /// /// This panics if self.replacement is not a DFG. pub fn all_boundary_edges<'a>( @@ -274,17 +278,35 @@ impl SimpleReplacement { } } -impl Rewrite for SimpleReplacement { - type Node = Node; +impl PatchVerification for SimpleReplacement { type Error = SimpleReplacementError; - type ApplyResult = Vec<(Node, OpType)>; - const UNCHANGED_ON_FAILURE: bool = true; + type Node = HostNode; - fn verify(&self, h: &impl HugrView) -> Result<(), SimpleReplacementError> { + fn verify(&self, h: &impl HugrView) -> Result<(), SimpleReplacementError> { self.is_valid_rewrite(h) } - fn apply(self, h: &mut impl HugrMut) -> Result { + #[inline] + fn invalidation_set(&self) -> impl Iterator { + let subcirc = self.subgraph.nodes().iter().copied(); + let out_neighs = self.nu_out.keys().map(|key| key.0); + subcirc.chain(out_neighs) + } +} + +/// Result of applying a [`SimpleReplacement`]. +pub struct Outcome { + /// Map from Node in replacement to corresponding Node in the result Hugr + pub node_map: HashMap, + /// Nodes removed from the result Hugr and their weights + pub removed_nodes: HashMap, +} + +impl PatchHugrMut for SimpleReplacement { + type Outcome = Outcome; + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result { self.is_valid_rewrite(h)?; let parent = self.subgraph.get_parent(h); @@ -305,13 +327,10 @@ impl Rewrite for SimpleReplacement { } = self; // 2. Insert the replacement as a whole. - let InsertionResult { - new_root, - node_map: index_map, - } = h.insert_hugr(parent, replacement); + let InsertionResult { new_root, node_map } = h.insert_hugr(parent, replacement); // remove the Input and Output nodes from the replacement graph - let replace_children = h.children(new_root).collect::>(); + let replace_children = h.children(new_root).collect::>(); for &io in &replace_children[..2] { h.remove_node(io); } @@ -324,24 +343,22 @@ impl Rewrite for SimpleReplacement { // 3. Insert all boundary edges. for (src, tgt) in boundary_edges { - let (src_node, src_port) = src.map_replacement(&index_map); - let (tgt_node, tgt_port) = tgt.map_replacement(&index_map); + let (src_node, src_port) = src.map_replacement(&node_map); + let (tgt_node, tgt_port) = tgt.map_replacement(&node_map); h.connect(src_node, src_port, tgt_node, tgt_port); } // 4. Remove all nodes in subgraph and edges between them. - Ok(subgraph + let removed_nodes = subgraph .nodes() .iter() .map(|&node| (node, h.remove_node(node))) - .collect()) - } + .collect(); - #[inline] - fn invalidation_set(&self) -> impl Iterator { - let subcirc = self.subgraph.nodes().iter().copied(); - let out_neighs = self.nu_out.keys().map(|key| key.0); - subcirc.chain(out_neighs) + Ok(Outcome { + node_map, + removed_nodes, + }) } } @@ -364,9 +381,10 @@ pub enum SimpleReplacementError { } #[cfg(test)] -pub(in crate::hugr::rewrite) mod test { +pub(in crate::hugr::patch) mod test { use itertools::Itertools; use rstest::{fixture, rstest}; + use std::collections::{HashMap, HashSet}; use crate::builder::test::n_identity; @@ -376,8 +394,9 @@ pub(in crate::hugr::rewrite) mod test { }; use crate::extension::prelude::{bool_t, qb_t}; use crate::extension::ExtensionSet; + use crate::hugr::patch::PatchVerification; use crate::hugr::views::{HugrView, SiblingSubgraph}; - use crate::hugr::{Hugr, HugrMut, Rewrite}; + use crate::hugr::{Hugr, HugrMut, Patch}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::NodeHandle; use crate::ops::OpTag; @@ -433,7 +452,7 @@ pub(in crate::hugr::rewrite) mod test { } #[fixture] - pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr { + pub(in crate::hugr::patch) fn simple_hugr() -> Hugr { make_hugr().unwrap() } /// Creates a hugr with a DFG root like the following: @@ -453,7 +472,7 @@ pub(in crate::hugr::rewrite) mod test { } #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr { + pub(in crate::hugr::patch) fn dfg_hugr() -> Hugr { make_dfg_hugr().unwrap() } @@ -473,7 +492,7 @@ pub(in crate::hugr::rewrite) mod test { } #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr { + pub(in crate::hugr::patch) fn dfg_hugr2() -> Hugr { make_dfg_hugr2().unwrap() } @@ -485,11 +504,12 @@ pub(in crate::hugr::rewrite) mod test { /// └─────────┘ │ ┌─────────┐ /// └────┤ (2) NOT ├── /// └─────────┘ - /// This can be replaced with an empty hugr coping the input to both outputs. + /// This can be replaced with an empty hugr coping the input to both + /// outputs. /// /// Returns the hugr and the nodes of the NOT gates, in order. #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec) { + pub(in crate::hugr::patch) fn dfg_hugr_copy_bools() -> (Hugr, Vec) { let mut dfg_builder = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [b] = dfg_builder.input_wires_arr(); @@ -516,11 +536,12 @@ pub(in crate::hugr::rewrite) mod test { /// └─────────┘ │ /// └───────────────── /// - /// This can be replaced with a single NOT op, coping the input to the first output. + /// This can be replaced with a single NOT op, coping the input to the first + /// output. /// /// Returns the hugr and the nodes of the NOT ops, in order. #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec) { + pub(in crate::hugr::patch) fn dfg_hugr_half_not_bools() -> (Hugr, Vec) { let mut dfg_builder = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [b] = dfg_builder.input_wires_arr(); @@ -682,7 +703,7 @@ pub(in crate::hugr::rewrite) mod test { nu_inp, nu_out, }; - h.apply_rewrite(r).unwrap(); + h.apply_patch(r).unwrap(); // Expect [DFG] to be replaced with: // ┌───┐┌───┐ // ┤ H ├┤ H ├ @@ -736,7 +757,7 @@ pub(in crate::hugr::rewrite) mod test { }) .map(|p| ((output, p), p)) .collect(); - h.apply_rewrite(SimpleReplacement::new( + h.apply_patch(SimpleReplacement::new( SiblingSubgraph::try_from_nodes(removal, &h).unwrap(), replacement, inputs, @@ -788,7 +809,7 @@ pub(in crate::hugr::rewrite) mod test { .map(|p| ((repl_output, p), p)) .collect(); - h.apply_rewrite(SimpleReplacement::new( + h.apply_patch(SimpleReplacement::new( SiblingSubgraph::try_from_nodes(removal, &h).unwrap(), repl, inputs, @@ -800,8 +821,8 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(h.node_count(), orig.node_count()); } - /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the input - /// directly to the outputs. + /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the + /// input directly to the outputs. /// /// https://github.com/CQCL/hugr/issues/1190 #[rstest] @@ -822,8 +843,9 @@ pub(in crate::hugr::rewrite) mod test { let subgraph = SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr) .unwrap(); - // A map from (target ports of edges from the Input node of `replacement`) to (target ports of - // edges from nodes not in `removal` to nodes in `removal`). + // A map from (target ports of edges from the Input node of `replacement`) to + // (target ports of edges from nodes not in `removal` to nodes in + // `removal`). let nu_inp = [ ( (repl_output, IncomingPort::from(0)), @@ -836,8 +858,8 @@ pub(in crate::hugr::rewrite) mod test { ] .into_iter() .collect(); - // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to - // (input ports of the Output node of `replacement`). + // A map from (target ports of edges from nodes in `removal` to nodes not in + // `removal`) to (input ports of the Output node of `replacement`). let nu_out = [ ((output, IncomingPort::from(0)), IncomingPort::from(0)), ((output, IncomingPort::from(1)), IncomingPort::from(1)), @@ -857,8 +879,8 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(hugr.node_count(), 3); } - /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting the input - /// directly to the output. + /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting + /// the input directly to the output. /// /// https://github.com/CQCL/hugr/issues/1323 #[rstest] @@ -880,8 +902,9 @@ pub(in crate::hugr::rewrite) mod test { let subgraph = SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap(); - // A map from (target ports of edges from the Input node of `replacement`) to (target ports of - // edges from nodes not in `removal` to nodes in `removal`). + // A map from (target ports of edges from the Input node of `replacement`) to + // (target ports of edges from nodes not in `removal` to nodes in + // `removal`). let nu_inp = [ ( (repl_output, IncomingPort::from(0)), @@ -894,8 +917,8 @@ pub(in crate::hugr::rewrite) mod test { ] .into_iter() .collect(); - // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to - // (input ports of the Output node of `replacement`). + // A map from (target ports of edges from nodes in `removal` to nodes not in + // `removal`) to (input ports of the Output node of `replacement`). let nu_out = [ ((output, IncomingPort::from(0)), IncomingPort::from(0)), ((output, IncomingPort::from(1)), IncomingPort::from(1)), @@ -959,9 +982,9 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(h.node_count(), 6); } - use crate::hugr::rewrite::replace::Replacement; + use crate::hugr::patch::replace::Replacement; fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement { - use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec}; + use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec}; let mut replacement = s.replacement; let (in_, out) = replacement @@ -1018,10 +1041,10 @@ pub(in crate::hugr::rewrite) mod test { } fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) { - h.apply_rewrite(rw).unwrap(); + h.apply_patch(rw).unwrap(); } fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) { - h.apply_rewrite(to_replace(h, rw)).unwrap(); + h.apply_patch(to_replace(h, rw)).unwrap(); } } diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 9502d9f6b..680d58a03 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -838,7 +838,7 @@ mod tests { use cool_asserts::assert_matches; use crate::builder::inout_sig; - use crate::hugr::Rewrite; + use crate::hugr::Patch; use crate::ops::Const; use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; use crate::std_extensions::logic::{self, LogicOp}; @@ -1011,7 +1011,7 @@ mod tests { assert_eq!(rep.subgraph().nodes().len(), 4); assert_eq!(hugr.node_count(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out - hugr.apply_rewrite(rep).unwrap(); + hugr.apply_patch(rep).unwrap(); assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out Ok(()) diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 3a3bd5e91..7e68e600a 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -4,6 +4,7 @@ use hugr_core::{ Hugr, Node, }; +use itertools::Itertools; use thiserror::Error; /// Replace all operations in a HUGR according to a mapping. @@ -69,9 +70,11 @@ pub fn lower_ops( .map(|(node, replacement)| { let subcirc = SiblingSubgraph::from_node(node, hugr); let rw = subcirc.create_simple_replacement(hugr, replacement)?; - let mut repls = hugr.apply_rewrite(rw)?; - debug_assert_eq!(repls.len(), 1); - Ok(repls.remove(0)) + let removed_nodes = hugr.apply_patch(rw)?.removed_nodes; + Ok(removed_nodes + .into_iter() + .exactly_one() + .expect("removed exactly one node")) }) .collect() } diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index a5de5eb57..5c76ba51d 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -7,8 +7,8 @@ use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::views::RootCheckable; use itertools::Itertools; -use hugr_core::hugr::rewrite::inline_dfg::InlineDFG; -use hugr_core::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; +use hugr_core::hugr::patch::inline_dfg::InlineDFG; +use hugr_core::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; use hugr_core::{Hugr, HugrView, Node}; @@ -39,11 +39,11 @@ where continue; }; let (rep, merge_bb, dfgs) = mk_rep(cfg, n, succ); - let node_map = cfg.apply_rewrite(rep).unwrap(); + let node_map = cfg.apply_patch(rep).unwrap(); let merged_bb = *node_map.get(&merge_bb).unwrap(); for dfg_id in dfgs { let n_id = *node_map.get(&dfg_id).unwrap(); - cfg.apply_rewrite(InlineDFG(n_id.into())).unwrap(); + cfg.apply_patch(InlineDFG(n_id.into())).unwrap(); } worklist.push(merged_bb); } diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index b98d4fb23..6e9df7f1a 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -44,10 +44,10 @@ use std::hash::Hash; use itertools::Itertools; use thiserror::Error; -use hugr_core::hugr::rewrite::outline_cfg::OutlineCfg; +use hugr_core::hugr::patch::outline_cfg::OutlineCfg; use hugr_core::hugr::views::sibling::SiblingMut; use hugr_core::hugr::views::{HierarchyView, HugrView, RootCheckable, SiblingGraph}; -use hugr_core::hugr::{hugrmut::HugrMut, Rewrite}; +use hugr_core::hugr::{hugrmut::HugrMut, Patch}; use hugr_core::ops::handle::{BasicBlockID, CfgID}; use hugr_core::ops::OpTag; use hugr_core::ops::OpTrait; @@ -260,7 +260,7 @@ impl> CfgNester for IdentityCfgMap { assert!([entry_edge.0, entry_edge.1, exit_edge.0, exit_edge.1] .iter() .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); - let (new_block, new_cfg) = OutlineCfg::new(blocks).apply(&mut self.h).unwrap(); + let [new_block, new_cfg] = OutlineCfg::new(blocks).apply(&mut self.h).unwrap(); debug_assert!([entry_edge.0, exit_edge.1] .iter() .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); @@ -579,7 +579,7 @@ pub(crate) mod test { }; use hugr_core::extension::{prelude::usize_t, ExtensionSet}; - use hugr_core::hugr::rewrite::insert_identity::{IdentityInsertion, IdentityInsertionError}; + use hugr_core::hugr::patch::insert_identity::{IdentityInsertion, IdentityInsertionError}; use hugr_core::hugr::views::RootChecked; use hugr_core::ops::handle::{ConstID, NodeHandle}; use hugr_core::ops::Value; @@ -830,7 +830,7 @@ pub(crate) mod test { let rw = IdentityInsertion::new(final_node, final_node_input); - let apply_result = h.apply_rewrite(rw); + let apply_result = h.apply_patch(rw); assert_eq!( apply_result, Err(IdentityInsertionError::InvalidPortKind(Some( diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index d074bed0f..00af101dc 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -131,7 +131,7 @@ impl ComposablePass for UntuplePass { let rewrites_applied = rewrites.len(); // The rewrites are independent, so we can always apply them all. for rewrite in rewrites { - hugr.apply_rewrite(rewrite)?; + hugr.apply_patch(rewrite)?; } Ok(UntupleResult { rewrites_applied }) } diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index a66de8315..0bd8f64ff 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -2,7 +2,7 @@ // Exports everything except the `internal` module. pub use hugr_core::hugr::{ - hugrmut, rewrite, serialize, validate, views, Hugr, HugrError, HugrView, IdentList, - InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Rewrite, + hugrmut, patch, serialize, validate, views, Hugr, HugrError, HugrView, IdentList, + InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Patch, SimpleReplacement, SimpleReplacementError, ValidationError, DEFAULT_OPTYPE, }; From c15d8ac624d6cd56b3ea5423dd0124a68ced1f7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 29 Apr 2025 15:10:37 +0100 Subject: [PATCH 14/18] feat!: Bump MSRV to 1.85 (#2136) Bumps the minimum supported rust version to 1.85, ~and migrates to the 2024 edition~. Most of the changes here are automated. The diff is quite noisy due to some changes in formatting, and me using this opportunity to auto--fix some optional clippy lints. It may be easier to check the changes per-commit. EDIT: Edition change has been left for a separate PR, as it's quite noisy. BREAKING CHANGE: Bumped MSRV to 1.85 --- .github/workflows/ci-rs.yml | 10 ++++------ .pre-commit-config.yaml | 2 +- Cargo.toml | 2 +- DEVELOPMENT.md | 2 +- hugr-cli/README.md | 2 +- hugr-core/README.md | 2 +- hugr-llvm/README.md | 2 +- hugr-model/README.md | 2 +- hugr-passes/README.md | 2 +- hugr-passes/src/replace_types/linearize.rs | 3 +-- hugr-passes/src/untuple.rs | 2 +- hugr/README.md | 2 +- justfile | 2 +- 13 files changed, 16 insertions(+), 19 deletions(-) diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index 56c093e8b..4fe5d244f 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -233,7 +233,7 @@ jobs: id: toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: "1.75" + toolchain: "1.85" - name: Install nightly toolchain uses: dtolnay/rust-toolchain@master with: @@ -252,12 +252,10 @@ jobs: cargo binstall cargo-minimal-versions --force - name: Pin transitive dependencies not compatible with our MSRV # Add new dependencies as needed if the check fails due to - # "package `XXX` cannot be built because it requires rustc YYY or newer, while the currently active rustc version is 1.75.0" + # "package `XXX` cannot be built because it requires rustc YYY or newer, while the currently active rustc version is 1.85.0" run: | - rm Cargo.lock - cargo add -p hugr half@2.4.1 - cargo add -p hugr litemap@0.7.4 - cargo add -p hugr zerofrom@0.1.5 + # rm Cargo.lock + # cargo add -p hugr half@2.4.1 - name: Build with no features run: cargo minimal-versions --direct test --verbose --no-default-features --no-run - name: Tests with no features diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 569ccfec1..4fe582d93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,7 +79,7 @@ repos: # built into a binary build (without using `maturin`) # # This feature list should be kept in sync with the `hugr-py/pyproject.toml` - entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/extension_inference hugr/declarative hugr/model_unstable hugr/llvm hugr/llvm-test hugr/zstd' + entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/extension_inference hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' language: system files: \.rs$ pass_filenames: false diff --git a/Cargo.toml b/Cargo.toml index 3031df1e7..c72326a33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ members = [ default-members = ["hugr", "hugr-core", "hugr-passes", "hugr-cli", "hugr-model"] [workspace.package] -rust-version = "1.75" +rust-version = "1.85" edition = "2021" homepage = "https://github.com/CQCL/hugr" repository = "https://github.com/CQCL/hugr" diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 4659f96c7..6d9465140 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -29,7 +29,7 @@ shell by setting up [direnv](https://devenv.sh/automatic-shell-activation/). To setup the environment manually you will need: - Just: https://just.systems/ -- Rust `>=1.75`: https://www.rust-lang.org/tools/install +- Rust `>=1.85`: https://www.rust-lang.org/tools/install - uv `>=0.3`: docs.astral.sh/uv/getting-started/installation - Optional: capnproto `>=1.0`: https://capnproto.org/install.html Required when modifying the `hugr-model` serialization schema. diff --git a/hugr-cli/README.md b/hugr-cli/README.md index 277628d2b..dba9900e2 100644 --- a/hugr-cli/README.md +++ b/hugr-cli/README.md @@ -64,7 +64,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-cli/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-cli [crates]: https://img.shields.io/crates/v/hugr-cli [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-core/README.md b/hugr-core/README.md index 379041a5b..765d4577b 100644 --- a/hugr-core/README.md +++ b/hugr-core/README.md @@ -36,7 +36,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-core/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main -[msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg +[msrv]: https://img.shields.io/crates/msrv/hugr-core [crates]: https://img.shields.io/crates/v/hugr-core [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-llvm/README.md b/hugr-llvm/README.md index 5fd2d3239..6d81cd35d 100644 --- a/hugr-llvm/README.md +++ b/hugr-llvm/README.md @@ -32,7 +32,7 @@ See [DEVELOPMENT](DEVELOPMENT.md) for instructions on setting up the development This project is licensed under Apache License, Version 2.0 ([LICENCE](LICENCE) or ). [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-llvm [hugr]: https://lib.rs/crates/hugr [inkwell]: https://thedan64.github.io/inkwell/inkwell/index.html [llvm-sys]: https://crates.io/crates/llvm-sys diff --git a/hugr-model/README.md b/hugr-model/README.md index 0ea6fdf8f..be93253eb 100644 --- a/hugr-model/README.md +++ b/hugr-model/README.md @@ -30,7 +30,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-model/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-model [crates]: https://img.shields.io/crates/v/hugr-core [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-passes/README.md b/hugr-passes/README.md index b9552fe75..b441ed5e7 100644 --- a/hugr-passes/README.md +++ b/hugr-passes/README.md @@ -51,7 +51,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-passes/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-passes [crates]: https://img.shields.io/crates/v/hugr-passes [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 321ec194f..2788a2379 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,4 +1,3 @@ -use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ @@ -273,7 +272,7 @@ impl Linearizer for DelegatingLinearizer { let mut elems_for_copy = vec![vec![]; num_outports]; for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { let inp_copies = if ty.copyable() { - repeat(inp).take(num_outports).collect::>() + std::iter::repeat_n(inp, num_outports).collect::>() } else { self.copy_discard_op(ty, num_outports)? .add(&mut case_b, [inp]) diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index 00af101dc..b2782e8d9 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -247,7 +247,7 @@ fn remove_pack_unpack<'h, T: HugrView>( .add_dataflow_op(op, replacement.input_wires()) .unwrap() .outputs_arr(); - outputs.extend(std::iter::repeat(tuple).take(num_other_outputs)) + outputs.extend(std::iter::repeat_n(tuple, num_other_outputs)) } // These should never fail, as we are defining the replacement ourselves. diff --git a/hugr/README.md b/hugr/README.md index 6ecfc405b..b54d4f62d 100644 --- a/hugr/README.md +++ b/hugr/README.md @@ -51,7 +51,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr [crates]: https://img.shields.io/crates/v/hugr [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/justfile b/justfile index 61173375b..7b8075f94 100644 --- a/justfile +++ b/justfile @@ -23,7 +23,7 @@ test-rust: HUGR_TEST_SCHEMA=1 cargo test \ --workspace \ --exclude 'hugr-py' \ - --features 'hugr/extension_inference hugr/declarative hugr/model_unstable hugr/llvm hugr/llvm-test hugr/zstd' + --features 'hugr/extension_inference hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' # Run all python tests. test-python: uv run maturin develop --uv From f0738b11c79a43f44f9454c1f71dcb752ab10feb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 29 Apr 2025 15:32:27 +0100 Subject: [PATCH 15/18] feat!: Cleanup core trait definitions (#2126) Moves some methods around in `HugrInternals` / `HugrView` that we'll require for #1926 and #2029. Notable changes: - Adds a `hierarchy` method to `HugrInternals` that replaces most calls to `base_hugr`. - `HugrInternals::base_hugr` is now deprecated. It's only used by `Sibling/DescendantGraph`, `validate` and a random use in `DeadCodeElimPass`. - Adds a `HugrInternals::region_portgraph` method that returns a `FlatRegion` portgraph wrapper, and a `HugrView::descendants` call. These lets us replace most uses of `SiblingGraph` and `DescendantGraph`. - Renamed `HugrInternals::{get_pg_index,get_node}` to `to_portgraph_node` and `from_portgraph_node`. This requires some new changes in `portgraph`. I'll make a minor release and update it here before merging. We should be able to remove `base_hugr` after #2029. The deprecation warning here is only temporary. BREAKING CHANGE: Modified multiple core `HugrView` and `HugrInternals` trait methods. See #2126. --- Cargo.lock | 96 ++--- Cargo.toml | 6 +- hugr-core/src/builder/dataflow.rs | 6 +- hugr-core/src/core.rs | 2 +- hugr-core/src/export.rs | 15 +- hugr-core/src/hugr.rs | 52 ++- hugr-core/src/hugr/hugrmut.rs | 218 ++++++------ hugr-core/src/hugr/internal.rs | 114 +++--- hugr-core/src/hugr/patch.rs | 4 +- hugr-core/src/hugr/patch/consts.rs | 4 +- hugr-core/src/hugr/patch/insert_identity.rs | 4 +- hugr-core/src/hugr/patch/outline_cfg.rs | 5 +- hugr-core/src/hugr/patch/replace.rs | 25 +- hugr-core/src/hugr/patch/simple_replace.rs | 12 +- hugr-core/src/hugr/rewrite.rs | 4 +- hugr-core/src/hugr/serialize.rs | 10 +- hugr-core/src/hugr/validate.rs | 44 ++- hugr-core/src/hugr/validate/test.rs | 16 +- hugr-core/src/hugr/views.rs | 329 +++++++++++------- hugr-core/src/hugr/views/descendants.rs | 97 ++++-- hugr-core/src/hugr/views/impls.rs | 59 ++-- hugr-core/src/hugr/views/petgraph.rs | 12 +- hugr-core/src/hugr/views/render.rs | 10 +- hugr-core/src/hugr/views/sibling.rs | 159 ++++++--- hugr-core/src/hugr/views/sibling_subgraph.rs | 8 +- hugr-core/src/ops/constant.rs | 2 +- hugr-llvm/src/emit/ops.rs | 54 ++- ...hugr_call_indirect@pre-mem2reg@llvm14.snap | 8 +- hugr-llvm/src/utils/fat.rs | 6 +- hugr-passes/src/const_fold/test.rs | 8 +- hugr-passes/src/dead_code.rs | 1 + hugr-passes/src/force_order.rs | 60 ++-- hugr-passes/src/lower.rs | 2 +- hugr-passes/src/merge_bbs.rs | 2 +- hugr-passes/src/replace_types.rs | 4 +- 35 files changed, 800 insertions(+), 658 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7198a4ea5..085d2d6a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,7 +24,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "serde", "version_check", @@ -301,9 +301,9 @@ checksum = "38c99613cb3cd7429889a08dfcf651721ca971c86afa30798461f8eee994de47" [[package]] name = "bstr" -version = "1.11.3" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" +checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", "regex-automata", @@ -351,9 +351,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.18" +version = "1.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c" +checksum = "04da6a0d40b948dfc4fa8f5bbf402b0fc1a64a28dbf7d12ffd683550f2c1b63a" dependencies = [ "jobserver", "libc", @@ -936,9 +936,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", @@ -1483,14 +1483,12 @@ dependencies = [ [[package]] name = "insta" -version = "1.42.2" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50259abbaa67d11d2bcafc7ba1d094ed7a0c70e3ce893f0d0997f73558cb3084" +checksum = "ab2d11b2f17a45095b8c3603928ba29d7d918d7129d0d0641a36ba73cf07daa6" dependencies = [ "console", - "linked-hash-map", "once_cell", - "pin-project", "serde", "similar", ] @@ -1622,21 +1620,15 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.171" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" - -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "linux-raw-sys" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "litemap" @@ -1696,9 +1688,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", ] @@ -1937,26 +1929,6 @@ dependencies = [ "serde", ] -[[package]] -name = "pin-project" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2011,9 +1983,9 @@ checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" [[package]] name = "portgraph" -version = "0.14.0" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a9ea69cfb011d5f17af28813ec37a0a9668a063090e14ad75dc5fc07ba01b47" +checksum = "5fdce52d51ec359351ff3c209fafb6f133562abf52d951ce5821c0184798d979" dependencies = [ "bitvec", "delegate", @@ -2029,7 +2001,7 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy 0.8.24", + "zerocopy 0.8.25", ] [[package]] @@ -2094,9 +2066,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -2259,7 +2231,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] @@ -2694,9 +2666,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.100" +version = "2.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" dependencies = [ "proc-macro2", "quote", @@ -2830,15 +2802,15 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "10558ed0bd2a1562e630926a2d1f0b98c827da99fabd3fe20920a59642504485" dependencies = [ "indexmap", "toml_datetime", @@ -3418,9 +3390,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] name = "winnow" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63d3fcd9bba44b03821e7d699eeee959f3126dcc4aa8e4ae18ec617c2a5cea10" +checksum = "6cb8234a863ea0e8cd7284fcdd4f145233eb00fee02bbdd9861aec44e6477bc5" dependencies = [ "memchr", ] @@ -3496,11 +3468,11 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ - "zerocopy-derive 0.8.24", + "zerocopy-derive 0.8.25", ] [[package]] @@ -3516,9 +3488,9 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index c72326a33..97dad7dea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ regex = "1.10.6" regex-syntax = "0.8.3" rstest = "0.24.0" semver = "1.0.26" -serde = "1.0.195" +serde = "1.0.219" serde_json = "1.0.140" serde_yaml = "0.9.34" smol_str = "0.3.1" @@ -87,8 +87,8 @@ zstd = "0.13.2" # These public dependencies usually require breaking changes downstream, so we # try to be as permissive as possible. pyo3 = ">= 0.23.4, < 0.25" -portgraph = { version = ">= 0.13.3, < 0.15" } -petgraph = { version = ">= 0.7.1, < 0.9", default-features = false } +portgraph = { version = "0.14.1" } +petgraph = { version = ">= 0.8.1, < 0.9", default-features = false } [profile.dev.package] insta.opt-level = 3 diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index 64c5f5c84..b84f3a05a 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -506,8 +506,8 @@ pub(crate) mod test { #[rstest] fn dfg_hugr(simple_dfg_hugr: Hugr) { - assert_eq!(simple_dfg_hugr.node_count(), 3); - assert_matches!(simple_dfg_hugr.root_type().tag(), OpTag::Dfg); + assert_eq!(simple_dfg_hugr.num_nodes(), 3); + assert_matches!(simple_dfg_hugr.root_optype().tag(), OpTag::Dfg); } #[test] @@ -533,7 +533,7 @@ pub(crate) mod test { }; let hugr = module_builder.finish_hugr()?; - assert_eq!(hugr.node_count(), 7); + assert_eq!(hugr.num_nodes(), 7); assert_eq!(hugr.get_metadata(hugr.root(), "x"), None); assert_eq!(hugr.get_metadata(dfg_node, "x").cloned(), Some(json!(42))); diff --git a/hugr-core/src/core.rs b/hugr-core/src/core.rs index 03e009bef..cc9da77ab 100644 --- a/hugr-core/src/core.rs +++ b/hugr-core/src/core.rs @@ -83,7 +83,7 @@ pub struct Wire(N, OutgoingPort); impl Node { /// Returns the node as a portgraph `NodeIndex`. #[inline] - pub(crate) fn pg_index(self) -> portgraph::NodeIndex { + pub(crate) fn into_portgraph(self) -> portgraph::NodeIndex { self.index } } diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 09ccf944c..078fe3c27 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1,4 +1,5 @@ //! Exporting HUGR graphs to their `hugr-model` representation. +use crate::hugr::internal::HugrInternals; use crate::{ extension::{ExtensionId, OpDef, SignatureFunc}, hugr::IdentList, @@ -94,7 +95,7 @@ struct Context<'a> { impl<'a> Context<'a> { pub fn new(hugr: &'a Hugr, bump: &'a Bump) -> Self { let mut module = table::Module::default(); - module.nodes.reserve(hugr.node_count()); + module.nodes.reserve(hugr.num_nodes()); let links = Links::new(hugr); Self { @@ -999,7 +1000,7 @@ impl<'a> Context<'a> { let outer_hugr = std::mem::replace(&mut self.hugr, hugr); let outer_node_to_id = std::mem::take(&mut self.node_to_id); - let region = match hugr.root_type() { + let region = match hugr.root_optype() { OpType::DFG(_) => self.export_dfg(hugr.root(), model::ScopeClosure::Closed), _ => panic!("Value::Function root must be a DFG"), }; @@ -1031,7 +1032,7 @@ impl<'a> Context<'a> { } pub fn export_node_metadata(&mut self, node: Node) -> &'a [table::TermId] { - let metadata_map = self.hugr.get_node_metadata(node); + let metadata_map = self.hugr.node_metadata_map(node); let has_order_edges = { fn is_relevant_node(hugr: &Hugr, node: Node) -> bool { @@ -1049,13 +1050,11 @@ impl<'a> Context<'a> { .any(|(other, _)| is_relevant_node(self.hugr, other)) }; - let meta_capacity = metadata_map.map_or(0, |map| map.len()) + has_order_edges as usize; + let meta_capacity = metadata_map.len() + has_order_edges as usize; let mut meta = BumpVec::with_capacity_in(meta_capacity, self.bump); - if let Some(metadata_map) = metadata_map { - for (name, value) in metadata_map { - meta.push(self.export_json_meta(name, value)); - } + for (name, value) in metadata_map { + meta.push(self.export_json_meta(name, value)); } if has_order_edges { diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 7a74b4070..93250b8e3 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -89,6 +89,25 @@ impl Hugr { Self::with_capacity(root_node.into(), 0, 0) } + /// Create a new Hugr, with a single root node and preallocated capacity. + pub fn with_capacity(root_node: OpType, nodes: usize, ports: usize) -> Self { + let mut graph = MultiPortGraph::with_capacity(nodes, ports); + let hierarchy = Hierarchy::new(); + let mut op_types = UnmanagedDenseMap::with_capacity(nodes); + let root = graph.add_node(root_node.input_count(), root_node.output_count()); + let extensions = root_node.used_extensions(); + op_types[root] = root_node; + + Self { + graph, + hierarchy, + root, + op_types, + metadata: UnmanagedDenseMap::with_capacity(nodes), + extensions: extensions.unwrap_or_default(), + } + } + /// Load a Hugr from a json reader. /// /// Validates the Hugr against the provided extension registry, ensuring all @@ -154,7 +173,7 @@ impl Hugr { .map(|ch| Ok((ch, infer(h, ch, remove)?))) .collect::, _>>()?; - let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else { + let Some(es) = delta_mut(h.op_types.get_mut(node.into_portgraph())) else { return Ok(h.get_optype(node).extension_delta()); }; if es.contains(&TO_BE_INFERRED) { @@ -260,31 +279,6 @@ impl Hugr { /// Internal API for HUGRs, not intended for use by users. impl Hugr { - /// Create a new Hugr, with a single root node and preallocated capacity. - pub(crate) fn with_capacity(root_node: OpType, nodes: usize, ports: usize) -> Self { - let mut graph = MultiPortGraph::with_capacity(nodes, ports); - let hierarchy = Hierarchy::new(); - let mut op_types = UnmanagedDenseMap::with_capacity(nodes); - let root = graph.add_node(root_node.input_count(), root_node.output_count()); - let extensions = root_node.used_extensions(); - op_types[root] = root_node; - - Self { - graph, - hierarchy, - root, - op_types, - metadata: UnmanagedDenseMap::with_capacity(nodes), - extensions: extensions.unwrap_or_default(), - } - } - - /// Set the root node of the hugr. - pub(crate) fn set_root(&mut self, root: Node) { - self.hierarchy.detach(self.root); - self.root = root.pg_index(); - } - /// Add a node to the graph. pub(crate) fn add_node(&mut self, nodetype: OpType) -> Node { let node = self @@ -322,7 +316,7 @@ impl Hugr { /// preserve the indices. pub fn canonicalize_nodes(&mut self, mut rekey: impl FnMut(Node, Node)) { // Generate the ordered list of nodes - let mut ordered = Vec::with_capacity(self.node_count()); + let mut ordered = Vec::with_capacity(self.num_nodes()); let root = self.root(); ordered.extend(self.as_mut().canonical_order(root)); @@ -339,8 +333,8 @@ impl Hugr { let target: Node = portgraph::NodeIndex::new(position).into(); if target != source { - let pg_target = target.pg_index(); - let pg_source = source.pg_index(); + let pg_target = target.into_portgraph(); + let pg_source = source.into_portgraph(); self.graph.swap_nodes(pg_target, pg_source); self.op_types.swap(pg_target, pg_source); self.hierarchy.swap_nodes(pg_target, pg_source); diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index c58ccbdbc..6353820f4 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -18,6 +18,7 @@ use crate::types::Substitution; use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; +use super::views::{panic_invalid_node, panic_invalid_non_root, panic_invalid_port}; /// Functions for low-level building of a HUGR. pub trait HugrMut: HugrMutInternals { @@ -26,12 +27,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph. - fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata { - panic_invalid_node(self, node); - self.node_metadata_map_mut(node) - .entry(key.as_ref()) - .or_insert(serde_json::Value::Null) - } + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata; /// Sets a metadata value associated with a node. /// @@ -43,21 +39,14 @@ pub trait HugrMut: HugrMutInternals { node: Self::Node, key: impl AsRef, metadata: impl Into, - ) { - let entry = self.get_metadata_mut(node, key); - *entry = metadata.into(); - } + ); /// Remove a metadata entry associated with a node. /// /// # Panics /// /// If the node is not in the graph. - fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef) { - panic_invalid_node(self, node); - let node_meta = self.node_metadata_map_mut(node); - node_meta.remove(key.as_ref()); - } + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef); /// Add a node to the graph with a parent in the hierarchy. /// @@ -209,9 +198,7 @@ pub trait HugrMut: HugrMutInternals { /// These can be queried using [`HugrView::extensions`]. /// /// See [`ExtensionRegistry::register_updated`] for more information. - fn use_extension(&mut self, extension: impl Into>) { - self.extensions_mut().register_updated(extension); - } + fn use_extension(&mut self, extension: impl Into>); /// Extend the set of extensions used by the hugr with the extensions in the /// registry. @@ -224,10 +211,7 @@ pub trait HugrMut: HugrMutInternals { /// See [`ExtensionRegistry::register_updated`] for more information. fn use_extensions(&mut self, registry: impl IntoIterator) where - ExtensionRegistry: Extend, - { - self.extensions_mut().extend(registry); - } + ExtensionRegistry: Extend; } /// Records the result of inserting a Hugr or view @@ -262,10 +246,33 @@ fn translate_indices( /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. impl HugrMut for Hugr { + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata { + panic_invalid_node(self, node); + self.node_metadata_map_mut(node) + .entry(key.as_ref()) + .or_insert(serde_json::Value::Null) + } + + fn set_metadata( + &mut self, + node: Self::Node, + key: impl AsRef, + metadata: impl Into, + ) { + let entry = self.get_metadata_mut(node, key); + *entry = metadata.into(); + } + + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef) { + panic_invalid_node(self, node); + let node_meta = self.node_metadata_map_mut(node); + node_meta.remove(key.as_ref()); + } + fn add_node_with_parent(&mut self, parent: Node, node: impl Into) -> Node { let node = self.as_mut().add_node(node.into()); self.hierarchy - .push_child(node.pg_index(), parent.pg_index()) + .push_child(node.into_portgraph(), parent.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node } @@ -273,7 +280,7 @@ impl HugrMut for Hugr { fn add_node_before(&mut self, sibling: Node, nodetype: impl Into) -> Node { let node = self.as_mut().add_node(nodetype.into()); self.hierarchy - .insert_before(node.pg_index(), sibling.pg_index()) + .insert_before(node.into_portgraph(), sibling.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node } @@ -281,16 +288,16 @@ impl HugrMut for Hugr { fn add_node_after(&mut self, sibling: Node, op: impl Into) -> Node { let node = self.as_mut().add_node(op.into()); self.hierarchy - .insert_after(node.pg_index(), sibling.pg_index()) + .insert_after(node.into_portgraph(), sibling.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node } fn remove_node(&mut self, node: Node) -> OpType { panic_invalid_non_root(self, node); - self.hierarchy.remove(node.pg_index()); - self.graph.remove_node(node.pg_index()); - self.op_types.take(node.pg_index()) + self.hierarchy.remove(node.into_portgraph()); + self.graph.remove_node(node.into_portgraph()); + self.op_types.take(node.into_portgraph()) } fn remove_subtree(&mut self, node: Node) { @@ -316,9 +323,9 @@ impl HugrMut for Hugr { panic_invalid_port(self, dst, dst_port); self.graph .link_nodes( - src.pg_index(), + src.into_portgraph(), src_port.index(), - dst.pg_index(), + dst.into_portgraph(), dst_port.index(), ) .expect("The ports should exist at this point."); @@ -330,7 +337,7 @@ impl HugrMut for Hugr { panic_invalid_port(self, node, port); let port = self .graph - .port_index(node.pg_index(), offset) + .port_index(node.into_portgraph(), offset) .expect("The port should exist at this point."); self.graph.unlink_port(port); } @@ -364,13 +371,17 @@ impl HugrMut for Hugr { self.metadata.set(new_node, meta); } debug_assert_eq!( - Some(&new_root.pg_index()), - node_map.get(&other.root().pg_index()) + Some(&new_root.into_portgraph()), + node_map.get(&other.root().into_portgraph()) ); InsertionResult { new_root, - node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) - .collect(), + node_map: translate_indices( + |n| other.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect(), } } @@ -384,19 +395,26 @@ impl HugrMut for Hugr { // // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { - let nodetype = other.get_optype(other.get_node(node)); + let node = other.from_portgraph_node(node); + let nodetype = other.get_optype(node); self.op_types.set(new_node, nodetype.clone()); - let meta = other.base_hugr().metadata.get(node); - self.metadata.set(new_node, meta.clone()); + let meta = other.node_metadata_map(node); + if !meta.is_empty() { + self.metadata.set(new_node, Some(meta.clone())); + } } debug_assert_eq!( - Some(&new_root.pg_index()), - node_map.get(&other.get_pg_index(other.root())) + Some(&new_root.into_portgraph()), + node_map.get(&other.to_portgraph_node(other.root())) ); InsertionResult { new_root, - node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) - .collect(), + node_map: translate_indices( + |n| other.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect(), } } @@ -410,7 +428,7 @@ impl HugrMut for Hugr { let context: HashSet = subgraph .nodes() .iter() - .map(|&n| other.get_pg_index(n)) + .map(|&n| other.to_portgraph_node(n)) .collect(); let portgraph: NodeFiltered<_, NodeFilter>, _> = NodeFiltered::new_node_filtered( @@ -421,16 +439,24 @@ impl HugrMut for Hugr { let node_map = insert_subgraph_internal(self, root, other, &portgraph); // Update the optypes and metadata, copying them from the other graph. for (&node, &new_node) in node_map.iter() { - let nodetype = other.get_optype(other.get_node(node)); + let node = other.from_portgraph_node(node); + let nodetype = other.get_optype(node); self.op_types.set(new_node, nodetype.clone()); - let meta = other.base_hugr().metadata.get(node); - self.metadata.set(new_node, meta.clone()); + let meta = other.node_metadata_map(node); + if !meta.is_empty() { + self.metadata.set(new_node, Some(meta.clone())); + } // Add the required extensions to the registry. if let Ok(exts) = nodetype.used_extensions() { self.use_extensions(exts); } } - translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map).collect() + translate_indices( + |n| other.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect() } fn copy_descendants( @@ -439,15 +465,19 @@ impl HugrMut for Hugr { new_parent: Self::Node, subst: Option, ) -> BTreeMap { - let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index()); + let mut descendants = self.hierarchy.descendants(root.into_portgraph()); let root2 = descendants.next(); - debug_assert_eq!(root2, Some(root.pg_index())); + debug_assert_eq!(root2, Some(root.into_portgraph())); let nodes = Vec::from_iter(descendants); let node_map = portgraph::view::Subgraph::with_nodes(&mut self.graph, nodes) .copy_in_parent() .expect("Is a MultiPortGraph"); - let node_map = translate_indices(|n| self.get_node(n), |n| self.get_node(n), node_map) - .collect::>(); + let node_map = translate_indices( + |n| self.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect::>(); for node in self.children(root).collect::>() { self.set_parent(*node_map.get(&node).unwrap(), new_parent); @@ -462,12 +492,25 @@ impl HugrMut for Hugr { (None, op) => op.clone(), (Some(subst), op) => op.substitute(subst), }; - self.op_types.set(new_node.pg_index(), new_optype); - let meta = self.base_hugr().metadata.get(node.pg_index()).clone(); - self.metadata.set(new_node.pg_index(), meta); + self.op_types.set(new_node.into_portgraph(), new_optype); + let meta = self.metadata.get(node.into_portgraph()).clone(); + self.metadata.set(new_node.into_portgraph(), meta); } node_map } + + #[inline] + fn use_extension(&mut self, extension: impl Into>) { + self.extensions_mut().register_updated(extension); + } + + #[inline] + fn use_extensions(&mut self, registry: impl IntoIterator) + where + ExtensionRegistry: Extend, + { + self.extensions_mut().extend(registry); + } } /// Internal implementation of `insert_hugr` and `insert_view` methods for @@ -487,18 +530,20 @@ fn insert_hugr_internal( .graph .insert_graph(&other.portgraph()) .unwrap_or_else(|e| panic!("Internal error while inserting a hugr into another: {e}")); - let other_root = node_map[&other.get_pg_index(other.root())]; + let other_root = node_map[&other.to_portgraph_node(other.root())]; // Update hierarchy and optypes hugr.hierarchy - .push_child(other_root, root.pg_index()) + .push_child(other_root, root.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); for (&node, &new_node) in node_map.iter() { - other.children(other.get_node(node)).for_each(|child| { - hugr.hierarchy - .push_child(node_map[&other.get_pg_index(child)], new_node) - .expect("Inserting a newly-created node into the hierarchy should never fail."); - }); + other + .children(other.from_portgraph_node(node)) + .for_each(|child| { + hugr.hierarchy + .push_child(node_map[&other.to_portgraph_node(child)], new_node) + .expect("Inserting a newly-created node into the hierarchy should never fail."); + }); } // Merge the extension sets. @@ -534,9 +579,9 @@ fn insert_subgraph_internal( // update the hierarchy with their new id. for (&node, &new_node) in node_map.iter() { let new_parent = other - .get_parent(other.get_node(node)) - .and_then(|parent| node_map.get(&other.get_pg_index(parent)).copied()) - .unwrap_or(root.pg_index()); + .get_parent(other.from_portgraph_node(node)) + .and_then(|parent| node_map.get(&other.to_portgraph_node(parent)).copied()) + .unwrap_or(root.into_portgraph()); hugr.hierarchy .push_child(new_node, new_parent) .expect("Inserting a newly-created node into the hierarchy should never fail."); @@ -545,45 +590,6 @@ fn insert_subgraph_internal( node_map } -/// Panic if [`HugrView::valid_node`] fails. -#[track_caller] -pub(super) fn panic_invalid_node(hugr: &H, node: H::Node) { - // TODO: When stacking hugr wrappers, this gets called for every layer. - // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. - if !hugr.valid_node(node) { - panic!("Received an invalid node {node} while mutating a HUGR.",); - } -} - -/// Panic if [`HugrView::valid_non_root`] fails. -#[track_caller] -pub(super) fn panic_invalid_non_root(hugr: &H, node: H::Node) { - // TODO: When stacking hugr wrappers, this gets called for every layer. - // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. - if !hugr.valid_non_root(node) { - panic!("Received an invalid non-root node {node} while mutating a HUGR.",); - } -} - -/// Panic if [`HugrView::valid_node`] fails. -#[track_caller] -pub(super) fn panic_invalid_port( - hugr: &H, - node: Node, - port: impl Into, -) { - let port = port.into(); - // TODO: When stacking hugr wrappers, this gets called for every layer. - // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. - if hugr - .portgraph() - .port_index(node.pg_index(), port.pg_offset()) - .is_none() - { - panic!("Received an invalid port {port} for node {node} while mutating a HUGR"); - } -} - #[cfg(test)] mod test { use crate::extension::PRELUDE; @@ -667,14 +673,14 @@ mod test { fd }); hugr.validate().unwrap(); - assert_eq!(hugr.node_count(), 7); + assert_eq!(hugr.num_nodes(), 7); hugr.remove_subtree(foo); hugr.validate().unwrap(); - assert_eq!(hugr.node_count(), 4); + assert_eq!(hugr.num_nodes(), 4); hugr.remove_subtree(bar); hugr.validate().unwrap(); - assert_eq!(hugr.node_count(), 1); + assert_eq!(hugr.num_nodes(), 1); } } diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 58ce066c0..f69d2ad39 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -1,6 +1,5 @@ //! Internal traits, not exposed in the public `hugr` API. -use std::borrow::Cow; use std::ops::Range; use std::sync::OnceLock; @@ -8,11 +7,12 @@ use itertools::Itertools; use portgraph::{LinkMut, LinkView, MultiPortGraph, PortMut, PortOffset, PortView}; use crate::extension::ExtensionRegistry; -use crate::ops::handle::NodeHandle; use crate::{Direction, Hugr, Node}; -use super::hugrmut::{panic_invalid_node, panic_invalid_non_root}; -use super::{HugrView, NodeMetadataMap, OpType}; +use super::views::{panic_invalid_node, panic_invalid_non_root}; +use super::HugrView; +use super::{NodeMetadataMap, OpType}; +use crate::ops::handle::NodeHandle; /// Trait for accessing the internals of a Hugr(View). /// @@ -20,7 +20,7 @@ use super::{HugrView, NodeMetadataMap, OpType}; /// view. pub trait HugrInternals { /// The underlying portgraph view type. - type Portgraph<'p>: LinkView + Clone + 'p + type Portgraph<'p>: LinkView + Clone + 'p where Self: 'p; @@ -30,24 +30,24 @@ pub trait HugrInternals { /// Returns a reference to the underlying portgraph. fn portgraph(&self) -> Self::Portgraph<'_>; + /// Returns a flat portgraph view of a region in the HUGR. + /// + /// This is a subgraph of [`HugrInternals::portgraph`], with a flat hierarchy. + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion<'_, impl LinkView + Clone + '_>; + /// Returns the portgraph [Hierarchy](portgraph::Hierarchy) of the graph /// returned by [`HugrInternals::portgraph`]. - #[inline] - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy> { - Cow::Borrowed(&self.base_hugr().hierarchy) - } - - /// Returns the Hugr at the base of a chain of views. - fn base_hugr(&self) -> &Hugr; - - /// Return the root node of this view. - fn root_node(&self) -> Self::Node; + fn hierarchy(&self) -> &portgraph::Hierarchy; /// Convert a node to a portgraph node index. - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex; + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex; /// Convert a portgraph node index to a node. - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; + #[allow(clippy::wrong_self_convention)] + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node; /// Returns a metadata entry associated with a node. /// @@ -55,6 +55,14 @@ pub trait HugrInternals { /// /// If the node is not in the graph. fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap; + + /// Returns the Hugr at the base of a chain of views. + // TODO: This will be removed in a future PR. + #[deprecated( + since = "0.16.0", + note = "This method will be removed in a future PR. Use the individual HugrInternals methods instead." + )] + fn base_hugr(&self) -> &Hugr; } impl HugrInternals for Hugr { @@ -71,34 +79,40 @@ impl HugrInternals for Hugr { } #[inline] - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy> { - Cow::Borrowed(&self.hierarchy) + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion<'_, impl LinkView + Clone + '_> { + let pg = self.portgraph(); + let root = self.to_portgraph_node(parent); + portgraph::view::FlatRegion::new_without_root(pg, &self.hierarchy, root) } #[inline] - fn base_hugr(&self) -> &Hugr { - self + fn hierarchy(&self) -> &portgraph::Hierarchy { + &self.hierarchy } #[inline] - fn root_node(&self) -> Self::Node { - self.root.into() + fn base_hugr(&self) -> &Hugr { + self } #[inline] - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { - node.node().pg_index() + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + node.node().into_portgraph() } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node { + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node { index.into() } + #[inline] fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { static EMPTY: OnceLock = OnceLock::new(); panic_invalid_node(self, node); - let map = self.metadata.get(node.pg_index()).as_ref(); + let map = self.metadata.get(node.into_portgraph()).as_ref(); map.unwrap_or(EMPTY.get_or_init(Default::default)) } } @@ -108,10 +122,14 @@ impl HugrInternals for Hugr { /// Specifically, this trait lets you apply arbitrary modifications that may /// invalidate the HUGR. pub trait HugrMutInternals: HugrView { - /// Set root node of the HUGR. + /// Set the node at the root of the HUGR hierarchy. /// - /// This should be an existing node in the HUGR. Most operations use the - /// root node as a starting point for traversal. + /// Any node not reachable from this root should be deleted from the HUGR + /// after this call. + /// + /// # Panics + /// + /// If the node is not in the graph. fn set_root(&mut self, root: Self::Node); /// Set the number of ports on a node. This may invalidate the node's `PortIndex`. @@ -225,21 +243,21 @@ pub trait HugrMutInternals: HugrView { /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. impl HugrMutInternals for Hugr { fn set_root(&mut self, root: Node) { - panic_invalid_node(self, root); - self.root = self.get_pg_index(root); + self.hierarchy.detach(self.root); + self.root = root.into_portgraph(); } #[inline] fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) { panic_invalid_node(self, node); self.graph - .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}) + .set_num_ports(node.into_portgraph(), incoming, outgoing, |_, _| {}) } fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { panic_invalid_node(self, node); - let mut incoming = self.graph.num_inputs(node.pg_index()); - let mut outgoing = self.graph.num_outputs(node.pg_index()); + let mut incoming = self.graph.num_inputs(node.into_portgraph()); + let mut outgoing = self.graph.num_outputs(node.into_portgraph()); let increment = |num: &mut usize| { let new = num.saturating_add_signed(amount); let range = *num..new; @@ -251,7 +269,7 @@ impl HugrMutInternals for Hugr { Direction::Outgoing => increment(&mut outgoing), }; self.graph - .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}); + .set_num_ports(node.into_portgraph(), incoming, outgoing, |_, _| {}); range } @@ -263,20 +281,18 @@ impl HugrMutInternals for Hugr { amount: usize, ) -> Range { panic_invalid_node(self, node); - let old_num_ports = self.base_hugr().graph.num_ports(node.pg_index(), direction); + let old_num_ports = self.graph.num_ports(node.into_portgraph(), direction); self.add_ports(node, direction, amount as isize); for swap_from_port in (index..old_num_ports).rev() { let swap_to_port = swap_from_port + amount; let [from_port_index, to_port_index] = [swap_from_port, swap_to_port].map(|p| { - self.base_hugr() - .graph - .port_index(node.pg_index(), PortOffset::new(direction, p)) + self.graph + .port_index(node.into_portgraph(), PortOffset::new(direction, p)) .unwrap() }); let linked_ports = self - .base_hugr() .graph .port_links(from_port_index) .map(|(_, to_subport)| to_subport.port()) @@ -295,27 +311,27 @@ impl HugrMutInternals for Hugr { fn set_parent(&mut self, node: Node, parent: Node) { panic_invalid_node(self, parent); panic_invalid_node(self, node); - self.hierarchy.detach(node.pg_index()); + self.hierarchy.detach(node.into_portgraph()); self.hierarchy - .push_child(node.pg_index(), parent.pg_index()) + .push_child(node.into_portgraph(), parent.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_after_sibling(&mut self, node: Node, after: Node) { panic_invalid_non_root(self, node); panic_invalid_non_root(self, after); - self.hierarchy.detach(node.pg_index()); + self.hierarchy.detach(node.into_portgraph()); self.hierarchy - .insert_after(node.pg_index(), after.pg_index()) + .insert_after(node.into_portgraph(), after.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_before_sibling(&mut self, node: Node, before: Node) { panic_invalid_non_root(self, node); panic_invalid_non_root(self, before); - self.hierarchy.detach(node.pg_index()); + self.hierarchy.detach(node.into_portgraph()); self.hierarchy - .insert_before(node.pg_index(), before.pg_index()) + .insert_before(node.into_portgraph(), before.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } @@ -326,14 +342,14 @@ impl HugrMutInternals for Hugr { fn optype_mut(&mut self, node: Self::Node) -> &mut OpType { panic_invalid_node(self, node); - let node = self.get_pg_index(node); + let node = self.to_portgraph_node(node); self.op_types.get_mut(node) } fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap { panic_invalid_node(self, node); self.metadata - .get_mut(node.pg_index()) + .get_mut(node.into_portgraph()) .get_or_insert_with(Default::default) } diff --git a/hugr-core/src/hugr/patch.rs b/hugr-core/src/hugr/patch.rs index bc6195eba..1744ce760 100644 --- a/hugr-core/src/hugr/patch.rs +++ b/hugr-core/src/hugr/patch.rs @@ -153,12 +153,12 @@ impl PatchHugrMut for Transactional { return self.underlying.apply_hugr_mut(h); } // Try to backup just the contents of this HugrMut. - let mut backup = Hugr::new(h.root_type().clone()); + let mut backup = Hugr::new(h.root_optype().clone()); backup.insert_from_view(backup.root(), h); let r = self.underlying.apply_hugr_mut(h); if r.is_err() { // Try to restore backup. - h.replace_op(h.root(), backup.root_type().clone()); + h.replace_op(h.root(), backup.root_optype().clone()); while let Some(child) = h.first_child(h.root()) { h.remove_node(child); } diff --git a/hugr-core/src/hugr/patch/consts.rs b/hugr-core/src/hugr/patch/consts.rs index 6d0c011fe..eb9142f85 100644 --- a/hugr-core/src/hugr/patch/consts.rs +++ b/hugr-core/src/hugr/patch/consts.rs @@ -144,7 +144,7 @@ mod test { let mut h = build.finish_hugr()?; // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple - assert_eq!(h.node_count(), 8); + assert_eq!(h.num_nodes(), 8); let tup_node = tup.node(); // can't remove invalid node assert_eq!( @@ -199,7 +199,7 @@ mod test { // remove const assert_eq!(h.apply_patch(remove_con)?, h.root()); - assert_eq!(h.node_count(), 4); + assert_eq!(h.num_nodes(), 4); assert!(h.validate().is_ok()); Ok(()) } diff --git a/hugr-core/src/hugr/patch/insert_identity.rs b/hugr-core/src/hugr/patch/insert_identity.rs index 98ab0ff02..c1f959ccd 100644 --- a/hugr-core/src/hugr/patch/insert_identity.rs +++ b/hugr-core/src/hugr/patch/insert_identity.rs @@ -118,7 +118,7 @@ mod tests { fn correct_insertion(dfg_hugr: Hugr) { let mut h = dfg_hugr; - assert_eq!(h.node_count(), 6); + assert_eq!(h.num_nodes(), 6); let final_node = h .input_neighbours(h.get_io(h.root()).unwrap()[1]) @@ -131,7 +131,7 @@ mod tests { let noop_node = h.apply_patch(rw).unwrap(); - assert_eq!(h.node_count(), 7); + assert_eq!(h.num_nodes(), 7); let noop: Noop = h.get_optype(noop_node).cast().unwrap(); diff --git a/hugr-core/src/hugr/patch/outline_cfg.rs b/hugr-core/src/hugr/patch/outline_cfg.rs index 0f40615a9..b43b6b4e3 100644 --- a/hugr-core/src/hugr/patch/outline_cfg.rs +++ b/hugr-core/src/hugr/patch/outline_cfg.rs @@ -202,11 +202,11 @@ impl PatchHugrMut for OutlineCfg { // https://github.com/CQCL/hugr/issues/2029 let hierarchy = h.hierarchy(); let inner_exit = hierarchy - .children(h.get_pg_index(cfg_node)) + .children(h.to_portgraph_node(cfg_node)) .exactly_one() .ok() .unwrap(); - let inner_exit = h.get_node(inner_exit); + let inner_exit = h.from_portgraph_node(inner_exit); //let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); // Entry node must be first @@ -512,6 +512,7 @@ mod test { } assert_eq!(h.get_parent(new_block), Some(cfg)); assert!(h.get_optype(new_block).is_dataflow_block()); + #[allow(deprecated)] let b = h.base_hugr(); // To cope with `h` potentially being a SiblingMut assert_eq!(b.get_parent(new_cfg), Some(new_block)); for n in blocks { diff --git a/hugr-core/src/hugr/patch/replace.rs b/hugr-core/src/hugr/patch/replace.rs index 6f0b0ed65..183200751 100644 --- a/hugr-core/src/hugr/patch/replace.rs +++ b/hugr-core/src/hugr/patch/replace.rs @@ -7,6 +7,7 @@ use thiserror::Error; use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; +use crate::hugr::views::check_valid_non_root; use crate::hugr::HugrMut; use crate::ops::{OpTag, OpTrait}; use crate::types::EdgeKind; @@ -188,7 +189,7 @@ impl Replacement { // equality of OpType/Signature, e.g. to ease changing of Input/Output // node signatures too. let removed = h.get_optype(parent).tag(); - let replacement = self.replacement.root_type().tag(); + let replacement = self.replacement.root_optype().tag(); if removed != replacement { return Err(ReplaceError::WrongRootNodeTag { removed, @@ -265,25 +266,25 @@ impl PatchVerification for Replacement { } e.check_src(h, WhichEdgeSpec::HostToHost)?; } - self.mu_out - .iter() - .try_for_each(|e| match self.replacement.valid_non_root(e.src) { + self.mu_out.iter().try_for_each(|e| { + match check_valid_non_root(&self.replacement, e.src) { true => e.check_src(&self.replacement, WhichEdgeSpec::ReplToHost), false => Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, WhichEdgeSpec::ReplToHost(e.clone()), )), - })?; + } + })?; // Edge targets... - self.mu_inp - .iter() - .try_for_each(|e| match self.replacement.valid_non_root(e.tgt) { + self.mu_inp.iter().try_for_each(|e| { + match check_valid_non_root(&self.replacement, e.tgt) { true => e.check_tgt(&self.replacement, WhichEdgeSpec::HostToRepl), false => Err(ReplaceError::BadEdgeSpec( Direction::Incoming, WhichEdgeSpec::HostToRepl(e.clone()), )), - })?; + } + })?; for e in self.mu_out.iter() { if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { return Err(ReplaceError::BadEdgeSpec( @@ -430,13 +431,13 @@ where .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Incoming, err_spec.clone()))?, kind: oe.kind, }; - if !h.valid_node(e.src) { + if !h.contains_node(e.src) { return Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, err_spec.clone(), )); } - if !h.valid_node(e.tgt) { + if !h.contains_node(e.tgt) { return Err(ReplaceError::BadEdgeSpec( Direction::Incoming, err_spec.clone(), @@ -810,7 +811,7 @@ mod test { // Root node type needs to be that of common parent of the removed nodes: let mut rep2 = rep.clone(); rep2.replacement - .replace_op(rep2.replacement.root(), h.root_type().clone()); + .replace_op(rep2.replacement.root(), h.root_optype().clone()); assert_eq!( check_same_errors(rep2), ReplaceError::WrongRootNodeTag { diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index e9283644d..3908ba58e 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -766,7 +766,7 @@ pub(in crate::hugr::patch) mod test { .unwrap(); // They should be the same, up to node indices - assert_eq!(h.edge_count(), orig.edge_count()); + assert_eq!(h.num_edges(), orig.num_edges()); } #[test] @@ -818,7 +818,7 @@ pub(in crate::hugr::patch) mod test { .unwrap(); // Nothing changed - assert_eq!(h.node_count(), orig.node_count()); + assert_eq!(h.num_nodes(), orig.num_nodes()); } /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the @@ -876,7 +876,7 @@ pub(in crate::hugr::patch) mod test { rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); assert_eq!(hugr.validate(), Ok(())); - assert_eq!(hugr.node_count(), 3); + assert_eq!(hugr.num_nodes(), 3); } /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting @@ -935,7 +935,7 @@ pub(in crate::hugr::patch) mod test { rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); assert_eq!(hugr.validate(), Ok(())); - assert_eq!(hugr.node_count(), 4); + assert_eq!(hugr.num_nodes(), 4); } #[rstest] @@ -974,12 +974,12 @@ pub(in crate::hugr::patch) mod test { let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out); - assert_eq!(h.node_count(), 4); + assert_eq!(h.num_nodes(), 4); rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}")); h.validate().unwrap_or_else(|e| panic!("{e}")); - assert_eq!(h.node_count(), 6); + assert_eq!(h.num_nodes(), 6); } use crate::hugr::patch::replace::Replacement; diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index e220864a7..76dc93ab1 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -77,12 +77,12 @@ impl Rewrite for Transactional { return self.underlying.apply(h); } // Try to backup just the contents of this HugrMut. - let mut backup = Hugr::new(h.root_type().clone()); + let mut backup = Hugr::new(h.root_optype().clone()); backup.insert_from_view(backup.root(), h); let r = self.underlying.apply(h); if r.is_err() { // Try to restore backup. - h.replace_op(h.root(), backup.root_type().clone()); + h.replace_op(h.root(), backup.root_optype().clone()); while let Some(child) = h.first_child(h.root()) { h.remove_node(child); } diff --git a/hugr-core/src/hugr/serialize.rs b/hugr-core/src/hugr/serialize.rs index 906084d55..5e4922157 100644 --- a/hugr-core/src/hugr/serialize.rs +++ b/hugr-core/src/hugr/serialize.rs @@ -157,13 +157,13 @@ impl TryFrom<&Hugr> for SerHugrLatest { fn try_from(hugr: &Hugr) -> Result { // We compact the operation nodes during the serialization process, // and ignore the copy nodes. - let mut node_rekey: HashMap = HashMap::with_capacity(hugr.node_count()); + let mut node_rekey: HashMap = HashMap::with_capacity(hugr.num_nodes()); for (order, node) in hugr.canonical_order(hugr.root()).enumerate() { node_rekey.insert(node, portgraph::NodeIndex::new(order).into()); } - let mut nodes = vec![None; hugr.node_count()]; - let mut metadata = vec![None; hugr.node_count()]; + let mut nodes = vec![None; hugr.num_nodes()]; + let mut metadata = vec![None; hugr.num_nodes()]; for n in hugr.nodes() { let parent = node_rekey[&hugr.get_parent(n).unwrap_or(n)]; let opt = hugr.get_optype(n); @@ -172,7 +172,7 @@ impl TryFrom<&Hugr> for SerHugrLatest { parent, op: opt.clone(), }); - metadata[new_node].clone_from(hugr.metadata.get(n.pg_index())); + metadata[new_node].clone_from(hugr.metadata.get(n.into_portgraph())); } let nodes = nodes .into_iter() @@ -251,7 +251,7 @@ impl TryFrom for Hugr { } let unwrap_offset = |node: Node, offset, dir, hugr: &Hugr| -> Result { - if !hugr.graph.contains_node(node.pg_index()) { + if !hugr.graph.contains_node(node.into_portgraph()) { return Err(HUGRSerializationError::UnknownEdgeNode { node }); } let offset = match offset { diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 3b04ccd86..3690ec947 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -20,7 +20,7 @@ use crate::types::EdgeKind; use crate::{Direction, Hugr, Node, Port}; use super::internal::HugrInternals; -use super::views::{HierarchyView, HugrView, SiblingGraph}; +use super::views::HugrView; use super::ExtensionError; /// Structure keeping track of pre-computed information used in the validation @@ -31,7 +31,7 @@ use super::ExtensionError; struct ValidationContext<'a> { hugr: &'a Hugr, /// Dominator tree for each CFG region, using the container node as index. - dominators: HashMap>, + dominators: HashMap>, } impl Hugr { @@ -138,10 +138,10 @@ impl<'a> ValidationContext<'a> { /// /// The results of this computation should be cached in `self.dominators`. /// We don't do it here to avoid mutable borrows. - fn compute_dominator(&self, parent: Node) -> Dominators { - let region: SiblingGraph = SiblingGraph::try_new(self.hugr, parent).unwrap(); + fn compute_dominator(&self, parent: Node) -> Dominators { + let region = self.hugr.region_portgraph(parent); let entry_node = self.hugr.children(parent).next().unwrap(); - dominators::simple_fast(®ion.as_petgraph(), entry_node) + dominators::simple_fast(®ion, entry_node.into_portgraph()) } /// Check the constraints on a single node. @@ -163,7 +163,7 @@ impl<'a> ValidationContext<'a> { for dir in Direction::BOTH { // Check that we have the correct amount of ports and edges. - let num_ports = self.hugr.graph.num_ports(node.pg_index(), dir); + let num_ports = self.hugr.graph.num_ports(node.into_portgraph(), dir); if num_ports != op_type.port_count(dir) { return Err(ValidationError::WrongNumberOfPorts { node, @@ -316,7 +316,7 @@ impl<'a> ValidationContext<'a> { fn validate_children(&self, node: Node, op_type: &OpType) -> Result<(), ValidationError> { let flags = op_type.validity_flags(); - if self.hugr.hierarchy().child_count(node.pg_index()) > 0 { + if self.hugr.hierarchy().child_count(node.into_portgraph()) > 0 { if flags.allowed_children.is_empty() { return Err(ValidationError::NonContainerWithChildren { node, @@ -352,7 +352,8 @@ impl<'a> ValidationContext<'a> { } } // Additional validations running over the full list of children optypes - let children_optypes = all_children.map(|c| (c.pg_index(), self.hugr.get_optype(c))); + let children_optypes = + all_children.map(|c| (c.into_portgraph(), self.hugr.get_optype(c))); if let Err(source) = op_type.validate_op_children(children_optypes) { return Err(ValidationError::InvalidChildren { parent: node, @@ -363,9 +364,9 @@ impl<'a> ValidationContext<'a> { // Additional validations running over the edges of the contained graph if let Some(edge_check) = flags.edge_check { - for source in self.hugr.hierarchy().children(node.pg_index()) { + for source in self.hugr.hierarchy().children(node.into_portgraph()) { for target in self.hugr.graph.output_neighbours(source) { - if self.hugr.hierarchy.parent(target) != Some(node.pg_index()) { + if self.hugr.hierarchy.parent(target) != Some(node.into_portgraph()) { continue; } let source_op = self.hugr.get_optype(source.into()); @@ -411,16 +412,16 @@ impl<'a> ValidationContext<'a> { /// Inter-graph edges are ignored. Only internal dataflow, constant, or /// state order edges are considered. fn validate_children_dag(&self, parent: Node, op_type: &OpType) -> Result<(), ValidationError> { - if !self.hugr.hierarchy.has_children(parent.pg_index()) { + if !self.hugr.hierarchy.has_children(parent.into_portgraph()) { // No children, nothing to do return Ok(()); }; - let region: SiblingGraph = SiblingGraph::try_new(self.hugr, parent).unwrap(); - let postorder = Topo::new(®ion.as_petgraph()); + let region = self.hugr.region_portgraph(parent); + let postorder = Topo::new(®ion); let nodes_visited = postorder - .iter(®ion.as_petgraph()) - .filter(|n| *n != parent) + .iter(®ion) + .filter(|n| *n != parent.into_portgraph()) .count(); let node_count = self.hugr.children(parent).count(); if nodes_visited != node_count { @@ -500,7 +501,7 @@ impl<'a> ValidationContext<'a> { // Must have an order edge. self.hugr .graph - .get_connections(from.pg_index(), ancestor.pg_index()) + .get_connections(from.into_portgraph(), ancestor.into_portgraph()) .find(|&(p, _)| { let offset = self.hugr.graph.port_offset(p).unwrap(); from_optype.port_kind(offset) == Some(EdgeKind::StateOrder) @@ -537,8 +538,8 @@ impl<'a> ValidationContext<'a> { } }; if !dominator_tree - .dominators(ancestor) - .is_some_and(|mut ds| ds.any(|n| n == from_parent)) + .dominators(ancestor.into_portgraph()) + .is_some_and(|mut ds| ds.any(|n| n == from_parent.into_portgraph())) { return Err(InterGraphEdgeError::NonDominatedAncestor { from, @@ -616,7 +617,12 @@ impl<'a> ValidationContext<'a> { // Root nodes are ignored, as they cannot have connected edges. if node != self.hugr.root() { for dir in Direction::BOTH { - for (i, port_index) in self.hugr.graph.ports(node.pg_index(), dir).enumerate() { + for (i, port_index) in self + .hugr + .graph + .ports(node.into_portgraph(), dir) + .enumerate() + { let port = Port::new(dir, i); self.validate_port(node, port, port_index, op_type, var_decls)?; } diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index a66296c35..236f40e3f 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -103,7 +103,7 @@ fn invalid_root() { ); // Fix the root - b.root = module.pg_index(); + b.root = module.into_portgraph(); b.remove_node(root); assert_eq!(b.validate(), Ok(())); } @@ -142,7 +142,7 @@ fn children_restrictions() { let root = b.root(); let (_input, copy, _output) = b .hierarchy - .children(def.pg_index()) + .children(def.into_portgraph()) .map_into() .collect_tuple() .unwrap(); @@ -185,7 +185,7 @@ fn df_children_restrictions() { let (mut b, def) = make_simple_hugr(2); let (_input, output, copy) = b .hierarchy - .children(def.pg_index()) + .children(def.into_portgraph()) .map_into() .collect_tuple() .unwrap(); @@ -202,7 +202,7 @@ fn df_children_restrictions() { assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) - => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} + => {assert_eq!(parent, def); assert_eq!(child, output.into_portgraph())} ); b.replace_op(output, ops::Output::new(vec![bool_t(), bool_t()])); @@ -211,7 +211,7 @@ fn df_children_restrictions() { assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) - => {assert_eq!(parent, def); assert_eq!(child, copy.pg_index())} + => {assert_eq!(parent, def); assert_eq!(child, copy.into_portgraph())} ); } @@ -791,7 +791,7 @@ fn cfg_children_restrictions() { let (mut b, def) = make_simple_hugr(1); let (_input, _output, copy) = b .hierarchy - .children(def.pg_index()) + .children(def.into_portgraph()) .map_into() .collect_tuple() .unwrap(); @@ -855,7 +855,7 @@ fn cfg_children_restrictions() { assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) - => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} + => {assert_eq!(parent, cfg); assert_eq!(child, exit2.into_portgraph())} ); b.remove_node(exit2); @@ -875,7 +875,7 @@ fn cfg_children_restrictions() { extension_delta: ExtensionSet::new(), }, ); - let mut block_children = b.hierarchy.children(block.pg_index()); + let mut block_children = b.hierarchy.children(block.into_portgraph()); let block_input = block_children.next().unwrap().into(); let block_output = block_children.next_back().unwrap().into(); b.replace_op(block_input, ops::Input::new(vec![qb_t()])); diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index a154a956f..f9eedd548 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -24,10 +24,8 @@ use itertools::Itertools; use portgraph::render::{DotFormat, MermaidFormat}; use portgraph::{LinkView, PortView}; -use super::internal::HugrInternals; -use super::{ - Hugr, HugrError, HugrMut, Node, NodeMetadata, NodeMetadataMap, ValidationError, DEFAULT_OPTYPE, -}; +use super::internal::{HugrInternals, HugrMutInternals}; +use super::{Hugr, HugrError, HugrMut, Node, NodeMetadata, ValidationError}; use crate::extension::ExtensionRegistry; use crate::ops::{OpParent, OpTag, OpTrait, OpType}; @@ -40,85 +38,67 @@ use itertools::Either; /// For end users we intend this to be superseded by region-specific APIs. pub trait HugrView: HugrInternals { /// Return the root node of this view. - #[inline] - fn root(&self) -> Self::Node { - self.root_node() - } + fn root(&self) -> Self::Node; - /// Return the type of the HUGR root node. + /// Return the optype of the HUGR root node. #[inline] - fn root_type(&self) -> &OpType { + fn root_optype(&self) -> &OpType { let node_type = self.get_optype(self.root()); - // Sadly no way to do this at present - // debug_assert!(Self::RootHandle::can_hold(node_type.tag())); node_type } - /// Returns whether the node exists. + /// Returns `true` if the node exists in the HUGR. fn contains_node(&self, node: Self::Node) -> bool; - /// Validates that a node is valid in the graph. - #[inline] - fn valid_node(&self, node: Self::Node) -> bool { - self.contains_node(node) - } - - /// Validates that a node is a valid root descendant in the graph. - /// - /// To include the root node use [`HugrView::valid_node`] instead. - #[inline] - fn valid_non_root(&self, node: Self::Node) -> bool { - self.root() != node && self.valid_node(node) - } - /// Returns the parent of a node. - #[inline] - fn get_parent(&self, node: Self::Node) -> Option { - if !self.valid_non_root(node) { - return None; - }; - self.base_hugr() - .hierarchy - .parent(self.get_pg_index(node)) - .map(|index| self.get_node(index)) - } - - /// Returns the operation type of a node. - #[inline] - fn get_optype(&self, node: Self::Node) -> &OpType { - match self.contains_node(node) { - true => self.base_hugr().op_types.get(self.get_pg_index(node)), - false => &DEFAULT_OPTYPE, - } - } + fn get_parent(&self, node: Self::Node) -> Option; /// Returns the metadata associated with a node. #[inline] fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&NodeMetadata> { match self.contains_node(node) { - true => self.get_node_metadata(node)?.get(key.as_ref()), + true => self.node_metadata_map(node).get(key.as_ref()), false => None, } } - /// Retrieve the complete metadata map for a node. - fn get_node_metadata(&self, node: Self::Node) -> Option<&NodeMetadataMap> { - if !self.valid_node(node) { - return None; - } - self.base_hugr() - .metadata - .get(self.get_pg_index(node)) - .as_ref() - } + /// Returns the operation type of a node. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn get_optype(&self, node: Self::Node) -> &OpType; + + /// Returns the number of nodes in the HUGR. + fn num_nodes(&self) -> usize; - /// Returns the number of nodes in the hugr. - fn node_count(&self) -> usize; + /// Returns the number of edges in the HUGR. + fn num_edges(&self) -> usize; + + /// Number of ports in node for a given direction. + fn num_ports(&self, node: Self::Node, dir: Direction) -> usize; - /// Returns the number of edges in the hugr. - fn edge_count(&self) -> usize; + /// Number of inputs to a node. + /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Incoming)`. + #[inline] + fn num_inputs(&self, node: Self::Node) -> usize { + self.num_ports(node, Direction::Incoming) + } - /// Iterates over the nodes in the port graph. + /// Number of outputs from a node. + /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Outgoing)`. + #[inline] + fn num_outputs(&self, node: Self::Node) -> usize { + self.num_ports(node, Direction::Outgoing) + } + + /// Iterates over the all the nodes in the HUGR. + /// + /// This iterator returns every node in the HUGR, including those that are + /// not descendants from the root node. + /// + /// See [`HugrView::descendants`] and [`HugrView::children`] for more specific + /// iterators. fn nodes(&self) -> impl Iterator + Clone; /// Iterator over ports of node in a given direction. @@ -260,26 +240,15 @@ pub trait HugrView: HugrInternals { self.linked_ports(node, port).next().is_some() } - /// Number of ports in node for a given direction. - fn num_ports(&self, node: Self::Node, dir: Direction) -> usize; - - /// Number of inputs to a node. - /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Incoming)`. - #[inline] - fn num_inputs(&self, node: Self::Node) -> usize { - self.num_ports(node, Direction::Incoming) - } - - /// Number of outputs from a node. - /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Outgoing)`. - #[inline] - fn num_outputs(&self, node: Self::Node) -> usize { - self.num_ports(node, Direction::Outgoing) - } - - /// Return iterator over the direct children of node. + /// Returns an iterator over the direct children of node. fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone; + /// Returns an iterator over all the descendants of a node, + /// including the node itself. + /// + /// Yields the node itself first, followed by its children in breath-first order. + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone; + /// Returns the first child of the specified node (if it is a parent). /// Useful because `x.children().next()` leaves x borrowed. fn first_child(&self, node: Self::Node) -> Option { @@ -334,13 +303,13 @@ pub trait HugrView: HugrInternals { /// In contrast to [`poly_func_type`][HugrView::poly_func_type], this /// method always return a concrete [`Signature`]. fn inner_function_type(&self) -> Option> { - self.root_type().inner_function_type() + self.root_optype().inner_function_type() } /// Returns the function type defined by this HUGR, i.e. `Some` iff the root is /// a [`FuncDecl`][crate::ops::FuncDecl] or [`FuncDefn`][crate::ops::FuncDefn]. fn poly_func_type(&self) -> Option { - match self.root_type() { + match self.root_optype() { OpType::FuncDecl(decl) => Some(decl.signature.clone()), OpType::FuncDefn(defn) => Some(defn.signature.clone()), _ => None, @@ -363,13 +332,7 @@ pub trait HugrView: HugrInternals { /// /// For a more detailed representation, use the [`HugrView::dot_string`] /// format instead. - fn mermaid_string(&self) -> String { - self.mermaid_string_with_config(RenderConfig { - node_indices: true, - port_offsets_in_edges: true, - type_labels_in_edges: true, - }) - } + fn mermaid_string(&self) -> String; /// Return the mermaid representation of the underlying hierarchical graph. /// @@ -378,35 +341,14 @@ pub trait HugrView: HugrInternals { /// /// For a more detailed representation, use the [`HugrView::dot_string`] /// format instead. - fn mermaid_string_with_config(&self, config: RenderConfig) -> String { - let hugr = self.base_hugr(); - let graph = self.portgraph(); - graph - .mermaid_format() - .with_hierarchy(&hugr.hierarchy) - .with_node_style(render::node_style(self, config)) - .with_edge_style(render::edge_style(self, config)) - .finish() - } + fn mermaid_string_with_config(&self, config: RenderConfig) -> String; /// Return the graphviz representation of the underlying graph and hierarchy side by side. /// /// For a simpler representation, use the [`HugrView::mermaid_string`] format instead. fn dot_string(&self) -> String where - Self: Sized, - { - let hugr = self.base_hugr(); - let graph = self.portgraph(); - let config = RenderConfig::default(); - graph - .dot_format() - .with_hierarchy(&hugr.hierarchy) - .with_node_style(render::node_style(self, config)) - .with_port_style(render::port_style(self, config)) - .with_edge_style(render::edge_style(self, config)) - .finish() - } + Self: Sized; /// If a node has a static input, return the source node. fn static_source(&self, node: Self::Node) -> Option { @@ -453,10 +395,9 @@ pub trait HugrView: HugrInternals { /// Returns the set of extensions used by the HUGR. /// - /// This set may contain extensions that are no longer required by the HUGR. - fn extensions(&self) -> &ExtensionRegistry { - &self.base_hugr().extensions - } + /// This set contains all extensions required to define the operations and + /// types in the HUGR. + fn extensions(&self) -> &ExtensionRegistry; /// Check the validity of the underlying HUGR. /// @@ -465,6 +406,7 @@ pub trait HugrView: HugrInternals { /// See [`HugrView::validate_no_extensions`] for a version that doesn't check /// extension requirements. fn validate(&self) -> Result<(), ValidationError> { + #[allow(deprecated)] self.base_hugr().validate() } @@ -474,6 +416,7 @@ pub trait HugrView: HugrInternals { /// /// For a more thorough check, use [`HugrView::validate`]. fn validate_no_extensions(&self) -> Result<(), ValidationError> { + #[allow(deprecated)] self.base_hugr().validate_no_extensions() } } @@ -526,18 +469,48 @@ impl ExtractHugr for &mut Hugr { impl HugrView for Hugr { #[inline] - fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(node.pg_index()) + fn root(&self) -> Self::Node { + self.root.into() + } + + #[inline] + fn contains_node(&self, node: Self::Node) -> bool { + self.graph.contains_node(node.into_portgraph()) + } + + #[inline] + fn get_parent(&self, node: Self::Node) -> Option { + if !check_valid_non_root(self, node) { + return None; + }; + self.hierarchy + .parent(self.to_portgraph_node(node)) + .map(|index| self.from_portgraph_node(index)) + } + + #[inline] + fn get_optype(&self, node: Node) -> &OpType { + // TODO: This currently fails because some methods get the optype of + // e.g. a parent outside a region view. We should be able to re-enable + // this once we add hugr entrypoints. + //panic_invalid_node(self, node); + self.op_types.get(self.to_portgraph_node(node)) + } + + #[inline] + fn num_nodes(&self) -> usize { + self.portgraph().node_count() } #[inline] - fn node_count(&self) -> usize { - self.graph.node_count() + fn num_edges(&self) -> usize { + self.portgraph().link_count() } #[inline] - fn edge_count(&self) -> usize { - self.graph.link_count() + fn num_ports(&self, node: Self::Node, dir: Direction) -> usize { + self.portgraph() + .num_ports(self.to_portgraph_node(node), dir) } #[inline] @@ -547,12 +520,16 @@ impl HugrView for Hugr { #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.graph.port_offsets(node.pg_index(), dir).map_into() + self.graph + .port_offsets(node.into_portgraph(), dir) + .map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { - self.graph.all_port_offsets(node.pg_index()).map_into() + self.graph + .all_port_offsets(node.into_portgraph()) + .map_into() } #[inline] @@ -565,7 +542,7 @@ impl HugrView for Hugr { let port = self .graph - .port_index(node.pg_index(), port.pg_offset()) + .port_index(node.into_portgraph(), port.pg_offset()) .unwrap(); self.graph.port_links(port).map(|(_, link)| { let port = link.port(); @@ -578,30 +555,72 @@ impl HugrView for Hugr { #[inline] fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph - .get_connections(node.pg_index(), other.pg_index()) + .get_connections(node.into_portgraph(), other.into_portgraph()) .map(|(p1, p2)| { [p1, p2].map(|link| self.graph.port_offset(link.port()).unwrap().into()) }) } #[inline] - fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(node.pg_index(), dir) + fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone { + self.hierarchy + .children(self.to_portgraph_node(node)) + .map(|n| self.from_portgraph_node(n)) } #[inline] - fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { - self.hierarchy.children(node.pg_index()).map_into() + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone { + self.hierarchy + .descendants(self.to_portgraph_node(node)) + .map(|n| self.from_portgraph_node(n)) } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.graph.neighbours(node.pg_index(), dir).map_into() + self.graph.neighbours(node.into_portgraph(), dir).map_into() } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { - self.graph.all_neighbours(node.pg_index()).map_into() + self.graph.all_neighbours(node.into_portgraph()).map_into() + } + + fn mermaid_string(&self) -> String { + self.mermaid_string_with_config(RenderConfig { + node_indices: true, + port_offsets_in_edges: true, + type_labels_in_edges: true, + }) + } + + fn mermaid_string_with_config(&self, config: RenderConfig) -> String { + let graph = self.portgraph(); + graph + .mermaid_format() + .with_hierarchy(&self.hierarchy) + .with_node_style(render::node_style(self, config)) + .with_edge_style(render::edge_style(self, config)) + .finish() + } + + fn dot_string(&self) -> String + where + Self: Sized, + { + let graph = self.portgraph(); + let config = RenderConfig::default(); + graph + .dot_format() + .with_hierarchy(&self.hierarchy) + .with_node_style(render::node_style(self, config)) + .with_port_style(render::port_style(self, config)) + .with_edge_style(render::edge_style(self, config)) + .finish() + } + + #[inline] + fn extensions(&self) -> &ExtensionRegistry { + &self.extensions } } @@ -630,7 +649,7 @@ where hugr: &impl HugrView, ) -> impl Iterator { self.filter(move |(n, p)| { - let kind = hugr.get_optype(*n).port_kind(*p); + let kind = HugrView::get_optype(hugr, *n).port_kind(*p); predicate(kind) }) } @@ -642,3 +661,47 @@ where P: Into + Copy, { } + +/// Returns `true` if the node exists in the graph and is not the module at the hierarchy root. +pub(super) fn check_valid_non_root(hugr: &H, node: H::Node) -> bool { + hugr.contains_node(node) && node != hugr.root() +} + +/// Panic if [`HugrView::contains_node`] fails. +#[track_caller] +pub(super) fn panic_invalid_node(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. + if !hugr.contains_node(node) { + panic!("Received an invalid node {node}.",); + } +} + +/// Panic if [`check_valid_non_root`] fails. +#[track_caller] +pub(super) fn panic_invalid_non_root(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. + if !check_valid_non_root(hugr, node) { + panic!("Received an invalid non-root node {node}.",); + } +} + +/// Panic if [`HugrView::valid_node`] fails. +#[track_caller] +pub(super) fn panic_invalid_port( + hugr: &H, + node: Node, + port: impl Into, +) { + let port = port.into(); + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. + if hugr + .portgraph() + .port_index(node.into_portgraph(), port.pg_offset()) + .is_none() + { + panic!("Received an invalid port {port} for node {node} while mutating a HUGR"); + } +} diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 906dea3e4..e3ba29e2c 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -41,37 +41,44 @@ pub struct DescendantsGraph<'g, Root = Node> { _phantom: std::marker::PhantomData, } impl HugrView for DescendantsGraph<'_, Root> { + #[inline] + fn root(&self) -> Self::Node { + self.root + } + #[inline] fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(self.get_pg_index(node)) + self.graph.contains_node(self.to_portgraph_node(node)) } #[inline] - fn node_count(&self) -> usize { + fn num_nodes(&self) -> usize { self.graph.node_count() } #[inline] - fn edge_count(&self) -> usize { + fn num_edges(&self) -> usize { self.graph.link_count() } #[inline] fn nodes(&self) -> impl Iterator + Clone { - self.graph.nodes_iter().map(|index| self.get_node(index)) + self.graph + .nodes_iter() + .map(|index| self.from_portgraph_node(index)) } #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .port_offsets(self.get_pg_index(node), dir) + .port_offsets(self.to_portgraph_node(node), dir) .map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_port_offsets(self.get_pg_index(node)) + .all_port_offsets(self.to_portgraph_node(node)) .map_into() } @@ -82,19 +89,19 @@ impl HugrView for DescendantsGraph<'_, Root> { ) -> impl Iterator + Clone { let port = self .graph - .port_index(self.get_pg_index(node), port.into().pg_offset()) + .port_index(self.to_portgraph_node(node), port.into().pg_offset()) .unwrap(); self.graph.port_links(port).map(|(_, link)| { let port: PortIndex = link.into(); let node = self.graph.port_node(port).unwrap(); let offset = self.graph.port_offset(port).unwrap(); - (self.get_node(node), offset.into()) + (self.from_portgraph_node(node), offset.into()) }) } fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph - .get_connections(self.get_pg_index(node), self.get_pg_index(other)) + .get_connections(self.to_portgraph_node(node), self.to_portgraph_node(other)) .map(|(p1, p2)| { [p1, p2].map(|link| { let offset = self.graph.port_offset(link).unwrap(); @@ -105,30 +112,46 @@ impl HugrView for DescendantsGraph<'_, Root> { #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(self.get_pg_index(node), dir) + self.graph.num_ports(self.to_portgraph_node(node), dir) } #[inline] fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { - let children = match self.graph.contains_node(self.get_pg_index(node)) { - true => self.base_hugr().hierarchy.children(self.get_pg_index(node)), + let hierarchy = self.hierarchy(); + let children = match self.graph.contains_node(self.to_portgraph_node(node)) { + true => hierarchy.children(self.to_portgraph_node(node)), false => portgraph::hierarchy::Children::default(), }; - children.map(|index| self.get_node(index)) + children.map(move |index| { + let _ = hierarchy; + self.from_portgraph_node(index) + }) } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .neighbours(self.get_pg_index(node), dir) - .map(|index| self.get_node(index)) + .neighbours(self.to_portgraph_node(node), dir) + .map(|index| self.from_portgraph_node(index)) } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_neighbours(self.get_pg_index(node)) - .map(|index| self.get_node(index)) + .all_neighbours(self.to_portgraph_node(node)) + .map(|index| self.from_portgraph_node(index)) + } + + delegate::delegate! { + to (&self.hugr) { + fn get_parent(&self, node: Self::Node) -> Option; + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone; + fn mermaid_string(&self) -> String; + fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; + fn dot_string(&self) -> String; + fn extensions(&self) -> &crate::extension::ExtensionRegistry; + } } } @@ -138,10 +161,11 @@ where { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { check_tag::(hugr, root)?; + #[allow(deprecated)] let hugr = hugr.base_hugr(); Ok(Self { root, - graph: RegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)), + graph: RegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.to_portgraph_node(root)), hugr, _phantom: std::marker::PhantomData, }) @@ -166,28 +190,39 @@ where &self.graph } - fn base_hugr(&self) -> &Hugr { - self.hugr + #[inline] + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion< + '_, + impl portgraph::view::LinkView + Clone + '_, + > { + self.hugr.region_portgraph(parent) } #[inline] - fn root_node(&self) -> Node { - self.root + fn hierarchy(&self) -> &portgraph::Hierarchy { + self.hugr.hierarchy() } #[inline] - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { - self.hugr.get_pg_index(node) + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + self.hugr.to_portgraph_node(node) } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Node { - self.hugr.get_node(index) + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Node { + self.hugr.from_portgraph_node(index) } fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap { self.hugr.node_metadata_map(node) } + + fn base_hugr(&self) -> &Hugr { + self.hugr + } } #[cfg(test)] @@ -245,7 +280,7 @@ pub(super) mod test { let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; let def_io = region.get_io(def).unwrap(); - assert_eq!(region.node_count(), 7); + assert_eq!(region.num_nodes(), 7); assert!(region.nodes().all(|n| n == def || hugr.get_parent(n) == Some(def) || hugr.get_parent(n) == Some(inner))); @@ -265,8 +300,8 @@ pub(super) mod test { inner_region.inner_function_type().map(Cow::into_owned), Some(Signature::new(vec![usize_t()], vec![usize_t()])) ); - assert_eq!(inner_region.node_count(), 3); - assert_eq!(inner_region.edge_count(), 1); + assert_eq!(inner_region.num_nodes(), 3); + assert_eq!(inner_region.num_edges(), 1); assert_eq!(inner_region.children(inner).count(), 2); assert_eq!(inner_region.children(hugr.root()).count(), 0); assert_eq!( @@ -315,8 +350,8 @@ pub(super) mod test { let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; - assert_eq!(region.node_count(), extracted.node_count()); - assert_eq!(region.root_type(), extracted.root_type()); + assert_eq!(region.num_nodes(), extracted.num_nodes()); + assert_eq!(region.root_optype(), extracted.root_optype()); Ok(()) } diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 440df9480..6cd1d7631 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -12,12 +12,13 @@ macro_rules! hugr_internal_methods { delegate::delegate! { to ({let $arg=self; $e}) { fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &crate::Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: impl crate::ops::handle::NodeHandle) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; + fn region_portgraph(&self, parent: Self::Node) -> portgraph::view::FlatRegion<'_, impl portgraph::view::LinkView + Clone + '_>; + fn hierarchy(&self) -> &portgraph::Hierarchy; + fn to_portgraph_node(&self, node: impl crate::ops::handle::NodeHandle) -> portgraph::NodeIndex; + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node; fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap; + #[allow(deprecated)] + fn base_hugr(&self) -> &crate::Hugr; } } }; @@ -30,34 +31,23 @@ macro_rules! hugr_view_methods { delegate::delegate! { to ({let $arg=self; $e}) { fn root(&self) -> Self::Node; - fn root_type(&self) -> &crate::ops::OpType; + fn root_optype(&self) -> &crate::ops::OpType; fn contains_node(&self, node: Self::Node) -> bool; - fn valid_node(&self, node: Self::Node) -> bool; - fn valid_non_root(&self, node: Self::Node) -> bool; fn get_parent(&self, node: Self::Node) -> Option; - fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&crate::hugr::NodeMetadata>; - fn get_node_metadata(&self, node: Self::Node) -> Option<&crate::hugr::NodeMetadataMap>; - fn node_count(&self) -> usize; - fn edge_count(&self) -> usize; + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; + fn num_nodes(&self) -> usize; + fn num_edges(&self) -> usize; + fn num_ports(&self, node: Self::Node, dir: crate::Direction) -> usize; + fn num_inputs(&self, node: Self::Node) -> usize; + fn num_outputs(&self, node: Self::Node) -> usize; fn nodes(&self) -> impl Iterator + Clone; fn node_ports(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; fn node_outputs(&self, node: Self::Node) -> impl Iterator + Clone; fn node_inputs(&self, node: Self::Node) -> impl Iterator + Clone; fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone; - fn linked_ports( - &self, - node: Self::Node, - port: impl Into, - ) -> impl Iterator + Clone; - fn all_linked_ports( - &self, - node: Self::Node, - dir: crate::Direction, - ) -> itertools::Either< - impl Iterator, - impl Iterator, - >; + fn linked_ports(&self, node: Self::Node, port: impl Into) -> impl Iterator + Clone; + fn all_linked_ports(&self, node: Self::Node, dir: crate::Direction) -> itertools::Either, impl Iterator>; fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator; fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator; fn single_linked_port(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::Port)>; @@ -67,31 +57,19 @@ macro_rules! hugr_view_methods { fn linked_inputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator + Clone; fn is_linked(&self, node: Self::Node, port: impl Into) -> bool; - fn num_ports(&self, node: Self::Node, dir: crate::Direction) -> usize; - fn num_inputs(&self, node: Self::Node) -> usize; - fn num_outputs(&self, node: Self::Node) -> usize; fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone; + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone; fn first_child(&self, node: Self::Node) -> Option; fn neighbours(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; fn input_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn output_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn all_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; - fn get_io(&self, node: Self::Node) -> Option<[Self::Node; 2]>; - fn inner_function_type(&self) -> Option>; - fn poly_func_type(&self) -> Option; - // TODO: cannot use delegate here. `PetgraphWrapper` is a thin - // wrapper around `Self`, so falling back to the default impl - // should be harmless. - // fn as_petgraph(&self) -> PetgraphWrapper<'_, Self>; fn mermaid_string(&self) -> String; fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; fn dot_string(&self) -> String; fn static_source(&self, node: Self::Node) -> Option; fn static_targets(&self, node: Self::Node) -> Option>; - fn signature(&self, node: Self::Node) -> Option>; fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator; - fn in_value_types(&self, node: Self::Node) -> impl Iterator; - fn out_value_types(&self, node: Self::Node) -> impl Iterator; fn extensions(&self) -> &crate::extension::ExtensionRegistry; fn validate(&self) -> Result<(), crate::hugr::ValidationError>; fn validate_no_extensions(&self) -> Result<(), crate::hugr::ValidationError>; @@ -128,6 +106,9 @@ macro_rules! hugr_mut_methods { ($arg:ident, $e:expr) => { delegate::delegate! { to ({let $arg=self; $e}) { + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut crate::hugr::NodeMetadata; + fn set_metadata(&mut self, node: Self::Node, key: impl AsRef, metadata: impl Into); + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef); fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into) -> Self::Node; fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into) -> Self::Node; fn add_node_after(&mut self, sibling: Self::Node, op: impl Into) -> Self::Node; @@ -140,6 +121,8 @@ macro_rules! hugr_mut_methods { fn insert_hugr(&mut self, root: Self::Node, other: crate::Hugr) -> crate::hugr::hugrmut::InsertionResult; fn insert_from_view(&mut self, root: Self::Node, other: &Other) -> crate::hugr::hugrmut::InsertionResult; fn insert_subgraph(&mut self, root: Self::Node, other: &Other, subgraph: &crate::hugr::views::SiblingSubgraph) -> std::collections::HashMap; + fn use_extension(&mut self, extension: impl Into>); + fn use_extensions(&mut self, registry: impl IntoIterator) where crate::extension::ExtensionRegistry: Extend; } } }; diff --git a/hugr-core/src/hugr/views/petgraph.rs b/hugr-core/src/hugr/views/petgraph.rs index 17c3e0062..22da47f0a 100644 --- a/hugr-core/src/hugr/views/petgraph.rs +++ b/hugr-core/src/hugr/views/petgraph.rs @@ -55,7 +55,7 @@ where T: HugrView, { fn node_count(&self) -> usize { - HugrView::node_count(self.hugr) + HugrView::num_nodes(self.hugr) } } @@ -64,15 +64,15 @@ where T: HugrView, { fn node_bound(&self) -> usize { - HugrView::node_count(self.hugr) + HugrView::num_nodes(self.hugr) } fn to_index(&self, ix: Self::NodeId) -> usize { - self.hugr.get_pg_index(ix).into() + self.hugr.to_portgraph_node(ix).into() } fn from_index(&self, ix: usize) -> Self::NodeId { - self.hugr.get_node(portgraph::NodeIndex::new(ix)) + self.hugr.from_portgraph_node(portgraph::NodeIndex::new(ix)) } } @@ -81,7 +81,7 @@ where T: HugrView, { fn edge_count(&self) -> usize { - HugrView::edge_count(self.hugr) + HugrView::num_edges(self.hugr) } } @@ -233,7 +233,7 @@ mod test { assert_eq!(wrapper.node_bound(), 5); assert_eq!(wrapper.edge_count(), 7); - let cx1_index = cx1.node().pg_index().index(); + let cx1_index = cx1.node().into_portgraph().index(); assert_eq!(wrapper.to_index(cx1.node()), cx1_index); assert_eq!(wrapper.from_index(cx1_index), cx1.node()); diff --git a/hugr-core/src/hugr/views/render.rs b/hugr-core/src/hugr/views/render.rs index ecb8549c0..43530e4c1 100644 --- a/hugr-core/src/hugr/views/render.rs +++ b/hugr-core/src/hugr/views/render.rs @@ -36,7 +36,7 @@ pub(super) fn node_style( config: RenderConfig, ) -> Box NodeStyle + '_> { fn node_name(h: &H, n: NodeIndex) -> String { - match h.get_optype(h.get_node(n)) { + match h.get_optype(h.from_portgraph_node(n)) { OpType::FuncDecl(f) => format!("FuncDecl: \"{}\"", f.name), OpType::FuncDefn(f) => format!("FuncDefn: \"{}\"", f.name), op => op.name().to_string(), @@ -45,14 +45,14 @@ pub(super) fn node_style( if config.node_indices { Box::new(move |n| { - NodeStyle::Box(format!( + NodeStyle::boxed(format!( "({ni}) {name}", ni = n.index(), name = node_name(h, n) )) }) } else { - Box::new(move |n| NodeStyle::Box(node_name(h, n))) + Box::new(move |n| NodeStyle::boxed(node_name(h, n))) } } @@ -64,7 +64,7 @@ pub(super) fn port_style( let graph = h.portgraph(); Box::new(move |port| { let node = graph.port_node(port).unwrap(); - let optype = h.get_optype(h.get_node(node)); + let optype = h.get_optype(h.from_portgraph_node(node)); let offset = graph.port_offset(port).unwrap(); match optype.port_kind(offset).unwrap() { EdgeKind::Function(pf) => PortStyle::new(html_escape::encode_text(&format!("{}", pf))), @@ -95,7 +95,7 @@ pub(super) fn edge_style( let graph = h.portgraph(); Box::new(move |src, tgt| { let src_node = graph.port_node(src).unwrap(); - let src_optype = h.get_optype(h.get_node(src_node)); + let src_optype = h.get_optype(h.from_portgraph_node(src_node)); let src_offset = graph.port_offset(src).unwrap(); let tgt_offset = graph.port_offset(tgt).unwrap(); diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index ac31d2695..44e29ab1a 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -51,15 +51,19 @@ pub struct SiblingGraph<'g, Root = Node> { macro_rules! impl_base_members { () => { #[inline] - fn node_count(&self) -> usize { - self.base_hugr() - .hierarchy - .child_count(self.get_pg_index(self.root)) + fn root(&self) -> Self::Node { + self.root + } + + #[inline] + fn num_nodes(&self) -> usize { + self.hierarchy() + .child_count(self.to_portgraph_node(self.root)) + 1 } #[inline] - fn edge_count(&self) -> usize { + fn num_edges(&self) -> usize { // Faster implementation than filtering all the nodes in the internal graph. self.nodes() .map(|n| self.output_neighbours(n).count()) @@ -70,10 +74,9 @@ macro_rules! impl_base_members { fn nodes(&self) -> impl Iterator + Clone { // Faster implementation than filtering all the nodes in the internal graph. let children = self - .base_hugr() - .hierarchy - .children(self.get_pg_index(self.root)) - .map(|n| self.get_node(n)); + .hierarchy() + .children(self.to_portgraph_node(self.root)) + .map(|n| self.from_portgraph_node(n)); iter::once(self.root).chain(children) } @@ -83,10 +86,41 @@ macro_rules! impl_base_members { ) -> impl DoubleEndedIterator + Clone { // Same as SiblingGraph let children = match node == self.root { - true => self.base_hugr().hierarchy.children(self.get_pg_index(node)), + true => self.hierarchy().children(self.to_portgraph_node(node)), false => portgraph::hierarchy::Children::default(), }; - children.map(|n| self.get_node(n)) + children.map(|n| self.from_portgraph_node(n)) + } + + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType { + self.hugr.get_optype(node) + } + + fn extensions(&self) -> &crate::extension::ExtensionRegistry { + self.hugr.extensions() + } + + fn get_parent(&self, node: Self::Node) -> Option { + match self.hugr.get_parent(node) { + Some(parent) if parent == self.root => Some(self.root), + _ => None, + } + } + + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone { + if node == self.root { + Either::Left(self.hugr.descendants(node)) + } else { + Either::Right(iter::empty()) + } + } + + delegate::delegate! { + to (&self.hugr) { + fn mermaid_string(&self) -> String; + fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; + fn dot_string(&self) -> String; + } } }; } @@ -96,20 +130,20 @@ impl HugrView for SiblingGraph<'_, Root> { #[inline] fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(self.get_pg_index(node)) + self.graph.contains_node(self.to_portgraph_node(node)) } #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .port_offsets(self.get_pg_index(node), dir) + .port_offsets(self.to_portgraph_node(node), dir) .map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_port_offsets(self.get_pg_index(node)) + .all_port_offsets(self.to_portgraph_node(node)) .map_into() } @@ -120,47 +154,52 @@ impl HugrView for SiblingGraph<'_, Root> { ) -> impl Iterator + Clone { let port = self .graph - .port_index(self.get_pg_index(node), port.into().pg_offset()) + .port_index(self.to_portgraph_node(node), port.into().pg_offset()) .unwrap(); self.graph.port_links(port).map(|(_, link)| { let node = self.graph.port_node(link).unwrap(); let offset = self.graph.port_offset(link).unwrap(); - (self.get_node(node), offset.into()) + (self.from_portgraph_node(node), offset.into()) }) } fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph - .get_connections(self.get_pg_index(node), self.get_pg_index(other)) + .get_connections(self.to_portgraph_node(node), self.to_portgraph_node(other)) .map(|(p1, p2)| [p1, p2].map(|link| self.graph.port_offset(link).unwrap().into())) } #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(self.get_pg_index(node), dir) + self.graph.num_ports(self.to_portgraph_node(node), dir) } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .neighbours(self.get_pg_index(node), dir) - .map(|n| self.get_node(n)) + .neighbours(self.to_portgraph_node(node), dir) + .map(|n| self.from_portgraph_node(n)) } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_neighbours(self.get_pg_index(node)) - .map(|n| self.get_node(n)) + .all_neighbours(self.to_portgraph_node(node)) + .map(|n| self.from_portgraph_node(n)) } } impl<'a, Root: NodeHandle> SiblingGraph<'a, Root> { fn new_unchecked(hugr: &'a impl HugrView, root: Node) -> Self { + #[allow(deprecated)] let hugr = hugr.base_hugr(); Self { root, - graph: FlatRegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)), + graph: FlatRegionGraph::new_with_root( + &hugr.graph, + &hugr.hierarchy, + hugr.to_portgraph_node(root), + ), hugr, _phantom: std::marker::PhantomData, } @@ -173,7 +212,7 @@ where { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { assert!( - hugr.valid_node(root), + hugr.contains_node(root), "Cannot create a sibling graph from an invalid node {}.", root ); @@ -200,23 +239,34 @@ where } #[inline] - fn base_hugr(&self) -> &Hugr { - self.hugr + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion< + '_, + impl portgraph::view::LinkView + Clone + '_, + > { + self.hugr.region_portgraph(parent) } #[inline] - fn root_node(&self) -> Node { - self.root + fn hierarchy(&self) -> &portgraph::Hierarchy { + self.hugr.hierarchy() } #[inline] - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { - self.hugr.get_pg_index(node) + fn base_hugr(&self) -> &Hugr { + self.hugr } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Node { - self.hugr.get_node(index) + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + self.hugr.to_portgraph_node(node) + } + + #[inline] + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Node { + self.hugr.from_portgraph_node(index) } #[inline] @@ -272,37 +322,50 @@ impl<'g, H: HugrMut, Root: NodeHandle> HugrInternals for SiblingMut<'g, #[inline] fn portgraph(&self) -> Self::Portgraph<'_> { - FlatRegionGraph::new( + FlatRegionGraph::new_with_root( + #[allow(deprecated)] &self.base_hugr().graph, - &self.base_hugr().hierarchy, - self.get_pg_index(self.root), + self.hierarchy(), + self.to_portgraph_node(self.root), ) } #[inline] - fn base_hugr(&self) -> &Hugr { - self.hugr.base_hugr() + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion< + '_, + impl portgraph::view::LinkView + Clone + '_, + > { + self.hugr.region_portgraph(parent) } #[inline] - fn root_node(&self) -> Self::Node { - self.root + fn hierarchy(&self) -> &portgraph::Hierarchy { + self.hugr.hierarchy() } #[inline] - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { - self.hugr.get_pg_index(node) + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + self.hugr.to_portgraph_node(node) } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node { - self.hugr.get_node(index) + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node { + self.hugr.from_portgraph_node(index) } #[inline] fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { self.hugr.node_metadata_map(node) } + + #[inline] + fn base_hugr(&self) -> &Hugr { + #[allow(deprecated)] + self.hugr.base_hugr() + } } impl> HugrView for SiblingMut<'_, H, Root> { @@ -435,7 +498,7 @@ mod test { { let def_io = region.get_io(def).unwrap(); - assert_eq!(region.node_count(), 5); + assert_eq!(region.num_nodes(), 5); assert_eq!(region.portgraph().node_count(), 5); assert!(region.nodes().all(|n| n == def || hugr.get_parent(n) == Some(def) @@ -455,8 +518,8 @@ mod test { inner_region.inner_function_type().map(Cow::into_owned), Some(Signature::new(vec![usize_t()], vec![usize_t()])) ); - assert_eq!(inner_region.node_count(), 3); - assert_eq!(inner_region.edge_count(), 1); + assert_eq!(inner_region.num_nodes(), 3); + assert_eq!(inner_region.num_edges(), 1); assert_eq!(inner_region.children(inner).count(), 2); assert_eq!(inner_region.children(hugr.root()).count(), 0); assert_eq!( @@ -591,8 +654,8 @@ mod test { let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; - assert_eq!(region.node_count(), extracted.node_count()); - assert_eq!(region.root_type(), extracted.root_type()); + assert_eq!(region.num_nodes(), extracted.num_nodes()); + assert_eq!(region.root_optype(), extracted.root_optype()); Ok(()) } diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 680d58a03..7fd2b9f54 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -194,7 +194,7 @@ impl SiblingSubgraph { let subpg = Subgraph::new_subgraph(pg.clone(), make_boundary(hugr, &inputs, &outputs)); let nodes = subpg .nodes_iter() - .map(|index| hugr.get_node(index)) + .map(|index| hugr.from_portgraph_node(index)) .collect_vec(); validate_subgraph(hugr, &nodes, &inputs, &outputs)?; @@ -525,7 +525,7 @@ fn make_boundary<'a, N: HugrNode>( ) -> Boundary { let to_pg_index = |n: N, p: Port| { hugr.portgraph() - .port_index(hugr.get_pg_index(n), p.pg_offset()) + .port_index(hugr.to_portgraph_node(n), p.pg_offset()) .unwrap() }; Boundary::new( @@ -1010,9 +1010,9 @@ mod tests { assert_eq!(rep.subgraph().nodes().len(), 4); - assert_eq!(hugr.node_count(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out + assert_eq!(hugr.num_nodes(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out hugr.apply_patch(rep).unwrap(); - assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out + assert_eq!(hugr.num_nodes(), 4); // Module + Def + In + Out Ok(()) } diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 794e6eaaa..6aad904cb 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -364,7 +364,7 @@ pub enum ConstTypeError { /// Hugrs (even functions) inside Consts must be monomorphic fn mono_fn_type(h: &Hugr) -> Result, ConstTypeError> { let err = || ConstTypeError::NotMonomorphicFunction { - hugr_root_type: h.root_type().clone(), + hugr_root_type: h.root_optype().clone(), }; if let Some(pf) = h.poly_func_type() { match pf.try_into() { diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 9cb6f9b10..76bb2bb09 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -5,7 +5,6 @@ use hugr_core::ops::{ }; use hugr_core::Node; use hugr_core::{ - hugr::views::SiblingGraph, types::{SumType, Type, TypeEnum}, HugrView, NodeIndex, }; @@ -71,34 +70,33 @@ where debug_assert!(i.out_value_types().count() == self.inputs.as_ref().unwrap().len()); debug_assert!(o.in_value_types().count() == self.outputs.as_ref().unwrap().len()); - let region: SiblingGraph = node.try_new_hierarchy_view().unwrap(); - Topo::new(®ion.as_petgraph()) - .iter(®ion.as_petgraph()) - .filter(|x| (*x != node.node())) - .map(|x| node.hugr().fat_optype(x)) - .try_for_each(|node| { - let inputs_rmb = context.node_ins_rmb(node)?; - let inputs = inputs_rmb.read(context.builder(), [])?; - let outputs = context.node_outs_rmb(node)?.promise(); - match node.as_ref() { - OpType::Input(_) => { - let i = self.take_input()?; - outputs.finish(context.builder(), i) - } - OpType::Output(_) => { - let o = self.take_output()?; - o.finish(context.builder(), inputs) - } - _ => emit_optype( - context, - EmitOpArgs { - node, - inputs, - outputs, - }, - ), + let region_graph = node.hugr().region_portgraph(node.node()); + let topo = Topo::new(®ion_graph); + for n in topo.iter(®ion_graph) { + let node = node.hugr().fat_optype(node.hugr().from_portgraph_node(n)); + let inputs_rmb = context.node_ins_rmb(node)?; + let inputs = inputs_rmb.read(context.builder(), [])?; + let outputs = context.node_outs_rmb(node)?.promise(); + match node.as_ref() { + OpType::Input(_) => { + let i = self.take_input()?; + outputs.finish(context.builder(), i)?; } - }) + OpType::Output(_) => { + let o = self.take_output()?; + o.finish(context.builder(), inputs)?; + } + _ => emit_optype( + context, + EmitOpArgs { + node, + inputs, + outputs, + }, + )?, + } + } + Ok(()) } } diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap index b3283ee1b..124f36b53 100644 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap @@ -20,14 +20,14 @@ entry_block: ; preds = %alloca_block define i1 @_hl.main_unary.6(i1 %0) { alloca_block: %"0" = alloca i1, align 1 - %"7_0" = alloca i1, align 1 %"9_0" = alloca i1 (i1)*, align 8 + %"7_0" = alloca i1, align 1 %"10_0" = alloca i1, align 1 br label %entry_block entry_block: ; preds = %alloca_block - store i1 %0, i1* %"7_0", align 1 store i1 (i1)* @_hl.main_unary.6, i1 (i1)** %"9_0", align 8 + store i1 %0, i1* %"7_0", align 1 %"9_01" = load i1 (i1)*, i1 (i1)** %"9_0", align 8 %"7_02" = load i1, i1* %"7_0", align 1 %1 = call i1 %"9_01"(i1 %"7_02") @@ -42,17 +42,17 @@ define { i1, i1 } @_hl.main_binary.11(i1 %0, i1 %1) { alloca_block: %"0" = alloca i1, align 1 %"1" = alloca i1, align 1 + %"14_0" = alloca { i1, i1 } (i1, i1)*, align 8 %"12_0" = alloca i1, align 1 %"12_1" = alloca i1, align 1 - %"14_0" = alloca { i1, i1 } (i1, i1)*, align 8 %"15_0" = alloca i1, align 1 %"15_1" = alloca i1, align 1 br label %entry_block entry_block: ; preds = %alloca_block + store { i1, i1 } (i1, i1)* @_hl.main_binary.11, { i1, i1 } (i1, i1)** %"14_0", align 8 store i1 %0, i1* %"12_0", align 1 store i1 %1, i1* %"12_1", align 1 - store { i1, i1 } (i1, i1)* @_hl.main_binary.11, { i1, i1 } (i1, i1)** %"14_0", align 8 %"14_01" = load { i1, i1 } (i1, i1)*, { i1, i1 } (i1, i1)** %"14_0", align 8 %"12_02" = load i1, i1* %"12_0", align 1 %"12_13" = load i1, i1* %"12_1", align 1 diff --git a/hugr-llvm/src/utils/fat.rs b/hugr-llvm/src/utils/fat.rs index dec866b4e..5deeb4bf0 100644 --- a/hugr-llvm/src/utils/fat.rs +++ b/hugr-llvm/src/utils/fat.rs @@ -47,7 +47,7 @@ where /// Note that while we do check the type of the node's `get_optype`, we /// do not verify that it is actually equal to `ot`. pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self { - assert!(hugr.valid_node(node)); + assert!(hugr.contains_node(node)); assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok()); // We don't actually check `ot == hugr.get_optype(node)` so as to not require OT: PartialEq` Self { @@ -63,7 +63,7 @@ where /// If the node is invalid, or if its `get_optype` is not `OT`, returns /// `None`. pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option { - (hugr.valid_node(node)).then_some(())?; + (hugr.contains_node(node)).then_some(())?; Some(Self::new( hugr, node, @@ -99,7 +99,7 @@ impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> { /// /// Panics if the node is not valid in the [Hugr]. pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self { - assert!(hugr.valid_node(node)); + assert!(hugr.contains_node(node)); FatNode::new(hugr, node, hugr.get_optype(node)) } diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index ff5cd93a5..3a296fc0b 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -160,7 +160,7 @@ fn test_big() { .unwrap(); let mut h = build.finish_hugr_with_outputs(to_int.outputs()).unwrap(); - assert_eq!(h.node_count(), 8); + assert_eq!(h.num_nodes(), 8); constant_fold_pass(&mut h); @@ -333,7 +333,7 @@ fn test_const_fold_to_nonfinite() { assert_fully_folded_with(&h0, |v| { v.get_custom_value::().unwrap().value() == 1.0 }); - assert_eq!(h0.node_count(), 5); + assert_eq!(h0.num_nodes(), 5); // HUGR computing 1.0 / 0.0 let mut build = DFGBuilder::new(noargfn(vec![float64_type()])).unwrap(); @@ -342,7 +342,7 @@ fn test_const_fold_to_nonfinite() { let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); let mut h1 = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); constant_fold_pass(&mut h1); - assert_eq!(h1.node_count(), 8); + assert_eq!(h1.num_nodes(), 8); } #[test] @@ -1362,7 +1362,7 @@ fn test_tail_loop_unknown() { constant_fold_pass(&mut h); // Must keep the loop, even though we know the output, in case the output doesn't happen - assert_eq!(h.node_count(), 12); + assert_eq!(h.num_nodes(), 12); let tl = h .nodes() .filter(|n| h.get_optype(*n).is_tail_loop()) diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index d92fed134..25f6cf798 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -145,6 +145,7 @@ impl DeadCodeElimPass { if let Some(res) = cache.get(&n) { return *res; } + #[allow(deprecated)] let res = match self.preserve_callback.as_ref()(h.base_hugr(), n) { PreserveNode::MustKeep => true, PreserveNode::CanRemoveIgnoringChildren => false, diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index ad40e2164..ec59ccefd 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -2,11 +2,7 @@ use std::{cmp::Reverse, collections::BinaryHeap, iter}; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, HierarchyView, SiblingGraph}, - HugrError, - }, + hugr::{hugrmut::HugrMut, HugrError}, ops::{NamedOp, OpTag, OpTrait}, types::EdgeKind, HugrView as _, Node, @@ -51,34 +47,42 @@ pub fn force_order_by_key, K: Ord>( root: Node, rank: impl Fn(&H, Node) -> K, ) -> Result<(), HugrError> { - let dataflow_parents = DescendantsGraph::::try_new(hugr, root)? - .nodes() + let dataflow_parents = hugr + .descendants(root) .filter(|n| hugr.get_optype(*n).tag() <= OpTag::DataflowParent) .collect_vec(); for dp in dataflow_parents { // we filter out the input and output nodes from the topological sort let [i, o] = hugr.get_io(dp).unwrap(); - let rank = |n| rank(hugr, n); - let sg = SiblingGraph::::try_new(hugr, dp)?; - let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp && x != i && x != o); - let ordered_nodes = ForceOrder::new(&petgraph, &rank) - .iter(&petgraph) - .filter(|&x| { - let expected_edge = Some(EdgeKind::StateOrder); - let optype = hugr.get_optype(x); - if optype.other_input() == expected_edge || optype.other_output() == expected_edge { - assert_eq!( - optype.other_input(), - optype.other_output(), - "Optype does not have both input and output order edge: {}", - optype.name() - ); - true - } else { - false - } - }) - .collect_vec(); + let ordered_nodes = { + let rank = |n| rank(hugr, hugr.from_portgraph_node(n)); + let sg = hugr.region_portgraph(dp); + let petgraph = NodeFiltered::from_fn(&sg, |x| { + let x = hugr.from_portgraph_node(x); + x != dp && x != i && x != o + }); + ForceOrder::new(&petgraph, &rank) + .iter(&petgraph) + .map(|x| hugr.from_portgraph_node(x)) + .filter(|&x| { + let expected_edge = Some(EdgeKind::StateOrder); + let optype = hugr.get_optype(x); + if optype.other_input() == expected_edge + || optype.other_output() == expected_edge + { + assert_eq!( + optype.other_input(), + optype.other_output(), + "Optype does not have both input and output order edge: {}", + optype.name() + ); + true + } else { + false + } + }) + .collect_vec() + }; // we iterate over the topologically sorted nodes, prepending the input // node and suffixing the output node. diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 7e68e600a..403e3d84b 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -141,6 +141,6 @@ mod test { }); assert_eq!(lowered.unwrap().len(), 1); - assert_eq!(h.node_count(), 3); // DFG, input, output + assert_eq!(h.num_nodes(), 3); // DFG, input, output } } diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 5c76ba51d..170ff3789 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -59,7 +59,7 @@ fn mk_rep( let succ_sig = succ_ty.inner_signature(); // Make a Hugr with just a single CFG root node having the same signature. - let mut replacement: Hugr = Hugr::new(cfg.root_type().clone()); + let mut replacement: Hugr = Hugr::new(cfg.root_optype().clone()); let merged = replacement.add_node_with_parent(replacement.root(), { let mut merged_block = DataflowBlock { diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index d33234126..25249f5ae 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -135,7 +135,7 @@ impl NodeTemplate { ) -> Result<(), Option> { let sig = match self { NodeTemplate::SingleOp(op_type) => op_type, - NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + NodeTemplate::CompoundOp(hugr) => hugr.root_optype(), NodeTemplate::Call(_, _) => return Ok(()), // no way to tell } .dataflow_signature(); @@ -1012,7 +1012,7 @@ mod test { // list -> read -> usz just becomes list -> read -> qb // list> -> read> -> opt becomes list -> get -> opt assert_eq!( - h.root_type().dataflow_signature().unwrap().io(), + h.root_optype().dataflow_signature().unwrap().io(), ( &vec![list_type(qb_t()); 2].into(), &vec![qb_t(), option_type(qb_t()).into()].into() From 1369d79cc3d30109bdee6e20ce8ca0d222b4b7b2 Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Tue, 29 Apr 2025 17:16:58 +0200 Subject: [PATCH 16/18] chore: Remove stray rewrite.rs file (#2142) Oupsie, during one of the merge conflict resolutions I must have forgotten to remove the old `rewrite.rs` file. As you can see in the `hugr-core/src/hugr.rs` file, this is no longer a module and thus the file should be deleted. It has been renamed to `patch.rs` in #2070 --- hugr-core/src/hugr/rewrite.rs | 98 ----------------------------------- 1 file changed, 98 deletions(-) delete mode 100644 hugr-core/src/hugr/rewrite.rs diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs deleted file mode 100644 index 76dc93ab1..000000000 --- a/hugr-core/src/hugr/rewrite.rs +++ /dev/null @@ -1,98 +0,0 @@ -//! Rewrite operations on the HUGR - replacement, outlining, etc. - -pub mod consts; -pub mod inline_call; -pub mod inline_dfg; -pub mod insert_identity; -pub mod outline_cfg; -mod port_types; -pub mod replace; -pub mod simple_replace; - -use crate::core::HugrNode; -use crate::{Hugr, HugrView}; -pub use port_types::{BoundaryPort, HostPort, ReplacementPort}; -pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; - -use super::HugrMut; - -/// An operation that can be applied to mutate a Hugr -pub trait Rewrite { - /// The node type used by the target Hugr. - type Node: HugrNode; - /// The type of Error with which this Rewrite may fail - type Error: std::error::Error; - /// The type returned on successful application of the rewrite. - type ApplyResult; - - /// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err. - /// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned. - const UNCHANGED_ON_FAILURE: bool; - - /// Checks whether the rewrite would succeed on the specified Hugr. - /// If this call succeeds, [self.apply] should also succeed on the same `h` - /// If this calls fails, [self.apply] would fail with the same error. - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; - - /// Mutate the specified Hugr, or fail with an error. - /// Returns [`Self::ApplyResult`] if successful. - /// If [self.unchanged_on_failure] is true, then `h` must be unchanged if Err is returned. - /// See also [self.verify] - /// # Panics - /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is, - /// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())` - /// being preferred. - fn apply( - self, - h: &mut impl HugrMut, - ) -> Result; - - /// Returns a set of nodes referenced by the rewrite. Modifying any of these - /// nodes will invalidate it. - /// - /// Two `impl Rewrite`s can be composed if their invalidation sets are - /// disjoint. - fn invalidation_set(&self) -> impl Iterator; -} - -/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure) -pub struct Transactional { - underlying: R, -} - -// Note we might like to constrain R to Rewrite but this -// is not yet supported, https://github.com/rust-lang/rust/issues/92827 -impl Rewrite for Transactional { - type Node = R::Node; - type Error = R::Error; - type ApplyResult = R::ApplyResult; - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { - self.underlying.verify(h) - } - - fn apply(self, h: &mut impl HugrMut) -> Result { - if R::UNCHANGED_ON_FAILURE { - return self.underlying.apply(h); - } - // Try to backup just the contents of this HugrMut. - let mut backup = Hugr::new(h.root_optype().clone()); - backup.insert_from_view(backup.root(), h); - let r = self.underlying.apply(h); - if r.is_err() { - // Try to restore backup. - h.replace_op(h.root(), backup.root_optype().clone()); - while let Some(child) = h.first_child(h.root()) { - h.remove_node(child); - } - h.insert_from_view(h.root(), &backup); - } - r - } - - #[inline] - fn invalidation_set(&self) -> impl Iterator { - self.underlying.invalidation_set() - } -} From 6f7775ba7d39a0a6e7f9e82aaae7dab169f06700 Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Wed, 30 Apr 2025 07:51:35 -0500 Subject: [PATCH 17/18] chore(hugr-llvm): upgrade to inkwell 0.6.0 (#2128) Part of https://github.com/quantinuum-dev/hugrverse/issues/158 --- Cargo.lock | 8 ++++---- DEVELOPMENT.md | 8 ++++---- hugr-llvm/Cargo.toml | 2 +- hugr-llvm/README.md | 2 +- hugr-llvm/src/extension/collections/static_array.rs | 8 ++++++++ hugr-llvm/src/sum.rs | 2 ++ hugr-llvm/src/sum/layout.rs | 3 +++ 7 files changed, 23 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 085d2d6a5..b05cba1e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1458,9 +1458,9 @@ checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "inkwell" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40fb405537710d51f6bdbc8471365ddd4cd6d3a3c3ad6e0c8291691031ba94b2" +checksum = "e67349bd7578d4afebbe15eaa642a80b884e8623db74b1716611b131feb1deef" dependencies = [ "either", "inkwell_internals", @@ -1472,9 +1472,9 @@ dependencies = [ [[package]] name = "inkwell_internals" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44" +checksum = "f365c8de536236cfdebd0ba2130de22acefed18b1fb99c32783b3840aec5fb46" dependencies = [ "proc-macro2", "quote", diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 6d9465140..d9f19ed64 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -28,10 +28,10 @@ shell by setting up [direnv](https://devenv.sh/automatic-shell-activation/). To setup the environment manually you will need: -- Just: https://just.systems/ -- Rust `>=1.85`: https://www.rust-lang.org/tools/install -- uv `>=0.3`: docs.astral.sh/uv/getting-started/installation -- Optional: capnproto `>=1.0`: https://capnproto.org/install.html +- Just: +- Rust `>=1.85`: +- uv `>=0.3`: +- Optional: capnproto `>=1.0`: Required when modifying the `hugr-model` serialization schema. - Optional: llvm `== 14.0`. The "llvm" feature (backed by the sub-crate `hugr-llvm`) requires LLVM installed. We use the rust bindings diff --git a/hugr-llvm/Cargo.toml b/hugr-llvm/Cargo.toml index 677a82a31..bdfc63f5a 100644 --- a/hugr-llvm/Cargo.toml +++ b/hugr-llvm/Cargo.toml @@ -23,7 +23,7 @@ llvm14-0 = ["inkwell/llvm14-0"] [dependencies] -inkwell = { version = "0.5.0", default-features = false } +inkwell = { version = "0.6.0", default-features = false } hugr-core = { path = "../hugr-core", version = "0.15.3" } anyhow = "1.0.98" itertools.workspace = true diff --git a/hugr-llvm/README.md b/hugr-llvm/README.md index 6d81cd35d..988a650dd 100644 --- a/hugr-llvm/README.md +++ b/hugr-llvm/README.md @@ -25,7 +25,7 @@ version will only change on major releases. ## Developing hugr-llvm -See [DEVELOPMENT](DEVELOPMENT.md) for instructions on setting up the development environment. +See [DEVELOPMENT](../DEVELOPMENT.md) for instructions on setting up the development environment. ## License diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index 7d3ac5f5c..7f59bff82 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -58,6 +58,7 @@ fn value_is_const<'c>(value: impl BasicValue<'c>) -> bool { BasicValueEnum::PointerValue(v) => v.is_const(), BasicValueEnum::StructValue(v) => v.is_const(), BasicValueEnum::VectorValue(v) => v.is_const(), + BasicValueEnum::ScalableVectorValue(v) => v.is_const(), } } @@ -109,6 +110,13 @@ fn const_array<'c>( .collect_vec() .as_slice(), ), + BasicTypeEnum::ScalableVectorType(t) => t.const_array( + values + .into_iter() + .map(|x| x.as_basic_value_enum().into_scalable_vector_value()) + .collect_vec() + .as_slice(), + ), } } diff --git a/hugr-llvm/src/sum.rs b/hugr-llvm/src/sum.rs index c2b9a0475..381e09469 100644 --- a/hugr-llvm/src/sum.rs +++ b/hugr-llvm/src/sum.rs @@ -47,6 +47,7 @@ fn basic_type_undef<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> { BasicTypeEnum::PointerType(t) => t.get_undef().as_basic_value_enum(), BasicTypeEnum::StructType(t) => t.get_undef().as_basic_value_enum(), BasicTypeEnum::VectorType(t) => t.get_undef().as_basic_value_enum(), + BasicTypeEnum::ScalableVectorType(t) => t.get_undef().as_basic_value_enum(), } } @@ -60,6 +61,7 @@ fn basic_type_poison<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> { BasicTypeEnum::PointerType(t) => t.get_poison().as_basic_value_enum(), BasicTypeEnum::StructType(t) => t.get_poison().as_basic_value_enum(), BasicTypeEnum::VectorType(t) => t.get_poison().as_basic_value_enum(), + BasicTypeEnum::ScalableVectorType(t) => t.get_poison().as_basic_value_enum(), } } diff --git a/hugr-llvm/src/sum/layout.rs b/hugr-llvm/src/sum/layout.rs index fd67a3240..d016de851 100644 --- a/hugr-llvm/src/sum/layout.rs +++ b/hugr-llvm/src/sum/layout.rs @@ -45,6 +45,9 @@ fn size_of_type<'c>(t: impl BasicType<'c>) -> Option { BasicTypeEnum::PointerType(t) => t.size_of().get_zero_extended_constant(), BasicTypeEnum::StructType(t) => t.size_of().and_then(|x| x.get_zero_extended_constant()), BasicTypeEnum::VectorType(t) => t.size_of().and_then(|x| x.get_zero_extended_constant()), + BasicTypeEnum::ScalableVectorType(t) => { + t.size_of().and_then(|x| x.get_zero_extended_constant()) + } } } From da4b16e78e1a7c89c652c1f1f10e7f748339da79 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Tue, 15 Apr 2025 10:14:19 +0100 Subject: [PATCH 18/18] Removed runtime extension sets. --- .github/workflows/ci-rs.yml | 2 +- .pre-commit-config.yaml | 7 +- hugr-core/Cargo.toml | 1 - hugr-core/README.md | 4 - hugr-core/src/builder.rs | 17 +- hugr-core/src/builder/build_traits.rs | 82 +---- hugr-core/src/builder/cfg.rs | 115 +------ hugr-core/src/builder/circuit.rs | 56 ++- hugr-core/src/builder/conditional.rs | 38 +-- hugr-core/src/builder/dataflow.rs | 23 +- hugr-core/src/builder/module.rs | 5 +- hugr-core/src/builder/tail_loop.rs | 26 +- hugr-core/src/export.rs | 10 +- hugr-core/src/extension.rs | 73 +--- hugr-core/src/extension/declarative.rs | 8 +- .../src/extension/declarative/signature.rs | 7 +- hugr-core/src/extension/op_def.rs | 56 +-- hugr-core/src/extension/prelude.rs | 30 +- .../src/extension/prelude/unwrap_builder.rs | 6 +- hugr-core/src/extension/resolution.rs | 4 - hugr-core/src/extension/resolution/test.rs | 46 +-- hugr-core/src/extension/resolution/types.rs | 2 - .../src/extension/resolution/types_mut.rs | 2 - hugr-core/src/hugr.rs | 318 +----------------- hugr-core/src/hugr/hugrmut.rs | 4 +- hugr-core/src/hugr/internal.rs | 3 +- hugr-core/src/hugr/patch/consts.rs | 6 +- hugr-core/src/hugr/patch/inline_call.rs | 19 +- hugr-core/src/hugr/patch/inline_dfg.rs | 12 +- hugr-core/src/hugr/patch/outline_cfg.rs | 33 +- hugr-core/src/hugr/patch/replace.rs | 27 +- hugr-core/src/hugr/patch/simple_replace.rs | 13 +- hugr-core/src/hugr/serialize/test.rs | 20 +- hugr-core/src/hugr/serialize/upgrade/test.rs | 1 - hugr-core/src/hugr/validate.rs | 59 +--- hugr-core/src/hugr/validate/test.rs | 315 ++--------------- hugr-core/src/hugr/views.rs | 15 - hugr-core/src/hugr/views/descendants.rs | 14 +- hugr-core/src/hugr/views/impls.rs | 1 - hugr-core/src/hugr/views/sibling.rs | 7 +- hugr-core/src/hugr/views/sibling_subgraph.rs | 64 +--- hugr-core/src/import.rs | 8 +- hugr-core/src/ops.rs | 8 +- hugr-core/src/ops/constant.rs | 29 +- hugr-core/src/ops/constant/custom.rs | 48 +-- hugr-core/src/ops/controlflow.rs | 66 +--- hugr-core/src/ops/custom.rs | 6 +- hugr-core/src/ops/dataflow.rs | 8 +- hugr-core/src/package.rs | 3 - .../std_extensions/arithmetic/conversions.rs | 8 +- .../std_extensions/arithmetic/float_ops.rs | 3 +- .../std_extensions/arithmetic/float_types.rs | 6 +- .../src/std_extensions/arithmetic/int_ops.rs | 9 +- .../std_extensions/arithmetic/int_types.rs | 6 +- .../src/std_extensions/collections/array.rs | 7 +- .../collections/array/array_repeat.rs | 41 +-- .../collections/array/array_scan.rs | 69 +--- .../collections/array/op_builder.rs | 10 +- .../src/std_extensions/collections/list.rs | 7 +- .../collections/static_array.rs | 17 +- hugr-core/src/std_extensions/ptr.rs | 5 +- hugr-core/src/types/poly_func.rs | 1 - hugr-core/src/types/signature.rs | 37 +- hugr-core/src/types/type_param.rs | 47 +-- hugr-llvm/src/emit/ops/cfg.rs | 9 +- hugr-llvm/src/emit/test.rs | 27 +- hugr-llvm/src/extension/collections/array.rs | 41 +-- hugr-passes/Cargo.toml | 3 - hugr-passes/README.md | 12 +- hugr-passes/src/composable.rs | 33 +- hugr-passes/src/const_fold/test.rs | 5 +- hugr-passes/src/dataflow/test.rs | 12 +- hugr-passes/src/dead_code.rs | 6 +- hugr-passes/src/force_order.rs | 2 +- hugr-passes/src/lower.rs | 2 +- hugr-passes/src/merge_bbs.rs | 10 +- hugr-passes/src/monomorphize.rs | 43 +-- hugr-passes/src/nest_cfgs.rs | 8 +- hugr-passes/src/non_local.rs | 8 +- hugr-passes/src/replace_types.rs | 3 +- hugr-passes/src/replace_types/handlers.rs | 21 +- hugr-passes/src/replace_types/linearize.rs | 11 +- hugr-passes/src/untuple.rs | 40 +-- hugr-passes/src/validation.rs | 11 +- hugr-py/src/hugr/_serialization/extension.py | 6 +- hugr-py/src/hugr/_serialization/ops.py | 24 +- hugr-py/src/hugr/_serialization/tys.py | 34 +- hugr-py/src/hugr/ext.py | 16 - hugr-py/src/hugr/ops.py | 11 +- .../_json_defs/arithmetic/conversions.json | 40 +-- .../hugr/std/_json_defs/arithmetic/float.json | 60 ++-- .../_json_defs/arithmetic/float/types.json | 1 - .../hugr/std/_json_defs/arithmetic/int.json | 141 +++----- .../std/_json_defs/arithmetic/int/types.json | 1 - .../std/_json_defs/collections/array.json | 31 +- .../hugr/std/_json_defs/collections/list.json | 19 +- .../_json_defs/collections/static_array.json | 7 +- hugr-py/src/hugr/std/_json_defs/logic.json | 16 +- hugr-py/src/hugr/std/_json_defs/prelude.json | 25 +- hugr-py/src/hugr/std/_json_defs/ptr.json | 10 +- hugr-py/src/hugr/std/int.py | 2 +- hugr-py/src/hugr/tys.py | 58 +--- hugr-py/tests/serialization/test_extension.py | 6 +- hugr-py/tests/test_custom.py | 2 +- hugr-py/tests/test_tys.py | 4 - hugr/Cargo.toml | 1 - hugr/README.md | 7 +- hugr/benches/benchmarks/hugr/examples.rs | 9 +- justfile | 2 +- specification/hugr.md | 65 ---- specification/schema/hugr_schema_live.json | 81 ----- .../schema/hugr_schema_strict_live.json | 81 ----- .../schema/testing_hugr_schema_live.json | 81 ----- .../testing_hugr_schema_strict_live.json | 81 ----- .../arithmetic/conversions.json | 40 +-- .../std_extensions/arithmetic/float.json | 60 ++-- .../arithmetic/float/types.json | 1 - .../std_extensions/arithmetic/int.json | 141 +++----- .../std_extensions/arithmetic/int/types.json | 1 - .../std_extensions/collections/array.json | 31 +- .../std_extensions/collections/list.json | 19 +- .../collections/static_array.json | 7 +- specification/std_extensions/logic.json | 16 +- specification/std_extensions/prelude.json | 25 +- specification/std_extensions/ptr.json | 10 +- 125 files changed, 591 insertions(+), 2989 deletions(-) diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index 4fe5d244f..c6814fc60 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -108,7 +108,7 @@ jobs: - name: Override criterion with the CodSpeed harness run: cargo add --dev codspeed-criterion-compat --rename criterion --package hugr - name: Build benchmarks - run: cargo codspeed build --profile bench --features extension_inference,declarative,llvm,llvm-test + run: cargo codspeed build --profile bench --features declarative,llvm,llvm-test - name: Run benchmarks uses: CodSpeedHQ/action@v3 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4fe582d93..b6e481bd5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,7 +79,7 @@ repos: # built into a binary build (without using `maturin`) # # This feature list should be kept in sync with the `hugr-py/pyproject.toml` - entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/extension_inference hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' + entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' language: system files: \.rs$ pass_filenames: false @@ -100,10 +100,7 @@ repos: - id: py-test name: pytest description: Run python tests - # We need to rebuild `hugr-cli` without the `extension_inference` feature - # to avoid test errors. - # TODO: Remove this once the issue is fixed. - entry: sh -c "cargo build -p hugr-cli && uv run pytest" + entry: sh -c "uv run pytest" language: system files: \.py$ pass_filenames: false diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 1e4fa392f..8da686ee8 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -17,7 +17,6 @@ categories = ["compilers"] workspace = true [features] -extension_inference = [] declarative = ["serde_yaml"] zstd = ["dep:zstd"] diff --git a/hugr-core/README.md b/hugr-core/README.md index 765d4577b..0e15305f1 100644 --- a/hugr-core/README.md +++ b/hugr-core/README.md @@ -14,10 +14,6 @@ Please read the [API documentation here][]. ## Experimental Features -- `extension_inference`: - Experimental feature which allows automatic inference of which extra extensions - are required at runtime by a HUGR when validating it. - Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index 056690e0a..9f7a219a7 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -42,7 +42,7 @@ //! let _dfg_handle = { //! let mut dfg = module_builder.define_function( //! "main", -//! Signature::new_endo(bool_t()).with_extension_delta(logic::EXTENSION_ID), +//! Signature::new_endo(bool_t()), //! )?; //! //! // Get the wires from the function inputs. @@ -59,8 +59,7 @@ //! let _circuit_handle = { //! let mut dfg = module_builder.define_function( //! "circuit", -//! Signature::new_endo(vec![bool_t(), bool_t()]) -//! .with_extension_delta(logic::EXTENSION_ID), +//! Signature::new_endo(vec![bool_t(), bool_t()]), //! )?; //! let mut circuit = dfg.as_circuit(dfg.input_wires()); //! @@ -89,7 +88,7 @@ use thiserror::Error; use crate::extension::simple_op::OpLoadError; -use crate::extension::{SignatureError, TO_BE_INFERRED}; +use crate::extension::SignatureError; use crate::hugr::ValidationError; use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID}; use crate::ops::{NamedOp, OpType}; @@ -123,16 +122,14 @@ pub use conditional::{CaseBuilder, ConditionalBuilder}; mod circuit; pub use circuit::{CircuitBuildError, CircuitBuilder}; -/// Return a FunctionType with the same input and output types (specified) -/// whose extension delta, when used in a non-FuncDefn container, will be inferred. +/// Return a FunctionType with the same input and output types (specified). pub fn endo_sig(types: impl Into) -> Signature { - Signature::new_endo(types).with_extension_delta(TO_BE_INFERRED) + Signature::new_endo(types) } -/// Return a FunctionType with the specified input and output types -/// whose extension delta, when used in a non-FuncDefn container, will be inferred. +/// Return a FunctionType with the specified input and output types. pub fn inout_sig(inputs: impl Into, outputs: impl Into) -> Signature { - Signature::new(inputs, outputs).with_extension_delta(TO_BE_INFERRED) + Signature::new(inputs, outputs) } #[derive(Debug, Clone, PartialEq, Error)] diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 58c15c54a..ba366c117 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -20,7 +20,7 @@ use crate::{ types::EdgeKind, }; -use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; +use crate::extension::ExtensionRegistry; use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -319,10 +319,7 @@ pub trait Dataflow: Container { inputs: impl IntoIterator, ) -> Result, BuildError> { let (types, input_wires): (Vec, Vec) = inputs.into_iter().unzip(); - self.dfg_builder( - Signature::new_endo(types).with_extension_delta(TO_BE_INFERRED), - input_wires, - ) + self.dfg_builder(Signature::new_endo(types), input_wires) } /// Return a builder for a [`crate::ops::CFG`] node, @@ -330,7 +327,6 @@ pub trait Dataflow: Container { /// The `inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// The Extension delta will be inferred. /// /// # Errors /// @@ -340,27 +336,6 @@ pub trait Dataflow: Container { &mut self, inputs: impl IntoIterator, output_types: TypeRow, - ) -> Result, BuildError> { - self.cfg_builder_exts(inputs, output_types, TO_BE_INFERRED) - } - - /// Return a builder for a [`crate::ops::CFG`] node, - /// i.e. a nested controlflow subgraph. - /// The `inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. - /// The `output_types` are the types of the outputs. - /// `extension_delta` is explicitly specified. Alternatively - /// [cfg_builder](Self::cfg_builder) may be used to infer it. - /// - /// # Errors - /// - /// This function will return an error if there is an error when building - /// the CFG node. - fn cfg_builder_exts( - &mut self, - inputs: impl IntoIterator, - output_types: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let (input_types, input_wires): (Vec, Vec) = inputs.into_iter().unzip(); @@ -369,8 +344,7 @@ pub trait Dataflow: Container { let (cfg_node, _) = add_node_with_wires( self, ops::CFG { - signature: Signature::new(inputs.clone(), output_types.clone()) - .with_extension_delta(extension_delta), + signature: Signature::new(inputs.clone(), output_types.clone()), }, input_wires, )?; @@ -449,7 +423,6 @@ pub trait Dataflow: Container { /// The `inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// The extension delta will be inferred. /// /// # Errors /// @@ -461,27 +434,6 @@ pub trait Dataflow: Container { just_inputs: impl IntoIterator, inputs_outputs: impl IntoIterator, just_out_types: TypeRow, - ) -> Result, BuildError> { - self.tail_loop_builder_exts(just_inputs, inputs_outputs, just_out_types, TO_BE_INFERRED) - } - - /// Return a builder for a [`crate::ops::TailLoop`] node. - /// The `inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. - /// The `output_types` are the types of the outputs. - /// `extension_delta` explicitly specified. Alternatively - /// [tail_loop_builder](Self::tail_loop_builder) may be used to infer it. - /// - /// # Errors - /// - /// This function will return an error if there is an error when building - /// the [`ops::TailLoop`] node. - fn tail_loop_builder_exts( - &mut self, - just_inputs: impl IntoIterator, - inputs_outputs: impl IntoIterator, - just_out_types: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let (input_types, mut input_wires): (Vec, Vec) = just_inputs.into_iter().unzip(); @@ -493,7 +445,6 @@ pub trait Dataflow: Container { just_inputs: input_types.into(), just_outputs: just_out_types, rest: rest_types.into(), - extension_delta: extension_delta.into(), }; // TODO: Make input extensions a parameter let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?; @@ -507,41 +458,17 @@ pub trait Dataflow: Container { /// /// The `other_inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. - /// The `output_types` are the types of the outputs. Extension delta will be inferred. - /// - /// # Errors - /// - /// This function will return an error if there is an error when building - /// the Conditional node. - fn conditional_builder( - &mut self, - sum_input: (impl IntoIterator, Wire), - other_inputs: impl IntoIterator, - output_types: TypeRow, - ) -> Result, BuildError> { - self.conditional_builder_exts(sum_input, other_inputs, output_types, TO_BE_INFERRED) - } - - /// Return a builder for a [`crate::ops::Conditional`] node. - /// `sum_rows` and `sum_wire` define the type of the Sum - /// variants and the wire carrying the Sum respectively. - /// - /// The `other_inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// `extension_delta` is explicitly specified. Alternatively - /// [conditional_builder](Self::conditional_builder) may be used to infer it. /// /// # Errors /// /// This function will return an error if there is an error when building /// the Conditional node. - fn conditional_builder_exts( + fn conditional_builder( &mut self, (sum_rows, sum_wire): (impl IntoIterator, Wire), other_inputs: impl IntoIterator, output_types: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let mut input_wires = vec![sum_wire]; let (input_types, rest_input_wires): (Vec, Vec) = @@ -558,7 +485,6 @@ pub trait Dataflow: Container { sum_rows, other_inputs: inputs, outputs: output_types, - extension_delta: extension_delta.into(), }, input_wires, )?; diff --git a/hugr-core/src/builder/cfg.rs b/hugr-core/src/builder/cfg.rs index 81c7d7269..0aadc047b 100644 --- a/hugr-core/src/builder/cfg.rs +++ b/hugr-core/src/builder/cfg.rs @@ -5,9 +5,8 @@ use super::{ BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire, }; -use crate::extension::TO_BE_INFERRED; use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType}; -use crate::{extension::ExtensionSet, types::Signature}; +use crate::types::Signature; use crate::{hugr::views::HugrView, types::TypeRow}; use crate::Node; @@ -106,7 +105,6 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// let hugr = cfg_builder.finish_hugr()?; /// Ok(hugr) /// }; -/// #[cfg(not(feature = "extension_inference"))] /// assert!(make_cfg().is_ok()); /// ``` #[derive(Debug, PartialEq)] @@ -157,10 +155,7 @@ impl CFGBuilder { } impl HugrBuilder for CFGBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.base.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.base.validate()?; Ok(self.base) } @@ -192,7 +187,7 @@ impl + AsRef> CFGBuilder { /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` /// and `outputs` and the variants of the branching Sum value - /// specified by `sum_rows`. Extension delta will be inferred. + /// specified by `sum_rows`. /// /// # Errors /// @@ -203,36 +198,12 @@ impl + AsRef> CFGBuilder { sum_rows: impl IntoIterator, other_outputs: TypeRow, ) -> Result, BuildError> { - self.block_builder_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED) - } - - /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` - /// and `outputs` and the variants of the branching Sum value - /// specified by `sum_rows`. Extension delta will be inferred. - /// - /// # Errors - /// - /// This function will return an error if there is an error adding the node. - pub fn block_builder_exts( - &mut self, - inputs: TypeRow, - sum_rows: impl IntoIterator, - other_outputs: TypeRow, - extension_delta: impl Into, - ) -> Result, BuildError> { - self.any_block_builder( - inputs, - extension_delta.into(), - sum_rows, - other_outputs, - false, - ) + self.any_block_builder(inputs, sum_rows, other_outputs, false) } fn any_block_builder( &mut self, inputs: TypeRow, - extension_delta: ExtensionSet, sum_rows: impl IntoIterator, other_outputs: TypeRow, entry: bool, @@ -242,7 +213,6 @@ impl + AsRef> CFGBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), sum_rows, - extension_delta, }); let parent = self.container_node(); let block_n = if entry { @@ -257,9 +227,9 @@ impl + AsRef> CFGBuilder { BlockBuilder::create(self.hugr_mut(), block_n) } - /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` - /// and `outputs` and `extension_delta` explicitly specified, plus a UnitSum type - /// (a Sum of `n_cases` unit types) to select the successor. + /// Return a builder for a non-entry [`DataflowBlock`] child graph with + /// `inputs` and `outputs` , plus a UnitSum type (a Sum of `n_cases` unit + /// types) to select the successor. /// /// # Errors /// @@ -269,17 +239,15 @@ impl + AsRef> CFGBuilder { signature: Signature, n_cases: usize, ) -> Result, BuildError> { - self.block_builder_exts( + self.block_builder( signature.input, vec![type_row![]; n_cases], signature.output, - signature.runtime_reqs, ) } /// Return a builder for the entry [`DataflowBlock`] child graph with `outputs` /// and the variants of the branching Sum value specified by `sum_rows`. - /// Extension delta will be inferred. /// /// # Errors /// @@ -288,35 +256,12 @@ impl + AsRef> CFGBuilder { &mut self, sum_rows: impl IntoIterator, other_outputs: TypeRow, - ) -> Result, BuildError> { - self.entry_builder_exts(sum_rows, other_outputs, TO_BE_INFERRED) - } - - /// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`, - /// the variants of the branching Sum value specified by `sum_rows`, and - /// `extension_delta` explicitly specified. ([entry_builder](Self::entry_builder) - /// may be used to infer.) - /// - /// # Errors - /// - /// This function will return an error if an entry block has already been built. - pub fn entry_builder_exts( - &mut self, - sum_rows: impl IntoIterator, - other_outputs: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let inputs = self .inputs .take() .ok_or(BuildError::EntryBuiltError(self.cfg_node))?; - self.any_block_builder( - inputs, - extension_delta.into(), - sum_rows, - other_outputs, - true, - ) + self.any_block_builder(inputs, sum_rows, other_outputs, true) } /// Return a builder for the entry [`DataflowBlock`] child graph with @@ -333,22 +278,6 @@ impl + AsRef> CFGBuilder { self.entry_builder(vec![type_row![]; n_cases], outputs) } - /// Return a builder for the entry [`DataflowBlock`] child graph with - /// `outputs` and a Sum of `n_cases` unit types, and explicit `extension_delta`. - /// ([simple_entry_builder](Self::simple_entry_builder) may be used to infer.) - /// - /// # Errors - /// - /// This function will return an error if there is an error adding the node. - pub fn simple_entry_builder_exts( - &mut self, - outputs: TypeRow, - n_cases: usize, - extension_delta: impl Into, - ) -> Result, BuildError> { - self.entry_builder_exts(vec![type_row![]; n_cases], outputs, extension_delta) - } - /// Returns the exit block of this [`CFGBuilder`]. pub fn exit_block(&self) -> BasicBlockID { self.exit_node.into() @@ -412,23 +341,10 @@ impl + AsRef> BlockBuilder { impl BlockBuilder { /// Initialize a [`DataflowBlock`] rooted HUGR builder. - /// Extension delta will be inferred. pub fn new( inputs: impl Into, sum_rows: impl IntoIterator, other_outputs: impl Into, - ) -> Result { - Self::new_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED) - } - - /// Initialize a [`DataflowBlock`] rooted HUGR builder. - /// `extension_delta` is explicitly specified; alternatively, [new](Self::new) - /// may be used to infer it. - pub fn new_exts( - inputs: impl Into, - sum_rows: impl IntoIterator, - other_outputs: impl Into, - extension_delta: impl Into, ) -> Result { let inputs = inputs.into(); let sum_rows: Vec<_> = sum_rows.into_iter().collect(); @@ -437,7 +353,6 @@ impl BlockBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), sum_rows, - extension_delta: extension_delta.into(), }; let base = Hugr::new(op); @@ -507,11 +422,7 @@ pub(crate) mod test { ) -> Result<(), BuildError> { let usize_row: TypeRow = vec![usize_t()].into(); let sum2_variants = vec![usize_row.clone(), usize_row]; - let mut entry_b = cfg_builder.entry_builder_exts( - sum2_variants.clone(), - type_row![], - ExtensionSet::new(), - )?; + let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?; let entry = { let [inw] = entry_b.input_wires_arr(); @@ -537,11 +448,7 @@ pub(crate) mod test { let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum()); let sum_variants = vec![type_row![]]; - let mut entry_b = cfg_builder.entry_builder_exts( - sum_variants.clone(), - type_row![], - ExtensionSet::new(), - )?; + let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![])?; let [inw] = entry_b.input_wires_arr(); let entry = { let sum = entry_b.load_const(&sum_tuple_const); diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 01f5e3e45..eb48b4fbc 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -245,8 +245,8 @@ mod test { use crate::builder::{Container, HugrBuilder, ModuleBuilder}; use crate::extension::prelude::{qb_t, usize_t}; - use crate::extension::{ExtensionId, ExtensionSet}; - use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; + use crate::extension::ExtensionId; + use crate::std_extensions::arithmetic::float_types::ConstF64; use crate::utils::test_quantum_extension::{ self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64, }; @@ -260,10 +260,7 @@ mod test { #[test] fn simple_linear() { let build_res = build_main( - Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) - .with_extension_delta(float_types::EXTENSION_ID) - .into(), + Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]).into(), |mut f_build| { let wires = f_build.input_wires().map(Some).collect(); @@ -314,11 +311,7 @@ mod test { Signature::new( vec![qb_t(), qb_t(), usize_t()], vec![qb_t(), qb_t(), bool_t()], - ) - .with_extension_delta(ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - my_ext_name, - ])), + ), ) .unwrap(); @@ -351,38 +344,33 @@ mod test { #[test] fn ancillae() { - let build_res = build_main( - Signature::new_endo(qb_t()) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) - .into(), - |mut f_build| { - let mut circ = f_build.as_circuit(f_build.input_wires()); - assert_eq!(circ.n_wires(), 1); + let build_res = build_main(Signature::new_endo(qb_t()).into(), |mut f_build| { + let mut circ = f_build.as_circuit(f_build.input_wires()); + assert_eq!(circ.n_wires(), 1); - let [q0] = circ.tracked_units_arr(); - let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?; - let ancilla = circ.track_wire(ancilla); + let [q0] = circ.tracked_units_arr(); + let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?; + let ancilla = circ.track_wire(ancilla); - assert_ne!(ancilla, 0); - assert_eq!(circ.n_wires(), 2); - assert_eq!(circ.tracked_units_arr(), [q0, ancilla]); + assert_ne!(ancilla, 0); + assert_eq!(circ.n_wires(), 2); + assert_eq!(circ.tracked_units_arr(), [q0, ancilla]); - circ.append(cx_gate(), [q0, ancilla])?; - let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?; + circ.append(cx_gate(), [q0, ancilla])?; + let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?; - let q0 = circ.untrack_wire(q0)?; + let q0 = circ.untrack_wire(q0)?; - assert_eq!(circ.tracked_units_arr(), [ancilla]); + assert_eq!(circ.tracked_units_arr(), [ancilla]); - circ.append_and_consume(q_discard(), [q0])?; + circ.append_and_consume(q_discard(), [q0])?; - let outs = circ.finish(); + let outs = circ.finish(); - assert_eq!(outs.len(), 1); + assert_eq!(outs.len(), 1); - f_build.finish_with_outputs(outs) - }, - ); + f_build.finish_with_outputs(outs) + }); assert_matches!(build_res, Ok(_)); } diff --git a/hugr-core/src/builder/conditional.rs b/hugr-core/src/builder/conditional.rs index 0404abaf3..73670526c 100644 --- a/hugr-core/src/builder/conditional.rs +++ b/hugr-core/src/builder/conditional.rs @@ -1,6 +1,4 @@ -use crate::extension::TO_BE_INFERRED; use crate::hugr::views::HugrView; -use crate::ops::dataflow::DataflowOpTrait; use crate::types::{Signature, TypeRow}; use crate::ops; @@ -16,7 +14,7 @@ use super::{ }; use crate::Node; -use crate::{extension::ExtensionSet, hugr::HugrMut, Hugr}; +use crate::{hugr::HugrMut, Hugr}; use std::collections::HashSet; @@ -107,7 +105,6 @@ impl + AsRef> ConditionalBuilder { .clone() .try_into() .expect("Parent node does not have Conditional optype."); - let extension_delta = cond.signature().runtime_reqs.clone(); let inputs = cond .case_input_row(case) .ok_or(ConditionalBuildError::NotCase { conditional, case })?; @@ -118,8 +115,7 @@ impl + AsRef> ConditionalBuilder { let outputs = cond.outputs; let case_op = ops::Case { - signature: Signature::new(inputs.clone(), outputs.clone()) - .with_extension_delta(extension_delta.clone()), + signature: Signature::new(inputs.clone(), outputs.clone()), }; let case_node = // add case before any existing subsequent cases @@ -134,7 +130,7 @@ impl + AsRef> ConditionalBuilder { let dfg_builder = DFGBuilder::create_with_io( self.hugr_mut(), case_node, - Signature::new(inputs, outputs).with_extension_delta(extension_delta), + Signature::new(inputs, outputs), )?; Ok(CaseBuilder::from_dfg_builder(dfg_builder)) @@ -142,33 +138,18 @@ impl + AsRef> ConditionalBuilder { } impl HugrBuilder for ConditionalBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.base.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.base.validate()?; Ok(self.base) } } impl ConditionalBuilder { - /// Initialize a Conditional rooted HUGR builder, extension delta will be inferred. + /// Initialize a Conditional rooted HUGR builder. pub fn new( sum_rows: impl IntoIterator, other_inputs: impl Into, outputs: impl Into, - ) -> Result { - Self::new_exts(sum_rows, other_inputs, outputs, TO_BE_INFERRED) - } - - /// Initialize a Conditional rooted HUGR builder, - /// `extension_delta` explicitly specified. Alternatively, - /// [new](Self::new) may be used to infer it. - pub fn new_exts( - sum_rows: impl IntoIterator, - other_inputs: impl Into, - outputs: impl Into, - extension_delta: impl Into, ) -> Result { let sum_rows: Vec<_> = sum_rows.into_iter().collect(); let other_inputs = other_inputs.into(); @@ -181,7 +162,6 @@ impl ConditionalBuilder { sum_rows, other_inputs, outputs, - extension_delta: extension_delta.into(), }; let base = Hugr::new(op); let conditional_node = base.root(); @@ -225,12 +205,8 @@ mod test { #[test] fn basic_conditional() -> Result<(), BuildError> { - let mut conditional_b = ConditionalBuilder::new_exts( - [type_row![], type_row![]], - vec![usize_t()], - vec![usize_t()], - ExtensionSet::new(), - )?; + let mut conditional_b = + ConditionalBuilder::new([type_row![], type_row![]], vec![usize_t()], vec![usize_t()])?; n_identity(conditional_b.case_builder(1)?)?; n_identity(conditional_b.case_builder(0)?)?; diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index b84f3a05a..4e66f857f 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -82,10 +82,7 @@ impl DFGBuilder { } impl HugrBuilder for DFGBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.base.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.base.validate()?; Ok(self.base) } @@ -418,19 +415,15 @@ pub(crate) mod test { #[test] fn simple_inter_graph_edge() { let builder = || -> Result { - let mut f_build = FunctionBuilder::new( - "main", - Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(), - )?; + let mut f_build = + FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?; let [i1] = f_build.input_wires_arr(); let noop = f_build.add_dataflow_op(Noop(bool_t()), [i1])?; let i1 = noop.out_wire(0); - let mut nested = f_build.dfg_builder( - Signature::new(type_row![], vec![bool_t()]).with_prelude(), - [], - )?; + let mut nested = + f_build.dfg_builder(Signature::new(type_row![], vec![bool_t()]), [])?; let id = nested.add_dataflow_op(Noop(bool_t()), [i1])?; @@ -445,10 +438,8 @@ pub(crate) mod test { #[test] fn add_inputs_outputs() { let builder = || -> Result<(Hugr, Node), BuildError> { - let mut f_build = FunctionBuilder::new( - "main", - Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(), - )?; + let mut f_build = + FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?; let f_node = f_build.container_node(); let [i0] = f_build.input_wires_arr(); diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 1387c1ec5..a77f01e5f 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -50,10 +50,7 @@ impl Default for ModuleBuilder { } impl HugrBuilder for ModuleBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.0.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.0.validate()?; Ok(self.0) } diff --git a/hugr-core/src/builder/tail_loop.rs b/hugr-core/src/builder/tail_loop.rs index fd6fb03b8..2baa0bcd5 100644 --- a/hugr-core/src/builder/tail_loop.rs +++ b/hugr-core/src/builder/tail_loop.rs @@ -1,4 +1,3 @@ -use crate::extension::{ExtensionSet, TO_BE_INFERRED}; use crate::ops::{self, DataflowOpTrait}; use crate::hugr::views::HugrView; @@ -72,29 +71,15 @@ impl + AsRef> TailLoopBuilder { impl TailLoopBuilder { /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR. - /// Extension delta will be inferred. pub fn new( just_inputs: impl Into, inputs_outputs: impl Into, just_outputs: impl Into, - ) -> Result { - Self::new_exts(just_inputs, inputs_outputs, just_outputs, TO_BE_INFERRED) - } - - /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR. - /// `extension_delta` is explicitly specified; alternatively, [new](Self::new) - /// may be used to infer it. - pub fn new_exts( - just_inputs: impl Into, - inputs_outputs: impl Into, - just_outputs: impl Into, - extension_delta: impl Into, ) -> Result { let tail_loop = ops::TailLoop { just_inputs: just_inputs.into(), just_outputs: just_outputs.into(), rest: inputs_outputs.into(), - extension_delta: extension_delta.into(), }; let base = Hugr::new(tail_loop.clone()); let root = base.root(); @@ -109,7 +94,7 @@ mod test { use crate::extension::prelude::bool_t; use crate::{ builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer}, - extension::prelude::{usize_t, ConstUsize, PRELUDE_ID}, + extension::prelude::{usize_t, ConstUsize}, hugr::ValidationError, ops::Value, type_row, @@ -120,8 +105,7 @@ mod test { #[test] fn basic_loop() -> Result<(), BuildError> { let build_result: Result = { - let mut loop_b = - TailLoopBuilder::new_exts(vec![], vec![bool_t()], vec![usize_t()], PRELUDE_ID)?; + let mut loop_b = TailLoopBuilder::new(vec![], vec![bool_t()], vec![usize_t()])?; let [i1] = loop_b.input_wires_arr(); let const_wire = loop_b.add_load_value(ConstUsize::new(1)); @@ -138,10 +122,8 @@ mod test { fn loop_with_conditional() -> Result<(), BuildError> { let build_result = { let mut module_builder = ModuleBuilder::new(); - let mut fbuild = module_builder.define_function( - "main", - Signature::new(vec![bool_t()], vec![usize_t()]).with_prelude(), - )?; + let mut fbuild = module_builder + .define_function("main", Signature::new(vec![bool_t()], vec![usize_t()]))?; let _fdef = { let [b1] = fbuild.input_wires_arr(); let loop_id = { diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 078fe3c27..42e04629b 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -832,7 +832,6 @@ impl<'a> Context<'a> { ); self.make_term(table::Term::List(parts)) } - TypeArg::Extensions { .. } => self.make_term_apply("compat.ext_set", &[]), TypeArg::Variable { v } => self.export_type_arg_var(v), } } @@ -939,7 +938,6 @@ impl<'a> Context<'a> { let types = self.make_term(table::Term::List(parts)); self.make_term_apply(model::CORE_TUPLE_TYPE, &[types]) } - TypeParam::Extensions => self.make_term_apply("compat.ext_set_type", &[]), } } @@ -1175,19 +1173,15 @@ mod test { use crate::{ builder::{Dataflow, DataflowSubContainer}, extension::prelude::qb_t, - std_extensions::arithmetic::float_types, types::Signature, - utils::test_quantum_extension::{self, cx_gate, h_gate}, + utils::test_quantum_extension::{cx_gate, h_gate}, Hugr, }; #[fixture] fn test_simple_circuit() -> Hugr { crate::builder::test::build_main( - Signature::new_endo(vec![qb_t(), qb_t()]) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) - .with_extension_delta(float_types::EXTENSION_ID) - .into(), + Signature::new_endo(vec![qb_t(), qb_t()]).into(), |mut f_build| { let wires: Vec<_> = f_build.input_wires().collect(); let mut linear = f_build.as_circuit(wires); diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 23238ccfd..4300c74ad 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -23,7 +23,7 @@ use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::RowVariable; -use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; +use crate::types::{CustomType, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; mod const_fold; @@ -547,8 +547,6 @@ pub struct Extension { pub version: Version, /// Unique identifier for the extension. pub name: ExtensionId, - /// Runtime dependencies this extension has on other extensions. - pub runtime_reqs: ExtensionSet, /// Types defined by this extension. types: BTreeMap, /// Operation declarations with serializable definitions. @@ -572,7 +570,6 @@ impl Extension { Self { name, version, - runtime_reqs: Default::default(), types: Default::default(), operations: Default::default(), } @@ -629,12 +626,6 @@ impl Extension { } } - /// Extend the runtime requirements of this extension with another set of extensions. - pub fn add_requirements(&mut self, runtime_reqs: impl Into) { - let reqs = mem::take(&mut self.runtime_reqs); - self.runtime_reqs = reqs.union(runtime_reqs.into()); - } - /// Allows read-only access to the operations in this Extension pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc> { self.operations.get(name) @@ -734,14 +725,6 @@ pub enum ExtensionBuildError { #[display("[{}]", _0.iter().join(", "))] pub struct ExtensionSet(BTreeSet); -/// A special ExtensionId which indicates that the delta of a non-Function -/// container node should be computed by extension inference. -/// -/// See [`infer_extensions`] which lists the container nodes to which this can be applied. -/// -/// [`infer_extensions`]: crate::hugr::Hugr::infer_extensions -pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked(".TO_BE_INFERRED"); - impl ExtensionSet { /// Creates a new empty extension set. pub const fn new() -> Self { @@ -753,14 +736,6 @@ impl ExtensionSet { self.0.insert(extension.clone()); } - /// Adds a type var (which must have been declared as a [TypeParam::Extensions]) to this set - pub fn insert_type_var(&mut self, idx: usize) { - // Represent type vars as string representation of variable index. - // This is not a legal IdentList or ExtensionId so should not conflict. - self.0 - .insert(ExtensionId::new_unchecked(idx.to_string().as_str())); - } - /// Returns `true` if the set contains the given extension. pub fn contains(&self, extension: &ExtensionId) -> bool { self.0.contains(extension) @@ -783,14 +758,6 @@ impl ExtensionSet { set } - /// An ExtensionSet containing a single type variable - /// (which must have been declared as a [TypeParam::Extensions]) - pub fn type_var(idx: usize) -> Self { - let mut set = Self::new(); - set.insert_type_var(idx); - set - } - /// Returns the union of two extension sets. pub fn union(mut self, other: Self) -> Self { self.0.extend(other.0); @@ -821,22 +788,6 @@ impl ExtensionSet { pub fn is_empty(&self) -> bool { self.0.is_empty() } - - pub(crate) fn validate(&self, params: &[TypeParam]) -> Result<(), SignatureError> { - self.iter() - .filter_map(as_typevar) - .try_for_each(|var_idx| check_typevar_decl(params, var_idx, &TypeParam::Extensions)) - } - - pub(crate) fn substitute(&self, t: &Substitution) -> Self { - Self::from_iter(self.0.iter().flat_map(|e| match as_typevar(e) { - None => vec![e.clone()], - Some(i) => match t.apply_var(i, &TypeParam::Extensions) { - TypeArg::Extensions{es} => es.iter().cloned().collect::>(), - _ => panic!("value for type var was not extension set - type scheme should be validated first"), - }, - })) - } } impl From for ExtensionSet { @@ -863,16 +814,6 @@ impl<'a> IntoIterator for &'a ExtensionSet { } } -fn as_typevar(e: &ExtensionId) -> Option { - // Type variables are represented as radix-10 numbers, which are illegal - // as standard ExtensionIds. Hence if an ExtensionId starts with a digit, - // we assume it must be a type variable, and fail fast if it isn't. - match e.chars().next() { - Some(c) if c.is_ascii_digit() => Some(str::parse(e).unwrap()), - _ => None, - } -} - impl FromIterator for ExtensionSet { fn from_iter>(iter: I) -> Self { Self(BTreeSet::from_iter(iter)) @@ -967,16 +908,8 @@ pub mod test { type Strategy = BoxedStrategy; fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - ( - hash_set(0..10usize, 0..3), - hash_set(any::(), 0..3), - ) - .prop_map(|(vars, extensions)| { - ExtensionSet::union_over( - std::iter::once(extensions.into_iter().collect::()) - .chain(vars.into_iter().map(ExtensionSet::type_var)), - ) - }) + hash_set(any::(), 0..3) + .prop_map(|extensions| extensions.into_iter().collect::()) .boxed() } } diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index 64092981f..14995db27 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -149,9 +149,14 @@ impl ExtensionDeclaration { /// Create an [`Extension`] from this declaration. pub fn make_extension( &self, - imports: &ExtensionSet, + _imports: &ExtensionSet, ctx: DeclarationContext<'_>, ) -> Result, ExtensionDeclarationError> { + // TODO: The imports were previously used as runtime extension + // requirements for the constructed extension. Now that runtime + // extension requirements are removed, they are no longer recorded + // anywhere in the `Extension`. + Extension::try_new_arc( self.name.clone(), // TODO: Get the version as a parameter. @@ -164,7 +169,6 @@ impl ExtensionDeclaration { for o in &self.operations { o.register(ext, ctx, extension_ref)?; } - ext.add_requirements(imports.clone()); Ok(()) }, diff --git a/hugr-core/src/extension/declarative/signature.rs b/hugr-core/src/extension/declarative/signature.rs index b84d56853..e2300956b 100644 --- a/hugr-core/src/extension/declarative/signature.rs +++ b/hugr-core/src/extension/declarative/signature.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use crate::extension::prelude::PRELUDE_ID; -use crate::extension::{ExtensionSet, SignatureFunc, TypeDef}; +use crate::extension::{SignatureFunc, TypeDef}; use crate::types::type_param::TypeParam; use crate::types::{CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeRowRV}; use crate::Extension; @@ -26,10 +26,6 @@ pub(super) struct SignatureDeclaration { inputs: Vec, /// The outputs of the operation. outputs: Vec, - /// A set of extensions invoked while running this operation. - #[serde(default)] - #[serde(skip_serializing_if = "crate::utils::is_default")] - extensions: ExtensionSet, } impl SignatureDeclaration { @@ -53,7 +49,6 @@ impl SignatureDeclaration { let body = FuncValueType { input: make_type_row(&self.inputs)?, output: make_type_row(&self.outputs)?, - runtime_reqs: self.extensions.clone(), }; let poly_func = PolyFuncTypeRV::new(op_params, body); diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index d5c9a5b5d..48eef663f 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -244,11 +244,7 @@ impl SignatureFunc { // TODO raise warning: https://github.com/CQCL/hugr/issues/1432 SignatureFunc::MissingValidateFunc(ts) => (ts, args), }; - let mut res = pf.instantiate(args)?; - - // Automatically add the extensions where the operation is defined to - // the runtime requirements of the op. - res.runtime_reqs.insert(def.extension.clone()); + let res = pf.instantiate(args)?; // If there are any row variables left, this will fail with an error: res.try_into() @@ -722,8 +718,7 @@ pub(super) mod test { Ok(Signature::new( vec![usize_t(); 3], vec![Type::new_tuple(vec![usize_t(); 3])] - ) - .with_extension_delta(EXT_ID)) + )) ); assert_eq!(def.validate_args(&args, &[]), Ok(())); @@ -733,10 +728,10 @@ pub(super) mod test { let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; assert_eq!( def.compute_signature(&args), - Ok( - Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) - .with_extension_delta(EXT_ID) - ) + Ok(Signature::new( + tyvars.clone(), + vec![Type::new_tuple(tyvars)] + )) ); def.validate_args(&args, &[TypeBound::Copyable.into()]) .unwrap(); @@ -787,14 +782,11 @@ pub(super) mod test { ), extension_ref, )?; - let tv = Type::new_var_use(1, TypeBound::Copyable); + let tv = Type::new_var_use(0, TypeBound::Copyable); let args = [TypeArg::Type { ty: tv.clone() }]; - let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; + let decls = [TypeBound::Copyable.into()]; def.validate_args(&args, &decls).unwrap(); - assert_eq!( - def.compute_signature(&args), - Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) - ); + assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv))); // But not with an external row variable let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); assert_eq!( @@ -811,36 +803,6 @@ pub(super) mod test { Ok(()) } - #[test] - fn instantiate_extension_delta() -> Result<(), Box> { - use crate::extension::prelude::bool_t; - - let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - let params: Vec = vec![TypeParam::Extensions]; - let db_set = ExtensionSet::type_var(0); - let fun_ty = Signature::new_endo(bool_t()).with_extension_delta(db_set); - - let def = ext.add_op( - "SimpleOp".into(), - "".into(), - PolyFuncTypeRV::new(params.clone(), fun_ty), - extension_ref, - )?; - - // Concrete extension set - let es = ExtensionSet::singleton(EXT_ID); - let exp_fun_ty = Signature::new_endo(bool_t()).with_extension_delta(es.clone()); - let args = [TypeArg::Extensions { es }]; - - def.validate_args(&args, ¶ms).unwrap(); - assert_eq!(def.compute_signature(&args), Ok(exp_fun_ty)); - - Ok(()) - })?; - - Ok(()) - } - mod proptest { use std::sync::Weak; diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index f88a84a0d..b1e78baf8 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -11,7 +11,7 @@ use crate::extension::simple_op::{ try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; use crate::extension::{ - ConstFold, ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDefBound, + ConstFold, ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDefBound, }; use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName}; use crate::ops::OpName; @@ -245,10 +245,6 @@ impl CustomConst for ConstString { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } - fn get_type(&self) -> Type { string_type() } @@ -438,10 +434,6 @@ impl CustomConst for ConstUsize { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } - fn get_type(&self) -> Type { usize_t() } @@ -495,9 +487,6 @@ impl CustomConst for ConstError { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } fn get_type(&self) -> Type { error_type() } @@ -555,9 +544,6 @@ impl CustomConst for ConstExternalSymbol { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } fn get_type(&self) -> Type { self.typ.clone() } @@ -1068,7 +1054,7 @@ mod test { let optype: OpType = op.clone().into(); assert_eq!( optype.dataflow_signature().unwrap().as_ref(), - &Signature::new_endo(type_row![Type::UNIT]).with_prelude() + &Signature::new_endo(type_row![Type::UNIT]) ); let new_op = Barrier::from_extension_op(optype.as_extension_op().unwrap()).unwrap(); @@ -1121,10 +1107,6 @@ mod test { assert!(error_val.validate().is_ok()); - assert_eq!( - error_val.extension_reqs(), - ExtensionSet::singleton(PRELUDE_ID) - ); assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); @@ -1181,10 +1163,6 @@ mod test { let string_const: ConstString = ConstString::new("Lorem ipsum".into()); assert_eq!(string_const.name(), "ConstString(\"Lorem ipsum\")"); assert!(string_const.validate().is_ok()); - assert_eq!( - string_const.extension_reqs(), - ExtensionSet::singleton(PRELUDE_ID) - ); assert!(string_const.equal_consts(&ConstString::new("Lorem ipsum".into()))); assert!(!string_const.equal_consts(&ConstString::new("Lorem ispum".into()))); } @@ -1206,10 +1184,6 @@ mod test { assert_eq!(subject.get_type(), Type::UNIT); assert_eq!(subject.name(), "@foo"); assert!(subject.validate().is_ok()); - assert_eq!( - subject.extension_reqs(), - ExtensionSet::singleton(PRELUDE_ID) - ); assert!(subject.equal_consts(&ConstExternalSymbol::new("foo", Type::UNIT, false))); assert!(!subject.equal_consts(&ConstExternalSymbol::new("bar", Type::UNIT, false))); assert!(!subject.equal_consts(&ConstExternalSymbol::new("foo", string_type(), false))); diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index 06b4e3939..3817d65c8 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -111,10 +111,8 @@ mod tests { #[test] fn test_build_unwrap() { - let mut builder = DFGBuilder::new( - Signature::new(Type::from(option_type(bool_t())), bool_t()).with_prelude(), - ) - .unwrap(); + let mut builder = + DFGBuilder::new(Signature::new(Type::from(option_type(bool_t())), bool_t())).unwrap(); let [opt] = builder.input_wires_arr(); diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 90eae9422..a08cbfb38 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -9,10 +9,6 @@ //! HUGR nodes and wire types. This is computed from the union of all extension //! required across the HUGR. //! -//! This is distinct from _runtime_ extension requirements, which are defined -//! more granularly in each function signature by the `runtime_reqs` -//! field. See the `extension_inference` feature and related modules for that. -//! //! Note: These procedures are only temporary until `hugr-model` is stabilized. //! Once that happens, hugrs will no longer be directly deserialized using serde //! but instead will be created by the methods in `crate::import`. As these diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index 19373b04c..f3ae229ec 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -11,7 +11,7 @@ use crate::builder::{ Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; -use crate::extension::prelude::{bool_t, usize_custom_t, usize_t, ConstUsize, PRELUDE_ID}; +use crate::extension::prelude::{bool_t, usize_custom_t, usize_t, ConstUsize}; use crate::extension::resolution::WeakExtensionRegistry; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError, @@ -28,7 +28,7 @@ use crate::std_extensions::arithmetic::int_types::{self, int_type}; use crate::std_extensions::collections::list::ListValue; use crate::types::type_param::TypeParam; use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; -use crate::{std_extensions, type_row, Extension, Hugr, HugrView}; +use crate::{type_row, Extension, Hugr, HugrView}; #[rstest] #[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())] @@ -158,17 +158,7 @@ fn check_extension_resolution(mut hugr: Hugr) { /// Build a small hugr using the float types extension and check that the extensions are resolved. #[rstest] fn resolve_hugr_extensions_simple() { - let mut build = DFGBuilder::new( - Signature::new(vec![], vec![float64_type()]).with_extension_delta( - [ - PRELUDE_ID.to_owned(), - std_extensions::arithmetic::float_types::EXTENSION_ID.to_owned(), - ] - .into_iter() - .collect::(), - ), - ) - .unwrap(); + let mut build = DFGBuilder::new(Signature::new(vec![], vec![float64_type()])).unwrap(); // A constant op using a non-prelude extension. let f_const = build.add_load_const(Value::extension(ConstF64::new(f64::consts::PI))); @@ -218,7 +208,7 @@ fn resolve_hugr_extensions() { let (ext_b, op_b) = make_extension("dummy.b", "op_b"); let (ext_c, op_c) = make_extension("dummy.c", "op_c"); let (ext_d, op_d) = make_extension("dummy.d", "op_d"); - let (ext_e, op_e) = make_extension("dummy.e", "op_e"); + let (_ext_e, op_e) = make_extension("dummy.e", "op_e"); let mut module = ModuleBuilder::new(); @@ -234,18 +224,7 @@ fn resolve_hugr_extensions() { let mut func = module .define_function( "dummy_fn", - Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( - [ - ext_a.name(), - ext_b.name(), - ext_c.name(), - ext_d.name(), - ext_e.name(), - ] - .into_iter() - .cloned() - .collect::(), - ), + Signature::new(vec![float64_type(), bool_t()], vec![]), ) .unwrap(); let [func_i0, func_i1] = func.input_wires_arr(); @@ -368,11 +347,7 @@ fn resolve_call() { let dummy_fn = module.declare("called_fn", dummy_fn_sig).unwrap(); let mut func = module - .define_function( - "caller_fn", - Signature::new(vec![], vec![bool_t()]) - .with_extension_delta(ExtensionSet::from_iter(expected_exts.clone())), - ) + .define_function("caller_fn", Signature::new(vec![], vec![bool_t()])) .unwrap(); let _load_func = func.load_func(&dummy_fn, &[generic_type_1]).unwrap(); let call = func.call(&dummy_fn, &[generic_type_2], vec![]).unwrap(); @@ -390,15 +365,10 @@ fn resolve_call() { /// Fail when collecting extensions but the weak pointers are not resolved. #[rstest] fn dropped_weak_extensions() { - let (ext_a, op_a) = make_extension("dummy.a", "op_a"); + let (_ext_a, op_a) = make_extension("dummy.a", "op_a"); let mut func = FunctionBuilder::new( "dummy_fn", - Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( - [ext_a.name()] - .into_iter() - .cloned() - .collect::(), - ), + Signature::new(vec![float64_type(), bool_t()], vec![]), ) .unwrap(); let [_func_i0, func_i1] = func.input_wires_arr(); diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 6094f0aee..28bd6a12b 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -131,8 +131,6 @@ pub(crate) fn collect_signature_exts( used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { - // Note that we do not include the signature's `runtime_reqs` here, as those refer - // to _runtime_ requirements that we do not be require to be defined. collect_type_row_exts(&signature.input, used_extensions, missing_extensions); collect_type_row_exts(&signature.output, used_extensions, missing_extensions); } diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index d70d6b861..af5803eff 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -124,8 +124,6 @@ pub(super) fn resolve_signature_exts( extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - // Note that we do not include the signature's `runtime_reqs` here, as those refer - // to _runtime_ requirements that may not be currently present. resolve_type_row_exts(node, &mut signature.input, extensions, used_extensions)?; resolve_type_row_exts(node, &mut signature.output, extensions, used_extensions)?; Ok(()) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 93250b8e3..16152b298 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -29,8 +29,8 @@ use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionResolutionError, WeakExtensionRegistry, }; -use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; -use crate::ops::{OpTag, OpTrait}; +use crate::extension::{ExtensionRegistry, ExtensionSet}; +use crate::ops::OpTag; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; use crate::{Direction, Node}; @@ -112,9 +112,6 @@ impl Hugr { /// /// Validates the Hugr against the provided extension registry, ensuring all /// operations are resolved. - /// - /// If the feature `extension_inference` is enabled, we will ensure every function - /// correctly specifies the extensions required by its contained ops. pub fn load_json( reader: impl Read, extension_registry: &ExtensionRegistry, @@ -122,87 +119,11 @@ impl Hugr { let mut hugr: Hugr = serde_json::from_reader(reader)?; hugr.resolve_extension_defs(extension_registry)?; - hugr.validate_no_extensions()?; - - if cfg!(feature = "extension_inference") { - hugr.infer_extensions(false)?; - hugr.validate_extensions()?; - } + hugr.validate()?; Ok(hugr) } - /// Infers an extension-delta for any non-function container node - /// whose current [extension_delta] contains [TO_BE_INFERRED]. The inferred delta - /// will be the smallest delta compatible with its children and that includes any - /// other [ExtensionId]s in the current delta. - /// - /// If `remove` is true, for such container nodes *without* [TO_BE_INFERRED], - /// ExtensionIds are removed from the delta if they are *not* used by any child node. - /// - /// The non-function container nodes are: - /// [Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop] - /// - /// [Case]: crate::ops::Case - /// [CFG]: crate::ops::CFG - /// [Conditional]: crate::ops::Conditional - /// [DataflowBlock]: crate::ops::DataflowBlock - /// [DFG]: crate::ops::DFG - /// [TailLoop]: crate::ops::TailLoop - /// [extension_delta]: crate::ops::OpType::extension_delta - /// [ExtensionId]: crate::extension::ExtensionId - pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> { - fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> { - match optype { - OpType::DFG(dfg) => Some(&mut dfg.signature.runtime_reqs), - OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta), - OpType::TailLoop(tl) => Some(&mut tl.extension_delta), - OpType::CFG(cfg) => Some(&mut cfg.signature.runtime_reqs), - OpType::Conditional(c) => Some(&mut c.extension_delta), - OpType::Case(c) => Some(&mut c.signature.runtime_reqs), - //OpType::Lift(_) // Not ATM: only a single element, and we expect Lift to be removed - //OpType::FuncDefn(_) // Not at present due to the possibility of recursion - _ => None, - } - } - fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result { - let mut child_sets = h - .children(node) - .collect::>() // Avoid borrowing h over recursive call - .into_iter() - .map(|ch| Ok((ch, infer(h, ch, remove)?))) - .collect::, _>>()?; - - let Some(es) = delta_mut(h.op_types.get_mut(node.into_portgraph())) else { - return Ok(h.get_optype(node).extension_delta()); - }; - if es.contains(&TO_BE_INFERRED) { - // Do not remove anything from current delta - any other elements are a lower bound - child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst - } else if remove { - child_sets.iter().try_for_each(|(ch, ch_exts)| { - if !es.is_superset(ch_exts) { - return Err(ExtensionError { - parent: node, - parent_extensions: es.clone(), - child: *ch, - child_extensions: ch_exts.clone(), - }); - } - Ok(()) - })?; - } else { - return Ok(es.clone()); // Can't neither add nor remove, so nothing to do - } - let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); - *es = ExtensionSet::singleton(TO_BE_INFERRED).missing_from(&merged); - - Ok(es.clone()) - } - infer(self, self.root(), remove)?; - Ok(()) - } - /// Given a Hugr that has been deserialized, collect all extensions used to /// define the HUGR while resolving all [`OpType::OpaqueOp`] operations into /// [`OpType::ExtensionOp`]s and updating the extension pointer in all @@ -214,11 +135,6 @@ impl Hugr { /// to define the HUGR nodes and wire types. This is computed from the union /// of all extension required across the HUGR. /// - /// This is distinct from _runtime_ extension requirements computed in - /// [`Hugr::infer_extensions`], which are computed more granularly in each - /// function signature by the `runtime_reqs` field and define the set - /// of capabilities required by the runtime to execute each function. - /// /// Updates the internal extension registry with the extensions used in the /// definition. /// @@ -393,73 +309,13 @@ pub enum LoadHugrError { #[cfg(test)] mod test { - use std::sync::Arc; use std::{fs::File, io::BufReader}; - use super::internal::HugrMutInternals; - #[cfg(feature = "extension_inference")] - use super::ValidationError; - use super::{ExtensionError, Hugr, HugrMut, HugrView, Node}; - use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY, TO_BE_INFERRED}; - use crate::ops::{ExtensionOp, OpName}; - use crate::types::type_param::TypeParam; - use crate::types::{ - FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV, TypeRow, - }; - - use crate::{const_extension_ids, ops, test_file, type_row, Extension}; - use cool_asserts::assert_matches; - use lazy_static::lazy_static; - use rstest::rstest; + use super::{Hugr, HugrView}; + use crate::extension::PRELUDE_REGISTRY; - const_extension_ids! { - pub(crate) const LIFT_EXT_ID: ExtensionId = "LIFT_EXT_ID"; - } - lazy_static! { - /// Tests only extension holding an Op that can add arbitrary extensions to a row. - pub(crate) static ref LIFT_EXT: Arc = { - Extension::new_arc( - LIFT_EXT_ID, - hugr::extension::Version::new(0, 0, 0), - |ext, extension_ref| { - ext.add_op( - OpName::new_inline("Lift"), - "".into(), - PolyFuncTypeRV::new( - vec![TypeParam::Extensions, TypeParam::new_list(TypeBound::Any)], - FuncValueType::new_endo(TypeRV::new_row_var_use(1, TypeBound::Any)) - .with_extension_delta(ExtensionSet::type_var(0)), - ), - extension_ref, - ) - .unwrap(); - }, - ) - }; - } - - pub(crate) fn lift_op( - type_row: impl Into, - extensions: impl Into, - ) -> ExtensionOp { - LIFT_EXT - .instantiate_extension_op( - "Lift", - [ - TypeArg::Extensions { - es: extensions.into(), - }, - TypeArg::Sequence { - elems: type_row - .into() - .iter() - .map(|t| TypeArg::Type { ty: t.clone() }) - .collect(), - }, - ], - ) - .unwrap() - } + use crate::test_file; + use cool_asserts::assert_matches; #[test] fn impls_send_and_sync() { @@ -522,164 +378,4 @@ mod test { ); assert_matches!(&hugr, Ok(_)); } - - const_extension_ids! { - const XA: ExtensionId = "EXT_A"; - const XB: ExtensionId = "EXT_B"; - } - - #[rstest] - #[case([], XA.into())] - #[case([XA], XA.into())] - #[case([XB], ExtensionSet::from_iter([XA, XB]))] - - fn infer_single_delta( - #[case] parent: impl IntoIterator, - #[values(true, false)] remove: bool, // makes no difference when inferring - #[case] result: ExtensionSet, - ) { - let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into()); - let (mut h, _) = build_ext_dfg(parent); - h.infer_extensions(remove).unwrap(); - assert_eq!(h, build_ext_dfg(result.union(LIFT_EXT_ID.into())).0); - } - - #[test] - fn infer_removes_from_delta() { - let parent = ExtensionSet::from_iter([XA, XB, LIFT_EXT_ID]); - let mut h = build_ext_dfg(parent.clone()).0; - let backup = h.clone(); - h.infer_extensions(false).unwrap(); - assert_eq!(h, backup); // did nothing - h.infer_extensions(true).unwrap(); - assert_eq!( - h, - build_ext_dfg(ExtensionSet::from_iter([XA, LIFT_EXT_ID])).0 - ); - } - - #[test] - fn infer_bad_remove() { - let (mut h, mid) = build_ext_dfg(XB.into()); - let backup = h.clone(); - h.infer_extensions(false).unwrap(); - assert_eq!(h, backup); // did nothing - let val_res = h.validate(); - let expected_err = ExtensionError { - parent: h.root(), - parent_extensions: XB.into(), - child: mid, - child_extensions: ExtensionSet::from_iter([XA, LIFT_EXT_ID]), - }; - #[cfg(feature = "extension_inference")] - assert_eq!( - val_res, - Err(ValidationError::ExtensionError(expected_err.clone())) - ); - #[cfg(not(feature = "extension_inference"))] - assert!(val_res.is_ok()); - - let inf_res = h.infer_extensions(true); - assert_eq!(inf_res, Err(expected_err)); - } - - fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) { - let ty = Type::new_function(Signature::new_endo(type_row![])); - let mut h = Hugr::new(ops::DFG { - signature: Signature::new_endo(ty.clone()).with_extension_delta(parent.clone()), - }); - let root = h.root(); - let mid = add_inliftout(&mut h, root, ty); - (h, mid) - } - - fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node { - let inp = h.add_node_with_parent( - p, - ops::Input { - types: ty.clone().into(), - }, - ); - let out = h.add_node_with_parent( - p, - ops::Output { - types: ty.clone().into(), - }, - ); - let mid = h.add_node_with_parent(p, lift_op(ty, XA)); - h.connect(inp, 0, mid, 0); - h.connect(mid, 0, out, 0); - mid - } - - #[rstest] - // Base case success: delta inferred for parent equals grandparent. - #[case([XA], [TO_BE_INFERRED], true, [XA])] - // Success: delta inferred for parent is subset of grandparent - #[case([XA, XB], [TO_BE_INFERRED], true, [XA])] - // Base case failure: infers [XA] for parent but grandparent has disjoint set - #[case([XB], [TO_BE_INFERRED], false, [XA])] - // Failure: as previous, but extra "lower bound" on parent that has no effect - #[case([XB], [XA, TO_BE_INFERRED], false, [XA])] - // Failure: grandparent ok wrt. child but parent specifies extra lower-bound XB - #[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])] - // Success: grandparent includes extra XB required for parent's "lower bound" - #[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])] - // Success: grandparent is also inferred so can include 'extra' XB from parent - #[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])] - // No inference: extraneous XB in parent is removed so all become [XA]. - #[case([XA], [XA, XB], true, [XA])] - fn infer_three_generations( - #[case] grandparent: impl IntoIterator, - #[case] parent: impl IntoIterator, - #[case] success: bool, - #[case] result: impl IntoIterator, - ) { - let ty = Type::new_function(Signature::new_endo(type_row![])); - let grandparent = ExtensionSet::from_iter(grandparent).union(LIFT_EXT_ID.into()); - let parent = ExtensionSet::from_iter(parent).union(LIFT_EXT_ID.into()); - let result = ExtensionSet::from_iter(result).union(LIFT_EXT_ID.into()); - let root_ty = ops::Conditional { - sum_rows: vec![type_row![]], - other_inputs: ty.clone().into(), - outputs: ty.clone().into(), - extension_delta: grandparent.clone(), - }; - let mut h = Hugr::new(root_ty.clone()); - let p = h.add_node_with_parent( - h.root(), - ops::Case { - signature: Signature::new_endo(ty.clone()).with_extension_delta(parent), - }, - ); - add_inliftout(&mut h, p, ty.clone()); - assert!(h.validate_extensions().is_err()); - let backup = h.clone(); - let inf_res = h.infer_extensions(true); - if success { - assert!(inf_res.is_ok()); - let expected_p = ops::Case { - signature: Signature::new_endo(ty).with_extension_delta(result.clone()), - }; - let mut expected = backup; - expected.replace_op(p, expected_p); - let expected_gp = ops::Conditional { - extension_delta: result, - ..root_ty - }; - expected.replace_op(h.root(), expected_gp); - - assert_eq!(h, expected); - } else { - assert_eq!( - inf_res, - Err(ExtensionError { - parent: h.root(), - parent_extensions: grandparent, - child: p, - child_extensions: result - }) - ); - } - } } diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 6353820f4..7805d3c67 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -614,9 +614,7 @@ mod test { module, ops::FuncDefn { name: "main".into(), - signature: Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]) - .with_prelude() - .into(), + signature: Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]).into(), }, ); diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index f69d2ad39..09f234de0 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -372,8 +372,7 @@ mod test { #[test] fn insert_ports() { let (nop, mut hugr) = { - let mut builder = - DFGBuilder::new(Signature::new_endo(Type::UNIT).with_prelude()).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(Type::UNIT)).unwrap(); let [nop_in] = builder.input_wires_arr(); let nop = builder .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) diff --git a/hugr-core/src/hugr/patch/consts.rs b/hugr-core/src/hugr/patch/consts.rs index eb9142f85..4ddd0b476 100644 --- a/hugr-core/src/hugr/patch/consts.rs +++ b/hugr-core/src/hugr/patch/consts.rs @@ -120,7 +120,6 @@ impl PatchHugrMut for RemoveConst { mod test { use super::*; - use crate::extension::prelude::PRELUDE_ID; use crate::{ builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}, extension::prelude::ConstUsize, @@ -133,10 +132,7 @@ mod test { let mut build = ModuleBuilder::new(); let con_node = build.add_constant(Value::extension(ConstUsize::new(2))); - let mut dfg_build = build.define_function( - "main", - Signature::new_endo(type_row![]).with_extension_delta(PRELUDE_ID.clone()), - )?; + let mut dfg_build = build.define_function("main", Signature::new_endo(type_row![]))?; let load_1 = dfg_build.load_const(&con_node); let load_2 = dfg_build.load_const(&con_node); let tup = dfg_build.make_tuple([load_1, load_2])?; diff --git a/hugr-core/src/hugr/patch/inline_call.rs b/hugr-core/src/hugr/patch/inline_call.rs index 0619d373e..5f31fbc79 100644 --- a/hugr-core/src/hugr/patch/inline_call.rs +++ b/hugr-core/src/hugr/patch/inline_call.rs @@ -121,10 +121,8 @@ mod test { use crate::extension::prelude::usize_t; use crate::ops::handle::{FuncID, NodeHandle}; use crate::ops::{Input, OpType, Value}; - use crate::std_extensions::arithmetic::{ - int_ops::{self, IntOpDef}, - int_types::{self, ConstInt, INT_TYPES}, - }; + use crate::std_extensions::arithmetic::int_types::INT_TYPES; + use crate::std_extensions::arithmetic::{int_ops::IntOpDef, int_types::ConstInt}; use crate::types::{PolyFuncType, Signature, Type, TypeBound}; use crate::{HugrView, Node}; @@ -145,9 +143,7 @@ mod test { fn test_inline() -> Result<(), Box> { let mut mb = ModuleBuilder::new(); let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?)); - let sig = Signature::new_endo(INT_TYPES[4].clone()) - .with_extension_delta(int_ops::EXTENSION_ID) - .with_extension_delta(int_types::EXTENSION_ID); + let sig = Signature::new_endo(INT_TYPES[4].clone()); let func = { let mut fb = mb.define_function("foo", sig.clone())?; let c1 = fb.load_const(&cst3); @@ -205,9 +201,7 @@ mod test { #[test] fn test_recursion() -> Result<(), Box> { let mut mb = ModuleBuilder::new(); - let sig = Signature::new_endo(INT_TYPES[5].clone()) - .with_extension_delta(int_ops::EXTENSION_ID) - .with_extension_delta(int_types::EXTENSION_ID); + let sig = Signature::new_endo(INT_TYPES[5].clone()); let (func, rec_call) = { let mut fb = mb.define_function("foo", sig.clone())?; let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?); @@ -294,10 +288,7 @@ mod test { #[test] fn test_polymorphic() -> Result<(), Box> { let tuple_ty = Type::new_tuple(vec![usize_t(); 2]); - let mut fb = FunctionBuilder::new( - "mkpair", - Signature::new(usize_t(), tuple_ty.clone()).with_prelude(), - )?; + let mut fb = FunctionBuilder::new("mkpair", Signature::new(usize_t(), tuple_ty.clone()))?; let inner = fb.define_function( "id", PolyFuncType::new( diff --git a/hugr-core/src/hugr/patch/inline_dfg.rs b/hugr-core/src/hugr/patch/inline_dfg.rs index 58fd51cbb..c7356f8e0 100644 --- a/hugr-core/src/hugr/patch/inline_dfg.rs +++ b/hugr-core/src/hugr/patch/inline_dfg.rs @@ -145,8 +145,6 @@ mod test { SubContainer, }; use crate::extension::prelude::qb_t; - use crate::extension::ExtensionSet; - use crate::hugr::patch::inline_dfg::InlineDFGError; use crate::hugr::HugrMut; use crate::ops::handle::{DfgID, NodeHandle}; use crate::ops::{OpType, Value}; @@ -175,6 +173,8 @@ mod test { #[case(true)] #[case(false)] fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box> { + use crate::hugr::patch::inline_dfg::InlineDFGError; + let int_ty = &int_types::INT_TYPES[6]; let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?; @@ -334,12 +334,8 @@ mod test { .add_dataflow_op(test_quantum_extension::measure(), r.outputs())? .outputs_arr(); // Node using the boolean. Here we just select between two empty computations. - let mut if_n = inner.conditional_builder_exts( - ([type_row![], type_row![]], b), - [], - type_row![], - ExtensionSet::new(), - )?; + let mut if_n = + inner.conditional_builder(([type_row![], type_row![]], b), [], type_row![])?; if_n.case_builder(0)?.finish_with_outputs([])?; if_n.case_builder(1)?.finish_with_outputs([])?; let if_n = if_n.finish_sub_container()?; diff --git a/hugr-core/src/hugr/patch/outline_cfg.rs b/hugr-core/src/hugr/patch/outline_cfg.rs index b43b6b4e3..b9cafed9e 100644 --- a/hugr-core/src/hugr/patch/outline_cfg.rs +++ b/hugr-core/src/hugr/patch/outline_cfg.rs @@ -6,11 +6,9 @@ use itertools::Itertools; use thiserror::Error; use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; -use crate::extension::ExtensionSet; use crate::hugr::{HugrMut, HugrView}; use crate::ops; use crate::ops::controlflow::BasicBlock; -use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::NodeHandle; use crate::ops::{DataflowBlock, OpType}; use crate::PortIndex; @@ -33,12 +31,11 @@ impl OutlineCfg { } /// Compute the entry and exit nodes of the CFG which contains - /// [`self.blocks`], along with the output neighbour its parent graph and - /// the combined extension_deltas of all of the blocks. - fn compute_entry_exit_outside_extensions( + /// [`self.blocks`], along with the output neighbour its parent graph. + fn compute_entry_exit( &self, h: &impl HugrView, - ) -> Result<(Node, Node, Node, ExtensionSet), OutlineCfgError> { + ) -> Result<(Node, Node, Node), OutlineCfgError> { let cfg_n = match self .blocks .iter() @@ -50,13 +47,12 @@ impl OutlineCfg { _ => return Err(OutlineCfgError::NotSiblings), }; let o = h.get_optype(cfg_n); - let OpType::CFG(o) = o else { + let OpType::CFG(_) = o else { return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone())); }; let cfg_entry = h.children(cfg_n).next().unwrap(); let mut entry = None; let mut exit_succ = None; - let mut extension_delta = ExtensionSet::new(); for &n in self.blocks.iter() { if n == cfg_entry || h.input_neighbours(n) @@ -71,7 +67,6 @@ impl OutlineCfg { } } } - extension_delta = extension_delta.union(o.signature().runtime_reqs.clone()); let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s)); match external_succs.at_most_one() { Ok(None) => (), // No external successors @@ -87,7 +82,7 @@ impl OutlineCfg { }; } match (entry, exit_succ) { - (Some(e), Some((x, o))) => Ok((e, x, o, extension_delta)), + (Some(e), Some((x, o))) => Ok((e, x, o)), (None, _) => Err(OutlineCfgError::NoEntryNode), (_, None) => Err(OutlineCfgError::NoExitNode), } @@ -98,7 +93,7 @@ impl PatchVerification for OutlineCfg { type Error = OutlineCfgError; type Node = Node; fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> { - self.compute_entry_exit_outside_extensions(h)?; + self.compute_entry_exit(h)?; Ok(()) } @@ -118,8 +113,7 @@ impl PatchHugrMut for OutlineCfg { self, h: &mut impl HugrMut, ) -> Result<[Node; 2], OutlineCfgError> { - let (entry, exit, outside, extension_delta) = - self.compute_entry_exit_outside_extensions(h)?; + let (entry, exit, outside) = self.compute_entry_exit(h)?; // 1. Compute signature // These panic()s only happen if the Hugr would not have passed validate() let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else { @@ -136,17 +130,10 @@ impl PatchHugrMut for OutlineCfg { // 2. new_block contains input node, sub-cfg, exit node all connected let (new_block, cfg_node) = { - let mut new_block_bldr = BlockBuilder::new_exts( - inputs.clone(), - vec![type_row![]], - outputs.clone(), - extension_delta.clone(), - ) - .unwrap(); + let mut new_block_bldr = + BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap(); let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires()); - let cfg = new_block_bldr - .cfg_builder_exts(wires_in, outputs, extension_delta) - .unwrap(); + let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap(); let cfg = cfg.finish_sub_container().unwrap(); let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum()); let pred_wire = new_block_bldr.load_const(&unit_sum); diff --git a/hugr-core/src/hugr/patch/replace.rs b/hugr-core/src/hugr/patch/replace.rs index 183200751..606733543 100644 --- a/hugr-core/src/hugr/patch/replace.rs +++ b/hugr-core/src/hugr/patch/replace.rs @@ -609,21 +609,18 @@ mod test { inputs: vec![listy.clone()].into(), sum_rows: vec![type_row![]], other_outputs: vec![listy.clone()].into(), - extension_delta: list::EXTENSION_ID.into(), }, ); let r_df1 = replacement.add_node_with_parent( r_bb, DFG { - signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())) - .with_extension_delta(list::EXTENSION_ID), + signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())), }, ); let r_df2 = replacement.add_node_with_parent( r_bb, DFG { - signature: Signature::new(intermed, simple_unary_plus(just_list.clone())) - .with_extension_delta(list::EXTENSION_ID), + signature: Signature::new(intermed, simple_unary_plus(just_list.clone())), }, ); [0, 1] @@ -706,7 +703,7 @@ mod test { }, op_sig.input() ); - h.simple_entry_builder_exts(op_sig.output.clone(), 1, op_sig.runtime_reqs.clone())? + h.simple_entry_builder(op_sig.output.clone(), 1)? } else { h.simple_block_builder(op_sig.into_owned(), 1)? }; @@ -733,25 +730,20 @@ mod test { ext.add_op("baz".into(), "".to_string(), utou.clone(), extension_ref) .unwrap(); }); - let ext_name = ext.name().clone(); let foo = ext.instantiate_extension_op("foo", []).unwrap(); let bar = ext.instantiate_extension_op("bar", []).unwrap(); let baz = ext.instantiate_extension_op("baz", []).unwrap(); let mut registry = test_quantum_extension::REG.clone(); registry.register(ext).unwrap(); - let mut h = DFGBuilder::new( - Signature::new(vec![usize_t(), bool_t()], vec![usize_t()]) - .with_extension_delta(ext_name.clone()), - ) - .unwrap(); + let mut h = + DFGBuilder::new(Signature::new(vec![usize_t(), bool_t()], vec![usize_t()])).unwrap(); let [i, b] = h.input_wires_arr(); let mut cond = h - .conditional_builder_exts( + .conditional_builder( (vec![type_row![]; 2], b), [(usize_t(), i)], vec![usize_t()].into(), - ext_name.clone(), ) .unwrap(); let mut case1 = cond.case_builder(0).unwrap(); @@ -759,12 +751,7 @@ mod test { let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node(); let mut case2 = cond.case_builder(1).unwrap(); let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap(); - let mut baz_dfg = case2 - .dfg_builder( - utou.clone().with_extension_delta(ext_name.clone()), - bar.outputs(), - ) - .unwrap(); + let mut baz_dfg = case2.dfg_builder(utou.clone(), bar.outputs()).unwrap(); let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap(); let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap(); let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node(); diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index 3908ba58e..245a3cdc0 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -393,7 +393,6 @@ pub(in crate::hugr::patch) mod test { DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::{bool_t, qb_t}; - use crate::extension::ExtensionSet; use crate::hugr::patch::PatchVerification; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Patch}; @@ -404,7 +403,7 @@ pub(in crate::hugr::patch) mod test { use crate::std_extensions::logic::test::and_op; use crate::std_extensions::logic::LogicOp; use crate::types::{Signature, Type}; - use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID}; + use crate::utils::test_quantum_extension::{cx_gate, h_gate}; use crate::{IncomingPort, Node}; use super::SimpleReplacement; @@ -421,12 +420,8 @@ pub(in crate::hugr::patch) mod test { fn make_hugr() -> Result { let mut module_builder = ModuleBuilder::new(); let _f_id = { - let just_q: ExtensionSet = EXTENSION_ID.into(); - let mut func_builder = module_builder.define_function( - "main", - Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) - .with_extension_delta(just_q.clone()), - )?; + let mut func_builder = module_builder + .define_function("main", Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]))?; let [qb0, qb1, qb2] = func_builder.input_wires_arr(); @@ -462,7 +457,7 @@ pub(in crate::hugr::patch) mod test { /// ┤ H ├┤ X ├ /// └───┘└───┘ fn make_dfg_hugr() -> Result { - let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?; + let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?; let [wire0, wire1] = dfg_builder.input_wires_arr(); let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?; let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 49a7b9321..6848062b7 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -6,10 +6,10 @@ use crate::builder::{ DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::Noop; -use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; +use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::extension::simple_op::MakeRegisteredOp; +use crate::extension::test::SimpleOpDef; use crate::extension::ExtensionRegistry; -use crate::extension::{test::SimpleOpDef, ExtensionSet}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::validate::ValidationError; use crate::ops::custom::{ExtensionOp, OpaqueOp, OpaqueOpError}; @@ -300,7 +300,7 @@ fn weighted_hugr_ser() { let t_row = vec![Type::new_sum([vec![usize_t()], vec![qb_t()]])]; let mut f_build = module_builder - .define_function("main", Signature::new(t_row.clone(), t_row).with_prelude()) + .define_function("main", Signature::new(t_row.clone(), t_row)) .unwrap(); let outputs = f_build @@ -324,7 +324,7 @@ fn weighted_hugr_ser() { #[test] fn dfg_roundtrip() -> Result<(), Box> { let tp: Vec = vec![bool_t(); 2]; - let mut dfg = DFGBuilder::new(Signature::new(tp.clone(), tp).with_prelude())?; + let mut dfg = DFGBuilder::new(Signature::new(tp.clone(), tp))?; let mut params: [_; 2] = dfg.input_wires_arr(); for p in params.iter_mut() { *p = dfg @@ -390,8 +390,8 @@ fn opaque_ops() -> Result<(), Box> { #[test] fn function_type() -> Result<(), Box> { - let fn_ty = Type::new_function(Signature::new_endo(vec![bool_t()]).with_prelude()); - let mut bldr = DFGBuilder::new(Signature::new_endo(vec![fn_ty.clone()]).with_prelude())?; + let fn_ty = Type::new_function(Signature::new_endo(vec![bool_t()])); + let mut bldr = DFGBuilder::new(Signature::new_endo(vec![fn_ty.clone()]))?; let op = bldr.add_dataflow_op(Noop(fn_ty), bldr.input_wires())?; let h = bldr.finish_hugr_with_outputs(op.outputs())?; @@ -482,10 +482,8 @@ fn roundtrip_value(#[case] value: Value) { } fn polyfunctype1() -> PolyFuncType { - let mut extension_set = ExtensionSet::new(); - extension_set.insert_type_var(1); - let function_type = Signature::new_endo(type_row![]).with_extension_delta(extension_set); - PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type) + let function_type = Signature::new_endo(type_row![]); + PolyFuncType::new([TypeParam::max_nat()], function_type) } fn polyfunctype2() -> PolyFuncTypeRV { @@ -541,7 +539,7 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] -#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}, TypeArg::Extensions{ es: ExtensionSet::singleton(PRELUDE_ID)} ]).unwrap())] +#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}]).unwrap())] #[case(ops::CallIndirect { signature : Signature::new_endo(vec![bool_t()]) })] fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { check_testing_roundtrip(NodeSer { diff --git a/hugr-core/src/hugr/serialize/upgrade/test.rs b/hugr-core/src/hugr/serialize/upgrade/test.rs index 5e1d3ee51..e3aa4740b 100644 --- a/hugr-core/src/hugr/serialize/upgrade/test.rs +++ b/hugr-core/src/hugr/serialize/upgrade/test.rs @@ -55,7 +55,6 @@ pub fn hugr_with_named_op() -> Hugr { #[rstest] #[case("empty_hugr", empty_hugr())] #[case("hugr_with_named_op", hugr_with_named_op())] -#[cfg_attr(feature = "extension_inference", ignore = "Fails extension inference")] fn serial_upgrade(#[case] name: String, #[case] hugr: Hugr) { let path = TEST_CASE_DIR.join(format!("{}.json", name)); if !path.exists() { diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 3690ec947..31b34b044 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,12 +9,12 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::{SignatureError, TO_BE_INFERRED}; +use crate::extension::SignatureError; use crate::ops::constant::ConstTypeError; use crate::ops::custom::{ExtensionOp, OpaqueOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; -use crate::ops::{FuncDefn, NamedOp, OpName, OpParent, OpTag, OpTrait, OpType, ValidateOp}; +use crate::ops::{FuncDefn, NamedOp, OpName, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::type_param::TypeParam; use crate::types::EdgeKind; use crate::{Direction, Hugr, Node, Port}; @@ -35,68 +35,15 @@ struct ValidationContext<'a> { } impl Hugr { - /// Check the validity of the HUGR, assuming that it has no open extension - /// variables. - /// TODO: Add a version of validation which allows for open extension - /// variables (see github issue #457) + /// Check the validity of the HUGR. pub fn validate(&self) -> Result<(), ValidationError> { - self.validate_no_extensions()?; - if cfg!(feature = "extension_inference") { - self.validate_extensions()?; - } - Ok(()) - } - - /// Check the validity of the HUGR, but don't check consistency of extension - /// requirements between connected nodes or between parents and children. - pub fn validate_no_extensions(&self) -> Result<(), ValidationError> { let mut validator = ValidationContext::new(self); validator.validate() } - - /// Validate extensions, i.e. that extension deltas from parent nodes are reflected in their children. - pub fn validate_extensions(&self) -> Result<(), ValidationError> { - for parent in self.nodes() { - let parent_op = self.get_optype(parent); - if parent_op.extension_delta().contains(&TO_BE_INFERRED) { - return Err(ValidationError::ExtensionsNotInferred { node: parent }); - } - let parent_extensions = match parent_op.inner_function_type() { - Some(s) => s.runtime_reqs.clone(), - None => match parent_op.tag() { - OpTag::Cfg | OpTag::Conditional => parent_op.extension_delta(), - // ModuleRoot holds but does not execute its children, so allow any extensions - OpTag::ModuleRoot => continue, - _ => { - assert!(self.children(parent).next().is_none(), - "Unknown parent node type {} - not a DataflowParent, Module, Cfg or Conditional", - parent_op); - continue; - } - }, - }; - for child in self.children(parent) { - let child_extensions = self.get_optype(child).extension_delta(); - if !parent_extensions.is_superset(&child_extensions) { - return Err(ExtensionError { - parent, - parent_extensions, - child, - child_extensions, - } - .into()); - } - } - } - Ok(()) - } } impl<'a> ValidationContext<'a> { /// Create a new validation context. - // Allow unused "extension_closure" variable for when - // the "extension_inference" feature is disabled. - #[allow(unused_variables)] pub fn new(hugr: &'a Hugr) -> Self { let dominators = HashMap::new(); Self { hugr, dominators } diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 236f40e3f..7fec75bce 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -11,8 +11,8 @@ use crate::builder::{ FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, }; use crate::extension::prelude::Noop; -use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; -use crate::extension::{Extension, ExtensionRegistry, ExtensionSet, TypeDefBound, PRELUDE}; +use crate::extension::prelude::{bool_t, qb_t, usize_t}; +use crate::extension::{Extension, ExtensionRegistry, TypeDefBound, PRELUDE}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::HugrMut; use crate::ops::dataflow::IOTrait; @@ -35,9 +35,7 @@ use crate::{ fn make_simple_hugr(copies: usize) -> (Hugr, Node) { let def_op: OpType = ops::FuncDefn { name: "main".into(), - signature: Signature::new(vec![bool_t()], vec![bool_t(); copies]) - .with_prelude() - .into(), + signature: Signature::new(vec![bool_t()], vec![bool_t(); copies]).into(), } .into(); @@ -119,7 +117,7 @@ fn leaf_root() { #[test] fn dfg_root() { let dfg_op: OpType = ops::DFG { - signature: Signature::new_endo(vec![bool_t()]).with_prelude(), + signature: Signature::new_endo(vec![bool_t()]), } .into(); @@ -217,17 +215,14 @@ fn df_children_restrictions() { #[test] fn test_ext_edge() { - let mut h = closed_dfg_root_hugr( - Signature::new(vec![bool_t(), bool_t()], vec![bool_t()]) - .with_extension_delta(TO_BE_INFERRED), - ); + let mut h = closed_dfg_root_hugr(Signature::new(vec![bool_t(), bool_t()], vec![bool_t()])); let [input, output] = h.get_io(h.root()).unwrap(); // Nested DFG bool_t() -> bool_t() let sub_dfg = h.add_node_with_parent( h.root(), ops::DFG { - signature: Signature::new_endo(vec![bool_t()]).with_extension_delta(TO_BE_INFERRED), + signature: Signature::new_endo(vec![bool_t()]), }, ); // this Xor has its 2nd input unconnected @@ -254,7 +249,6 @@ fn test_ext_edge() { ); //Order edge. This will need metadata indicating its purpose. h.add_other_edge(input, sub_dfg); - h.infer_extensions(false).unwrap(); h.validate().unwrap(); } @@ -289,8 +283,7 @@ fn no_ext_edge_into_func() -> Result<(), Box> { #[test] fn test_local_const() { - let mut h = - closed_dfg_root_hugr(Signature::new_endo(bool_t()).with_extension_delta(TO_BE_INFERRED)); + let mut h = closed_dfg_root_hugr(Signature::new_endo(bool_t())); let [input, output] = h.get_io(h.root()).unwrap(); let and = h.add_node_with_parent(h.root(), and_op()); h.connect(input, 0, and, 0); @@ -312,7 +305,6 @@ fn test_local_const() { h.connect(lcst, 0, and, 1); assert_eq!(h.static_source(lcst), Some(cst)); // There is no edge from Input to LoadConstant, but that's OK: - h.infer_extensions(false).unwrap(); h.validate().unwrap(); } @@ -549,11 +541,7 @@ fn no_polymorphic_consts() -> Result<(), Box> { reg.validate()?; let mut def = FunctionBuilder::new( "myfunc", - PolyFuncType::new( - [BOUND], - Signature::new(vec![], vec![list_of_var.clone()]) - .with_extension_delta(list::EXTENSION_ID), - ), + PolyFuncType::new([BOUND], Signature::new(vec![], vec![list_of_var.clone()])), )?; let empty_list = Value::extension(list::ListValue::new_empty(Type::new_var_use( 0, @@ -646,7 +634,7 @@ fn row_variables() -> Result<(), Box> { "id", PolyFuncType::new( [TypeParam::new_list(TypeBound::Any)], - Signature::new(inner_ft.clone(), ft_usz).with_extension_delta(e.name.clone()), + Signature::new(inner_ft.clone(), ft_usz), ), )?; // All the wires here are carrying higher-order Function values @@ -668,19 +656,15 @@ fn row_variables() -> Result<(), Box> { #[test] fn test_polymorphic_call() -> Result<(), Box> { + // TODO: This tests a function call that is polymorphic in an extension set. + // Should this be rewritten to be polymorphic in something else or removed? + let e = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - let params: Vec = vec![ - TypeBound::Any.into(), - TypeParam::Extensions, - TypeBound::Any.into(), - ]; - let evaled_fn = Type::new_function( - Signature::new( - Type::new_var_use(0, TypeBound::Any), - Type::new_var_use(2, TypeBound::Any), - ) - .with_extension_delta(ExtensionSet::type_var(1)), - ); + let params: Vec = vec![TypeBound::Any.into(), TypeBound::Any.into()]; + let evaled_fn = Type::new_function(Signature::new( + Type::new_var_use(0, TypeBound::Any), + Type::new_var_use(1, TypeBound::Any), + )); // Single-input/output version of the higher-order "eval" operation, with extension param. // Note the extension-delta of the eval node includes that of the input function. ext.add_op( @@ -690,9 +674,8 @@ fn test_polymorphic_call() -> Result<(), Box> { params.clone(), Signature::new( vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], - Type::new_var_use(2, TypeBound::Any), - ) - .with_extension_delta(ExtensionSet::type_var(1)), + Type::new_var_use(1, TypeBound::Any), + ), ), extension_ref, )?; @@ -700,27 +683,23 @@ fn test_polymorphic_call() -> Result<(), Box> { Ok(()) })?; - fn utou(e: impl Into) -> Type { - Type::new_function(Signature::new_endo(usize_t()).with_extension_delta(e.into())) + fn utou() -> Type { + Type::new_function(Signature::new_endo(usize_t())) } let int_pair = Type::new_tuple(vec![usize_t(); 2]); - // Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints + // Root DFG: applies a function int-->int to each element of a pair of two ints let mut d = DFGBuilder::new(inout_sig( - vec![utou(PRELUDE_ID), int_pair.clone()], + vec![utou(), int_pair.clone()], vec![int_pair.clone()], ))?; - // ....by calling a function parametrized (int--e-->int, int_pair) -> int_pair + // ....by calling a function (int-->int, int_pair) -> int_pair let f = { - let es = ExtensionSet::type_var(0); let mut f = d.define_function( "two_ints", PolyFuncType::new( - vec![TypeParam::Extensions], - Signature::new(vec![utou(es.clone()), int_pair.clone()], int_pair.clone()) - .with_extension_delta(EXT_ID) - .with_prelude() - .with_extension_delta(es.clone()), + vec![], + Signature::new(vec![utou(), int_pair.clone()], int_pair.clone()), ), )?; let [func, tup] = f.input_wires_arr(); @@ -731,14 +710,7 @@ fn test_polymorphic_call() -> Result<(), Box> { )?; let mut cc = c.case_builder(0)?; let [i1, i2] = cc.input_wires_arr(); - let op = e.instantiate_extension_op( - "eval", - vec![ - usize_t().into(), - TypeArg::Extensions { es }, - usize_t().into(), - ], - )?; + let op = e.instantiate_extension_op("eval", vec![usize_t().into(), usize_t().into()])?; let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr(); let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr(); cc.finish_with_outputs([f1, f2])?; @@ -748,18 +720,10 @@ fn test_polymorphic_call() -> Result<(), Box> { }; let [func, tup] = d.input_wires_arr(); - let call = d.call( - f.handle(), - &[TypeArg::Extensions { - es: ExtensionSet::singleton(PRELUDE_ID), - }], - [func, tup], - )?; + let call = d.call(f.handle(), &[], [func, tup])?; let h = d.finish_hugr_with_outputs(call.outputs())?; let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap(); - let exp_fun_ty = Signature::new(vec![utou(PRELUDE_ID), int_pair.clone()], int_pair) - .with_extension_delta(EXT_ID) - .with_prelude(); + let exp_fun_ty = Signature::new(vec![utou(), int_pair.clone()], int_pair); assert_eq!(call_ty.as_ref(), &exp_fun_ty); Ok(()) } @@ -817,7 +781,6 @@ fn cfg_children_restrictions() { inputs: vec![bool_t()].into(), sum_rows: vec![type_row![]], other_outputs: vec![bool_t()].into(), - extension_delta: ExtensionSet::new(), }, ); let const_op: ops::Const = ops::Value::unit_sum(0, 1).unwrap().into(); @@ -872,7 +835,6 @@ fn cfg_children_restrictions() { inputs: vec![qb_t()].into(), sum_rows: vec![type_row![]], other_outputs: vec![qb_t()].into(), - extension_delta: ExtensionSet::new(), }, ); let mut block_children = b.hierarchy.children(block.into_portgraph()); @@ -899,8 +861,7 @@ fn cfg_connections() -> Result<(), Box> { let mut hugr = CFGBuilder::new(Signature::new_endo(usize_t()))?; let unary_pred = hugr.add_constant(Value::unary_unit_sum()); - let mut entry = - hugr.simple_entry_builder_exts(vec![usize_t()].into(), 1, ExtensionSet::new())?; + let mut entry = hugr.simple_entry_builder(vec![usize_t()].into(), 1)?; let p = entry.load_const(&unary_pred); let ins = entry.input_wires(); let entry = entry.finish_with_outputs(p, ins)?; @@ -944,219 +905,3 @@ fn cfg_entry_io_bug() -> Result<(), Box> { Ok(()) } - -#[cfg(feature = "extension_inference")] -mod extension_tests { - use self::ops::handle::{BasicBlockID, TailLoopID}; - use rstest::rstest; - - use super::*; - use crate::builder::handle::Outputs; - use crate::builder::{BlockBuilder, BuildHandle, CFGBuilder, DFGWrapper, TailLoopBuilder}; - use crate::extension::prelude::PRELUDE_ID; - use crate::extension::ExtensionSet; - use crate::hugr::test::{lift_op, LIFT_EXT_ID}; - use crate::macros::const_extension_ids; - use crate::Wire; - const_extension_ids! { - const XA: ExtensionId = "A"; - const XB: ExtensionId = "BOOL_EXT"; - } - - #[rstest] - #[case::d1(|signature| ops::DFG {signature}.into())] - #[case::f1(|sig: Signature| ops::FuncDefn {name: "foo".to_string(), signature: sig.into()}.into())] - #[case::c1(|signature| ops::Case {signature}.into())] - fn parent_extension_mismatch( - #[case] parent_f: impl Fn(Signature) -> OpType, - #[values(ExtensionSet::new(), XA.into())] parent_extensions: ExtensionSet, - ) { - // Child graph adds extension "XB", but the parent (in all cases) - // declares a different delta, causing a mismatch. - - let parent = parent_f( - Signature::new_endo(usize_t()).with_extension_delta(parent_extensions.clone()), - ); - let mut hugr = Hugr::new(parent); - - let input = hugr.add_node_with_parent( - hugr.root(), - ops::Input { - types: vec![usize_t()].into(), - }, - ); - let output = hugr.add_node_with_parent( - hugr.root(), - ops::Output { - types: vec![usize_t()].into(), - }, - ); - - let lift = hugr.add_node_with_parent(hugr.root(), lift_op(usize_t(), XB)); - - hugr.connect(input, 0, lift, 0); - hugr.connect(lift, 0, output, 0); - - let result = hugr.validate(); - assert_eq!( - result, - Err(ValidationError::ExtensionError(ExtensionError { - parent: hugr.root(), - parent_extensions, - child: lift, - child_extensions: ExtensionSet::from_iter([LIFT_EXT_ID, XB]), - })) - ); - } - - #[rstest] - #[case(XA.into(), false)] - #[case(ExtensionSet::new(), false)] - #[case(ExtensionSet::from_iter([XA, XB]), true)] - fn cfg_extension_mismatch( - #[case] parent_extensions: ExtensionSet, - #[case] success: bool, - ) -> Result<(), BuildError> { - let mut cfg = CFGBuilder::new( - Signature::new_endo(usize_t()).with_extension_delta(parent_extensions.clone()), - )?; - let mut bb = cfg.simple_entry_builder_exts(usize_t().into(), 1, XB)?; - let pred = bb.add_load_value(Value::unary_unit_sum()); - let inputs = bb.input_wires(); - let blk = bb.finish_with_outputs(pred, inputs)?; - let exit = cfg.exit_block(); - cfg.branch(&blk, 0, &exit)?; - let root = cfg.hugr().root(); - let res = cfg.finish_hugr(); - if success { - assert!(res.is_ok()) - } else { - assert_eq!( - res, - Err(ValidationError::ExtensionError(ExtensionError { - parent: root, - parent_extensions, - child: blk.node(), - child_extensions: XB.into() - })) - ); - } - Ok(()) - } - - #[rstest] - #[case(XA.into(), false)] - #[case(ExtensionSet::new(), false)] - #[case(ExtensionSet::from_iter([XA, XB, LIFT_EXT_ID]), true)] - fn conditional_extension_mismatch( - #[case] parent_extensions: ExtensionSet, - #[case] success: bool, - ) { - // Child graph adds extension "XB", but the parent - // declares a different delta, in same cases causing a mismatch. - let parent = ops::Conditional { - sum_rows: vec![type_row![], type_row![]], - other_inputs: vec![usize_t()].into(), - outputs: vec![usize_t()].into(), - extension_delta: parent_extensions.clone(), - }; - let mut hugr = Hugr::new(parent); - - // First case with no delta should be ok in all cases. Second one may not be. - let [_, child] = [None, Some(XB)].map(|case_ext| { - let case_exts = if let Some(ex) = &case_ext { - ExtensionSet::from_iter([ex.clone(), LIFT_EXT_ID]) - } else { - ExtensionSet::new() - }; - let case = hugr.add_node_with_parent( - hugr.root(), - ops::Case { - signature: Signature::new_endo(usize_t()).with_extension_delta(case_exts), - }, - ); - - let input = hugr.add_node_with_parent( - case, - ops::Input { - types: vec![usize_t()].into(), - }, - ); - let output = hugr.add_node_with_parent( - case, - ops::Output { - types: vec![usize_t()].into(), - }, - ); - let res = match case_ext { - None => input, - Some(new_ext) => { - let lift = hugr.add_node_with_parent(case, lift_op(usize_t(), new_ext)); - hugr.connect(input, 0, lift, 0); - lift - } - }; - hugr.connect(res, 0, output, 0); - case - }); - // case is the last-assigned child, i.e. the one that requires 'XB' - let result = hugr.validate(); - let expected = if success { - Ok(()) - } else { - Err(ValidationError::ExtensionError(ExtensionError { - parent: hugr.root(), - parent_extensions, - child, - child_extensions: ExtensionSet::from_iter([XB, LIFT_EXT_ID]), - })) - }; - assert_eq!(result, expected); - } - - #[rstest] - #[case(make_bb, |bb: &mut DFGWrapper<_,_>, outs| bb.make_tuple(outs))] - #[case(make_tailloop, |tl: &mut DFGWrapper<_,_>, outs| tl.make_break(tl.loop_signature().unwrap().clone(), outs))] - fn bb_extension_mismatch( - #[case] dfg_fn: impl Fn(Type, ExtensionSet) -> DFGWrapper, - #[case] make_pred: impl Fn(&mut DFGWrapper, Outputs) -> Result, - // last one includes prelude because `MakeTuple` is in prelude - #[values((ExtensionSet::from_iter([XA,LIFT_EXT_ID]), false), (LIFT_EXT_ID.into(), false), (ExtensionSet::from_iter([XA,XB,LIFT_EXT_ID,PRELUDE_ID]), true))] - parent_exts_success: (ExtensionSet, bool), - ) -> Result<(), BuildError> { - let (parent_extensions, success) = parent_exts_success; - let mut dfg = dfg_fn(usize_t(), parent_extensions.clone()); - let lift = dfg.add_dataflow_op(lift_op(usize_t(), XB), dfg.input_wires())?; - let pred = make_pred(&mut dfg, lift.outputs())?; - let root = dfg.hugr().root(); - let res = dfg.finish_hugr_with_outputs([pred]); - if success { - if res.is_err() { - dbg!(&res); - } - assert!(res.is_ok()) - } else { - assert_eq!( - res, - Err(BuildError::InvalidHUGR(ValidationError::ExtensionError( - ExtensionError { - parent: root, - parent_extensions, - child: lift.node(), - child_extensions: ExtensionSet::from_iter([XB, LIFT_EXT_ID]) - } - ))) - ); - } - Ok(()) - } - - fn make_bb(t: Type, es: ExtensionSet) -> DFGWrapper { - BlockBuilder::new_exts(t.clone(), vec![t.into()], type_row![], es).unwrap() - } - - fn make_tailloop(t: Type, es: ExtensionSet) -> DFGWrapper> { - let row = TypeRow::from(t); - TailLoopBuilder::new_exts(row.clone(), type_row![], row, es).unwrap() - } -} diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index f9eedd548..ea414c376 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -400,25 +400,10 @@ pub trait HugrView: HugrInternals { fn extensions(&self) -> &ExtensionRegistry; /// Check the validity of the underlying HUGR. - /// - /// This includes checking consistency of extension requirements between - /// connected nodes and between parents and children. - /// See [`HugrView::validate_no_extensions`] for a version that doesn't check - /// extension requirements. fn validate(&self) -> Result<(), ValidationError> { #[allow(deprecated)] self.base_hugr().validate() } - - /// Check the validity of the underlying HUGR, but don't check consistency - /// of extension requirements between connected nodes or between parents and - /// children. - /// - /// For a more thorough check, use [`HugrView::validate`]. - fn validate_no_extensions(&self) -> Result<(), ValidationError> { - #[allow(deprecated)] - self.base_hugr().validate_no_extensions() - } } /// A common trait for views of a HUGR hierarchical subgraph. diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index e3ba29e2c..13dfde8f7 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -236,7 +236,7 @@ pub(super) mod test { use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, types::Signature, - utils::test_quantum_extension::{h_gate, EXTENSION_ID}, + utils::test_quantum_extension::h_gate, }; use super::*; @@ -249,10 +249,8 @@ pub(super) mod test { let mut module_builder = ModuleBuilder::new(); let (f_id, inner_id) = { - let mut func_builder = module_builder.define_function( - "main", - Signature::new_endo(vec![usize_t(), qb_t()]).with_extension_delta(EXTENSION_ID), - )?; + let mut func_builder = module_builder + .define_function("main", Signature::new_endo(vec![usize_t(), qb_t()]))?; let [int, qb] = func_builder.input_wires_arr(); @@ -288,11 +286,7 @@ pub(super) mod test { assert_eq!( region.poly_func_type(), - Some( - Signature::new_endo(vec![usize_t(), qb_t()]) - .with_extension_delta(EXTENSION_ID) - .into() - ) + Some(Signature::new_endo(vec![usize_t(), qb_t()]).into()) ); let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 6cd1d7631..9be352b5e 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -72,7 +72,6 @@ macro_rules! hugr_view_methods { fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator; fn extensions(&self) -> &crate::extension::ExtensionRegistry; fn validate(&self) -> Result<(), crate::hugr::ValidationError>; - fn validate_no_extensions(&self) -> Result<(), crate::hugr::ValidationError>; } } } diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 44e29ab1a..fa8378c7a 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -480,7 +480,6 @@ mod test { use crate::ops::OpType; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; use crate::types::Signature; - use crate::utils::test_quantum_extension::EXTENSION_ID; use crate::IncomingPort; use super::super::descendants::test::make_module_hgr; @@ -507,11 +506,7 @@ mod test { assert_eq!( region.poly_func_type(), - Some( - Signature::new_endo(vec![usize_t(), qb_t()]) - .with_extension_delta(EXTENSION_ID) - .into() - ) + Some(Signature::new_endo(vec![usize_t(), qb_t()]).into()) ); assert_eq!( diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 7fd2b9f54..b2eba044e 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -21,7 +21,6 @@ use thiserror::Error; use crate::builder::{Container, FunctionBuilder}; use crate::core::HugrNode; -use crate::extension::ExtensionSet; use crate::hugr::{HugrMut, HugrView}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{ContainerHandle, DataflowOpID}; @@ -349,11 +348,7 @@ impl SiblingSubgraph { sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); - Signature::new(input, output).with_extension_delta(ExtensionSet::union_over( - self.nodes - .iter() - .map(|n| hugr.get_optype(*n).extension_delta()), - )) + Signature::new(input, output) } /// The parent of the sibling subgraph. @@ -840,10 +835,10 @@ mod tests { use crate::builder::inout_sig; use crate::hugr::Patch; use crate::ops::Const; - use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; - use crate::std_extensions::logic::{self, LogicOp}; + use crate::std_extensions::arithmetic::float_types::ConstF64; + use crate::std_extensions::logic::LogicOp; use crate::type_row; - use crate::utils::test_quantum_extension::{self, cx_gate, rz_f64}; + use crate::utils::test_quantum_extension::{cx_gate, rz_f64}; use crate::{ builder::{ BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, @@ -889,12 +884,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) - .with_extension_delta(ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - float_types::EXTENSION_ID, - ])) - .into(), + Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -913,12 +903,7 @@ mod tests { /// A bool to bool hugr with three subsequent NOT gates. fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare( - "test", - Signature::new_endo(vec![bool_t()]) - .with_extension_delta(logic::EXTENSION_ID) - .into(), - )?; + let func = mod_builder.declare("test", Signature::new_endo(vec![bool_t()]).into())?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?; @@ -937,9 +922,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new(bool_t(), vec![bool_t(), bool_t()]) - .with_extension_delta(logic::EXTENSION_ID) - .into(), + Signature::new(bool_t(), vec![bool_t(), bool_t()]).into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -957,12 +940,7 @@ mod tests { /// A HUGR with a copy fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare( - "test", - Signature::new_endo(bool_t()) - .with_extension_delta(logic::EXTENSION_ID) - .into(), - )?; + let func = mod_builder.declare("test", Signature::new_endo(bool_t()).into())?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let in_wire = dfg.input_wires().exactly_one().unwrap(); @@ -1024,12 +1002,7 @@ mod tests { let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; assert_eq!( sub.signature(&func), - Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).with_extension_delta( - ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - float_types::EXTENSION_ID, - ]) - ) + Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) ); Ok(()) } @@ -1218,12 +1191,7 @@ mod tests { #[test] fn test_unconnected() { // test a replacement on a subgraph with a discarded output - let mut b = DFGBuilder::new( - Signature::new(bool_t(), type_row![]) - // .with_prelude() - .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID), - ) - .unwrap(); + let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap(); let inw = b.input_wires().exactly_one().unwrap(); let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap(); // Unconnected output, discarded @@ -1234,11 +1202,7 @@ mod tests { assert_eq!(subg.nodes().len(), 1); // TODO create a valid replacement let replacement = { - let mut rep_b = DFGBuilder::new( - Signature::new_endo(bool_t()) - .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID), - ) - .unwrap(); + let mut rep_b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let inw = rep_b.input_wires().exactly_one().unwrap(); let not_n = rep_b.add_dataflow_op(LogicOp::Not, [inw]).unwrap(); @@ -1253,11 +1217,7 @@ mod tests { #[test] fn single_node_subgraph() { // A hugr with a single NOT operation, with disconnected output. - let mut b = DFGBuilder::new( - Signature::new(bool_t(), type_row![]) - .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID), - ) - .unwrap(); + let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap(); let inw = b.input_wires().exactly_one().unwrap(); let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap(); // Unconnected output, discarded diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 899deb17d..ce5971364 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use crate::{ - extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError}, + extension::{ExtensionId, ExtensionRegistry, SignatureError}, hugr::{HugrMut, NodeMetadata}, ops::{ constant::{CustomConst, CustomSerialized, OpaqueValue}, @@ -791,7 +791,6 @@ impl<'a> Context<'a> { just_inputs, just_outputs, rest, - extension_delta: ExtensionSet::new(), }); let node = self.make_node(node_id, optype, parent)?; @@ -819,7 +818,6 @@ impl<'a> Context<'a> { sum_rows, other_inputs, outputs, - extension_delta: ExtensionSet::new(), }); let node = self.make_node(node_id, optype, parent)?; @@ -887,7 +885,6 @@ impl<'a> Context<'a> { inputs: types.clone(), other_outputs: TypeRow::default(), sum_rows: vec![types.clone()], - extension_delta: ExtensionSet::default(), }), ); @@ -988,7 +985,6 @@ impl<'a> Context<'a> { inputs, other_outputs, sum_rows, - extension_delta: ExtensionSet::new(), }); let node = self.make_node(node_id, optype, parent)?; @@ -1491,7 +1487,7 @@ impl<'a> Context<'a> { let runtime_type = self.import_type(runtime_type)?; let value: serde_json::Value = serde_json::from_str(json) .map_err(|_| table::ModelError::TypeError(term_id))?; - let custom_const = CustomSerialized::new(runtime_type, value, ExtensionSet::new()); + let custom_const = CustomSerialized::new(runtime_type, value); let opaque_value = OpaqueValue::new(custom_const); return Ok(Value::Extension { e: opaque_value }); } diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index ce0d44de0..5b5dbc420 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -16,7 +16,7 @@ use crate::extension::resolution::{ use std::borrow::Cow; use crate::extension::simple_op::MakeExtensionOp; -use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet}; +use crate::extension::{ExtensionId, ExtensionRegistry}; use crate::types::{EdgeKind, Signature, Substitution}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; @@ -398,12 +398,6 @@ pub trait OpTrait: Sized + Clone { None } - /// The delta between the input extensions specified for a node, - /// and the output extensions calculated for that node - fn extension_delta(&self) -> ExtensionSet { - ExtensionSet::new() - } - /// The edge kind for the non-dataflow inputs of the operation, /// not described by the signature. /// diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 6aad904cb..18f3974d4 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -8,7 +8,6 @@ use std::hash::{Hash, Hasher}; use super::{NamedOp, OpName, OpTrait, StaticTag}; use super::{OpTag, OpType}; -use crate::extension::ExtensionSet; use crate::types::{CustomType, EdgeKind, Signature, SumType, SumTypeError, Type, TypeRow}; use crate::{Hugr, HugrView}; @@ -81,10 +80,6 @@ impl OpTrait for Const { "Constant value" } - fn extension_delta(&self) -> ExtensionSet { - self.value().extension_reqs() - } - fn tag(&self) -> OpTag { ::TAG } @@ -251,7 +246,6 @@ pub enum Value { /// use serde_json::json; /// /// let expected_json = json!({ -/// "extensions": ["prelude"], /// "typ": usize_t(), /// "value": {'c': "ConstUsize", 'v': 1} /// }); @@ -259,9 +253,8 @@ pub enum Value { /// assert_eq!(&serde_json::to_value(&ev).unwrap(), &expected_json); /// assert_eq!(ev, serde_json::from_value(expected_json).unwrap()); /// -/// let ev = OpaqueValue::new(CustomSerialized::new(usize_t().clone(), serde_json::Value::Null, ExtensionSet::default())); +/// let ev = OpaqueValue::new(CustomSerialized::new(usize_t().clone(), serde_json::Value::Null)); /// let expected_json = json!({ -/// "extensions": [], /// "typ": usize_t(), /// "value": null /// }); @@ -297,8 +290,6 @@ impl OpaqueValue { pub fn get_type(&self) -> Type; /// An identifier of the internal [`CustomConst`]. pub fn name(&self) -> ValueName; - /// The extension(s) defining the internal [`CustomConst`]. - pub fn extension_reqs(&self) -> ExtensionSet; } } } @@ -523,17 +514,6 @@ impl Value { .into() } - /// The extensions required by a [`Value`] - pub fn extension_reqs(&self) -> ExtensionSet { - match self { - Self::Extension { e } => e.extension_reqs().clone(), - Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run) - Self::Sum(Sum { values, .. }) => { - ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs())) - } - } - } - /// Check the value. pub fn validate(&self) -> Result<(), ConstTypeError> { match self { @@ -631,10 +611,6 @@ pub(crate) mod test { format!("CustomTestValue({:?})", self.0).into() } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(self.0.extension().clone()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, @@ -849,8 +825,7 @@ pub(crate) mod test { // Dummy extension reference. &Weak::default(), ); - let json_const: Value = - CustomSerialized::new(typ_int.clone(), 6.into(), ex_id.clone()).into(); + let json_const: Value = CustomSerialized::new(typ_int.clone(), 6.into()).into(); let classic_t = Type::new_extension(typ_int.clone()); assert_matches!(classic_t.least_upper_bound(), TypeBound::Copyable); assert_eq!(json_const.get_type(), classic_t); diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index 985e15594..6ff1b67aa 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -13,7 +13,6 @@ use thiserror::Error; use crate::extension::resolution::{ resolve_type_extensions, ExtensionResolutionError, WeakExtensionRegistry, }; -use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; use crate::types::{CustomCheckFailure, Type}; use crate::IncomingPort; @@ -44,7 +43,6 @@ use super::{Value, ValueName}; /// #[typetag::serde] /// impl CustomConst for CC { /// fn name(&self) -> ValueName { "CC".into() } -/// fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::singleton(int_types::EXTENSION_ID) } /// fn get_type(&self) -> Type { int_types::INT_TYPES[5].clone() } /// } /// @@ -61,13 +59,6 @@ pub trait CustomConst: /// An identifier for the constant. fn name(&self) -> ValueName; - /// The extension(s) defining the custom constant - /// (a set to allow, say, a [List] of [USize]) - /// - /// [List]: crate::std_extensions::collections::list::LIST_TYPENAME - /// [USize]: crate::extension::prelude::usize_t - fn extension_reqs(&self) -> ExtensionSet; - /// Check the value. fn validate(&self) -> Result<(), CustomCheckFailure> { Ok(()) @@ -185,7 +176,6 @@ impl_box_clone!(CustomConst, CustomConstBoxClone); pub struct CustomSerialized { typ: Type, value: serde_json::Value, - extensions: ExtensionSet, } #[derive(Debug, Error)] @@ -206,15 +196,10 @@ pub struct DeserializeError { impl CustomSerialized { /// Creates a new [`CustomSerialized`]. - pub fn new( - typ: impl Into, - value: serde_json::Value, - exts: impl Into, - ) -> Self { + pub fn new(typ: impl Into, value: serde_json::Value) -> Self { Self { typ: typ.into(), value, - extensions: exts.into(), } } @@ -240,7 +225,6 @@ impl CustomSerialized { err, payload: cc.clone_box(), })?, - cc.extension_reqs(), ), }) } @@ -259,10 +243,10 @@ impl CustomSerialized { match cc.downcast::() { Ok(x) => Ok(*x), Err(cc) => { - let (typ, extension_reqs) = (cc.get_type(), cc.extension_reqs()); + let typ = cc.get_type(); let value = serialize_custom_const(cc.as_ref()) .map_err(|err| SerializeError { err, payload: cc })?; - Ok(Self::new(typ, value, extension_reqs)) + Ok(Self::new(typ, value)) } } } @@ -313,9 +297,6 @@ impl CustomConst for CustomSerialized { Some(self) == other.downcast_ref() } - fn extension_reqs(&self) -> ExtensionSet { - self.extensions.clone() - } fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, @@ -437,11 +418,8 @@ mod test { // check serialize_custom_const assert_eq!(expected_json, serialize_custom_const(&example.cc).unwrap()); - let expected_custom_serialized = CustomSerialized::new( - example.cc.get_type(), - expected_json, - example.cc.extension_reqs(), - ); + let expected_custom_serialized = + CustomSerialized::new(example.cc.get_type(), expected_json); // check all the try_from/try_into/into variations assert_eq!( @@ -494,11 +472,7 @@ mod test { let inner = example_custom_serialized().1; ( inner.clone(), - CustomSerialized::new( - inner.get_type(), - serialize_custom_const(&inner).unwrap(), - inner.extension_reqs(), - ), + CustomSerialized::new(inner.get_type(), serialize_custom_const(&inner).unwrap()), ) } @@ -545,7 +519,6 @@ mod proptest { use ::proptest::prelude::*; use crate::{ - extension::ExtensionSet, ops::constant::CustomSerialized, proptest::{any_serde_json_value, any_string}, types::Type, @@ -556,7 +529,6 @@ mod proptest { type Strategy = BoxedStrategy; fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { let typ = any::(); - let extensions = any::(); // here we manually construct a serialized `dyn CustomConst`. // The "c" and "v" come from the `typetag::serde` annotation on // `trait CustomConst`. @@ -570,12 +542,8 @@ mod proptest { .collect::>() .into() }); - (typ, value, extensions) - .prop_map(|(typ, value, extensions)| CustomSerialized { - typ, - value, - extensions, - }) + (typ, value) + .prop_map(|(typ, value)| CustomSerialized { typ, value }) .boxed() } } diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 49728980f..07c04f5c4 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -2,7 +2,6 @@ use std::borrow::Cow; -use crate::extension::ExtensionSet; use crate::types::{EdgeKind, Signature, Type, TypeRow}; use crate::Direction; @@ -20,8 +19,6 @@ pub struct TailLoop { pub just_outputs: TypeRow, /// Types that are appended to both input and output pub rest: TypeRow, - /// Extension requirements to execute the body - pub extension_delta: ExtensionSet, } impl_op_name!(TailLoop); @@ -37,9 +34,7 @@ impl DataflowOpTrait for TailLoop { // TODO: Store a cached signature let [inputs, outputs] = [&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter())); - Cow::Owned( - Signature::new(inputs, outputs).with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new(inputs, outputs)) } fn substitute(&self, subst: &crate::types::Substitution) -> Self { @@ -47,7 +42,6 @@ impl DataflowOpTrait for TailLoop { just_inputs: self.just_inputs.substitute(subst), just_outputs: self.just_outputs.substitute(subst), rest: self.rest.substitute(subst), - extension_delta: self.extension_delta.substitute(subst), } } } @@ -80,10 +74,10 @@ impl TailLoop { impl DataflowParent for TailLoop { fn inner_signature(&self) -> Cow<'_, Signature> { // TODO: Store a cached signature - Cow::Owned( - Signature::new(self.body_input_row(), self.body_output_row()) - .with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new( + self.body_input_row(), + self.body_output_row(), + )) } } @@ -97,8 +91,6 @@ pub struct Conditional { pub other_inputs: TypeRow, /// Output types pub outputs: TypeRow, - /// Extensions used to produce the outputs - pub extension_delta: ExtensionSet, } impl_op_name!(Conditional); @@ -115,10 +107,7 @@ impl DataflowOpTrait for Conditional { inputs .to_mut() .insert(0, Type::new_sum(self.sum_rows.clone())); - Cow::Owned( - Signature::new(inputs, self.outputs.clone()) - .with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new(inputs, self.outputs.clone())) } fn substitute(&self, subst: &crate::types::Substitution) -> Self { @@ -126,7 +115,6 @@ impl DataflowOpTrait for Conditional { sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), other_inputs: self.other_inputs.substitute(subst), outputs: self.outputs.substitute(subst), - extension_delta: self.extension_delta.substitute(subst), } } } @@ -174,7 +162,6 @@ pub struct DataflowBlock { pub inputs: TypeRow, pub other_outputs: TypeRow, pub sum_rows: Vec, - pub extension_delta: ExtensionSet, } #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -213,10 +200,10 @@ impl DataflowParent for DataflowBlock { let sum_type = Type::new_sum(self.sum_rows.clone()); let mut node_outputs = vec![sum_type]; node_outputs.extend_from_slice(&self.other_outputs); - Cow::Owned( - Signature::new(self.inputs.clone(), TypeRow::from(node_outputs)) - .with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new( + self.inputs.clone(), + TypeRow::from(node_outputs), + )) } } @@ -237,10 +224,6 @@ impl OpTrait for DataflowBlock { Some(EdgeKind::ControlFlow) } - fn extension_delta(&self) -> ExtensionSet { - self.extension_delta.clone() - } - fn non_df_port_count(&self, dir: Direction) -> usize { match dir { Direction::Incoming => 1, @@ -253,7 +236,6 @@ impl OpTrait for DataflowBlock { inputs: self.inputs.substitute(subst), other_outputs: self.other_outputs.substitute(subst), sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), - extension_delta: self.extension_delta.substitute(subst), } } } @@ -343,10 +325,6 @@ impl OpTrait for Case { "A case node inside a conditional" } - fn extension_delta(&self) -> ExtensionSet { - self.signature.runtime_reqs.clone() - } - fn tag(&self) -> OpTag { ::TAG } @@ -373,10 +351,7 @@ impl Case { #[cfg(test)] mod test { use crate::{ - extension::{ - prelude::{qb_t, usize_t, PRELUDE_ID}, - ExtensionSet, - }, + extension::prelude::{qb_t, usize_t}, ops::{Conditional, DataflowOpTrait, DataflowParent}, types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV}, }; @@ -391,19 +366,12 @@ mod test { inputs: vec![usize_t(), tv0.clone()].into(), other_outputs: vec![tv0.clone()].into(), sum_rows: vec![usize_t().into(), vec![qb_t(), tv0.clone()].into()], - extension_delta: ExtensionSet::type_var(1), }; - let dfb2 = dfb.substitute(&Substitution::new(&[ - qb_t().into(), - TypeArg::Extensions { - es: PRELUDE_ID.into(), - }, - ])); + let dfb2 = dfb.substitute(&Substitution::new(&[qb_t().into()])); let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]); assert_eq!( dfb2.inner_signature(), Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) - .with_extension_delta(PRELUDE_ID) ); } @@ -414,7 +382,6 @@ mod test { sum_rows: vec![usize_t().into(), tv1.clone().into()], other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any))].into(), outputs: vec![usize_t(), tv1].into(), - extension_delta: ExtensionSet::new(), }; let cond2 = cond.substitute(&Substitution::new(&[ TypeArg::Sequence { @@ -439,21 +406,14 @@ mod test { just_inputs: vec![qb_t(), tv0.clone()].into(), just_outputs: vec![tv0.clone(), qb_t()].into(), rest: vec![tv0.clone()].into(), - extension_delta: ExtensionSet::type_var(1), }; - let tail2 = tail_loop.substitute(&Substitution::new(&[ - usize_t().into(), - TypeArg::Extensions { - es: PRELUDE_ID.into(), - }, - ])); + let tail2 = tail_loop.substitute(&Substitution::new(&[usize_t().into()])); assert_eq!( tail2.signature(), Signature::new( vec![qb_t(), usize_t(), usize_t()], vec![usize_t(), qb_t(), usize_t()] ) - .with_extension_delta(PRELUDE_ID) ); } } diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 6b907c947..5f5a13427 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -233,7 +233,6 @@ impl OpaqueOp { args: impl Into>, signature: Signature, ) -> Self { - let signature = signature.with_extension_delta(extension.clone()); Self { extension, name: name.into(), @@ -382,10 +381,7 @@ mod test { assert_eq!(op.name(), "res.op"); assert_eq!(DataflowOpTrait::description(&op), "desc"); assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]); - assert_eq!( - op.signature().as_ref(), - &sig.with_extension_delta(op.extension().clone()) - ); + assert_eq!(op.signature().as_ref(), &sig); } #[test] diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index c63c44b87..ba8f81c0c 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -4,7 +4,7 @@ use std::borrow::Cow; use super::{impl_op_name, OpTag, OpTrait}; -use crate::extension::{ExtensionSet, SignatureError}; +use crate::extension::SignatureError; use crate::ops::StaticTag; use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow}; use crate::{type_row, IncomingPort}; @@ -151,15 +151,15 @@ impl OpTrait for T { fn description(&self) -> &str { DataflowOpTrait::description(self) } + fn tag(&self) -> OpTag { T::TAG } + fn dataflow_signature(&self) -> Option> { Some(DataflowOpTrait::signature(self)) } - fn extension_delta(&self) -> ExtensionSet { - DataflowOpTrait::signature(self).runtime_reqs.clone() - } + fn other_input(&self) -> Option { DataflowOpTrait::other_input(self) } diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index 5e1fecdb6..a7c48b3a2 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -224,9 +224,6 @@ impl Package { // As a fallback, try to load a hugr json. if let Ok(mut hugr) = serde_json::from_value::(val) { hugr.resolve_extension_defs(extension_registry)?; - if cfg!(feature = "extension_inference") { - hugr.infer_extensions(false)?; - } return Ok(Package::from_hugr(hugr)?); } diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index abeb61ab0..ea1004d92 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -8,7 +8,7 @@ use crate::extension::prelude::sum_with_error; use crate::extension::prelude::{bool_t, string_type, usize_t}; use crate::extension::simple_op::{HasConcrete, HasDef}; use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}; -use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc}; use crate::ops::OpName; use crate::ops::{custom::ExtensionOp, NamedOp}; use crate::std_extensions::arithmetic::int_ops::int_polytype; @@ -167,12 +167,6 @@ lazy_static! { /// Extension for conversions between integers and floats. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements( - ExtensionSet::from_iter(vec![ - super::int_types::EXTENSION_ID, - super::float_types::EXTENSION_ID, - ])); - ConvertOpDef::load_all_ops(extension, extension_ref).unwrap(); }) }; diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 08b478535..f61353528 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -9,7 +9,7 @@ use crate::{ extension::{ prelude::{bool_t, string_type}, simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError}, - ExtensionId, ExtensionSet, OpDef, SignatureFunc, + ExtensionId, OpDef, SignatureFunc, }, types::Signature, Extension, @@ -111,7 +111,6 @@ lazy_static! { /// Extension for basic float operations. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID)); FloatOps::load_all_ops(extension, extension_ref).unwrap(); }) }; diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 200e9dcbf..b5a741953 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, Weak}; use crate::ops::constant::{TryHash, ValueName}; use crate::types::TypeName; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::ExtensionId, ops::constant::CustomConst, types::{CustomType, Type, TypeBound}, Extension, @@ -97,10 +97,6 @@ impl CustomConst for ConstF64 { fn equal_consts(&self, _: &dyn CustomConst) -> bool { false } - - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(EXTENSION_ID) - } } lazy_static! { diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index d0ae7baa7..69939d4e1 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -14,7 +14,7 @@ use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV}; use crate::utils::collect_array; use crate::{ - extension::{ExtensionId, ExtensionSet, SignatureError}, + extension::{ExtensionId, SignatureError}, types::{type_param::TypeArg, Type}, Extension, }; @@ -252,7 +252,6 @@ lazy_static! { /// Extension for basic integer operations. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID)); IntOpDef::load_all_ops(extension, extension_ref).unwrap(); }) }; @@ -377,7 +376,7 @@ mod test { .unwrap() .signature() .as_ref(), - &Signature::new(int_type(3), int_type(4)).with_extension_delta(EXTENSION_ID) + &Signature::new(int_type(3), int_type(4)) ); assert_eq!( IntOpDef::iwiden_s @@ -386,7 +385,7 @@ mod test { .unwrap() .signature() .as_ref(), - &Signature::new_endo(int_type(3)).with_extension_delta(EXTENSION_ID) + &Signature::new_endo(int_type(3)) ); assert_eq!( IntOpDef::inarrow_s @@ -396,7 +395,6 @@ mod test { .signature() .as_ref(), &Signature::new(int_type(3), sum_ty_with_err(int_type(3))) - .with_extension_delta(EXTENSION_ID) ); assert!( IntOpDef::iwiden_u @@ -414,7 +412,6 @@ mod test { .signature() .as_ref(), &Signature::new(int_type(2), sum_ty_with_err(int_type(1))) - .with_extension_delta(EXTENSION_ID) ); assert!(IntOpDef::inarrow_u diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 1342dd932..022f4d61e 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Weak}; use crate::ops::constant::ValueName; use crate::types::TypeName; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::ExtensionId, ops::constant::CustomConst, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, @@ -184,10 +184,6 @@ impl CustomConst for ConstInt { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(EXTENSION_ID) - } - fn get_type(&self) -> Type { int_type(type_arg(self.log_width)) } diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index fac12b1bf..2e7ee5b75 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -17,7 +17,7 @@ use crate::extension::resolution::{ WeakExtensionRegistry, }; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; -use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}; +use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; use crate::ops::constant::{maybe_hash_values, CustomConst, TryHash, ValueName}; use crate::ops::{ExtensionOp, OpName, Value}; use crate::types::type_param::{TypeArg, TypeParam}; @@ -143,11 +143,6 @@ impl CustomConst for ArrayValue { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index 544866970..a31505cb2 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Weak}; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; -use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, NamedOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncTypeRV, Signature, Type, TypeBound}; @@ -42,16 +42,10 @@ impl FromStr for ArrayRepeatDef { impl ArrayRepeatDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![ - TypeParam::max_nat(), - TypeBound::Any.into(), - TypeParam::Extensions, - ]; + let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; let n = TypeArg::new_var_use(0, TypeParam::max_nat()); let t = Type::new_var_use(1, TypeBound::Any); - let es = ExtensionSet::type_var(2); - let func = - Type::new_function(Signature::new(vec![], vec![t.clone()]).with_extension_delta(es)); + let func = Type::new_function(Signature::new(vec![], vec![t.clone()])); let array_ty = instantiate_array(array_def, n, t).expect("Array type instantiation failed"); PolyFuncTypeRV::new(params, FuncValueType::new(vec![func], array_ty)).into() } @@ -109,18 +103,12 @@ pub struct ArrayRepeat { pub elem_ty: Type, /// Size of the array. pub size: u64, - /// The extensions required by the function that generates the array elements. - pub extension_reqs: ExtensionSet, } impl ArrayRepeat { /// Creates a new array repeat op. - pub fn new(elem_ty: Type, size: u64, extension_reqs: ExtensionSet) -> Self { - ArrayRepeat { - elem_ty, - size, - extension_reqs, - } + pub fn new(elem_ty: Type, size: u64) -> Self { + ArrayRepeat { elem_ty, size } } } @@ -143,9 +131,6 @@ impl MakeExtensionOp for ArrayRepeat { vec![ TypeArg::BoundedNat { n: self.size }, self.elem_ty.clone().into(), - TypeArg::Extensions { - es: self.extension_reqs.clone(), - }, ] } } @@ -169,8 +154,8 @@ impl HasConcrete for ArrayRepeatDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }, TypeArg::Extensions { es }] => { - Ok(ArrayRepeat::new(ty.clone(), *n, es.clone())) + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { + Ok(ArrayRepeat::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -179,7 +164,7 @@ impl HasConcrete for ArrayRepeatDef { #[cfg(test)] mod tests { - use crate::std_extensions::collections::array::{array_type, EXTENSION_ID}; + use crate::std_extensions::collections::array::array_type; use crate::{ extension::prelude::qb_t, ops::{OpTrait, OpType}, @@ -190,7 +175,7 @@ mod tests { #[test] fn test_repeat_def() { - let op = ArrayRepeat::new(qb_t(), 2, ExtensionSet::singleton(EXTENSION_ID)); + let op = ArrayRepeat::new(qb_t(), 2); let optype: OpType = op.clone().into(); let new_op: ArrayRepeat = optype.cast().unwrap(); assert_eq!(new_op, op); @@ -200,8 +185,7 @@ mod tests { fn test_repeat() { let size = 2; let element_ty = qb_t(); - let es = ExtensionSet::singleton(EXTENSION_ID); - let op = ArrayRepeat::new(element_ty.clone(), size, es.clone()); + let op = ArrayRepeat::new(element_ty.clone(), size); let optype: OpType = op.into(); @@ -210,10 +194,7 @@ mod tests { assert_eq!( sig.io(), ( - &vec![Type::new_function( - Signature::new(vec![], vec![qb_t()]).with_extension_delta(es) - )] - .into(), + &vec![Type::new_function(Signature::new(vec![], vec![qb_t()]))].into(), &vec![array_type(size, element_ty.clone())].into(), ) ); diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 86a0fe94e..8064a73d0 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; -use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, NamedOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncTypeBase, PolyFuncTypeRV, RowVariable, Type, TypeBound, TypeRV}; @@ -51,13 +51,11 @@ impl ArrayScanDef { TypeBound::Any.into(), TypeBound::Any.into(), TypeParam::new_list(TypeBound::Any), - TypeParam::Extensions, ]; let n = TypeArg::new_var_use(0, TypeParam::max_nat()); let t1 = Type::new_var_use(1, TypeBound::Any); let t2 = Type::new_var_use(2, TypeBound::Any); let s = TypeRV::new_row_var_use(3, TypeBound::Any); - let es = ExtensionSet::type_var(4); PolyFuncTypeRV::new( params, FuncTypeBase::::new( @@ -65,13 +63,10 @@ impl ArrayScanDef { instantiate_array(array_def, n.clone(), t1.clone()) .expect("Array type instantiation failed") .into(), - Type::new_function( - FuncTypeBase::::new( - vec![t1.into(), s.clone()], - vec![t2.clone().into(), s.clone()], - ) - .with_extension_delta(es), - ) + Type::new_function(FuncTypeBase::::new( + vec![t1.into(), s.clone()], + vec![t2.clone().into(), s.clone()], + )) .into(), s.clone(), ], @@ -145,25 +140,16 @@ pub struct ArrayScan { pub acc_tys: Vec, /// Size of the array. pub size: u64, - /// The extensions required by the scan function. - pub extension_reqs: ExtensionSet, } impl ArrayScan { /// Creates a new array scan op. - pub fn new( - src_ty: Type, - tgt_ty: Type, - acc_tys: Vec, - size: u64, - extension_reqs: ExtensionSet, - ) -> Self { + pub fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec, size: u64) -> Self { ArrayScan { src_ty, tgt_ty, acc_tys, size, - extension_reqs, } } } @@ -191,9 +177,6 @@ impl MakeExtensionOp for ArrayScan { TypeArg::Sequence { elems: self.acc_tys.clone().into_iter().map_into().collect(), }, - TypeArg::Extensions { - es: self.extension_reqs.clone(), - }, ] } } @@ -217,7 +200,7 @@ impl HasConcrete for ArrayScanDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }, TypeArg::Extensions { es }] => + [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }] => { let acc_tys: Result<_, OpLoadError> = acc_tys .iter() @@ -226,13 +209,7 @@ impl HasConcrete for ArrayScanDef { _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); - Ok(ArrayScan::new( - src_ty.clone(), - tgt_ty.clone(), - acc_tys?, - *n, - es.clone(), - )) + Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -243,7 +220,7 @@ impl HasConcrete for ArrayScanDef { mod tests { use crate::extension::prelude::usize_t; - use crate::std_extensions::collections::array::{array_type, EXTENSION_ID}; + use crate::std_extensions::collections::array::array_type; use crate::{ extension::prelude::{bool_t, qb_t}, ops::{OpTrait, OpType}, @@ -254,13 +231,7 @@ mod tests { #[test] fn test_scan_def() { - let op = ArrayScan::new( - bool_t(), - qb_t(), - vec![usize_t()], - 2, - ExtensionSet::singleton(EXTENSION_ID), - ); + let op = ArrayScan::new(bool_t(), qb_t(), vec![usize_t()], 2); let optype: OpType = op.clone().into(); let new_op: ArrayScan = optype.cast().unwrap(); assert_eq!(new_op, op); @@ -271,9 +242,8 @@ mod tests { let size = 2; let src_ty = qb_t(); let tgt_ty = bool_t(); - let es = ExtensionSet::singleton(EXTENSION_ID); - let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size, es.clone()); + let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -282,9 +252,7 @@ mod tests { ( &vec![ array_type(size, src_ty.clone()), - Type::new_function( - Signature::new(vec![src_ty], vec![tgt_ty.clone()]).with_extension_delta(es) - ) + Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()])) ] .into(), &vec![array_type(size, tgt_ty)].into(), @@ -299,14 +267,12 @@ mod tests { let tgt_ty = bool_t(); let acc_ty1 = usize_t(); let acc_ty2 = qb_t(); - let es = ExtensionSet::singleton(EXTENSION_ID); let op = ArrayScan::new( src_ty.clone(), tgt_ty.clone(), vec![acc_ty1.clone(), acc_ty2.clone()], size, - es.clone(), ); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -316,13 +282,10 @@ mod tests { ( &vec![ array_type(size, src_ty.clone()), - Type::new_function( - Signature::new( - vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], - vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] - ) - .with_extension_delta(es) - ), + Type::new_function(Signature::new( + vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], + vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] + )), acc_ty1.clone(), acc_ty2.clone() ] diff --git a/hugr-core/src/std_extensions/collections/array/op_builder.rs b/hugr-core/src/std_extensions/collections/array/op_builder.rs index 46338dd43..623443347 100644 --- a/hugr-core/src/std_extensions/collections/array/op_builder.rs +++ b/hugr-core/src/std_extensions/collections/array/op_builder.rs @@ -213,9 +213,7 @@ impl ArrayOpBuilder for D {} #[cfg(test)] mod test { - use crate::extension::prelude::PRELUDE_ID; - use crate::extension::ExtensionSet; - use crate::std_extensions::collections::array::{self, array_type}; + use crate::std_extensions::collections::array::array_type; use crate::{ builder::{DFGBuilder, HugrBuilder}, extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, @@ -229,11 +227,7 @@ mod test { #[rstest::fixture] #[default(DFGBuilder)] fn all_array_ops( - #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW) - .with_extension_delta(ExtensionSet::from_iter([ - PRELUDE_ID, - array::EXTENSION_ID - ]))).unwrap())] + #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW)).unwrap())] mut builder: B, ) -> B { let us0 = builder.add_load_value(ConstUsize::new(0)); diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 98804bab0..3ffb4d9a0 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -25,7 +25,7 @@ use crate::types::{TypeName, TypeRowRV}; use crate::{ extension::{ simple_op::{MakeExtensionOp, OpLoadError}, - ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound, + ExtensionId, SignatureError, TypeDef, TypeDefBound, }, ops::constant::CustomConst, ops::{custom::ExtensionOp, NamedOp}, @@ -126,11 +126,6 @@ impl CustomConst for ListValue { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, diff --git a/hugr-core/src/std_extensions/collections/static_array.rs b/hugr-core/src/std_extensions/collections/static_array.rs index 9d2259e0b..05e5651a1 100644 --- a/hugr-core/src/std_extensions/collections/static_array.rs +++ b/hugr-core/src/std_extensions/collections/static_array.rs @@ -28,7 +28,7 @@ use crate::{ try_from_name, HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }, - ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef, + ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef, }, ops::{ constant::{maybe_hash_values, CustomConst, TryHash, ValueName}, @@ -128,11 +128,6 @@ impl CustomConst for StaticArrayValue { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.get_contents().iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, @@ -404,7 +399,7 @@ impl StaticArrayOpBuilder for T {} mod test { use crate::{ builder::{DFGBuilder, DataflowHugr as _}, - extension::prelude::{qb_t, ConstUsize, PRELUDE_ID}, + extension::prelude::{qb_t, ConstUsize}, type_row, }; @@ -419,10 +414,10 @@ mod test { #[test] fn all_ops() { let _ = { - let mut builder = DFGBuilder::new( - Signature::new(type_row![], Type::from(option_type(usize_t()))) - .with_extension_delta(ExtensionSet::from_iter([PRELUDE_ID, EXTENSION_ID])), - ) + let mut builder = DFGBuilder::new(Signature::new( + type_row![], + Type::from(option_type(usize_t())), + )) .unwrap(); let array = builder.add_load_value( StaticArrayValue::try_new( diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index fc0b1bbb4..6d77ae52d 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -268,10 +268,7 @@ pub(crate) mod test { let in_row = vec![bool_t(), float64_type()]; let hugr = { - let mut builder = DFGBuilder::new( - Signature::new(in_row.clone(), type_row![]).with_extension_delta(EXTENSION_ID), - ) - .unwrap(); + let mut builder = DFGBuilder::new(Signature::new(in_row.clone(), type_row![])).unwrap(); let in_wires: [Wire; 2] = builder.input_wires_arr(); for (ty, w) in in_row.into_iter().zip(in_wires.iter()) { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 67bc7fbf5..885b6bae8 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -277,7 +277,6 @@ pub(crate) mod test { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo(Type::new_extension(list_def.instantiate([tv])?)); for decl in [ - TypeParam::Extensions, TypeParam::List { param: Box::new(TypeParam::max_nat()), }, diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 28c39fa08..78965f1b6 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -37,8 +37,6 @@ pub struct FuncTypeBase { /// Value outputs of the function. #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] pub output: TypeRowBase, - /// The extensions the function specifies as required at runtime. - pub runtime_reqs: ExtensionSet, } /// The concept of "signature" in the spec - the edges required to/from a node @@ -55,22 +53,10 @@ pub type Signature = FuncTypeBase; pub type FuncValueType = FuncTypeBase; impl FuncTypeBase { - /// Builder method, add runtime_reqs to a FunctionType - pub fn with_extension_delta(mut self, rs: impl Into) -> Self { - self.runtime_reqs = self.runtime_reqs.union(rs.into()); - self - } - - /// Shorthand for adding the prelude extension to a FunctionType. - pub fn with_prelude(self) -> Self { - self.with_extension_delta(crate::extension::prelude::PRELUDE_ID) - } - pub(crate) fn substitute(&self, tr: &Substitution) -> Self { Self { input: self.input.substitute(tr), output: self.output.substitute(tr), - runtime_reqs: self.runtime_reqs.substitute(tr), } } @@ -79,7 +65,6 @@ impl FuncTypeBase { Self { input: input.into(), output: output.into(), - runtime_reqs: ExtensionSet::new(), } } @@ -117,19 +102,10 @@ impl FuncTypeBase { pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { self.input.validate(var_decls)?; - self.output.validate(var_decls)?; - self.runtime_reqs.validate(var_decls) + self.output.validate(var_decls) } /// Returns a registry with the concrete extensions used by this signature. - /// - /// Note that extension type parameters are not included, as they have not - /// been instantiated yet. - /// - /// This method only returns extensions actually used by the types in the - /// signature. The extension deltas added via [`Self::with_extension_delta`] - /// refer to _runtime_ extensions, which may not be in all places that - /// manipulate a HUGR. pub fn used_extensions(&self) -> Result { let mut used = WeakExtensionRegistry::default(); let mut missing = ExtensionSet::new(); @@ -167,7 +143,6 @@ impl Default for FuncTypeBase { Self { input: Default::default(), output: Default::default(), - runtime_reqs: Default::default(), } } } @@ -290,9 +265,6 @@ impl Display for FuncTypeBase { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.input.fmt(f)?; f.write_str(" -> ")?; - if !self.runtime_reqs.is_empty() { - self.runtime_reqs.fmt(f)?; - } self.output.fmt(f) } } @@ -303,7 +275,7 @@ impl TryFrom for Signature { fn try_from(value: FuncValueType) -> Result { let input: TypeRow = value.input.try_into()?; let output: TypeRow = value.output.try_into()?; - Ok(Self::new(input, output).with_extension_delta(value.runtime_reqs)) + Ok(Self::new(input, output)) } } @@ -312,16 +284,13 @@ impl From for FuncValueType { Self { input: value.input.into(), output: value.output.into(), - runtime_reqs: value.runtime_reqs, } } } impl PartialEq> for FuncTypeBase { fn eq(&self, other: &FuncTypeBase) -> bool { - self.input == other.input - && self.output == other.output - && self.runtime_reqs == other.runtime_reqs + self.input == other.input && self.output == other.output } } diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index db2efecc6..e8fa28346 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -15,7 +15,6 @@ use super::{ check_typevar_decl, NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeBound, TypeTransformer, }; -use crate::extension::ExtensionSet; use crate::extension::SignatureError; /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] @@ -92,10 +91,6 @@ pub enum TypeParam { /// The [TypeParam]s contained in the tuple. params: Vec, }, - /// Argument is a [TypeArg::Extensions]. A set of [ExtensionId]s. - /// - /// [ExtensionId]: crate::extension::ExtensionId - Extensions, } impl TypeParam { @@ -131,7 +126,6 @@ impl TypeParam { (TypeParam::Tuple { params: es1 }, TypeParam::Tuple { params: es2 }) => { es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.contains(e2)) } - (TypeParam::Extensions, TypeParam::Extensions) => true, _ => false, } } @@ -184,18 +178,9 @@ pub enum TypeArg { /// List of element types elems: Vec, }, - /// Instance of [TypeParam::Extensions], providing the extension ids. - #[display("Exts({})", { - use itertools::Itertools as _; - es.iter().map(|t|t.to_string()).join(",") - })] - Extensions { - #[allow(missing_docs)] - es: ExtensionSet, - }, /// Variable (used in type schemes or inside polymorphic functions), /// but not a [TypeArg::Type] (not even a row variable i.e. [TypeParam::List] of type) - /// nor [TypeArg::Extensions] - see [TypeArg::new_var_use] + /// - see [TypeArg::new_var_use] #[display("{v}")] Variable { #[allow(missing_docs)] @@ -239,14 +224,7 @@ impl From> for TypeArg { } } -impl From for TypeArg { - fn from(es: ExtensionSet) -> Self { - Self::Extensions { es } - } -} - -/// Variable in a TypeArg, that is neither a [TypeArg::Extensions] -/// nor a single [TypeArg::Type] (i.e. not a [Type::new_var_use] +/// Variable in a TypeArg, that is not a single [TypeArg::Type] (i.e. not a [Type::new_var_use] /// - it might be a [Type::new_row_var_use]). #[derive( Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, @@ -270,10 +248,6 @@ impl TypeArg { // as a TypeArg::Type because the latter stores a Type i.e. only a single type, // not a RowVariable. TypeParam::Type { b } => Type::new_var_use(idx, b).into(), - // Prevent TypeArg::Variable(idx, TypeParam::Extensions) - TypeParam::Extensions => TypeArg::Extensions { - es: ExtensionSet::type_var(idx), - }, _ => TypeArg::Variable { v: TypeArgVariable { idx, @@ -314,7 +288,6 @@ impl TypeArg { TypeArg::Type { ty } => ty.validate(var_decls), TypeArg::BoundedNat { .. } | TypeArg::String { .. } => Ok(()), TypeArg::Sequence { elems } => elems.iter().try_for_each(|a| a.validate(var_decls)), - TypeArg::Extensions { es: _ } => Ok(()), TypeArg::Variable { v: TypeArgVariable { idx, cached_decl }, } => { @@ -362,9 +335,6 @@ impl TypeArg { }; TypeArg::Sequence { elems } } - TypeArg::Extensions { es } => TypeArg::Extensions { - es: es.substitute(t), - }, TypeArg::Variable { v: TypeArgVariable { idx, cached_decl }, } => t.apply_var(*idx, cached_decl), @@ -377,10 +347,9 @@ impl Transformable for TypeArg { match self { TypeArg::Type { ty } => ty.transform(tr), TypeArg::Sequence { elems } => elems.transform(tr), - TypeArg::BoundedNat { .. } - | TypeArg::String { .. } - | TypeArg::Extensions { .. } - | TypeArg::Variable { .. } => Ok(false), + TypeArg::BoundedNat { .. } | TypeArg::String { .. } | TypeArg::Variable { .. } => { + Ok(false) + } } } } @@ -449,7 +418,6 @@ pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgErr } (TypeArg::String { .. }, TypeParam::String) => Ok(()), - (TypeArg::Extensions { .. }, TypeParam::Extensions) => Ok(()), _ => Err(TypeArgError::TypeMismatch { arg: arg.clone(), param: param.clone(), @@ -659,7 +627,6 @@ mod test { use proptest::prelude::*; use super::super::{TypeArg, TypeArgVariable, TypeParam, UpperBound}; - use crate::extension::ExtensionSet; use crate::proptest::RecursionDepth; use crate::types::{Type, TypeBound}; @@ -680,7 +647,6 @@ mod test { use prop::collection::vec; use prop::strategy::Union; let mut strat = Union::new([ - Just(Self::Extensions).boxed(), Just(Self::String).boxed(), any::().prop_map(|b| Self::Type { b }).boxed(), any::() @@ -711,9 +677,6 @@ mod test { let mut strat = Union::new([ any::().prop_map(|n| Self::BoundedNat { n }).boxed(), any::().prop_map(|arg| Self::String { arg }).boxed(), - any::() - .prop_map(|es| Self::Extensions { es }) - .boxed(), any_with::(depth) .prop_map(|ty| Self::Type { ty }) .boxed(), diff --git a/hugr-llvm/src/emit/ops/cfg.rs b/hugr-llvm/src/emit/ops/cfg.rs index 4d62350be..12f22d2f7 100644 --- a/hugr-llvm/src/emit/ops/cfg.rs +++ b/hugr-llvm/src/emit/ops/cfg.rs @@ -219,7 +219,7 @@ impl<'c, 'hugr, H: HugrView> CfgEmitter<'c, 'hugr, H> { mod test { use hugr_core::builder::{Dataflow, DataflowSubContainer, SubContainer}; use hugr_core::extension::prelude::{self, bool_t}; - use hugr_core::extension::{ExtensionRegistry, ExtensionSet}; + use hugr_core::extension::ExtensionRegistry; use hugr_core::ops::Value; use hugr_core::std_extensions::arithmetic::int_types::{self, INT_TYPES}; use hugr_core::type_row; @@ -239,7 +239,6 @@ mod test { llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_int_extensions); let t1 = INT_TYPES[0].clone(); let t2 = INT_TYPES[1].clone(); - let es = ExtensionSet::from_iter([int_types::EXTENSION_ID, prelude::PRELUDE_ID]); let hugr = SimpleHugrConfig::new() .with_ins(vec![t1.clone(), t2.clone()]) .with_outs(t2.clone()) @@ -250,11 +249,7 @@ mod test { .finish(|mut builder| { let [in1, in2] = builder.input_wires_arr(); let mut cfg_builder = builder - .cfg_builder_exts( - [(t1.clone(), in1), (t2.clone(), in2)], - t2.clone().into(), - es.clone(), - ) + .cfg_builder([(t1.clone(), in1), (t2.clone(), in2)], t2.clone().into()) .unwrap(); // entry block takes (t1,t2) and unconditionally branches to b1 with no other outputs diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 3f6977a8c..d53d2ef0c 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -4,13 +4,8 @@ use anyhow::{anyhow, Result}; use hugr_core::builder::{ BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer, }; -use hugr_core::extension::prelude::PRELUDE_ID; -use hugr_core::extension::{ExtensionRegistry, ExtensionSet}; +use hugr_core::extension::ExtensionRegistry; use hugr_core::ops::handle::FuncID; -use hugr_core::std_extensions::arithmetic::{ - conversions, float_ops, float_types, int_ops, int_types, -}; -use hugr_core::std_extensions::{collections, logic}; use hugr_core::types::TypeRow; use hugr_core::{Hugr, HugrView, Node}; use inkwell::module::Module; @@ -150,23 +145,7 @@ impl SimpleHugrConfig { ) -> Hugr { let mut mod_b = ModuleBuilder::new(); let func_b = mod_b - .define_function( - "main", - HugrFuncType::new(self.ins, self.outs).with_extension_delta( - ExtensionSet::from_iter([ - PRELUDE_ID, - int_types::EXTENSION_ID, - int_ops::EXTENSION_ID, - float_types::EXTENSION_ID, - float_ops::EXTENSION_ID, - conversions::EXTENSION_ID, - logic::EXTENSION_ID, - collections::array::EXTENSION_ID, - collections::list::EXTENSION_ID, - collections::static_array::EXTENSION_ID, - ]), - ), - ) + .define_function("main", HugrFuncType::new(self.ins, self.outs)) .unwrap(); make(func_b, &self.extensions); @@ -265,7 +244,7 @@ mod test_fns { use hugr_core::ops::{CallIndirect, Tag, Value}; use hugr_core::std_extensions::arithmetic::int_ops::{self}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; + use hugr_core::std_extensions::arithmetic::int_types::{self, ConstInt}; use hugr_core::std_extensions::STD_REG; use hugr_core::types::{Signature, Type, TypeRow}; use hugr_core::{type_row, Hugr}; diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 55dcecefc..0216e9014 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -708,7 +708,6 @@ pub fn emit_scan_op<'c, H: HugrView>( mod test { use hugr_core::builder::Container as _; use hugr_core::extension::prelude::either_type; - use hugr_core::extension::ExtensionSet; use hugr_core::ops::Tag; use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan}; use hugr_core::std_extensions::STD_REG; @@ -854,16 +853,6 @@ mod test { ]) } - fn exec_extension_set() -> ExtensionSet { - ExtensionSet::from_iter([ - int_types::EXTENSION_ID, - int_ops::EXTENSION_ID, - logic::EXTENSION_ID, - prelude::PRELUDE_ID, - array::EXTENSION_ID, - ]) - } - #[rstest] #[case(0, 1)] #[case(1, 2)] @@ -1223,16 +1212,12 @@ mod test { .with_extensions(exec_registry()) .finish(|mut builder| { let mut func = builder - .define_function( - "foo", - Signature::new(vec![], vec![int_ty.clone()]) - .with_extension_delta(exec_extension_set()), - ) + .define_function("foo", Signature::new(vec![], vec![int_ty.clone()])) .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); let func_id = func.finish_with_outputs(vec![v]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let repeat = ArrayRepeat::new(int_ty.clone(), size, exec_extension_set()); + let repeat = ArrayRepeat::new(int_ty.clone(), size); let arr = builder .add_dataflow_op(repeat, vec![func_v]) .unwrap() @@ -1280,8 +1265,7 @@ mod test { let mut func = builder .define_function( "foo", - Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]) - .with_extension_delta(exec_extension_set()), + Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]), ) .unwrap(); let [elem] = func.input_wires_arr(); @@ -1289,13 +1273,7 @@ mod test { let out = func.add_iadd(6, elem, delta).unwrap(); let func_id = func.finish_with_outputs(vec![out]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let scan = ArrayScan::new( - int_ty.clone(), - int_ty.clone(), - vec![], - size, - exec_extension_set(), - ); + let scan = ArrayScan::new(int_ty.clone(), int_ty.clone(), vec![], size); let mut arr = builder .add_dataflow_op(scan, [arr, func_v]) .unwrap() @@ -1357,8 +1335,7 @@ mod test { Signature::new( vec![int_ty.clone(), int_ty.clone()], vec![Type::UNIT, int_ty.clone()], - ) - .with_extension_delta(exec_extension_set()), + ), ) .unwrap(); let [elem, acc] = func.input_wires_arr(); @@ -1369,13 +1346,7 @@ mod test { .out_wire(0); let func_id = func.finish_with_outputs(vec![unit, acc]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let scan = ArrayScan::new( - int_ty.clone(), - Type::UNIT, - vec![int_ty.clone()], - size, - exec_extension_set(), - ); + let scan = ArrayScan::new(int_ty.clone(), Type::UNIT, vec![int_ty.clone()], size); let zero = builder.add_load_value(ConstInt::new_u(6, 0).unwrap()); let sum = builder .add_dataflow_op(scan, [arr, func_v, zero]) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 7a2c50367..e4e4f9087 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -26,9 +26,6 @@ paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } -[features] -extension_inference = ["hugr-core/extension_inference"] - [dev-dependencies] rstest = { workspace = true } proptest = { workspace = true } diff --git a/hugr-passes/README.md b/hugr-passes/README.md index b441ed5e7..c2bca2124 100644 --- a/hugr-passes/README.md +++ b/hugr-passes/README.md @@ -1,7 +1,6 @@ ![](/hugr/assets/hugr_logo.svg) -hugr-passes -=============== +# hugr-passes [![build_status][]](https://github.com/CQCL/hugr/actions) [![crates][]](https://crates.io/crates/hugr-passes) @@ -29,13 +28,6 @@ cargo add hugr-passes Please read the [API documentation here][]. -## Experimental Features - -- `extension_inference`: - Experimental feature which allows automatic inference of which extra extensions - are required at runtime by a HUGR when validating it. - Not enabled by default. - ## Recent Changes See [CHANGELOG][] for a list of changes. The minimum supported rust @@ -55,4 +47,4 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [crates]: https://img.shields.io/crates/v/hugr-passes [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md \ No newline at end of file diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs index ad8ff1ec0..faf92b8a7 100644 --- a/hugr-passes/src/composable.rs +++ b/hugr-passes/src/composable.rs @@ -132,21 +132,11 @@ pub enum ValidatePassError { /// Runs an underlying pass, but with validation of the Hugr /// both before and afterwards. -pub struct ValidatingPass

(P, bool); +pub struct ValidatingPass

(P); impl ValidatingPass

{ - pub fn new_default(underlying: P) -> Self { - // Self(underlying, cfg!(feature = "extension_inference")) - // Sadly, many tests fail with extension inference, hence: - Self(underlying, false) - } - - pub fn new_validating_extensions(underlying: P) -> Self { - Self(underlying, true) - } - - pub fn new(underlying: P, validate_extensions: bool) -> Self { - Self(underlying, validate_extensions) + pub fn new(underlying: P) -> Self { + Self(underlying) } fn validation_impl( @@ -154,11 +144,8 @@ impl ValidatingPass

{ hugr: &impl HugrView, mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, ) -> Result<(), ValidatePassError> { - match self.1 { - false => hugr.validate_no_extensions(), - true => hugr.validate(), - } - .map_err(|err| mk_err(err, hugr.mermaid_string())) + hugr.validate() + .map_err(|err| mk_err(err, hugr.mermaid_string())) } } @@ -222,7 +209,7 @@ pub(crate) fn validate_if_test( hugr: &mut impl HugrMut, ) -> Result> { if cfg!(test) { - ValidatingPass::new_default(pass).run(hugr) + ValidatingPass::new(pass).run(hugr) } else { pass.run(hugr).map_err(ValidatePassError::Underlying) } @@ -237,9 +224,7 @@ mod test { Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; - use hugr_core::extension::prelude::{ - bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple, PRELUDE_ID, - }; + use hugr_core::extension::prelude::{bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG}; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; @@ -315,7 +300,7 @@ mod test { cfold.run(&mut h).unwrap(); assert_eq!(h, backup); // Did nothing - let r = ValidatingPass(cfold, false).run(&mut h); + let r = ValidatingPass(cfold).run(&mut h); assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); } @@ -324,7 +309,7 @@ mod test { let tr = TypeRow::from(vec![usize_t(); 2]); let h = { - let sig = Signature::new_endo(tr.clone()).with_extension_delta(PRELUDE_ID); + let sig = Signature::new_endo(tr.clone()); let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap(); let [a, b] = fb.input_wires_arr(); let tup = fb diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 3a296fc0b..dcdc4df0a 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -3,7 +3,6 @@ use std::collections::HashSet; use hugr_core::ops::handle::NodeHandle; use hugr_core::ops::Const; -use hugr_core::std_extensions::arithmetic::{int_ops, int_types}; use itertools::Itertools; use lazy_static::lazy_static; use rstest::rstest; @@ -1595,9 +1594,7 @@ fn test_module() -> Result<(), Box> { let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?; let mut main = mb.define_function( "main", - Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]) - .with_extension_delta(int_types::EXTENSION_ID) - .with_extension_delta(int_ops::EXTENSION_ID), + Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]), )?; let lc7 = main.load_const(&c7); let lc17 = main.load_const(&c17); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 1c4b4e439..a67556ce1 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -9,10 +9,7 @@ use hugr_core::ops::{CallIndirect, TailLoop}; use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, - extension::{ - prelude::{bool_t, UnpackTuple}, - ExtensionSet, - }, + extension::prelude::{bool_t, UnpackTuple}, ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, type_row, types::{Signature, SumType, Type}, @@ -176,12 +173,7 @@ fn test_tail_loop_two_iters() { let false_w = builder.add_load_value(Value::false_val()); let tlb = builder - .tail_loop_builder_exts( - [], - [(bool_t(), false_w), (bool_t(), true_w)], - type_row![], - ExtensionSet::new(), - ) + .tail_loop_builder([], [(bool_t(), false_w), (bool_t(), true_w)], type_row![]) .unwrap(); assert_eq!( tlb.loop_signature().unwrap().signature().as_ref(), diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index 25f6cf798..69bcfabf6 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -180,7 +180,7 @@ mod test { use std::sync::Arc; use hugr_core::builder::{CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder}; - use hugr_core::extension::prelude::{usize_t, ConstUsize, PRELUDE_ID}; + use hugr_core::extension::prelude::{usize_t, ConstUsize}; use hugr_core::ops::{handle::NodeHandle, OpTag, OpTrait}; use hugr_core::types::Signature; use hugr_core::{ops::Value, type_row, HugrView}; @@ -192,9 +192,7 @@ mod test { #[test] fn test_cfg_callback() { - let mut cb = - CFGBuilder::new(Signature::new_endo(type_row![]).with_extension_delta(PRELUDE_ID)) - .unwrap(); + let mut cb = CFGBuilder::new(Signature::new_endo(type_row![])).unwrap(); let cst_unused = cb.add_constant(Value::from(ConstUsize::new(3))); let cst_used_in_dfg = cb.add_constant(Value::from(ConstUsize::new(5))); let cst_used = cb.add_constant(Value::unary_unit_sum()); diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index ec59ccefd..cbb637b2a 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -279,7 +279,7 @@ mod test { .iter(&hugr.as_petgraph()) .filter(|n| rank_map.contains_key(n)) .collect_vec(); - hugr.validate_no_extensions().unwrap(); + hugr.validate().unwrap(); topo_sorted } diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 403e3d84b..334127bab 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -94,7 +94,7 @@ mod test { #[fixture] fn noop_hugr() -> Hugr { - let mut b = DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); + let mut b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let out = b .add_dataflow_op(Noop::new(bool_t()), [b.input_wires().next().unwrap()]) .unwrap() diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 170ff3789..2c739c3d7 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -62,14 +62,10 @@ fn mk_rep( let mut replacement: Hugr = Hugr::new(cfg.root_optype().clone()); let merged = replacement.add_node_with_parent(replacement.root(), { - let mut merged_block = DataflowBlock { + DataflowBlock { inputs: pred_ty.inputs.clone(), ..succ_ty.clone() - }; - merged_block.extension_delta = merged_block - .extension_delta - .union(pred_ty.extension_delta.clone()); - merged_block + } }); let input = replacement.add_node_with_parent( merged, @@ -225,7 +221,7 @@ mod test { let e = extension(); let tst_op = e.instantiate_extension_op("Test", [])?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; - let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; + let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1)?; let n = no_b1.add_dataflow_op(Noop::new(qb_t()), no_b1.input_wires())?; let br = unary_unit_sum(&mut no_b1); let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 3ac85a020..cfe2c9514 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -2,7 +2,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, convert::Infallible, fmt::Write, - ops::Deref, }; use hugr_core::{ @@ -300,10 +299,6 @@ fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fm TypeArg::BoundedNat { n } => f.write_fmt(format_args!("n({n})")), TypeArg::String { arg } => f.write_fmt(format_args!("s({})", escape_dollar(arg))), TypeArg::Sequence { elems } => f.write_fmt(format_args!("seq({})", TypeArgsList(elems))), - TypeArg::Extensions { es } => f.write_fmt(format_args!( - "es({})", - es.iter().map(|x| x.deref()).join(",") - )), // We are monomorphizing. We will never monomorphize to a signature // containing a variable. TypeArg::Variable { .. } => panic!("type_arg_str variable: {arg}"), @@ -338,6 +333,7 @@ mod test { use std::iter; use hugr_core::extension::simple_op::MakeRegisteredOp as _; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections; use hugr_core::std_extensions::collections::array::{array_type_parametric, ArrayOpDef}; use hugr_core::types::type_param::TypeParam; @@ -347,16 +343,10 @@ mod test { Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; - use hugr_core::extension::prelude::{ - usize_t, ConstUsize, UnpackTuple, UnwrapBuilder, PRELUDE_ID, - }; - use hugr_core::extension::ExtensionSet; + use hugr_core::extension::prelude::{usize_t, ConstUsize, UnpackTuple, UnwrapBuilder}; use hugr_core::ops::handle::{FuncID, NodeHandle}; use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, FuncDefn, Tag}; - use hugr_core::std_extensions::arithmetic::int_types::{self, INT_TYPES}; - use hugr_core::types::{ - PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum, TypeRow, - }; + use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum}; use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; @@ -372,10 +362,6 @@ mod test { Type::new_tuple(vec![ty.clone(), ty.clone(), ty]) } - fn prelusig(ins: impl Into, outs: impl Into) -> Signature { - Signature::new(ins, outs).with_extension_delta(PRELUDE_ID) - } - #[test] fn test_null() { let dfg_builder = @@ -411,7 +397,7 @@ mod test { }; let tr = { - let sig = prelusig(tv0(), Type::new_tuple(vec![tv0(); 3])); + let sig = Signature::new(tv0(), Type::new_tuple(vec![tv0(); 3])); let mut fb = mb.define_function( "triple", PolyFuncType::new([TypeBound::Copyable.into()], sig), @@ -428,7 +414,7 @@ mod test { }; let mn = { let outs = vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))]; - let mut fb = mb.define_function("main", prelusig(usize_t(), outs))?; + let mut fb = mb.define_function("main", Signature::new(usize_t(), outs))?; let [elem] = fb.input_wires_arr(); let [res1] = fb .call(tr.handle(), &[usize_t().into()], [elem])? @@ -493,37 +479,30 @@ mod test { let n: u64 = 5; let mut outer = FunctionBuilder::new( "mainish", - prelusig( + Signature::new( array_type_parametric(sa(n), array_type_parametric(sa(2), usize_t()).unwrap()) .unwrap(), vec![usize_t(); 2], - ) - .with_extension_delta(collections::array::EXTENSION_ID), + ), ) .unwrap(); let arr2u = || array_type_parametric(sa(2), usize_t()).unwrap(); let pf1t = PolyFuncType::new( [TypeParam::max_nat()], - prelusig(array_type_parametric(sv(0), arr2u()).unwrap(), usize_t()) - .with_extension_delta(collections::array::EXTENSION_ID), + Signature::new(array_type_parametric(sv(0), arr2u()).unwrap(), usize_t()), ); let mut pf1 = outer.define_function("pf1", pf1t).unwrap(); let pf2t = PolyFuncType::new( [TypeParam::max_nat(), TypeBound::Copyable.into()], - prelusig(vec![array_type_parametric(sv(0), tv(1)).unwrap()], tv(1)) - .with_extension_delta(collections::array::EXTENSION_ID), + Signature::new(vec![array_type_parametric(sv(0), tv(1)).unwrap()], tv(1)), ); let mut pf2 = pf1.define_function("pf2", pf2t).unwrap(); let mono_func = { let mut fb = pf2 - .define_function( - "get_usz", - prelusig(vec![], usize_t()) - .with_extension_delta(collections::array::EXTENSION_ID), - ) + .define_function("get_usz", Signature::new(vec![], usize_t())) .unwrap(); let cst0 = fb.add_load_value(ConstUsize::new(1)); fb.finish_with_outputs([cst0]).unwrap() @@ -706,8 +685,6 @@ mod test { #[case::string(vec!["arg".into()], "$foo$$s(arg)")] #[case::dollar_string(vec!["$arg".into()], "$foo$$s(\\$arg)")] #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$seq($n(0)$t(Unit))")] - #[case::extensionset(vec![ExtensionSet::from_iter([PRELUDE_ID,int_types::EXTENSION_ID]).into()], - "$foo$$es(arithmetic.int.types,prelude)")] // alphabetic ordering of extension names #[should_panic] #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::String)], "$foo$$v(1)")] diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 6e9df7f1a..3c15ca6f2 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -577,7 +577,7 @@ pub(crate) mod test { use hugr_core::builder::{ endo_sig, BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder, }; - use hugr_core::extension::{prelude::usize_t, ExtensionSet}; + use hugr_core::extension::prelude::usize_t; use hugr_core::hugr::patch::insert_identity::{IdentityInsertion, IdentityInsertionError}; use hugr_core::hugr::views::RootChecked; @@ -612,11 +612,7 @@ pub(crate) mod test { let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); let entry = n_identity( - cfg_builder.simple_entry_builder_exts( - vec![usize_t()].into(), - 1, - ExtensionSet::new(), - )?, + cfg_builder.simple_entry_builder(vec![usize_t()].into(), 1)?, &const_unit, )?; let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 180e9d6fc..a2219d14f 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -54,8 +54,7 @@ mod test { #[test] fn ensures_no_nonlocal_edges() { let hugr = { - let mut builder = - DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [in_w] = builder.input_wires_arr(); let [out_w] = builder .add_dataflow_op(Noop::new(bool_t()), [in_w]) @@ -69,12 +68,11 @@ mod test { #[test] fn find_nonlocal_edges() { let (hugr, edge) = { - let mut builder = - DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [in_w] = builder.input_wires_arr(); let ([out_w], edge) = { let mut dfg_builder = builder - .dfg_builder(Signature::new(type_row![], bool_t()).with_prelude(), []) + .dfg_builder(Signature::new(type_row![], bool_t()), []) .unwrap(); let noop = dfg_builder .add_dataflow_op(Noop::new(bool_t()), [in_w]) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 25249f5ae..45bc25bcf 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -745,8 +745,7 @@ mod test { let inps = fb.input_wires(); let id = fb.finish_with_outputs(inps).unwrap(); - let sig = Signature::new(vec![i64_t(), c_int.clone(), c_bool.clone()], bool_t()) - .with_extension_delta(ext.name.clone()); + let sig = Signature::new(vec![i64_t(), c_int.clone(), c_bool.clone()], bool_t()); let mut fb = mb.define_function("main", sig).unwrap(); let [idx, indices, bools] = fb.input_wires_arr(); let [indices] = fb diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index b6e6e6780..573188340 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -3,7 +3,6 @@ use hugr_core::builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{option_type, UnwrapBuilder}; -use hugr_core::extension::ExtensionSet; use hugr_core::ops::{constant::OpaqueValue, Value}; use hugr_core::ops::{OpTrait, OpType, Tag}; use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; @@ -13,8 +12,8 @@ use hugr_core::std_extensions::collections::array::{ array_type, ArrayOpDef, ArrayRepeat, ArrayScan, ArrayValue, }; use hugr_core::std_extensions::collections::list::ListValue; +use hugr_core::type_row; use hugr_core::types::{SumType, Transformable, Type, TypeArg}; -use hugr_core::{type_row, Hugr, HugrView}; use itertools::Itertools; use super::{ @@ -67,10 +66,6 @@ pub fn array_const( Ok(Some(ArrayValue::new(elem_t, vals).into())) } -fn runtime_reqs(h: &Hugr) -> ExtensionSet { - h.signature(h.root()).unwrap().runtime_reqs.clone() -} - /// Handler for copying/discarding arrays if their elements have become linear. /// Included in [ReplaceTypes::default] and [DelegatingLinearizer::default]. /// @@ -97,7 +92,7 @@ pub fn linearize_array( dfb.finish_hugr_with_outputs([ret]).unwrap() }; // Now array.scan that over the input array to get an array of unit (which can be discarded) - let array_scan = ArrayScan::new(ty.clone(), Type::UNIT, vec![], *n, runtime_reqs(&map_fn)); + let array_scan = ArrayScan::new(ty.clone(), Type::UNIT, vec![], *n); let in_type = array_type(*n, ty.clone()); return Ok(NodeTemplate::CompoundOp(Box::new({ let mut dfb = DFGBuilder::new(inout_sig(in_type, type_row![])).unwrap(); @@ -131,8 +126,7 @@ pub fn linearize_array( .unwrap(); dfb.finish_hugr_with_outputs(none.outputs()).unwrap() }; - let repeats = - vec![ArrayRepeat::new(option_ty.clone(), *n, runtime_reqs(&fn_none)); num_new]; + let repeats = vec![ArrayRepeat::new(option_ty.clone(), *n); num_new]; let fn_none = dfb.add_load_value(Value::function(fn_none).unwrap()); repeats .into_iter() @@ -212,7 +206,6 @@ pub fn linearize_array( .chain(vec![option_array; num_new]) .collect(), *n, - runtime_reqs(©_elem), ); let copy_elem = dfb.add_load_value(Value::function(copy_elem).unwrap()); @@ -240,13 +233,7 @@ pub fn linearize_array( dfb.finish_hugr_with_outputs([val]).unwrap() }; - let unwrap_scan = ArrayScan::new( - option_ty.clone(), - ty.clone(), - vec![], - *n, - runtime_reqs(&unwrap_elem), - ); + let unwrap_scan = ArrayScan::new(option_ty.clone(), ty.clone(), vec![], *n); let unwrap_elem = dfb.add_load_value(Value::function(unwrap_elem).unwrap()); let out_arrays = std::iter::once(out_array1) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 2788a2379..81324dbee 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -622,10 +622,7 @@ mod test { NodeTemplate::SingleOp(copy3.clone()), NodeTemplate::SingleOp(discard.clone().into()), ); - let sig3 = Some( - Signature::new(lin_t.clone(), vec![lin_t.clone(); 3]) - .with_extension_delta(ext.name().clone()), - ); + let sig3 = Some(Signature::new(lin_t.clone(), vec![lin_t.clone(); 3])); assert_eq!( bad_copy, Err(LinearizeError::WrongSignature { @@ -782,11 +779,7 @@ mod test { let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); let discard_fn = { let mut fb = dfb - .define_function( - "drop", - Signature::new(lin_t.clone(), type_row![]) - .with_extension_delta(e.name().clone()), - ) + .define_function("drop", Signature::new(lin_t.clone(), type_row![])) .unwrap(); let ins = fb.input_wires(); fb.add_dataflow_op( diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index b2782e8d9..1c9be1c75 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -278,9 +278,7 @@ mod test { /// These can be removed entirely. #[fixture] fn unused_pack() -> Hugr { - let mut h = - DFGBuilder::new(Signature::new(vec![bool_t(), bool_t()], vec![]).with_prelude()) - .unwrap(); + let mut h = DFGBuilder::new(Signature::new(vec![bool_t(), bool_t()], vec![])).unwrap(); let mut inps = h.input_wires(); let b1 = inps.next().unwrap(); let b2 = inps.next().unwrap(); @@ -295,8 +293,7 @@ mod test { /// These can be removed entirely. #[fixture] fn simple_pack_unpack() -> Hugr { - let mut h = - DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()]).with_prelude()).unwrap(); + let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap(); let mut inps = h.input_wires(); let qb1 = inps.next().unwrap(); let b2 = inps.next().unwrap(); @@ -315,8 +312,7 @@ mod test { /// we just remove everything. #[fixture] fn ordered_pack_unpack() -> Hugr { - let mut h = - DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()]).with_prelude()).unwrap(); + let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap(); let mut inps = h.input_wires(); let qb1 = inps.next().unwrap(); let b2 = inps.next().unwrap(); @@ -338,13 +334,10 @@ mod test { /// These can be removed entirely. #[fixture] fn multi_unpack() -> Hugr { - let mut h = DFGBuilder::new( - Signature::new( - vec![bool_t(), bool_t()], - vec![bool_t(), bool_t(), bool_t(), bool_t()], - ) - .with_prelude(), - ) + let mut h = DFGBuilder::new(Signature::new( + vec![bool_t(), bool_t()], + vec![bool_t(), bool_t(), bool_t(), bool_t()], + )) .unwrap(); let mut inps = h.input_wires(); let b1 = inps.next().unwrap(); @@ -369,17 +362,14 @@ mod test { /// The unpack operation can be removed, but the pack operation cannot. #[fixture] fn partial_unpack() -> Hugr { - let mut h = DFGBuilder::new( - Signature::new( - vec![bool_t(), bool_t()], - vec![ - bool_t(), - bool_t(), - Type::new_tuple(vec![bool_t(), bool_t()]), - ], - ) - .with_prelude(), - ) + let mut h = DFGBuilder::new(Signature::new( + vec![bool_t(), bool_t()], + vec![ + bool_t(), + bool_t(), + Type::new_tuple(vec![bool_t(), bool_t()]), + ], + )) .unwrap(); let mut inps = h.input_wires(); let b1 = inps.next().unwrap(); diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index 6c3e61fb4..90d338faf 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -16,11 +16,8 @@ use hugr_core::HugrView; pub enum ValidationLevel { /// Do no verification. None, - /// Validate using [HugrView::validate_no_extensions]. This is useful when you - /// do not expect valid Extension annotations on Nodes. - WithoutExtensions, /// Validate using [HugrView::validate]. - WithExtensions, + Validate, } #[derive(Error, Debug, PartialEq)] @@ -44,8 +41,7 @@ pub enum ValidatePassError { impl Default for ValidationLevel { fn default() -> Self { if cfg!(test) { - // Many tests fail when run with Self::WithExtensions - Self::WithoutExtensions + Self::Validate } else { Self::None } @@ -86,8 +82,7 @@ impl ValidationLevel { { match self { ValidationLevel::None => Ok(()), - ValidationLevel::WithoutExtensions => hugr.validate_no_extensions(), - ValidationLevel::WithExtensions => hugr.validate(), + ValidationLevel::Validate => hugr.validate(), } .map_err(|err| mk_err(err, hugr.mermaid_string()).into()) } diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 95e59754e..3bb377ed5 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -86,9 +86,7 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): def deserialize(self, extension: ext.Extension) -> ext.OpDef: signature = ext.OpDefSig( - self.signature.deserialize().with_runtime_reqs([extension.name]) - if self.signature - else None, + self.signature.deserialize() if self.signature else None, self.binary, ) @@ -106,7 +104,6 @@ def deserialize(self, extension: ext.Extension) -> ext.OpDef: class Extension(ConfiguredBaseModel): version: SemanticVersion name: ExtensionId - runtime_reqs: set[ExtensionId] types: dict[str, TypeDef] operations: dict[str, OpDef] @@ -118,7 +115,6 @@ def deserialize(self) -> ext.Extension: e = ext.Extension( version=self.version, # type: ignore[arg-type] name=self.name, - runtime_reqs=self.runtime_reqs, ) for k, t in self.types.items(): diff --git a/hugr-py/src/hugr/_serialization/ops.py b/hugr-py/src/hugr/_serialization/ops.py index 48b4e6b87..28a1daf5e 100644 --- a/hugr-py/src/hugr/_serialization/ops.py +++ b/hugr-py/src/hugr/_serialization/ops.py @@ -206,7 +206,6 @@ class DataflowBlock(BaseOp): inputs: TypeRow = Field(default_factory=list) other_outputs: TypeRow = Field(default_factory=list) sum_rows: list[TypeRow] - extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: num_cases = len(out_types) @@ -384,13 +383,11 @@ class DFG(DataflowOp): signature: FunctionType = Field(default_factory=FunctionType.empty) def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: - self.signature = FunctionType( - input=list(inputs), output=list(outputs), runtime_reqs=ExtensionSet([]) - ) + self.signature = FunctionType(input=list(inputs), output=list(outputs)) def deserialize(self) -> ops.DFG: sig = self.signature.deserialize() - return ops.DFG(sig.input, sig.output, sig.runtime_reqs) + return ops.DFG(sig.input, sig.output) # ------------------------------------------------ @@ -407,8 +404,6 @@ class Conditional(DataflowOp): sum_rows: list[TypeRow] = Field( description="The possible rows of the Sum input", default_factory=list ) - # Extensions used to produce the outputs - extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: # First port is a predicate, i.e. a sum of tuple types. We need to unpack @@ -442,9 +437,7 @@ class Case(BaseOp): signature: FunctionType = Field(default_factory=FunctionType.empty) def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: - self.signature = stys.FunctionType( - input=list(inputs), output=list(outputs), runtime_reqs=ExtensionSet([]) - ) + self.signature = stys.FunctionType(input=list(inputs), output=list(outputs)) def deserialize(self) -> ops.Case: sig = self.signature.deserialize() @@ -455,11 +448,12 @@ class TailLoop(DataflowOp): """Tail-controlled loop.""" op: Literal["TailLoop"] = "TailLoop" - just_inputs: TypeRow = Field(default_factory=list) # Types that are only input - just_outputs: TypeRow = Field(default_factory=list) # Types that are only output + # Types that are only input + just_inputs: TypeRow = Field(default_factory=list) + # Types that are only output + just_outputs: TypeRow = Field(default_factory=list) # Types that are appended to both input and output: rest: TypeRow = Field(default_factory=list) - extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert in_types == out_types @@ -472,7 +466,6 @@ def deserialize(self) -> ops.TailLoop: just_inputs=deser_it(self.just_inputs), _just_outputs=deser_it(self.just_outputs), rest=deser_it(self.rest), - extension_delta=self.extension_delta, ) @@ -484,7 +477,8 @@ class CFG(DataflowOp): def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None: self.signature = FunctionType( - input=list(inputs), output=list(outputs), runtime_reqs=ExtensionSet([]) + input=list(inputs), + output=list(outputs), ) def deserialize(self) -> ops.CFG: diff --git a/hugr-py/src/hugr/_serialization/tys.py b/hugr-py/src/hugr/_serialization/tys.py index 4a0a0e75b..c00a73375 100644 --- a/hugr-py/src/hugr/_serialization/tys.py +++ b/hugr-py/src/hugr/_serialization/tys.py @@ -110,23 +110,11 @@ def deserialize(self) -> tys.TupleParam: return tys.TupleParam(params=deser_it(self.params)) -class ExtensionsParam(BaseTypeParam): - tp: Literal["Extensions"] = "Extensions" - - def deserialize(self) -> tys.ExtensionsParam: - return tys.ExtensionsParam() - - class TypeParam(RootModel): """A type parameter.""" root: Annotated[ - TypeTypeParam - | BoundedNatParam - | StringParam - | ListParam - | TupleParam - | ExtensionsParam, + TypeTypeParam | BoundedNatParam | StringParam | ListParam | TupleParam, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tp") @@ -178,14 +166,6 @@ def deserialize(self) -> tys.SequenceArg: return tys.SequenceArg(elems=deser_it(self.elems)) -class ExtensionsArg(BaseTypeArg): - tya: Literal["Extensions"] = "Extensions" - es: ExtensionSet - - def deserialize(self) -> tys.ExtensionsArg: - return tys.ExtensionsArg(extensions=self.es) - - class VariableArg(BaseTypeArg): tya: Literal["Variable"] = "Variable" idx: int @@ -199,12 +179,7 @@ class TypeArg(RootModel): """A type argument.""" root: Annotated[ - TypeTypeArg - | BoundedNatArg - | StringArg - | SequenceArg - | ExtensionsArg - | VariableArg, + TypeTypeArg | BoundedNatArg | StringArg | SequenceArg | VariableArg, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tya") @@ -307,18 +282,15 @@ class FunctionType(BaseType): input: TypeRow # Value inputs of the function. output: TypeRow # Value outputs of the function. - # The extension requirements which are added by the operation - runtime_reqs: ExtensionSet = Field(default_factory=ExtensionSet) @classmethod def empty(cls) -> FunctionType: - return FunctionType(input=[], output=[], runtime_reqs=[]) + return FunctionType(input=[], output=[]) def deserialize(self) -> tys.FunctionType: return tys.FunctionType( input=deser_it(self.input), output=deser_it(self.output), - runtime_reqs=self.runtime_reqs, ) model_config = ConfigDict( diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 7bd02f982..fd59da0fc 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -235,13 +235,6 @@ def instantiate( concrete_signature: Concrete function type of the operation, only required if the operation is polymorphic. """ - # Add the extension where the operation is defined as a runtime requirement. - # We don't store this in the json definition as it is redundant information. - if concrete_signature is not None: - concrete_signature = concrete_signature.with_runtime_reqs( - [self.get_extension().name] - ) - return ops.ExtOp(self, concrete_signature, list(args or [])) @@ -256,8 +249,6 @@ class Extension: name: ExtensionId #: The version of the extension. version: Version - #: Extensions required by this extension at runtime, identified by name. - runtime_reqs: set[ExtensionId] = field(default_factory=set) #: Type definitions in the extension. types: dict[str, TypeDef] = field(default_factory=dict) #: Operation definitions in the extension. @@ -273,7 +264,6 @@ def _to_serial(self) -> ext_s.Extension: return ext_s.Extension( name=self.name, version=self.version, # type: ignore[arg-type] - runtime_reqs=self.runtime_reqs, types={k: v._to_serial() for k, v in self.types.items()}, operations={k: v._to_serial() for k, v in self.operations.items()}, ) @@ -303,12 +293,6 @@ def add_op_def(self, op_def: OpDef) -> OpDef: Returns: The added operation definition, now associated with the extension. """ - if op_def.signature.poly_func is not None: - # Ensure the op def signature has the extension as a requirement - op_def.signature.poly_func = op_def.signature.poly_func.with_runtime_reqs( - [self.name] - ) - op_def._extension = self self.operations[op_def.name] = op_def return self.operations[op_def.name] diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 1555dab4d..b6030b6a0 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -456,7 +456,6 @@ def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType( input=self.types, output=[tys.Tuple(*self.types)], - runtime_reqs=["prelude"], ) def type_args(self) -> list[tys.TypeArg]: @@ -499,7 +498,6 @@ def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType( input=[tys.Tuple(*self.types)], output=self.types, - runtime_reqs=["prelude"], ) def type_args(self) -> list[tys.TypeArg]: @@ -632,7 +630,6 @@ class DFG(DfParentOp, DataflowOp): #: Inputs types of the operation. inputs: tys.TypeRow _outputs: tys.TypeRow | None = field(default=None, repr=False) - _extension_delta: tys.ExtensionSet = field(default_factory=list, repr=False) @property def outputs(self) -> tys.TypeRow: @@ -650,7 +647,7 @@ def signature(self) -> tys.FunctionType: Raises: IncompleteOp: If the outputs have not been set. """ - return tys.FunctionType(self.inputs, self.outputs, self._extension_delta) + return tys.FunctionType(self.inputs, self.outputs) @property def num_out(self) -> int: @@ -729,7 +726,6 @@ class DataflowBlock(DfParentOp): inputs: tys.TypeRow _sum: tys.Sum | None = None _other_outputs: tys.TypeRow | None = field(default=None, repr=False) - extension_delta: tys.ExtensionSet = field(default_factory=list) @property def sum_ty(self) -> tys.Sum: @@ -762,7 +758,6 @@ def _to_serial(self, parent: Node) -> sops.DataflowBlock: inputs=ser_it(self.inputs), sum_rows=list(map(ser_it, self.sum_ty.variant_rows)), other_outputs=ser_it(self.other_outputs), - extension_delta=self.extension_delta, ) def inner_signature(self) -> tys.FunctionType: @@ -993,7 +988,6 @@ class TailLoop(DfParentOp, DataflowOp): #: Types that are appended to both inputs and outputs of the graph. rest: tys.TypeRow _just_outputs: tys.TypeRow | None = field(default=None, repr=False) - extension_delta: tys.ExtensionSet = field(default_factory=list, repr=False) @property def just_outputs(self) -> tys.TypeRow: @@ -1014,7 +1008,6 @@ def _to_serial(self, parent: Node) -> sops.TailLoop: just_inputs=ser_it(self.just_inputs), just_outputs=ser_it(self.just_outputs), rest=ser_it(self.rest), - extension_delta=self.extension_delta, ) def inner_signature(self) -> tys.FunctionType: @@ -1334,13 +1327,11 @@ def type_args(self) -> list[tys.TypeArg]: def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType.endo( [self.type_], - runtime_reqs=["prelude"], ) def outer_signature(self) -> tys.FunctionType: return tys.FunctionType.endo( [self.type_], - runtime_reqs=["prelude"], ) def _set_in_types(self, types: tys.TypeRow) -> None: diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json index 1d310df25..3c2fb983c 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json @@ -1,10 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.conversions", - "runtime_reqs": [ - "arithmetic.float.types", - "arithmetic.int.types" - ], "types": {}, "operations": { "bytecast_float64_to_int64": { @@ -36,8 +32,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -71,8 +66,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -115,8 +109,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -159,8 +152,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -192,8 +184,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -223,8 +214,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -256,8 +246,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -300,8 +289,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -344,8 +332,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -375,8 +362,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -436,8 +422,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -497,8 +482,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json index 8da056772..60180ec84 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json @@ -1,9 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, "operations": { "fabs": { @@ -30,8 +27,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -67,8 +63,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -97,8 +92,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -134,8 +128,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -169,8 +162,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -199,8 +191,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -234,8 +225,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -269,8 +259,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -304,8 +293,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -339,8 +327,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -376,8 +363,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -413,8 +399,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -450,8 +435,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -485,8 +469,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -515,8 +498,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -552,8 +534,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -582,8 +563,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -619,8 +599,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -649,8 +628,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json index 0c563c474..33db43f5b 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float.types", - "runtime_reqs": [], "types": { "float64": { "extension": "arithmetic.float.types", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json index 5b1a81250..e8e6fdca8 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json @@ -1,9 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, "operations": { "iabs": { @@ -53,8 +50,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -122,8 +118,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -191,8 +186,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -277,8 +271,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -363,8 +356,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -432,8 +424,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -501,8 +492,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -611,8 +601,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -721,8 +710,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -806,8 +794,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -891,8 +878,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -949,8 +935,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1007,8 +992,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1065,8 +1049,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1123,8 +1106,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1181,8 +1163,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1239,8 +1220,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1297,8 +1277,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1355,8 +1334,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1413,8 +1391,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1482,8 +1459,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1551,8 +1527,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1620,8 +1595,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1689,8 +1663,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1775,8 +1748,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1861,8 +1833,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1930,8 +1901,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1999,8 +1969,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2068,8 +2037,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2142,8 +2110,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2216,8 +2183,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2274,8 +2240,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2327,8 +2292,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2380,8 +2344,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2449,8 +2412,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2518,8 +2480,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2587,8 +2548,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2656,8 +2616,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2709,8 +2668,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2778,8 +2736,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2847,8 +2804,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2916,8 +2872,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2969,8 +2924,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -3026,8 +2980,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3083,8 +3036,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3152,8 +3104,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json index 36df125a6..0b77d2e55 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int.types", - "runtime_reqs": [], "types": { "int": { "extension": "arithmetic.int.types", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/array.json b/hugr-py/src/hugr/std/_json_defs/collections/array.json index 375e13c72..fba222793 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.array", - "runtime_reqs": [], "types": { "array": { "extension": "collections.array", @@ -60,8 +59,7 @@ "bound": "A" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false @@ -126,8 +124,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -166,9 +163,6 @@ { "tp": "Type", "b": "A" - }, - { - "tp": "Extensions" } ], "body": { @@ -182,9 +176,6 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [ - "2" ] } ], @@ -213,8 +204,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -243,9 +233,6 @@ "tp": "Type", "b": "A" } - }, - { - "tp": "Extensions" } ], "body": { @@ -299,9 +286,6 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [ - "4" ] }, { @@ -340,8 +324,7 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -465,8 +448,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -578,8 +560,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/collections/list.json b/hugr-py/src/hugr/std/_json_defs/collections/list.json index 8a60d3544..de9736e4e 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/list.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/list.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.list", - "runtime_reqs": [], "types": { "List": { "extension": "collections.list", @@ -70,8 +69,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -151,8 +149,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -207,8 +204,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -274,8 +270,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -332,8 +327,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -413,8 +407,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json index 53b8e61c7..cde35e063 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.static_array", - "runtime_reqs": [], "types": { "static_array": { "extension": "collections.static_array", @@ -68,8 +67,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -108,8 +106,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/logic.json b/hugr-py/src/hugr/std/_json_defs/logic.json index ff29d2c21..45cd7f606 100644 --- a/hugr-py/src/hugr/std/_json_defs/logic.json +++ b/hugr-py/src/hugr/std/_json_defs/logic.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "logic", - "runtime_reqs": [], "types": {}, "operations": { "And": { @@ -29,8 +28,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -60,8 +58,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -86,8 +83,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -117,8 +113,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -148,8 +143,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index ec392b155..7cf1d02c7 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -1,7 +1,6 @@ { "version": "0.2.0", "name": "prelude", - "runtime_reqs": [], "types": { "error": { "extension": "prelude", @@ -73,8 +72,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -115,8 +113,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -146,8 +143,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -188,8 +184,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -236,8 +231,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -259,8 +253,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -307,8 +300,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -329,8 +321,7 @@ "bound": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/ptr.json b/hugr-py/src/hugr/std/_json_defs/ptr.json index 614b6aecf..d701fff53 100644 --- a/hugr-py/src/hugr/std/_json_defs/ptr.json +++ b/hugr-py/src/hugr/std/_json_defs/ptr.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "ptr", - "runtime_reqs": [], "types": { "ptr": { "extension": "ptr", @@ -56,8 +55,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -98,8 +96,7 @@ "i": 0, "b": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -139,8 +136,7 @@ "b": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 4c1d0cdeb..27432a3d5 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -93,7 +93,7 @@ def type_args(self) -> list[tys.TypeArg]: def cached_signature(self) -> tys.FunctionType | None: row: list[tys.Type] = [int_t(self.width)] * 2 - return tys.FunctionType.endo(row, runtime_reqs=[INT_OPS_EXTENSION.name]) + return tys.FunctionType.endo(row) @classmethod def from_ext(cls, custom: ExtOp) -> Self | None: diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index fbaadf7d3..8411f19bf 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -188,21 +188,6 @@ def to_model(self) -> model.Term: return model.Apply("core.tuple", [item_types]) -@dataclass(frozen=True) -class ExtensionsParam(TypeParam): - """An extension set parameter.""" - - def _to_serial(self) -> stys.ExtensionsParam: - return stys.ExtensionsParam() - - def __str__(self) -> str: - return "Extensions" - - def to_model(self) -> model.Term: - # Since extension sets will be deprecated, this is just a placeholder. - return model.Apply("compat.ext_set_type") - - # ------------------------------------------ # --------------- TypeArg ------------------ # ------------------------------------------ @@ -280,23 +265,6 @@ def to_model(self) -> model.Term: return model.List([elem.to_model() for elem in self.elems]) -@dataclass(frozen=True) -class ExtensionsArg(TypeArg): - """Type argument for an :class:`ExtensionsParam`.""" - - extensions: ExtensionSet - - def _to_serial(self) -> stys.ExtensionsArg: - return stys.ExtensionsArg(es=self.extensions) - - def __str__(self) -> str: - return f"Extensions({comma_sep_str(self.extensions)})" - - def to_model(self) -> model.Term: - # Since extension sets will be deprecated, this is just a placeholder. - return model.Apply("compat.ext_set") - - @dataclass(frozen=True) class VariableArg(TypeArg): """A type argument variable.""" @@ -518,7 +486,6 @@ class FunctionType(Type): input: TypeRow output: TypeRow - runtime_reqs: ExtensionSet = field(default_factory=ExtensionSet) def type_bound(self) -> TypeBound: return TypeBound.Copyable @@ -527,7 +494,6 @@ def _to_serial(self) -> stys.FunctionType: return stys.FunctionType( input=ser_it(self.input), output=ser_it(self.output), - runtime_reqs=self.runtime_reqs, ) @classmethod @@ -541,16 +507,14 @@ def empty(cls) -> FunctionType: return cls(input=[], output=[]) @classmethod - def endo( - cls, tys: TypeRow, runtime_reqs: ExtensionSet | None = None - ) -> FunctionType: + def endo(cls, tys: TypeRow) -> FunctionType: """Function type with the same input and output types. Example: >>> FunctionType.endo([Qubit]) FunctionType([Qubit], [Qubit]) """ - return cls(input=tys, output=tys, runtime_reqs=runtime_reqs or ExtensionSet()) + return cls(input=tys, output=tys) def flip(self) -> FunctionType: """Return a new function type with input and output types swapped. @@ -569,17 +533,8 @@ def resolve(self, registry: ext.ExtensionRegistry) -> FunctionType: return FunctionType( input=[ty.resolve(registry) for ty in self.input], output=[ty.resolve(registry) for ty in self.output], - runtime_reqs=self.runtime_reqs, ) - def with_runtime_reqs(self, runtime_reqs: ExtensionSet) -> FunctionType: - """Adds a list of extension requirements to the function type, and - returns the new signature. - """ - exts = set(self.runtime_reqs) - exts = exts.union(runtime_reqs) - return FunctionType(self.input, self.output, [*exts]) - def __str__(self) -> str: return f"{comma_sep_str(self.input)} -> {comma_sep_str(self.output)}" @@ -614,15 +569,6 @@ def resolve(self, registry: ext.ExtensionRegistry) -> PolyFuncType: body=self.body.resolve(registry), ) - def with_runtime_reqs(self, runtime_reqs: ExtensionSet) -> PolyFuncType: - """Adds a list of extension requirements to the function type, and - returns the new signature. - """ - return PolyFuncType( - params=self.params, - body=self.body.with_runtime_reqs(runtime_reqs), - ) - def __str__(self) -> str: return f"∀ {comma_sep_str(self.params)}. {self.body!s}" diff --git a/hugr-py/tests/serialization/test_extension.py b/hugr-py/tests/serialization/test_extension.py index cf595319a..7f1ea28bf 100644 --- a/hugr-py/tests/serialization/test_extension.py +++ b/hugr-py/tests/serialization/test_extension.py @@ -25,7 +25,6 @@ { "version": "0.1.0", "name": "ext", - "runtime_reqs": [], "types": { "foo": { "extension": "ext", @@ -64,8 +63,7 @@ "b": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "lower_funcs": [] @@ -99,7 +97,6 @@ def test_extension(): ext = Extension( version=SemanticVersion(0, 1, 0), name="ext", - runtime_reqs=set(), types={"foo": type_def}, values={}, operations={"New": op_def}, @@ -121,7 +118,6 @@ def test_package(): ext = Extension( version=SemanticVersion(0, 1, 0), name="ext", - runtime_reqs=set(), types={}, values={}, operations={}, diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 3018bf863..48f57de7a 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -37,7 +37,7 @@ def type_args(self) -> list[tys.TypeArg]: return [tys.StringArg(self.tag)] def cached_signature(self) -> tys.FunctionType | None: - return tys.FunctionType.endo([], runtime_reqs=[STRINGLY_EXT.name]) + return tys.FunctionType.endo([]) @classmethod def from_ext(cls, custom: ops.ExtOp) -> "StringlyOp": diff --git a/hugr-py/tests/test_tys.py b/hugr-py/tests/test_tys.py index e2c6d7d51..33bf55561 100644 --- a/hugr-py/tests/test_tys.py +++ b/hugr-py/tests/test_tys.py @@ -14,8 +14,6 @@ BoundedNatArg, BoundedNatParam, Either, - ExtensionsArg, - ExtensionsParam, ExtType, FunctionType, ListParam, @@ -95,7 +93,6 @@ def test_tys_sum_str(ty: Type, string: str, repr_str: str): "(Any, Nat(3))", ), (ListParam(StringParam()), "[String]"), - (ExtensionsParam(), "Extensions"), ], ) def test_params_str(param: TypeParam, string: str): @@ -113,7 +110,6 @@ def test_params_str(param: TypeParam, string: str): "(Type(Qubit), 3)", ), (VariableArg(2, StringParam()), "$2"), - (ExtensionsArg(["A", "B"]), "Extensions(A, B)"), ], ) def test_args_str(arg: TypeArg, string: str): diff --git a/hugr/Cargo.toml b/hugr/Cargo.toml index 3763366ae..3385df9f1 100644 --- a/hugr/Cargo.toml +++ b/hugr/Cargo.toml @@ -24,7 +24,6 @@ path = "src/lib.rs" [features] default = ["zstd"] -extension_inference = ["hugr-core/extension_inference"] declarative = ["hugr-core/declarative"] llvm = ["hugr-llvm/llvm14-0"] llvm-test = ["hugr-llvm/llvm14-0", "hugr-llvm/test-utils"] diff --git a/hugr/README.md b/hugr/README.md index b54d4f62d..83a2cc501 100644 --- a/hugr/README.md +++ b/hugr/README.md @@ -1,7 +1,6 @@ ![](/hugr/assets/hugr_logo.svg) -hugr -=============== +# hugr [![build_status][]](https://github.com/CQCL/hugr/actions) [![crates][]](https://crates.io/crates/hugr) @@ -29,10 +28,6 @@ Please read the [API documentation here][]. ## Experimental Features -- `extension_inference`: - Experimental feature which allows automatic inference of which extra extensions - are required at runtime by a HUGR when validating it. - Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 0ece1eefb..2b7676439 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -7,9 +7,8 @@ use hugr::builder::{ HugrBuilder, ModuleBuilder, }; use hugr::extension::prelude::{bool_t, qb_t, usize_t}; -use hugr::extension::ExtensionSet; use hugr::ops::OpName; -use hugr::std_extensions::arithmetic::float_types::{self, float64_type, ConstF64}; +use hugr::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; use hugr::types::Signature; use hugr::{type_row, CircuitUnit, Extension, Hugr, Node}; use lazy_static::lazy_static; @@ -97,11 +96,7 @@ pub fn circuit(layers: usize) -> (Hugr, Vec) { let h_gate = QUANTUM_EXT.instantiate_extension_op("H", []).unwrap(); let cx_gate = QUANTUM_EXT.instantiate_extension_op("CX", []).unwrap(); let rz = QUANTUM_EXT.instantiate_extension_op("Rz", []).unwrap(); - let signature = - Signature::new_endo(vec![qb_t(), qb_t()]).with_extension_delta(ExtensionSet::from_iter([ - QUANTUM_EXT.name().clone(), - float_types::EXTENSION_ID, - ])); + let signature = Signature::new_endo(vec![qb_t(), qb_t()]); let mut module_builder = ModuleBuilder::new(); let mut f_build = module_builder.define_function("main", signature).unwrap(); diff --git a/justfile b/justfile index 7b8075f94..d7e3f81f2 100644 --- a/justfile +++ b/justfile @@ -23,7 +23,7 @@ test-rust: HUGR_TEST_SCHEMA=1 cargo test \ --workspace \ --exclude 'hugr-py' \ - --features 'hugr/extension_inference hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' + --features 'hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' # Run all python tests. test-python: uv run maturin develop --uv diff --git a/specification/hugr.md b/specification/hugr.md index 6204e0e4f..3bd22b8ef 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -891,71 +891,6 @@ See [Declarative Format](#declarative-format) for more examples. Note that since a row variable does not have kind Type, it cannot be used as the type of an edge. -### Extension Tracking - -The type of `Function` includes a set of [extensions](#extension-system) which are required to execute the graph. -Similarly, every dataflow node in the HUGR has a set of extensions required to execute the node (computed from its operation), -also known as the "delta". The delta of any node must be a subset of its parent's delta, -except for FuncDefn's: -* the delta of any child of a FuncDefn must be a subset of the extensions in the FuncDefn's *type* -* the FuncDefn itself has no delta (trivially a subset of any parent): this reflects that the extensions -are not needed to *know* the FuncDefn, only to *execute* it -(by a Call node, whose delta is taken from the called FuncDefn's *type*). - -Keeping track of the extension requirements like this allows extension designers -and third-party tooling to control how/where a module is run. - -Concretely, if a plugin writer adds an extension -*X*, then some function from -a plugin needs to provide a mechanism to convert the -*X* to some other extension -requirement before it can interface with other plugins which don't know -about *X*. - -A runtime could have access to means of -running different extensions. By the same mechanism, the runtime can reason -about where to run different parts of the graph by inspecting their -extension requirements. - -Special operations **lift** and **liftGraph** can add extension requirements: -* `lift>` is a node with input and output rows `R` and extension-delta `{E}` -* `liftGraph, E: ExtensionSet, O: List>` has one input -$ \vec{I}^{\underrightarrow{\;E\;}}\vec{O} $ and one output $ \vec{I}^{\underrightarrow{\;E \cup N\;}}\vec{O}$. -That is, given a graph, it adds extensions $N$ to the requirements of the graph. - -The latter is useful for higher-order operations such as conditionally selecting -one function or another, where the output must have a consistent type (including -the extension-requirements of the function). - -### Rewriting Extension Requirements - -Extension requirements help denote different runtime capabilities. -For example, a quantum computer may not be able to handle arithmetic -while running a circuit, so its use is tracked in the function type so that -rewrites can be performed which remove the arithmetic. - -Simple circuits may look something like: - -```haskell -Function[Quantum](Array(5, Q), (ms: Array(5, Qubit), results: Array(5, Bit))) -``` - -A circuit built using a higher-order extension to manage control flow -could then look like: - -```haskell -Function[Quantum, HigherOrder](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit))) -``` - -So the compiler would need to perform some graph transformation pass to turn the -graph-based control flow into a CFG node that a quantum computer could -run, which removes the `HigherOrder` extension requirement. - -```haskell -precompute :: Function[](Function[Quantum,HigherOrder](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit))), - Function[Quantum](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit)))) -``` - ## Extension System ### Goals and constraints diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index ea08dff5b..02889a3f4 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -528,7 +506,6 @@ "required": [ "version", "name", - "runtime_reqs", "types", "operations" ], @@ -581,42 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionsArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -742,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1541,13 +1475,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1654,7 +1581,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1675,9 +1601,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1753,7 +1676,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1776,9 +1698,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 8b65bae94..558f64c57 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -528,7 +506,6 @@ "required": [ "version", "name", - "runtime_reqs", "types", "operations" ], @@ -581,42 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionsArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -742,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1541,13 +1475,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1654,7 +1581,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1675,9 +1601,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1753,7 +1676,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1776,9 +1698,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index 91b121da6..f534a3cbd 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -528,7 +506,6 @@ "required": [ "version", "name", - "runtime_reqs", "types", "operations" ], @@ -581,42 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionsArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -742,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1540,13 +1474,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1732,7 +1659,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1753,9 +1679,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1831,7 +1754,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1854,9 +1776,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index eae6a13a7..eb3fcff0f 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -528,7 +506,6 @@ "required": [ "version", "name", - "runtime_reqs", "types", "operations" ], @@ -581,42 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionsArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -742,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1540,13 +1474,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1732,7 +1659,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1753,9 +1679,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1831,7 +1754,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1854,9 +1776,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/std_extensions/arithmetic/conversions.json b/specification/std_extensions/arithmetic/conversions.json index 1d310df25..3c2fb983c 100644 --- a/specification/std_extensions/arithmetic/conversions.json +++ b/specification/std_extensions/arithmetic/conversions.json @@ -1,10 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.conversions", - "runtime_reqs": [ - "arithmetic.float.types", - "arithmetic.int.types" - ], "types": {}, "operations": { "bytecast_float64_to_int64": { @@ -36,8 +32,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -71,8 +66,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -115,8 +109,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -159,8 +152,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -192,8 +184,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -223,8 +214,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -256,8 +246,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -300,8 +289,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -344,8 +332,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -375,8 +362,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -436,8 +422,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -497,8 +482,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/arithmetic/float.json b/specification/std_extensions/arithmetic/float.json index 8da056772..60180ec84 100644 --- a/specification/std_extensions/arithmetic/float.json +++ b/specification/std_extensions/arithmetic/float.json @@ -1,9 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, "operations": { "fabs": { @@ -30,8 +27,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -67,8 +63,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -97,8 +92,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -134,8 +128,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -169,8 +162,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -199,8 +191,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -234,8 +225,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -269,8 +259,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -304,8 +293,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -339,8 +327,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -376,8 +363,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -413,8 +399,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -450,8 +435,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -485,8 +469,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -515,8 +498,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -552,8 +534,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -582,8 +563,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -619,8 +599,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -649,8 +628,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/arithmetic/float/types.json b/specification/std_extensions/arithmetic/float/types.json index 0c563c474..33db43f5b 100644 --- a/specification/std_extensions/arithmetic/float/types.json +++ b/specification/std_extensions/arithmetic/float/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float.types", - "runtime_reqs": [], "types": { "float64": { "extension": "arithmetic.float.types", diff --git a/specification/std_extensions/arithmetic/int.json b/specification/std_extensions/arithmetic/int.json index 5b1a81250..e8e6fdca8 100644 --- a/specification/std_extensions/arithmetic/int.json +++ b/specification/std_extensions/arithmetic/int.json @@ -1,9 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, "operations": { "iabs": { @@ -53,8 +50,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -122,8 +118,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -191,8 +186,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -277,8 +271,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -363,8 +356,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -432,8 +424,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -501,8 +492,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -611,8 +601,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -721,8 +710,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -806,8 +794,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -891,8 +878,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -949,8 +935,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1007,8 +992,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1065,8 +1049,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1123,8 +1106,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1181,8 +1163,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1239,8 +1220,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1297,8 +1277,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1355,8 +1334,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1413,8 +1391,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1482,8 +1459,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1551,8 +1527,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1620,8 +1595,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1689,8 +1663,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1775,8 +1748,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1861,8 +1833,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1930,8 +1901,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1999,8 +1969,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2068,8 +2037,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2142,8 +2110,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2216,8 +2183,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2274,8 +2240,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2327,8 +2292,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2380,8 +2344,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2449,8 +2412,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2518,8 +2480,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2587,8 +2548,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2656,8 +2616,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2709,8 +2668,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2778,8 +2736,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2847,8 +2804,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2916,8 +2872,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2969,8 +2924,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -3026,8 +2980,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3083,8 +3036,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3152,8 +3104,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/arithmetic/int/types.json b/specification/std_extensions/arithmetic/int/types.json index 36df125a6..0b77d2e55 100644 --- a/specification/std_extensions/arithmetic/int/types.json +++ b/specification/std_extensions/arithmetic/int/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int.types", - "runtime_reqs": [], "types": { "int": { "extension": "arithmetic.int.types", diff --git a/specification/std_extensions/collections/array.json b/specification/std_extensions/collections/array.json index 375e13c72..fba222793 100644 --- a/specification/std_extensions/collections/array.json +++ b/specification/std_extensions/collections/array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.array", - "runtime_reqs": [], "types": { "array": { "extension": "collections.array", @@ -60,8 +59,7 @@ "bound": "A" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false @@ -126,8 +124,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -166,9 +163,6 @@ { "tp": "Type", "b": "A" - }, - { - "tp": "Extensions" } ], "body": { @@ -182,9 +176,6 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [ - "2" ] } ], @@ -213,8 +204,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -243,9 +233,6 @@ "tp": "Type", "b": "A" } - }, - { - "tp": "Extensions" } ], "body": { @@ -299,9 +286,6 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [ - "4" ] }, { @@ -340,8 +324,7 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -465,8 +448,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -578,8 +560,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/collections/list.json b/specification/std_extensions/collections/list.json index 8a60d3544..de9736e4e 100644 --- a/specification/std_extensions/collections/list.json +++ b/specification/std_extensions/collections/list.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.list", - "runtime_reqs": [], "types": { "List": { "extension": "collections.list", @@ -70,8 +69,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -151,8 +149,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -207,8 +204,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -274,8 +270,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -332,8 +327,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -413,8 +407,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/collections/static_array.json b/specification/std_extensions/collections/static_array.json index 53b8e61c7..cde35e063 100644 --- a/specification/std_extensions/collections/static_array.json +++ b/specification/std_extensions/collections/static_array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.static_array", - "runtime_reqs": [], "types": { "static_array": { "extension": "collections.static_array", @@ -68,8 +67,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -108,8 +106,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index ff29d2c21..45cd7f606 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "logic", - "runtime_reqs": [], "types": {}, "operations": { "And": { @@ -29,8 +28,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -60,8 +58,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -86,8 +83,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -117,8 +113,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -148,8 +143,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index ec392b155..7cf1d02c7 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -1,7 +1,6 @@ { "version": "0.2.0", "name": "prelude", - "runtime_reqs": [], "types": { "error": { "extension": "prelude", @@ -73,8 +72,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -115,8 +113,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -146,8 +143,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -188,8 +184,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -236,8 +231,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -259,8 +253,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -307,8 +300,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -329,8 +321,7 @@ "bound": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/specification/std_extensions/ptr.json b/specification/std_extensions/ptr.json index 614b6aecf..d701fff53 100644 --- a/specification/std_extensions/ptr.json +++ b/specification/std_extensions/ptr.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "ptr", - "runtime_reqs": [], "types": { "ptr": { "extension": "ptr", @@ -56,8 +55,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -98,8 +96,7 @@ "i": 0, "b": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -139,8 +136,7 @@ "b": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false