Skip to content

Commit 7ee173c

Browse files
ss2165doug-q
andauthored
feat: add MonomorphizePass and deprecate monomorphize (#1809)
Co-authored-by: Douglas Wilson <[email protected]>
1 parent e065d70 commit 7ee173c

File tree

2 files changed

+104
-33
lines changed

2 files changed

+104
-33
lines changed

hugr-passes/src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@ mod half_node;
77
pub mod lower;
88
pub mod merge_bbs;
99
mod monomorphize;
10-
pub use monomorphize::{monomorphize, remove_polyfuncs};
10+
// TODO: Deprecated re-export. Remove on a breaking release.
11+
#[deprecated(
12+
since = "0.14.1",
13+
note = "Use `hugr::algorithms::MonomorphizePass` instead."
14+
)]
15+
#[allow(deprecated)]
16+
pub use monomorphize::monomorphize;
17+
pub use monomorphize::{remove_polyfuncs, MonomorphizeError, MonomorphizePass};
1118
pub mod nest_cfgs;
1219
pub mod non_local;
1320
pub mod validation;

hugr-passes/src/monomorphize.rs

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ use hugr_core::{
1010
Node,
1111
};
1212

13-
use hugr_core::hugr::{hugrmut::HugrMut, internal::HugrMutInternals, Hugr, HugrView, OpType};
13+
use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType};
1414
use itertools::Itertools as _;
15+
use thiserror::Error;
1516

1617
/// Replaces calls to polymorphic functions with calls to new monomorphic
1718
/// instantiations of the polymorphic ones.
@@ -28,26 +29,25 @@ use itertools::Itertools as _;
2829
/// children of the root node. We make best effort to ensure that names (derived
2930
/// from parent function names and concrete type args) of new functions are unique
3031
/// whenever the names of their parents are unique, but this is not guaranteed.
32+
#[deprecated(
33+
since = "0.14.1",
34+
note = "Use `hugr::algorithms::MonomorphizePass` instead."
35+
)]
36+
// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`.
3137
pub fn monomorphize(mut h: Hugr) -> Hugr {
32-
let validate = |h: &Hugr| h.validate().unwrap_or_else(|e| panic!("{e}"));
33-
34-
// We clone the extension registry because we will need a reference to
35-
// create our mutable substitutions. This is cannot cause a problem because
36-
// we will not be adding any new types or extension ops to the HUGR.
37-
#[cfg(debug_assertions)]
38-
validate(&h);
38+
monomorphize_ref(&mut h);
39+
h
40+
}
3941

42+
fn monomorphize_ref(h: &mut impl HugrMut) {
4043
let root = h.root();
4144
// If the root is a polymorphic function, then there are no external calls, so nothing to do
4245
if !is_polymorphic_funcdefn(h.get_optype(root)) {
43-
mono_scan(&mut h, root, None, &mut HashMap::new());
46+
mono_scan(h, root, None, &mut HashMap::new());
4447
if !h.get_optype(root).is_module() {
45-
return remove_polyfuncs(h);
48+
remove_polyfuncs_ref(h);
4649
}
4750
}
48-
#[cfg(debug_assertions)]
49-
validate(&h);
50-
h
5151
}
5252

