diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7552ed36f..37f3eb6e0 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -3,17 +3,14 @@ //! An (example) use of the [dataflow analysis framework](super::dataflow). pub mod value_handle; +use itertools::{Either, Itertools}; use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, ExtractHugr, HierarchyView}, - }, + hugr::hugrmut::HugrMut, ops::{ - constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - OpType, Value, + constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, }, types::{EdgeKind, TypeArg}, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, @@ -207,14 +204,7 @@ pub fn constant_fold_pass(h: &mut H) { struct ConstFoldContext<'a, H>(&'a H); -impl std::ops::Deref for ConstFoldContext<'_, H> { - type Target = H; - fn deref(&self) -> &H { - self.0 - } -} - -impl> ConstLoader> for ConstFoldContext<'_, H> { +impl ConstLoader> for ConstFoldContext<'_, H> { type Node = H::Node; fn value_from_opaque( @@ -238,17 +228,7 @@ impl> ConstLoader> for ConstFoldCo node: H::Node, type_args: &[TypeArg], ) -> Option> { - if !type_args.is_empty() { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(&**self, node).ok()?; - Some(ValueHandle::new_const_hugr( - ConstLocation::Node(node), - Box::new(func.extract_hugr()), - )) + Some(ValueHandle::NodeRef(node, type_args.to_vec())) } } @@ -278,6 +258,51 @@ impl> DFContext> for ConstFoldCont partial_from_const(self, ConstLocation::Field(p.index(), &node.into()), &v); } } + + fn interpret_call_indirect( + &mut self, + func: &PartialValue>, + args: &[PartialValue>], + outs: &mut [PartialValue>], + ) { + let PartialValue::Value(func) = func else { + return; + }; + let inputs = args.iter().cloned().enumerate().map(|(i, v)| (i.into(), v)); + let vals: Vec<_> = match func { + ValueHandle::NodeRef(node, _) => { + let mut m = Machine::new(self.0); + m.prepopulate_inputs(*node, inputs).unwrap(); + let results = m.run(ConstFoldContext(self.0), []); + (0..outs.len()) + .map(|p| results.read_out_wire(Wire::new(*node, p))) + .collect() + } + ValueHandle::Unhashable { + leaf: Either::Right(hugr), + .. + } => { + let h = hugr.as_ref(); + // The problem here---which we'd see if we didn't constrain H::Node==Node, + // because the ValueHandle's would be incompatible---is that `args` may contain + // (a) UnhashableConsts keyed by NodeId in the *outer* Hugr (`self.0`, not `h`). + // (b) NodeRefs referring to nodes in the outer Hugr ! + // + // We can solve the first by remapping keys of HashedConsts to the input ports + // (like `fresh_node` in ConstFoldContext::run_no_validate), but not the second. + let results = Machine::new(h).run(ConstFoldContext(h), inputs); + (0..outs.len()) + .map(|p| results.read_out_wire(Wire::new(h.root(), p))) + .collect() + } + _ => return, + }; + for (val, out) in vals.into_iter().zip_eq(outs) { + if let Some(val) = val { + *out = val; + } + } + } } #[cfg(test)] diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index bda7bffd2..441b88788 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use hugr_core::core::HugrNode; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; +use hugr_core::types::TypeArg; use hugr_core::{Hugr, Node}; use itertools::Either; @@ -46,6 +47,15 @@ impl Hash for HashedConst { /// An [Eq]-able and [Hash]-able leaf (non-[Sum](Value::Sum)) Value #[derive(Clone, Debug)] pub enum ValueHandle { + /// The result of [LoadFunction] on a [FuncDefn] (or [FuncDecl]), i.e. a "function + /// pointer" to a function in the Hugr. (Cannot be represented as a [Value::Function] + /// without lots of cloning, because it may have static edges from other + /// functions/constants/etc.) + /// + /// [LoadFunction]: hugr_core::ops::LoadFunction + /// [FuncDefn]: hugr_core::ops::FuncDefn + /// [FuncDecl]: hugr_core::ops::FuncDefn + NodeRef(N, Vec), /// A [Value::Extension] that has been hashed Hashable(HashedConst), /// Either a [Value::Extension] that can't be hashed, or a [Value::Function]. @@ -108,6 +118,7 @@ impl AbstractValue for ValueHandle {} impl PartialEq for ValueHandle { fn eq(&self, other: &Self) -> bool { match (self, other) { + (Self::NodeRef(n1, args1), Self::NodeRef(n2, args2)) => n1 == n2 && args1 == args2, (Self::Hashable(h1), Self::Hashable(h2)) => h1 == h2, ( Self::Unhashable { @@ -138,6 +149,10 @@ impl Eq for ValueHandle {} impl Hash for ValueHandle { fn hash(&self, state: &mut I) { match self { + ValueHandle::NodeRef(n, args) => { + n.hash(state); + args.hash(state); + } ValueHandle::Hashable(hc) => hc.hash(state), ValueHandle::Unhashable { node, @@ -153,9 +168,11 @@ impl Hash for ValueHandle { // Unfortunately we need From for Value to be able to pass // Value's into interpret_leaf_op. So that probably doesn't make sense... -impl From> for Value { - fn from(value: ValueHandle) -> Self { - match value { +impl TryFrom> for Value { + type Error = N; + fn try_from(value: ValueHandle) -> Result { + Ok(match value { + ValueHandle::NodeRef(n, _) => return Err(n), ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable { leaf: Either::Left(val), @@ -169,7 +186,7 @@ impl From> for Value { } => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) .map_err(|e| e.to_string()) .unwrap(), - } + }) } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 43caa9c94..6db55b92b 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -35,6 +35,18 @@ pub trait DFContext: ConstLoader { _outs: &mut [PartialValue], ) { } + + /// Given lattice values for the called function, and arguments to pass to it, update + /// lattice values for the (dataflow) outputs + /// of a [CallIndirect](hugr_core::ops::CallIndirect). + /// (The default does nothing, i.e. leaves `Top` for all outputs.) + fn interpret_call_indirect( + &mut self, + _func: &PartialValue, + _args: &[PartialValue], + _outs: &mut [PartialValue], + ) { + } } /// A location where a [Value] could be find in a Hugr. That is, diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 13e510daf..037ca2bda 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -401,6 +401,13 @@ fn propagate_leaf_op( outs })) } - o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + OpType::CallIndirect(_) => Some(ValueRow::from_iter(if row_contains_bottom(ins) { + vec![PartialValue::Bottom; num_outs] + } else { + let mut outs = vec![PartialValue::Top; num_outs]; + ctx.interpret_call_indirect(&ins[0], &ins[1..], &mut outs[..]); + outs + })), + o => todo!("Unhandled: {:?}", o), // OpType is "non-exhaustive" } }