Skip to content

Commit 96cac0e

Browse files
authored
Refactor: move rewrite inside hugr, Rewrite -> Replace implementing new 'Rewrite' trait (#119)
* Remove Pattern and Pattern.rs * Move rewrite module to hugr/rewrite * Rename old `Rewrite` struct to `Replace` but leave as skeletal * Add `Rewrite` trait, (parameterized) Hugr::apply_rewrite dispatches to that * Associated `Error` type and `unchanged_on_failure: bool` * unchanged_on_failure as trait associated constant * Drive-by: simple_replace.rs: change ".ok();"s to unwrap
1 parent 344ef0c commit 96cac0e

File tree

7 files changed

+295
-226
lines changed

7 files changed

+295
-226
lines changed

src/hugr.rs

Lines changed: 7 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,25 @@
22
33
mod hugrmut;
44

5+
pub mod rewrite;
56
pub mod serialize;
67
pub mod typecheck;
78
pub mod validate;
89
pub mod view;
910

10-
use std::collections::HashMap;
11-
1211
pub(crate) use self::hugrmut::HugrMut;
1312
pub use self::validate::ValidationError;
1413

1514
use derive_more::From;
15+
pub use rewrite::{Replace, ReplaceError, Rewrite, SimpleReplacement, SimpleReplacementError};
16+
1617
use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle};
1718
use portgraph::multiportgraph::MultiPortGraph;
18-
use portgraph::{Hierarchy, LinkView, NodeIndex, PortView, UnmanagedDenseMap};
19+
use portgraph::{Hierarchy, LinkView, PortView, UnmanagedDenseMap};
1920
use thiserror::Error;
2021

2122
pub use self::view::HugrView;
22-
use crate::ops::tag::OpTag;
23-
use crate::ops::{OpName, OpTrait, OpType};
24-
use crate::replacement::{SimpleReplacement, SimpleReplacementError};
25-
use crate::rewrite::{Rewrite, RewriteError};
23+
use crate::ops::{OpName, OpType};
2624
use crate::types::EdgeKind;
2725

2826
/// The Hugr data structure.
@@ -81,187 +79,9 @@ pub struct Wire(Node, usize);
8179

