|
2 | 2 |
|
3 | 3 | mod hugrmut;
|
4 | 4 |
|
| 5 | +pub mod rewrite; |
5 | 6 | pub mod serialize;
|
6 | 7 | pub mod typecheck;
|
7 | 8 | pub mod validate;
|
8 | 9 | pub mod view;
|
9 | 10 |
|
10 |
| -use std::collections::HashMap; |
11 |
| - |
12 | 11 | pub(crate) use self::hugrmut::HugrMut;
|
13 | 12 | pub use self::validate::ValidationError;
|
14 | 13 |
|
15 | 14 | use derive_more::From;
|
| 15 | +pub use rewrite::{Replace, ReplaceError, Rewrite, SimpleReplacement, SimpleReplacementError}; |
| 16 | + |
16 | 17 | use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle};
|
17 | 18 | use portgraph::multiportgraph::MultiPortGraph;
|
18 |
| -use portgraph::{Hierarchy, LinkView, NodeIndex, PortView, UnmanagedDenseMap}; |
| 19 | +use portgraph::{Hierarchy, LinkView, PortView, UnmanagedDenseMap}; |
19 | 20 | use thiserror::Error;
|
20 | 21 |
|
21 | 22 | pub use self::view::HugrView;
|
22 |
| -use crate::ops::tag::OpTag; |
23 |
| -use crate::ops::{OpName, OpTrait, OpType}; |
24 |
| -use crate::replacement::{SimpleReplacement, SimpleReplacementError}; |
25 |
| -use crate::rewrite::{Rewrite, RewriteError}; |
| 23 | +use crate::ops::{OpName, OpType}; |
26 | 24 | use crate::types::EdgeKind;
|
27 | 25 |
|
28 | 26 | /// The Hugr data structure.
|
@@ -81,187 +79,9 @@ pub struct Wire(Node, usize);
|
81 | 79 |
|
82 | 80 | /// Public API for HUGRs.
|
83 | 81 | impl Hugr {
|
84 |
| - /// Apply a simple replacement operation to the HUGR. |
85 |
| - pub fn apply_simple_replacement( |
86 |
| - &mut self, |
87 |
| - r: SimpleReplacement, |
88 |
| - ) -> Result<(), SimpleReplacementError> { |
89 |
| - // 1. Check the parent node exists and is a DFG node. |
90 |
| - if self.get_optype(r.parent).tag() != OpTag::Dfg { |
91 |
| - return Err(SimpleReplacementError::InvalidParentNode()); |
92 |
| - } |
93 |
| - // 2. Check that all the to-be-removed nodes are children of it and are leaves. |
94 |
| - for node in &r.removal { |
95 |
| - if self.hierarchy.parent(node.index) != Some(r.parent.index) |
96 |
| - || self.hierarchy.has_children(node.index) |
97 |
| - { |
98 |
| - return Err(SimpleReplacementError::InvalidRemovedNode()); |
99 |
| - } |
100 |
| - } |
101 |
| - // 3. Do the replacement. |
102 |
| - // 3.1. Add copies of all replacement nodes and edges to self. Exclude Input/Output nodes. |
103 |
| - // Create map from old NodeIndex (in r.replacement) to new NodeIndex (in self). |
104 |
| - let mut index_map: HashMap<NodeIndex, NodeIndex> = HashMap::new(); |
105 |
| - let replacement_nodes = r |
106 |
| - .replacement |
107 |
| - .children(r.replacement.root()) |
108 |
| - .collect::<Vec<Node>>(); |
109 |
| - // slice of nodes omitting Input and Output: |
110 |
| - let replacement_inner_nodes = &replacement_nodes[2..]; |
111 |
| - for &node in replacement_inner_nodes { |
112 |
| - // Check there are no const inputs. |
113 |
| - if !r |
114 |
| - .replacement |
115 |
| - .get_optype(node) |
116 |
| - .signature() |
117 |
| - .const_input |
118 |
| - .is_empty() |
119 |
| - { |
120 |
| - return Err(SimpleReplacementError::InvalidReplacementNode()); |
121 |
| - } |
122 |
| - } |
123 |
| - let self_output_node_index = self.children(r.parent).nth(1).unwrap(); |
124 |
| - let replacement_output_node = *replacement_nodes.get(1).unwrap(); |
125 |
| - for &node in replacement_inner_nodes { |
126 |
| - // Add the nodes. |
127 |
| - let op: &OpType = r.replacement.get_optype(node); |
128 |
| - let new_node_index = self |
129 |
| - .add_op_after(self_output_node_index, op.clone()) |
130 |
| - .unwrap(); |
131 |
| - index_map.insert(node.index, new_node_index.index); |
132 |
| - } |
133 |
| - // Add edges between all newly added nodes matching those in replacement. |
134 |
| - // TODO This will probably change when implicit copies are implemented. |
135 |
| - for &node in replacement_inner_nodes { |
136 |
| - let new_node_index = index_map.get(&node.index).unwrap(); |
137 |
| - for node_successor in r.replacement.output_neighbours(node) { |
138 |
| - if r.replacement.get_optype(node_successor).tag() != OpTag::Output { |
139 |
| - let new_node_successor_index = index_map.get(&node_successor.index).unwrap(); |
140 |
| - for connection in r |
141 |
| - .replacement |
142 |
| - .graph |
143 |
| - .get_connections(node.index, node_successor.index) |
144 |
| - { |
145 |
| - let src_offset = r |
146 |
| - .replacement |
147 |
| - .graph |
148 |
| - .port_offset(connection.0) |
149 |
| - .unwrap() |
150 |
| - .index(); |
151 |
| - let tgt_offset = r |
152 |
| - .replacement |
153 |
| - .graph |
154 |
| - .port_offset(connection.1) |
155 |
| - .unwrap() |
156 |
| - .index(); |
157 |
| - self.graph |
158 |
| - .link_nodes( |
159 |
| - *new_node_index, |
160 |
| - src_offset, |
161 |
| - *new_node_successor_index, |
162 |
| - tgt_offset, |
163 |
| - ) |
164 |
| - .ok(); |
165 |
| - } |
166 |
| - } |
167 |
| - } |
168 |
| - } |
169 |
| - // 3.2. For each p = r.nu_inp[q] such that q is not an Output port, add an edge from the |
170 |
| - // predecessor of p to (the new copy of) q. |
171 |
| - for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &r.nu_inp { |
172 |
| - if r.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output { |
173 |
| - let new_inp_node_index = index_map.get(&rep_inp_node.index).unwrap(); |
174 |
| - // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) |
175 |
| - let rem_inp_port_index = self |
176 |
| - .graph |
177 |
| - .port_index(rem_inp_node.index, rem_inp_port.offset) |
178 |
| - .unwrap(); |
179 |
| - let rem_inp_predecessor_port_index = |
180 |
| - self.graph.port_link(rem_inp_port_index).unwrap().port(); |
181 |
| - let new_inp_port_index = self |
182 |
| - .graph |
183 |
| - .port_index(*new_inp_node_index, rep_inp_port.offset) |
184 |
| - .unwrap(); |
185 |
| - self.graph.unlink_port(rem_inp_predecessor_port_index); |
186 |
| - self.graph |
187 |
| - .link_ports(rem_inp_predecessor_port_index, new_inp_port_index) |
188 |
| - .ok(); |
189 |
| - } |
190 |
| - } |
191 |
| - // 3.3. For each q = r.nu_out[p] such that the predecessor of q is not an Input port, add an |
192 |
| - // edge from (the new copy of) the predecessor of q to p. |
193 |
| - for ((rem_out_node, rem_out_port), rep_out_port) in &r.nu_out { |
194 |
| - let rem_out_port_index = self |
195 |
| - .graph |
196 |
| - .port_index(rem_out_node.index, rem_out_port.offset) |
197 |
| - .unwrap(); |
198 |
| - let rep_out_port_index = r |
199 |
| - .replacement |
200 |
| - .graph |
201 |
| - .port_index(replacement_output_node.index, rep_out_port.offset) |
202 |
| - .unwrap(); |
203 |
| - let rep_out_predecessor_port_index = |
204 |
| - r.replacement.graph.port_link(rep_out_port_index).unwrap(); |
205 |
| - let rep_out_predecessor_node_index = r |
206 |
| - .replacement |
207 |
| - .graph |
208 |
| - .port_node(rep_out_predecessor_port_index) |
209 |
| - .unwrap(); |
210 |
| - if r.replacement |
211 |
| - .get_optype(rep_out_predecessor_node_index.into()) |
212 |
| - .tag() |
213 |
| - != OpTag::Input |
214 |
| - { |
215 |
| - let rep_out_predecessor_port_offset = r |
216 |
| - .replacement |
217 |
| - .graph |
218 |
| - .port_offset(rep_out_predecessor_port_index) |
219 |
| - .unwrap(); |
220 |
| - let new_out_node_index = index_map.get(&rep_out_predecessor_node_index).unwrap(); |
221 |
| - let new_out_port_index = self |
222 |
| - .graph |
223 |
| - .port_index(*new_out_node_index, rep_out_predecessor_port_offset) |
224 |
| - .unwrap(); |
225 |
| - self.graph.unlink_port(rem_out_port_index); |
226 |
| - self.graph |
227 |
| - .link_ports(new_out_port_index, rem_out_port_index) |
228 |
| - .ok(); |
229 |
| - } |
230 |
| - } |
231 |
| - // 3.4. For each q = r.nu_out[p1], p0 = r.nu_inp[q], add an edge from the predecessor of p0 |
232 |
| - // to p1. |
233 |
| - for ((rem_out_node, rem_out_port), &rep_out_port) in &r.nu_out { |
234 |
| - let rem_inp_nodeport = r.nu_inp.get(&(replacement_output_node, rep_out_port)); |
235 |
| - if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { |
236 |
| - // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port): |
237 |
| - let rem_inp_port_index = self |
238 |
| - .graph |
239 |
| - .port_index(rem_inp_node.index, rem_inp_port.offset) |
240 |
| - .unwrap(); |
241 |
| - let rem_inp_predecessor_port_index = |
242 |
| - self.graph.port_link(rem_inp_port_index).unwrap().port(); |
243 |
| - let rem_out_port_index = self |
244 |
| - .graph |
245 |
| - .port_index(rem_out_node.index, rem_out_port.offset) |
246 |
| - .unwrap(); |
247 |
| - self.graph.unlink_port(rem_inp_port_index); |
248 |
| - self.graph.unlink_port(rem_out_port_index); |
249 |
| - self.graph |
250 |
| - .link_ports(rem_inp_predecessor_port_index, rem_out_port_index) |
251 |
| - .ok(); |
252 |
| - } |
253 |
| - } |
254 |
| - // 3.5. Remove all nodes in r.removal and edges between them. |
255 |
| - for node in &r.removal { |
256 |
| - self.graph.remove_node(node.index); |
257 |
| - self.hierarchy.remove(node.index); |
258 |
| - } |
259 |
| - Ok(()) |
260 |
| - } |
261 |
| - |
262 | 82 | /// Applies a rewrite to the graph.
|
263 |
| - pub fn apply_rewrite(self, _rewrite: Rewrite) -> Result<(), RewriteError> { |
264 |
| - unimplemented!() |
| 83 | + pub fn apply_rewrite<E>(&mut self, rw: impl Rewrite<Error = E>) -> Result<(), E> { |
| 84 | + rw.apply(self) |
265 | 85 | }
|
266 | 86 |
|
267 | 87 | /// Return dot string showing underlying graph and hierarchy side by side.
|
|
0 commit comments