Skip to content

Commit e40b6c7

Browse files
authored
fix: Call ops not tracking their parameter extensions (#1805)
Fixes #1795 I forgot to track the `type_args` of `Call` and `LoadFunction` when doing extension resolution drive-by: Add derives to `types::Substitution`
1 parent d730c43 commit e40b6c7

File tree

4 files changed

+58
-3
lines changed

4 files changed

+58
-3
lines changed

hugr-core/src/extension/resolution/test.rs

+43-2
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ use crate::extension::{
2222
use crate::ops::constant::test::CustomTestValue;
2323
use crate::ops::constant::CustomConst;
2424
use crate::ops::{CallIndirect, ExtensionOp, Input, OpTrait, OpType, Tag, Value};
25-
use crate::std_extensions::arithmetic::float_types::{float64_type, ConstF64};
25+
use crate::std_extensions::arithmetic::float_types::{self, float64_type, ConstF64};
2626
use crate::std_extensions::arithmetic::int_ops;
2727
use crate::std_extensions::arithmetic::int_types::{self, int_type};
2828
use crate::std_extensions::collections::list::ListValue;
29-
use crate::types::{Signature, Type};
29+
use crate::types::type_param::TypeParam;
30+
use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound};
3031
use crate::{std_extensions, type_row, Extension, Hugr, HugrView};
3132

3233
#[rstest]
@@ -346,6 +347,46 @@ fn resolve_custom_const(#[case] custom_const: impl CustomConst) {
346347
check_extension_resolution(hugr);
347348
}
348349

350+
/// Test resolution of function call with type arguments.
351+
#[rstest]
352+
fn resolve_call() {
353+
let dummy_fn_sig = PolyFuncType::new(
354+
vec![TypeParam::Type { b: TypeBound::Any }],
355+
Signature::new(vec![], vec![bool_t()]),
356+
);
357+
358+
let generic_type_1 = TypeArg::Type { ty: float64_type() };
359+
let generic_type_2 = TypeArg::Type { ty: int_type(6) };
360+
let expected_exts = [
361+
float_types::EXTENSION_ID.to_owned(),
362+
int_types::EXTENSION_ID.to_owned(),
363+
]
364+
.into_iter()
365+
.collect::<ExtensionSet>();
366+
367+
let mut module = ModuleBuilder::new();
368+
let dummy_fn = module.declare("called_fn", dummy_fn_sig).unwrap();
369+
370+
let mut func = module
371+
.define_function(
372+
"caller_fn",
373+
Signature::new(vec![], vec![bool_t()])
374+
.with_extension_delta(ExtensionSet::from_iter(expected_exts.clone())),
375+
)
376+
.unwrap();
377+
let _load_func = func.load_func(&dummy_fn, &[generic_type_1]).unwrap();
378+
let call = func.call(&dummy_fn, &[generic_type_2], vec![]).unwrap();
379+
func.finish_with_outputs(call.outputs()).unwrap();
380+
381+
let hugr = module.finish_hugr().unwrap();
382+
383+
for ext in expected_exts {
384+
assert!(hugr.extensions().contains(&ext));
385+
}
386+
387+
check_extension_resolution(hugr);
388+
}
389+
349390
/// Fail when collecting extensions but the weak pointers are not resolved.
350391
#[rstest]
351392
fn dropped_weak_extensions() {

hugr-core/src/extension/resolution/types.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,18 @@ pub(crate) fn collect_op_types_extensions(
5050
OpType::Call(c) => {
5151
collect_signature_exts(c.func_sig.body(), &mut used, &mut missing);
5252
collect_signature_exts(&c.instantiation, &mut used, &mut missing);
53+
for arg in &c.type_args {
54+
collect_typearg_exts(arg, &mut used, &mut missing);
55+
}
5356
}
5457
OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing),
5558
OpType::LoadConstant(lc) => collect_type_exts(&lc.datatype, &mut used, &mut missing),
5659
OpType::LoadFunction(lf) => {
5760
collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing);
5861
collect_signature_exts(&lf.instantiation, &mut used, &mut missing);
62+
for arg in &lf.type_args {
63+
collect_typearg_exts(arg, &mut used, &mut missing);
64+
}
5965
}
6066
OpType::DFG(dfg) => collect_signature_exts(&dfg.signature, &mut used, &mut missing),
6167
OpType::OpaqueOp(op) => {
@@ -203,7 +209,7 @@ pub(super) fn collect_type_exts<RV: MaybeRV>(
203209
/// - `used_extensions`: A The registry where to store the used extensions.
204210
/// - `missing_extensions`: A set of `ExtensionId`s of which the
205211
/// `Weak<Extension>` pointer has been invalidated.
206-
fn collect_typearg_exts(
212+
pub(super) fn collect_typearg_exts(
207213
arg: &TypeArg,
208214
used_extensions: &mut WeakExtensionRegistry,
209215
missing_extensions: &mut ExtensionSet,

hugr-core/src/extension/resolution/types_mut.rs

+6
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ pub fn resolve_op_types_extensions(
5050
OpType::Call(c) => {
5151
resolve_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?;
5252
resolve_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?;
53+
for arg in &mut c.type_args {
54+
resolve_typearg_exts(node, arg, extensions, used_extensions)?;
55+
}
5356
}
5457
OpType::CallIndirect(c) => {
5558
resolve_signature_exts(node, &mut c.signature, extensions, used_extensions)?
@@ -60,6 +63,9 @@ pub fn resolve_op_types_extensions(
6063
OpType::LoadFunction(lf) => {
6164
resolve_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?;
6265
resolve_signature_exts(node, &mut lf.instantiation, extensions, used_extensions)?;
66+
for arg in &mut lf.type_args {
67+
resolve_typearg_exts(node, arg, extensions, used_extensions)?;
68+
}
6369
}
6470
OpType::DFG(dfg) => {
6571
resolve_signature_exts(node, &mut dfg.signature, extensions, used_extensions)?

hugr-core/src/types.rs

+2
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,8 @@ impl From<Type> for TypeRV {
587587

588588
/// Details a replacement of type variables with a finite list of known values.
589589
/// (Variables out of the range of the list will result in a panic)
590+
#[derive(Clone, Debug, derive_more::Display)]
591+
#[display("[{}]", _0.iter().map(|ta| ta.to_string()).join(", "))]
590592
pub struct Substitution<'a>(&'a [TypeArg]);
591593

592594
impl<'a> Substitution<'a> {

0 commit comments

Comments
 (0)