Skip to content

feat!: Handle CallIndirect in constant folding #2046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
75 changes: 50 additions & 25 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
//! An (example) use of the [dataflow analysis framework](super::dataflow).

pub mod value_handle;
use itertools::{Either, Itertools};
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;

use hugr_core::{
hugr::{
hugrmut::HugrMut,
views::{DescendantsGraph, ExtractHugr, HierarchyView},
},
hugr::hugrmut::HugrMut,
ops::{
constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant,
OpType, Value,
constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value,
},
types::{EdgeKind, TypeArg},
HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire,
Expand Down Expand Up @@ -207,14 +204,7 @@ pub fn constant_fold_pass<H: HugrMut>(h: &mut H) {

struct ConstFoldContext<'a, H>(&'a H);

impl<H: HugrView> std::ops::Deref for ConstFoldContext<'_, H> {
type Target = H;
fn deref(&self) -> &H {
self.0
}
}

impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldContext<'_, H> {
impl<H: HugrView> ConstLoader<ValueHandle<H::Node>> for ConstFoldContext<'_, H> {
type Node = H::Node;

fn value_from_opaque(
Expand All @@ -238,17 +228,7 @@ impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldCo
node: H::Node,
type_args: &[TypeArg],
) -> Option<ValueHandle<H::Node>> {
if !type_args.is_empty() {
// TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709)
return None;
};
// Returning the function body as a value, here, would be sufficient for inlining IndirectCall
// but not for transforming to a direct Call.
let func = DescendantsGraph::<FuncID<true>>::try_new(&**self, node).ok()?;
Some(ValueHandle::new_const_hugr(
ConstLocation::Node(node),
Box::new(func.extract_hugr()),
))
Some(ValueHandle::NodeRef(node, type_args.to_vec()))
}
}

Expand Down Expand Up @@ -278,6 +258,51 @@ impl<H: HugrView<Node = Node>> DFContext<ValueHandle<H::Node>> for ConstFoldCont
partial_from_const(self, ConstLocation::Field(p.index(), &node.into()), &v);
}
}

fn interpret_call_indirect(
&mut self,
func: &PartialValue<ValueHandle<H::Node>>,
args: &[PartialValue<ValueHandle<H::Node>>],
outs: &mut [PartialValue<ValueHandle<H::Node>>],
) {
let PartialValue::Value(func) = func else {
return;
};
let inputs = args.iter().cloned().enumerate().map(|(i, v)| (i.into(), v));
let vals: Vec<_> = match func {
ValueHandle::NodeRef(node, _) => {
let mut m = Machine::new(self.0);
m.prepopulate_inputs(*node, inputs).unwrap();
let results = m.run(ConstFoldContext(self.0), []);
(0..outs.len())
.map(|p| results.read_out_wire(Wire::new(*node, p)))
.collect()
}
ValueHandle::Unhashable {
leaf: Either::Right(hugr),
..
} => {
let h = hugr.as_ref();
// The problem here---which we'd see if we didn't constrain H::Node==Node,
// because the ValueHandle's would be incompatible---is that `args` may contain
// (a) UnhashableConsts keyed by NodeId in the *outer* Hugr (`self.0`, not `h`).
// (b) NodeRefs referring to nodes in the outer Hugr !
//
// We can solve the first by remapping keys of HashedConsts to the input ports
// (like `fresh_node` in ConstFoldContext::run_no_validate), but not the second.
let results = Machine::new(h).run(ConstFoldContext(h), inputs);
(0..outs.len())
.map(|p| results.read_out_wire(Wire::new(h.root(), p)))
.collect()
}
_ => return,
};
for (val, out) in vals.into_iter().zip_eq(outs) {
if let Some(val) = val {
*out = val;
}
}
}
}

#[cfg(test)]
Expand Down
25 changes: 21 additions & 4 deletions hugr-passes/src/const_fold/value_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::Arc;
use hugr_core::core::HugrNode;
use hugr_core::ops::constant::OpaqueValue;
use hugr_core::ops::Value;
use hugr_core::types::TypeArg;
use hugr_core::{Hugr, Node};
use itertools::Either;