5353
/// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have
@@ -57,6 +57,11 @@ pub fn monomorphize(mut h: Hugr) -> Hugr {
5757
/// TODO replace this with a more general remove-unused-functions pass
5858
/// <https://github.com/CQCL/hugr/issues/1753>
5959
pub fn remove_polyfuncs(mut h: Hugr) -> Hugr {
60+
remove_polyfuncs_ref(&mut h);
61+
h
62+
}
63+
64+
fn remove_polyfuncs_ref(h: &mut impl HugrMut) {
6065
let mut pfs_to_delete = Vec::new();
6166
let mut to_scan = Vec::from_iter(h.children(h.root()));
6267
while let Some(n) = to_scan.pop() {
@@ -69,7 +74,6 @@ pub fn remove_polyfuncs(mut h: Hugr) -> Hugr {
6974
for n in pfs_to_delete {
7075
h.remove_subtree(n);
7176
}
72-
h
7377
}
7478

7579
fn is_polymorphic(fd: &FuncDefn) -> bool {
@@ -93,7 +97,7 @@ type Instantiations = HashMap<Node, HashMap<Vec<TypeArg>, Node>>;
9397
/// Optionally copies the subtree into a new location whilst applying a substitution.
9498
/// The subtree should be monomorphic after the substitution (if provided) has been applied.
9599
fn mono_scan(
96-
h: &mut Hugr,
100+
h: &mut impl HugrMut,
97101
parent: Node,
98102
mut subst_into: Option<&mut Instantiating>,
99103
cache: &mut Instantiations,
@@ -161,7 +165,7 @@ fn mono_scan(
161165
}
162166

163167
fn instantiate(
164-
h: &mut Hugr,
168+
h: &mut impl HugrMut,
165169
poly_func: Node,
166170
type_args: Vec<TypeArg>,
167171
mono_sig: Signature,
@@ -218,20 +222,20 @@ fn instantiate(
218222
// 'ext' edges by copying every node before recursing on any of them,
219223
// 'dom' edges would *also* require recursing in dominator-tree preorder.
220224
for (&old_ch, &new_ch) in node_map.iter() {
221-
for inport in h.node_inputs(old_ch).collect::<Vec<_>>() {
225+
for in_port in h.node_inputs(old_ch).collect::<Vec<_>>() {
222226
// Edges from monomorphized functions to their calls already added during mono_scan()
223227
// as these depend not just on the original FuncDefn but also the TypeArgs
224-
if h.linked_outputs(new_ch, inport).next().is_some() {
228+
if h.linked_outputs(new_ch, in_port).next().is_some() {
225229
continue;
226230
};
227-
let srcs = h.linked_outputs(old_ch, inport).collect::<Vec<_>>();
231+
let srcs = h.linked_outputs(old_ch, in_port).collect::<Vec<_>>();
228232
for (src, outport) in srcs {
229233
// Sources could be a mixture of within this polymorphic FuncDefn, and Static edges from outside
230234
h.connect(
231235
node_map.get(&src).copied().unwrap_or(src),
232236
outport,
233237
new_ch,
234-
inport,
238+
in_port,
235239
);
236240
}
237241
}
@@ -240,6 +244,57 @@ fn instantiate(
240244
mono_tgt
241245
}
242246

247+
use crate::validation::{ValidatePassError, ValidationLevel};
248+
249+
/// Replaces calls to polymorphic functions with calls to new monomorphic
250+
/// instantiations of the polymorphic ones.
251+
///
252+
/// If the Hugr is [Module](OpType::Module)-rooted,
253+
/// * then the original polymorphic [FuncDefn]s are left untouched (including Calls inside them)
254+
/// - call [remove_polyfuncs] when no other Hugr will be linked in that might instantiate these
255+
/// * else, the originals are removed (they are invisible from outside the Hugr).
256+
///
257+
/// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic
258+
/// signature then the HUGR will not be modified.
259+
///
260+
/// Monomorphic copies of polymorphic functions will be added to the HUGR as
261+
/// children of the root node. We make best effort to ensure that names (derived
262+
/// from parent function names and concrete type args) of new functions are unique
263+
/// whenever the names of their parents are unique, but this is not guaranteed.
264+
#[derive(Debug, Clone, Default)]
265+
pub struct MonomorphizePass {
266+
validation: ValidationLevel,
267+
}
268+
269+
#[derive(Debug, Error)]
270+
#[non_exhaustive]
271+
/// Errors produced by [MonomorphizePass].
272+
pub enum MonomorphizeError {
273+
#[error(transparent)]
274+
#[allow(missing_docs)]
275+
ValidationError(#[from] ValidatePassError),
276+
}
277+
278+
impl MonomorphizePass {
279+
/// Sets the validation level used before and after the pass is run.
280+
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
281+
self.validation = level;
282+
self
283+
}
284+
285+
/// Run the Monomorphization pass.
286+
fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> {
287+
monomorphize_ref(hugr);
288+
Ok(())
289+
}
290+
291+
/// Run the pass using specified configuration.
292+
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), MonomorphizeError> {
293+
self.validation
294+
.run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr))
295+
}
296+
}
297+
243298
struct TypeArgsList<'a>(&'a [TypeArg]);
244299

245300
impl std::fmt::Display for TypeArgsList<'_> {
@@ -322,7 +377,9 @@ mod test {
322377
use hugr_core::{Hugr, HugrView, Node};
323378
use rstest::rstest;
324379

325-
use super::{is_polymorphic, mangle_inner_func, mangle_name, monomorphize, remove_polyfuncs};
380+
use crate::monomorphize::{remove_polyfuncs_ref, MonomorphizePass};
381+
382+
use super::{is_polymorphic, mangle_inner_func, mangle_name, remove_polyfuncs};
326383

327384
fn pair_type(ty: Type) -> Type {
328385
Type::new_tuple(vec![ty.clone(), ty])
@@ -342,7 +399,8 @@ mod test {
342399
DFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap();
343400
let [i1] = dfg_builder.input_wires_arr();
344401
let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap();
345-
let hugr2 = monomorphize(hugr.clone());
402+
let mut hugr2 = hugr.clone();
403+
MonomorphizePass::default().run(&mut hugr2).unwrap();
346404
assert_eq!(hugr, hugr2);
347405
}
348406

@@ -397,14 +455,15 @@ mod test {
397455
let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr();
398456
fb.finish_with_outputs([res1, res2])?;
399457
}
400-
let hugr = mb.finish_hugr()?;
458+
let mut hugr = mb.finish_hugr()?;
401459
assert_eq!(
402460
hugr.nodes()
403461
.filter(|n| hugr.get_optype(*n).is_func_defn())
404462
.count(),
405463
3
406464
);
407-
let mono = monomorphize(hugr);
465+
MonomorphizePass::default().run(&mut hugr)?;
466+
let mono = hugr;
408467
mono.validate()?;
409468

410469
let mut funcs = list_funcs(&mono);
@@ -423,8 +482,10 @@ mod test {
423482
funcs.into_keys().sorted().collect_vec(),
424483
["double", "main", "triple"]
425484
);
485+
let mut mono2 = mono.clone();
486+
MonomorphizePass::default().run(&mut mono2)?;
426487

427-
assert_eq!(monomorphize(mono.clone()), mono); // Idempotent
488+
assert_eq!(mono2, mono); // Idempotent
428489

429490
let nopoly = remove_polyfuncs(mono);
430491
let mut funcs = list_funcs(&nopoly);
@@ -527,9 +588,10 @@ mod test {
527588
.call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped])
528589
.unwrap()
529590
.outputs_arr();
530-
let hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap();
591+
let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap();
531592

532-
let mono_hugr = monomorphize(hugr);
593+
MonomorphizePass::default().run(&mut hugr).unwrap();
594+
let mono_hugr = hugr;
533595
mono_hugr.validate().unwrap();
534596
let funcs = list_funcs(&mono_hugr);
535597
let pf2_name = mangle_inner_func("pf1", "pf2");
@@ -588,8 +650,9 @@ mod test {
588650
.outputs_arr();
589651
let mono = mono.finish_with_outputs([a, b]).unwrap();
590652
let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap();
591-
let hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap();
592-
let mono_hugr = monomorphize(hugr);
653+
let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap();
654+
MonomorphizePass::default().run(&mut hugr)?;
655+
let mono_hugr = hugr;
593656

594657
let mut funcs = list_funcs(&mono_hugr);
595658
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
@@ -606,7 +669,7 @@ mod test {
606669

607670
#[test]
608671
fn load_function() {
609-
let hugr = {
672+
let mut hugr = {
610673
let mut module_builder = ModuleBuilder::new();
611674
let foo = {
612675
let builder = module_builder
@@ -645,9 +708,10 @@ mod test {
645708
module_builder.finish_hugr().unwrap()
646709
};
647710

648-
let mono_hugr = remove_polyfuncs(monomorphize(hugr));
711+
MonomorphizePass::default().run(&mut hugr).unwrap();
712+
remove_polyfuncs_ref(&mut hugr);
649713

650-
let funcs = list_funcs(&mono_hugr);
714+
let funcs = list_funcs(&hugr);
651715
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
652716
}
653717

0 commit comments

Comments
 (0)