8280
/// Public API for HUGRs.
8381
impl Hugr {
84-
/// Apply a simple replacement operation to the HUGR.
85-
pub fn apply_simple_replacement(
86-
&mut self,
87-
r: SimpleReplacement,
88-
) -> Result<(), SimpleReplacementError> {
89-
// 1. Check the parent node exists and is a DFG node.
90-
if self.get_optype(r.parent).tag() != OpTag::Dfg {
91-
return Err(SimpleReplacementError::InvalidParentNode());
92-
}
93-
// 2. Check that all the to-be-removed nodes are children of it and are leaves.
94-
for node in &r.removal {
95-
if self.hierarchy.parent(node.index) != Some(r.parent.index)
96-
|| self.hierarchy.has_children(node.index)
97-
{
98-
return Err(SimpleReplacementError::InvalidRemovedNode());
99-
}
100-
}
101-
// 3. Do the replacement.
102-
// 3.1. Add copies of all replacement nodes and edges to self. Exclude Input/Output nodes.
103-
// Create map from old NodeIndex (in r.replacement) to new NodeIndex (in self).
104-
let mut index_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
105-
let replacement_nodes = r
106-
.replacement
107-
.children(r.replacement.root())
108-
.collect::<Vec<Node>>();
109-
// slice of nodes omitting Input and Output:
110-
let replacement_inner_nodes = &replacement_nodes[2..];
111-
for &node in replacement_inner_nodes {
112-
// Check there are no const inputs.
113-
if !r
114-
.replacement
115-
.get_optype(node)
116-
.signature()
117-
.const_input
118-
.is_empty()
119-
{
120-
return Err(SimpleReplacementError::InvalidReplacementNode());
121-
}
122-
}
123-
let self_output_node_index = self.children(r.parent).nth(1).unwrap();
124-
let replacement_output_node = *replacement_nodes.get(1).unwrap();
125-
for &node in replacement_inner_nodes {
126-
// Add the nodes.
127-
let op: &OpType = r.replacement.get_optype(node);
128-
let new_node_index = self
129-
.add_op_after(self_output_node_index, op.clone())
130-
.unwrap();
131-
index_map.insert(node.index, new_node_index.index);
132-
}
133-
// Add edges between all newly added nodes matching those in replacement.
134-
// TODO This will probably change when implicit copies are implemented.
135-
for &node in replacement_inner_nodes {
136-
let new_node_index = index_map.get(&node.index).unwrap();
137-
for node_successor in r.replacement.output_neighbours(node) {
138-
if r.replacement.get_optype(node_successor).tag() != OpTag::Output {
139-
let new_node_successor_index = index_map.get(&node_successor.index).unwrap();
140-
for connection in r
141-
.replacement
142-
.graph
143-
.get_connections(node.index, node_successor.index)
144-
{
145-
let src_offset = r
146-
.replacement
147-
.graph
148-
.port_offset(connection.0)
149-
.unwrap()
150-
.index();
151-
let tgt_offset = r
152-
.replacement
153-
.graph
154-
.port_offset(connection.1)
155-
.unwrap()
156-
.index();
157-
self.graph
158-
.link_nodes(
159-
*new_node_index,
160-
src_offset,
161-
*new_node_successor_index,
162-
tgt_offset,
163-
)
164-
.ok();
165-
}
166-
}
167-
}
168-
}
169-
// 3.2. For each p = r.nu_inp[q] such that q is not an Output port, add an edge from the
170-
// predecessor of p to (the new copy of) q.
171-
for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &r.nu_inp {
172-
if r.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output {
173-
let new_inp_node_index = index_map.get(&rep_inp_node.index).unwrap();
174-
// add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
175-
let rem_inp_port_index = self
176-
.graph
177-
.port_index(rem_inp_node.index, rem_inp_port.offset)
178-
.unwrap();
179-
let rem_inp_predecessor_port_index =
180-
self.graph.port_link(rem_inp_port_index).unwrap().port();
181-
let new_inp_port_index = self
182-
.graph
183-
.port_index(*new_inp_node_index, rep_inp_port.offset)
184-
.unwrap();
185-
self.graph.unlink_port(rem_inp_predecessor_port_index);
186-
self.graph
187-
.link_ports(rem_inp_predecessor_port_index, new_inp_port_index)
188-
.ok();
189-
}
190-
}
191-
// 3.3. For each q = r.nu_out[p] such that the predecessor of q is not an Input port, add an
192-
// edge from (the new copy of) the predecessor of q to p.
193-
for ((rem_out_node, rem_out_port), rep_out_port) in &r.nu_out {
194-
let rem_out_port_index = self
195-
.graph
196-
.port_index(rem_out_node.index, rem_out_port.offset)
197-
.unwrap();
198-
let rep_out_port_index = r
199-
.replacement
200-
.graph
201-
.port_index(replacement_output_node.index, rep_out_port.offset)
202-
.unwrap();
203-
let rep_out_predecessor_port_index =
204-
r.replacement.graph.port_link(rep_out_port_index).unwrap();
205-
let rep_out_predecessor_node_index = r
206-
.replacement
207-
.graph
208-
.port_node(rep_out_predecessor_port_index)
209-
.unwrap();
210-
if r.replacement
211-
.get_optype(rep_out_predecessor_node_index.into())
212-
.tag()
213-
!= OpTag::Input
214-
{
215-
let rep_out_predecessor_port_offset = r
216-
.replacement
217-
.graph
218-
.port_offset(rep_out_predecessor_port_index)
219-
.unwrap();
220-
let new_out_node_index = index_map.get(&rep_out_predecessor_node_index).unwrap();
221-
let new_out_port_index = self
222-
.graph
223-
.port_index(*new_out_node_index, rep_out_predecessor_port_offset)
224-
.unwrap();
225-
self.graph.unlink_port(rem_out_port_index);
226-
self.graph
227-
.link_ports(new_out_port_index, rem_out_port_index)
228-
.ok();
229-
}
230-
}
231-
// 3.4. For each q = r.nu_out[p1], p0 = r.nu_inp[q], add an edge from the predecessor of p0
232-
// to p1.
233-
for ((rem_out_node, rem_out_port), &rep_out_port) in &r.nu_out {
234-
let rem_inp_nodeport = r.nu_inp.get(&(replacement_output_node, rep_out_port));
235-
if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
236-
// add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
237-
let rem_inp_port_index = self
238-
.graph
239-
.port_index(rem_inp_node.index, rem_inp_port.offset)
240-
.unwrap();
241-
let rem_inp_predecessor_port_index =
242-
self.graph.port_link(rem_inp_port_index).unwrap().port();
243-
let rem_out_port_index = self
244-
.graph
245-
.port_index(rem_out_node.index, rem_out_port.offset)
246-
.unwrap();
247-
self.graph.unlink_port(rem_inp_port_index);
248-
self.graph.unlink_port(rem_out_port_index);
249-
self.graph
250-
.link_ports(rem_inp_predecessor_port_index, rem_out_port_index)
251-
.ok();
252-
}
253-
}
254-
// 3.5. Remove all nodes in r.removal and edges between them.
255-
for node in &r.removal {
256-
self.graph.remove_node(node.index);
257-
self.hierarchy.remove(node.index);
258-
}
259-
Ok(())
260-
}
261-
26282
/// Applies a rewrite to the graph.
263-
pub fn apply_rewrite(self, _rewrite: Rewrite) -> Result<(), RewriteError> {
264-
unimplemented!()
83+
pub fn apply_rewrite<E>(&mut self, rw: impl Rewrite<Error = E>) -> Result<(), E> {
84+
rw.apply(self)
26585
}
26686

26787
/// Return dot string showing underlying graph and hierarchy side by side.

src/hugr/rewrite.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//! Rewrite operations on the HUGR - replacement, outlining, etc.
2+
3+
pub mod replace;
4+
pub mod simple_replace;
5+
use std::mem;
6+
7+
use crate::Hugr;
8+
pub use replace::{OpenHugr, Replace, ReplaceError};
9+
pub use simple_replace::{SimpleReplacement, SimpleReplacementError};
10+
11+
/// An operation that can be applied to mutate a Hugr
12+
pub trait Rewrite {
13+
/// The type of Error with which this Rewrite may fail
14+
type Error: std::error::Error;
15+
16+
/// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err.
17+
/// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned.
18+
const UNCHANGED_ON_FAILURE: bool;
19+
20+
/// Checks whether the rewrite would succeed on the specified Hugr.
21+
/// If this call succeeds, [self.apply] should also succeed on the same `h`
22+
/// If this calls fails, [self.apply] would fail with the same error.
23+
fn verify(&self, h: &Hugr) -> Result<(), Self::Error>;
24+
25+
/// Mutate the specified Hugr, or fail with an error.
26+
/// If [self.unchanged_on_failure] is true, then `h` must be unchanged if Err is returned.
27+
/// See also [self.verify]
28+
/// # Panics
29+
/// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is,
30+
/// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())`
31+
/// being preferred.
32+
fn apply(self, h: &mut Hugr) -> Result<(), Self::Error>;
33+
}
34+
35+
/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure)
36+
pub struct Transactional<R> {
37+
underlying: R,
38+
}
39+
40+
// Note we might like to constrain R to Rewrite<unchanged_on_failure=false> but this
41+
// is not yet supported, https://github.com/rust-lang/rust/issues/92827
42+
impl<R: Rewrite> Rewrite for Transactional<R> {
43+
type Error = R::Error;
44+
const UNCHANGED_ON_FAILURE: bool = true;
45+
46+
fn verify(&self, h: &Hugr) -> Result<(), Self::Error> {
47+
self.underlying.verify(h)
48+
}
49+
50+
fn apply(self, h: &mut Hugr) -> Result<(), Self::Error> {
51+
if R::UNCHANGED_ON_FAILURE {
52+
return self.underlying.apply(h);
53+
}
54+
let backup = h.clone();
55+
let r = self.underlying.apply(h);
56+
if r.is_err() {
57+
// drop the old h, it was undefined
58+
let _ = mem::replace(h, backup);
59+
}
60+
r
61+
}
62+
}

