@@ -22,11 +22,12 @@ use crate::extension::{
22
22
use crate :: ops:: constant:: test:: CustomTestValue ;
23
23
use crate :: ops:: constant:: CustomConst ;
24
24
use crate :: ops:: { CallIndirect , ExtensionOp , Input , OpTrait , OpType , Tag , Value } ;
25
- use crate :: std_extensions:: arithmetic:: float_types:: { float64_type, ConstF64 } ;
25
+ use crate :: std_extensions:: arithmetic:: float_types:: { self , float64_type, ConstF64 } ;
26
26
use crate :: std_extensions:: arithmetic:: int_ops;
27
27
use crate :: std_extensions:: arithmetic:: int_types:: { self , int_type} ;
28
28
use crate :: std_extensions:: collections:: list:: ListValue ;
29
- use crate :: types:: { Signature , Type } ;
29
+ use crate :: types:: type_param:: TypeParam ;
30
+ use crate :: types:: { PolyFuncType , Signature , Type , TypeArg , TypeBound } ;
30
31
use crate :: { std_extensions, type_row, Extension , Hugr , HugrView } ;
31
32
32
33
#[ rstest]
@@ -346,6 +347,46 @@ fn resolve_custom_const(#[case] custom_const: impl CustomConst) {
346
347
check_extension_resolution ( hugr) ;
347
348
}
348
349
350
+ /// Test resolution of function call with type arguments.
351
+ #[ rstest]
352
+ fn resolve_call ( ) {
353
+ let dummy_fn_sig = PolyFuncType :: new (
354
+ vec ! [ TypeParam :: Type { b: TypeBound :: Any } ] ,
355
+ Signature :: new ( vec ! [ ] , vec ! [ bool_t( ) ] ) ,
356
+ ) ;
357
+
358
+ let generic_type_1 = TypeArg :: Type { ty : float64_type ( ) } ;
359
+ let generic_type_2 = TypeArg :: Type { ty : int_type ( 6 ) } ;
360
+ let expected_exts = [
361
+ float_types:: EXTENSION_ID . to_owned ( ) ,
362
+ int_types:: EXTENSION_ID . to_owned ( ) ,
363
+ ]
364
+ . into_iter ( )
365
+ . collect :: < ExtensionSet > ( ) ;
366
+
367
+ let mut module = ModuleBuilder :: new ( ) ;
368
+ let dummy_fn = module. declare ( "called_fn" , dummy_fn_sig) . unwrap ( ) ;
369
+
370
+ let mut func = module
371
+ . define_function (
372
+ "caller_fn" ,
373
+ Signature :: new ( vec ! [ ] , vec ! [ bool_t( ) ] )
374
+ . with_extension_delta ( ExtensionSet :: from_iter ( expected_exts. clone ( ) ) ) ,
375
+ )
376
+ . unwrap ( ) ;
377
+ let _load_func = func. load_func ( & dummy_fn, & [ generic_type_1] ) . unwrap ( ) ;
378
+ let call = func. call ( & dummy_fn, & [ generic_type_2] , vec ! [ ] ) . unwrap ( ) ;
379
+ func. finish_with_outputs ( call. outputs ( ) ) . unwrap ( ) ;
380
+
381
+ let hugr = module. finish_hugr ( ) . unwrap ( ) ;
382
+
383
+ for ext in expected_exts {
384
+ assert ! ( hugr. extensions( ) . contains( & ext) ) ;
385
+ }
386
+
387
+ check_extension_resolution ( hugr) ;
388
+ }
389
+
349
390
/// Fail when collecting extensions but the weak pointers are not resolved.
350
391
#[ rstest]
351
392
fn dropped_weak_extensions ( ) {
0 commit comments