@@ -10,8 +10,9 @@ use hugr_core::{
10
10
Node ,
11
11
} ;
12
12
13
- use hugr_core:: hugr:: { hugrmut:: HugrMut , internal :: HugrMutInternals , Hugr , HugrView , OpType } ;
13
+ use hugr_core:: hugr:: { hugrmut:: HugrMut , Hugr , HugrView , OpType } ;
14
14
use itertools:: Itertools as _;
15
+ use thiserror:: Error ;
15
16
16
17
/// Replaces calls to polymorphic functions with calls to new monomorphic
17
18
/// instantiations of the polymorphic ones.
@@ -28,26 +29,25 @@ use itertools::Itertools as _;
28
29
/// children of the root node. We make best effort to ensure that names (derived
29
30
/// from parent function names and concrete type args) of new functions are unique
30
31
/// whenever the names of their parents are unique, but this is not guaranteed.
32
+ #[ deprecated(
33
+ since = "0.14.1" ,
34
+ note = "Use `hugr::algorithms::MonomorphizePass` instead."
35
+ ) ]
36
+ // TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`.
31
37
pub fn monomorphize ( mut h : Hugr ) -> Hugr {
32
- let validate = |h : & Hugr | h. validate ( ) . unwrap_or_else ( |e| panic ! ( "{e}" ) ) ;
33
-
34
- // We clone the extension registry because we will need a reference to
35
- // create our mutable substitutions. This is cannot cause a problem because
36
- // we will not be adding any new types or extension ops to the HUGR.
37
- #[ cfg( debug_assertions) ]
38
- validate ( & h) ;
38
+ monomorphize_ref ( & mut h) ;
39
+ h
40
+ }
39
41
42
+ fn monomorphize_ref ( h : & mut impl HugrMut ) {
40
43
let root = h. root ( ) ;
41
44
// If the root is a polymorphic function, then there are no external calls, so nothing to do
42
45
if !is_polymorphic_funcdefn ( h. get_optype ( root) ) {
43
- mono_scan ( & mut h, root, None , & mut HashMap :: new ( ) ) ;
46
+ mono_scan ( h, root, None , & mut HashMap :: new ( ) ) ;
44
47
if !h. get_optype ( root) . is_module ( ) {
45
- return remove_polyfuncs ( h) ;
48
+ remove_polyfuncs_ref ( h) ;
46
49
}
47
50
}
48
- #[ cfg( debug_assertions) ]
49
- validate ( & h) ;
50
- h
51
51
}
52
52
53
53
/// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have
@@ -57,6 +57,11 @@ pub fn monomorphize(mut h: Hugr) -> Hugr {
57
57
/// TODO replace this with a more general remove-unused-functions pass
58
58
/// <https://github.com/CQCL/hugr/issues/1753>
59
59
pub fn remove_polyfuncs ( mut h : Hugr ) -> Hugr {
60
+ remove_polyfuncs_ref ( & mut h) ;
61
+ h
62
+ }
63
+
64
+ fn remove_polyfuncs_ref ( h : & mut impl HugrMut ) {
60
65
let mut pfs_to_delete = Vec :: new ( ) ;
61
66
let mut to_scan = Vec :: from_iter ( h. children ( h. root ( ) ) ) ;
62
67
while let Some ( n) = to_scan. pop ( ) {
@@ -69,7 +74,6 @@ pub fn remove_polyfuncs(mut h: Hugr) -> Hugr {
69
74
for n in pfs_to_delete {
70
75
h. remove_subtree ( n) ;
71
76
}
72
- h
73
77
}
74
78
75
79
fn is_polymorphic ( fd : & FuncDefn ) -> bool {
@@ -93,7 +97,7 @@ type Instantiations = HashMap<Node, HashMap<Vec<TypeArg>, Node>>;
93
97
/// Optionally copies the subtree into a new location whilst applying a substitution.
94
98
/// The subtree should be monomorphic after the substitution (if provided) has been applied.
95
99
fn mono_scan (
96
- h : & mut Hugr ,
100
+ h : & mut impl HugrMut ,
97
101
parent : Node ,
98
102
mut subst_into : Option < & mut Instantiating > ,
99
103
cache : & mut Instantiations ,
@@ -161,7 +165,7 @@ fn mono_scan(
161
165
}
162
166
163
167
fn instantiate (
164
- h : & mut Hugr ,
168
+ h : & mut impl HugrMut ,
165
169
poly_func : Node ,
166
170
type_args : Vec < TypeArg > ,
167
171
mono_sig : Signature ,
@@ -218,20 +222,20 @@ fn instantiate(
218
222
// 'ext' edges by copying every node before recursing on any of them,
219
223
// 'dom' edges would *also* require recursing in dominator-tree preorder.
220
224
for ( & old_ch, & new_ch) in node_map. iter ( ) {
221
- for inport in h. node_inputs ( old_ch) . collect :: < Vec < _ > > ( ) {
225
+ for in_port in h. node_inputs ( old_ch) . collect :: < Vec < _ > > ( ) {
222
226
// Edges from monomorphized functions to their calls already added during mono_scan()
223
227
// as these depend not just on the original FuncDefn but also the TypeArgs
224
- if h. linked_outputs ( new_ch, inport ) . next ( ) . is_some ( ) {
228
+ if h. linked_outputs ( new_ch, in_port ) . next ( ) . is_some ( ) {
225
229
continue ;
226
230
} ;
227
- let srcs = h. linked_outputs ( old_ch, inport ) . collect :: < Vec < _ > > ( ) ;
231
+ let srcs = h. linked_outputs ( old_ch, in_port ) . collect :: < Vec < _ > > ( ) ;
228
232
for ( src, outport) in srcs {
229
233
// Sources could be a mixture of within this polymorphic FuncDefn, and Static edges from outside
230
234
h. connect (
231
235
node_map. get ( & src) . copied ( ) . unwrap_or ( src) ,
232
236
outport,
233
237
new_ch,
234
- inport ,
238
+ in_port ,
235
239
) ;
236
240
}
237
241
}
@@ -240,6 +244,57 @@ fn instantiate(
240
244
mono_tgt
241
245
}
242
246
247
+ use crate :: validation:: { ValidatePassError , ValidationLevel } ;
248
+
249
+ /// Replaces calls to polymorphic functions with calls to new monomorphic
250
+ /// instantiations of the polymorphic ones.
251
+ ///
252
+ /// If the Hugr is [Module](OpType::Module)-rooted,
253
+ /// * then the original polymorphic [FuncDefn]s are left untouched (including Calls inside them)
254
+ /// - call [remove_polyfuncs] when no other Hugr will be linked in that might instantiate these
255
+ /// * else, the originals are removed (they are invisible from outside the Hugr).
256
+ ///
257
+ /// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic
258
+ /// signature then the HUGR will not be modified.
259
+ ///
260
+ /// Monomorphic copies of polymorphic functions will be added to the HUGR as
261
+ /// children of the root node. We make best effort to ensure that names (derived
262
+ /// from parent function names and concrete type args) of new functions are unique
263
+ /// whenever the names of their parents are unique, but this is not guaranteed.
264
+ #[ derive( Debug , Clone , Default ) ]
265
+ pub struct MonomorphizePass {
266
+ validation : ValidationLevel ,
267
+ }
268
+
269
+ #[ derive( Debug , Error ) ]
270
+ #[ non_exhaustive]
271
+ /// Errors produced by [MonomorphizePass].
272
+ pub enum MonomorphizeError {
273
+ #[ error( transparent) ]
274
+ #[ allow( missing_docs) ]
275
+ ValidationError ( #[ from] ValidatePassError ) ,
276
+ }
277
+
278
+ impl MonomorphizePass {
279
+ /// Sets the validation level used before and after the pass is run.
280
+ pub fn validation_level ( mut self , level : ValidationLevel ) -> Self {
281
+ self . validation = level;
282
+ self
283
+ }
284
+
285
+ /// Run the Monomorphization pass.
286
+ fn run_no_validate ( & self , hugr : & mut impl HugrMut ) -> Result < ( ) , MonomorphizeError > {
287
+ monomorphize_ref ( hugr) ;
288
+ Ok ( ( ) )
289
+ }
290
+
291
+ /// Run the pass using specified configuration.
292
+ pub fn run < H : HugrMut > ( & self , hugr : & mut H ) -> Result < ( ) , MonomorphizeError > {
293
+ self . validation
294
+ . run_validated_pass ( hugr, |hugr : & mut H , _| self . run_no_validate ( hugr) )
295
+ }
296
+ }
297
+
243
298
struct TypeArgsList < ' a > ( & ' a [ TypeArg ] ) ;
244
299
245
300
impl std:: fmt:: Display for TypeArgsList < ' _ > {
@@ -322,7 +377,9 @@ mod test {
322
377
use hugr_core:: { Hugr , HugrView , Node } ;
323
378
use rstest:: rstest;
324
379
325
- use super :: { is_polymorphic, mangle_inner_func, mangle_name, monomorphize, remove_polyfuncs} ;
380
+ use crate :: monomorphize:: { remove_polyfuncs_ref, MonomorphizePass } ;
381
+
382
+ use super :: { is_polymorphic, mangle_inner_func, mangle_name, remove_polyfuncs} ;
326
383
327
384
fn pair_type ( ty : Type ) -> Type {
328
385
Type :: new_tuple ( vec ! [ ty. clone( ) , ty] )
@@ -342,7 +399,8 @@ mod test {
342
399
DFGBuilder :: new ( Signature :: new ( vec ! [ usize_t( ) ] , vec ! [ usize_t( ) ] ) ) . unwrap ( ) ;
343
400
let [ i1] = dfg_builder. input_wires_arr ( ) ;
344
401
let hugr = dfg_builder. finish_hugr_with_outputs ( [ i1] ) . unwrap ( ) ;
345
- let hugr2 = monomorphize ( hugr. clone ( ) ) ;
402
+ let mut hugr2 = hugr. clone ( ) ;
403
+ MonomorphizePass :: default ( ) . run ( & mut hugr2) . unwrap ( ) ;
346
404
assert_eq ! ( hugr, hugr2) ;
347
405
}
348
406
@@ -397,14 +455,15 @@ mod test {
397
455
let [ res2] = fb. call ( tr. handle ( ) , & [ pty] , pair. outputs ( ) ) ?. outputs_arr ( ) ;
398
456
fb. finish_with_outputs ( [ res1, res2] ) ?;
399
457
}
400
- let hugr = mb. finish_hugr ( ) ?;
458
+ let mut hugr = mb. finish_hugr ( ) ?;
401
459
assert_eq ! (
402
460
hugr. nodes( )
403
461
. filter( |n| hugr. get_optype( * n) . is_func_defn( ) )
404
462
. count( ) ,
405
463
3
406
464
) ;
407
- let mono = monomorphize ( hugr) ;
465
+ MonomorphizePass :: default ( ) . run ( & mut hugr) ?;
466
+ let mono = hugr;
408
467
mono. validate ( ) ?;
409
468
410
469
let mut funcs = list_funcs ( & mono) ;
@@ -423,8 +482,10 @@ mod test {
423
482
funcs. into_keys( ) . sorted( ) . collect_vec( ) ,
424
483
[ "double" , "main" , "triple" ]
425
484
) ;
485
+ let mut mono2 = mono. clone ( ) ;
486
+ MonomorphizePass :: default ( ) . run ( & mut mono2) ?;
426
487
427
- assert_eq ! ( monomorphize ( mono . clone ( ) ) , mono) ; // Idempotent
488
+ assert_eq ! ( mono2 , mono) ; // Idempotent
428
489
429
490
let nopoly = remove_polyfuncs ( mono) ;
430
491
let mut funcs = list_funcs ( & nopoly) ;
@@ -527,9 +588,10 @@ mod test {
527
588
. call ( pf1. handle ( ) , & [ sa ( n - 1 ) ] , [ ar2_unwrapped] )
528
589
. unwrap ( )
529
590
. outputs_arr ( ) ;
530
- let hugr = outer. finish_hugr_with_outputs ( [ e1, e2] ) . unwrap ( ) ;
591
+ let mut hugr = outer. finish_hugr_with_outputs ( [ e1, e2] ) . unwrap ( ) ;
531
592
532
- let mono_hugr = monomorphize ( hugr) ;
593
+ MonomorphizePass :: default ( ) . run ( & mut hugr) . unwrap ( ) ;
594
+ let mono_hugr = hugr;
533
595
mono_hugr. validate ( ) . unwrap ( ) ;
534
596
let funcs = list_funcs ( & mono_hugr) ;
535
597
let pf2_name = mangle_inner_func ( "pf1" , "pf2" ) ;
@@ -588,8 +650,9 @@ mod test {
588
650
. outputs_arr ( ) ;
589
651
let mono = mono. finish_with_outputs ( [ a, b] ) . unwrap ( ) ;
590
652
let c = dfg. call ( mono. handle ( ) , & [ ] , dfg. input_wires ( ) ) . unwrap ( ) ;
591
- let hugr = dfg. finish_hugr_with_outputs ( c. outputs ( ) ) . unwrap ( ) ;
592
- let mono_hugr = monomorphize ( hugr) ;
653
+ let mut hugr = dfg. finish_hugr_with_outputs ( c. outputs ( ) ) . unwrap ( ) ;
654
+ MonomorphizePass :: default ( ) . run ( & mut hugr) ?;
655
+ let mono_hugr = hugr;
593
656
594
657
let mut funcs = list_funcs ( & mono_hugr) ;
595
658
assert ! ( funcs. values( ) . all( |( _, fd) | !is_polymorphic( fd) ) ) ;
@@ -606,7 +669,7 @@ mod test {
606
669
607
670
#[ test]
608
671
fn load_function ( ) {
609
- let hugr = {
672
+ let mut hugr = {
610
673
let mut module_builder = ModuleBuilder :: new ( ) ;
611
674
let foo = {
612
675
let builder = module_builder
@@ -645,9 +708,10 @@ mod test {
645
708
module_builder. finish_hugr ( ) . unwrap ( )
646
709
} ;
647
710
648
- let mono_hugr = remove_polyfuncs ( monomorphize ( hugr) ) ;
711
+ MonomorphizePass :: default ( ) . run ( & mut hugr) . unwrap ( ) ;
712
+ remove_polyfuncs_ref ( & mut hugr) ;
649
713
650
- let funcs = list_funcs ( & mono_hugr ) ;
714
+ let funcs = list_funcs ( & hugr ) ;
651
715
assert ! ( funcs. values( ) . all( |( _, fd) | !is_polymorphic( fd) ) ) ;
652
716
}
653
717
0 commit comments