src/rewrite/rewrite.rs renamed to src/hugr/rewrite/replace.rs

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#![allow(missing_docs)]
2-
//! Rewrite operations on Hugr graphs.
2+
//! Replace operations on Hugr graphs. This is a nonfunctional
3+
//! dummy implementation just to demonstrate design principles.
34
45
use std::collections::HashMap;
56

67
use portgraph::substitute::OpenGraph;
78
use portgraph::{NodeIndex, PortIndex};
89
use thiserror::Error;
910

11+
use super::Rewrite;
1012
use crate::Hugr;
1113

1214
/// A subset of the nodes in a graph, and the ports that it is connected to.
@@ -77,7 +79,7 @@ pub type ParentsMap = HashMap<NodeIndex, NodeIndex>;
7779
/// Includes the new weights for the nodes in the replacement graph.
7880
#[derive(Debug, Clone)]
7981
#[allow(unused)]
80-
pub struct Rewrite {
82+
pub struct Replace {
8183
/// The subgraph to be replaced.
8284
subgraph: BoundedSubgraph,
8385
/// The replacement graph.
@@ -86,7 +88,7 @@ pub struct Rewrite {
8688
parents: ParentsMap,
8789
}
8890

89-
impl Rewrite {
91+
impl Replace {
9092
/// Creates a new rewrite operation.
9193
pub fn new(
9294
subgraph: BoundedSubgraph,
@@ -114,30 +116,42 @@ impl Rewrite {
114116
)
115117
}
116118

119+
pub fn verify_convexity(&self) -> Result<(), ReplaceError> {
120+
unimplemented!()
121+
}
122+
123+
pub fn verify_boundaries(&self) -> Result<(), ReplaceError> {
124+
unimplemented!()
125+
}
126+
}
127+
128+
impl Rewrite for Replace {
129+
type Error = ReplaceError;
130+
const UNCHANGED_ON_FAILURE: bool = false;
131+
117132
/// Checks that the rewrite is valid.
118133
///
119134
/// This includes having a convex subgraph (TODO: include definition), and
120135
/// having matching numbers of ports on the boundaries.
121-
pub fn verify(&self) -> Result<(), RewriteError> {
136+
/// TODO not clear this implementation really provides much guarantee about [self.apply]
137+
/// but this class is not really working anyway.
138+
fn verify(&self, _h: &Hugr) -> Result<(), ReplaceError> {
122139
self.verify_convexity()?;
123140
self.verify_boundaries()?;
124141
Ok(())
125142
}
126143

127-
pub fn verify_convexity(&self) -> Result<(), RewriteError> {
128-
todo!()
129-
}
130-
131-
pub fn verify_boundaries(&self) -> Result<(), RewriteError> {
132-
todo!()
144+
/// Performs a Replace operation on the graph.
145+
fn apply(self, _h: &mut Hugr) -> Result<(), ReplaceError> {
146+
unimplemented!()
133147
}
134148
}
135149

136150
/// Error generated when a rewrite fails.
137151
#[derive(Debug, Clone, Error, PartialEq, Eq)]
138-
pub enum RewriteError {
139-
/// The rewrite failed because the boundary defined by the
140-
/// [`Rewrite`] could not be matched to the dangling ports of the
152+
pub enum ReplaceError {
153+
/// The replacement failed because the boundary defined by the
154+
/// [`Replace`] could not be matched to the dangling ports of the
141155
/// [`OpenHugr`].
142156
#[error("The boundary defined by the rewrite could not be matched to the dangling ports of the OpenHugr")]
143157
BoundarySize(#[source] portgraph::substitute::RewriteError),
@@ -152,7 +166,7 @@ pub enum RewriteError {
152166
NotConvex(),
153167
}
154168

155-
impl From<portgraph::substitute::RewriteError> for RewriteError {
169+
impl From<portgraph::substitute::RewriteError> for ReplaceError {
156170
fn from(e: portgraph::substitute::RewriteError) -> Self {
157171
match e {
158172
portgraph::substitute::RewriteError::BoundarySize => Self::BoundarySize(e),

0 commit comments

Comments
 (0)