Expand Down Expand Up @@ -46,6 +47,15 @@ impl Hash for HashedConst {
/// An [Eq]-able and [Hash]-able leaf (non-[Sum](Value::Sum)) Value
#[derive(Clone, Debug)]
pub enum ValueHandle<N = Node> {
/// The result of [LoadFunction] on a [FuncDefn] (or [FuncDecl]), i.e. a "function
/// pointer" to a function in the Hugr. (Cannot be represented as a [Value::Function]
/// without lots of cloning, because it may have static edges from other
/// functions/constants/etc.)
///
/// [LoadFunction]: hugr_core::ops::LoadFunction
/// [FuncDefn]: hugr_core::ops::FuncDefn
/// [FuncDecl]: hugr_core::ops::FuncDefn
NodeRef(N, Vec<TypeArg>),
/// A [Value::Extension] that has been hashed
Hashable(HashedConst),
/// Either a [Value::Extension] that can't be hashed, or a [Value::Function].
Expand Down Expand Up @@ -108,6 +118,7 @@ impl<N: HugrNode> AbstractValue for ValueHandle<N> {}
impl<N: HugrNode> PartialEq for ValueHandle<N> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::NodeRef(n1, args1), Self::NodeRef(n2, args2)) => n1 == n2 && args1 == args2,
(Self::Hashable(h1), Self::Hashable(h2)) => h1 == h2,
(
Self::Unhashable {
Expand Down Expand Up @@ -138,6 +149,10 @@ impl<N: HugrNode> Eq for ValueHandle<N> {}
impl<N: HugrNode> Hash for ValueHandle<N> {
fn hash<I: Hasher>(&self, state: &mut I) {
match self {
ValueHandle::NodeRef(n, args) => {
n.hash(state);
args.hash(state);
}
ValueHandle::Hashable(hc) => hc.hash(state),
ValueHandle::Unhashable {
node,
Expand All @@ -153,9 +168,11 @@ impl<N: HugrNode> Hash for ValueHandle<N> {

// Unfortunately we need From<ValueHandle> for Value to be able to pass
// Value's into interpret_leaf_op. So that probably doesn't make sense...
impl<N: HugrNode> From<ValueHandle<N>> for Value {
fn from(value: ValueHandle<N>) -> Self {
match value {
impl<N: HugrNode> TryFrom<ValueHandle<N>> for Value {
type Error = N;
fn try_from(value: ValueHandle<N>) -> Result<Value, N> {
Ok(match value {
ValueHandle::NodeRef(n, _) => return Err(n),
ValueHandle::Hashable(HashedConst { val, .. })
| ValueHandle::Unhashable {
leaf: Either::Left(val),
Expand All @@ -169,7 +186,7 @@ impl<N: HugrNode> From<ValueHandle<N>> for Value {
} => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone()))
.map_err(|e| e.to_string())
.unwrap(),
}
})
}
}

Expand Down
12 changes: 12 additions & 0 deletions hugr-passes/src/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ pub trait DFContext<V>: ConstLoader<V> {
_outs: &mut [PartialValue<V>],
) {
}

/// Given lattice values for the called function, and arguments to pass to it, update
/// lattice values for the (dataflow) outputs
/// of a [CallIndirect](hugr_core::ops::CallIndirect).
/// (The default does nothing, i.e. leaves `Top` for all outputs.)
fn interpret_call_indirect(
&mut self,
_func: &PartialValue<V>,
_args: &[PartialValue<V>],
_outs: &mut [PartialValue<V>],
) {
}
}

/// A location where a [Value] could be find in a Hugr. That is,
Expand Down
9 changes: 8 additions & 1 deletion hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,13 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>(
outs
}))
}
o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive"
OpType::CallIndirect(_) => Some(ValueRow::from_iter(if row_contains_bottom(ins) {
vec![PartialValue::Bottom; num_outs]
} else {
let mut outs = vec![PartialValue::Top; num_outs];
ctx.interpret_call_indirect(&ins[0], &ins[1..], &mut outs[..]);
outs
})),
o => todo!("Unhandled: {:?}", o), // OpType is "non-exhaustive"
}
}
Loading