Skip to content

feat: Faster import/export through better name handling. #2055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hugr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ extension_inference = []
declarative = ["serde_yaml"]
model_unstable = ["hugr-model"]
zstd = ["dep:zstd"]
default = ["model_unstable"]

[lib]
bench = false
Expand Down
158 changes: 80 additions & 78 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ use crate::{
use fxhash::{FxBuildHasher, FxHashMap};
use hugr_model::v0::{
self as model,
bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump},
bumpalo::{collections::Vec as BumpVec, Bump},
table,
};
use petgraph::unionfind::UnionFind;
use std::fmt::Write;
use smol_str::{format_smolstr, SmolStr, ToSmolStr};

/// Export a [`Package`] to its representation in the model.
pub fn export_package<'a>(package: &'a Package, bump: &'a Bump) -> table::Package<'a> {
Expand Down Expand Up @@ -77,18 +77,18 @@ struct Context<'a> {
links: Links,

/// The symbol table tracking symbols that are currently in scope.
symbols: model::scope::SymbolTable<'a>,
symbols: model::scope::SymbolTable<CoreSymbol>,

/// Mapping from implicit imports to their node ids.
implicit_imports: FxHashMap<&'a str, table::NodeId>,
implicit_imports: FxHashMap<CoreSymbol, table::NodeId>,

/// Map from node ids in the [`Hugr`] to the corresponding node ids in the model.
node_to_id: FxHashMap<Node, table::NodeId>,

/// Mapping from node ids in the [`Hugr`] to the corresponding model nodes.
id_to_node: FxHashMap<table::NodeId, Node>,
// TODO: Once this module matures, we should consider adding an auxiliary structure
// that ensures that the `node_to_id` and `id_to_node` maps stay in sync.

static_symbols: FxHashMap<&'static str, CoreSymbol>,
}

impl<'a> Context<'a> {
Expand All @@ -110,6 +110,7 @@ impl<'a> Context<'a> {
implicit_imports: FxHashMap::default(),
node_to_id: FxHashMap::default(),
id_to_node: FxHashMap::default(),
static_symbols: FxHashMap::default(),
}
}

Expand Down Expand Up @@ -183,26 +184,6 @@ impl<'a> Context<'a> {
.or_insert_with(|| self.module.insert_term(term))
}

pub fn make_qualified_name(
&mut self,
extension: &ExtensionId,
name: impl AsRef<str>,
) -> &'a str {
let capacity = extension.len() + name.as_ref().len() + 1;
let mut output = BumpString::with_capacity_in(capacity, self.bump);
let _ = write!(&mut output, "{}.{}", extension, name.as_ref());
output.into_bump_str()
}

pub fn make_named_global_ref(
&mut self,
extension: &IdentList,
name: impl AsRef<str>,
) -> table::NodeId {
let symbol = self.make_qualified_name(extension, name);
self.resolve_symbol(symbol)
}

/// Get the node that declares or defines the function associated with the given
/// node via the static input. Returns `None` if the node is not connected to a function.
fn connected_function(&self, node: Node) -> Option<Node> {
Expand All @@ -215,16 +196,6 @@ impl<'a> Context<'a> {
}
}

/// Get the name of a function definition or declaration node. Returns `None` if not
/// one of those operations.
fn get_func_name(&self, func_node: Node) -> Option<&'a str> {
match self.hugr.get_optype(func_node) {
OpType::FuncDecl(func_decl) => Some(&func_decl.name),
OpType::FuncDefn(func_defn) => Some(&func_defn.name),
_ => None,
}
}

fn with_local_scope<T>(&mut self, node: table::NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
let prev_local_scope = self.local_scope.replace(node);
let prev_local_constraints = std::mem::take(&mut self.local_constraints);
Expand Down Expand Up @@ -253,16 +224,16 @@ impl<'a> Context<'a> {

// We record the name of the symbol defined by the node, if any.
let symbol = match optype {
OpType::FuncDefn(func_defn) => Some(func_defn.name.as_str()),
OpType::FuncDecl(func_decl) => Some(func_decl.name.as_str()),
OpType::AliasDecl(alias_decl) => Some(alias_decl.name.as_str()),
OpType::AliasDefn(alias_defn) => Some(alias_defn.name.as_str()),
OpType::FuncDefn(func_defn) => Some(func_defn.name.to_smolstr()),
OpType::FuncDecl(func_decl) => Some(func_decl.name.to_smolstr()),
OpType::AliasDecl(alias_decl) => Some(alias_decl.name.clone()),
OpType::AliasDefn(alias_defn) => Some(alias_defn.name.clone()),
_ => None,
};

if let Some(symbol) = symbol {
self.symbols
.insert(symbol, node_id)
.insert(CoreSymbol::Local(symbol), node_id)
.expect("duplicate symbol");
}

Expand Down Expand Up @@ -319,7 +290,7 @@ impl<'a> Context<'a> {
}

OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| {
let name = this.get_func_name(node).unwrap();
let name = this.symbols.symbol_name(node_id).unwrap().clone();
let symbol = this.export_poly_func_type(name, &func.signature);
regions = this
.bump
Expand All @@ -328,7 +299,7 @@ impl<'a> Context<'a> {
}),

OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| {
let name = this.get_func_name(node).unwrap();
let name = this.symbols.symbol_name(node_id).unwrap().clone();
let symbol = this.export_poly_func_type(name, &func.signature);
table::Operation::DeclareFunc(symbol)
}),
Expand All @@ -337,7 +308,7 @@ impl<'a> Context<'a> {
// TODO: We should support aliases with different types and with parameters
let signature = this.make_term_apply(model::CORE_TYPE, &[]);
let symbol = this.bump.alloc(table::Symbol {
name: &alias.name,
name: model::SymbolName::new(alias.name.clone()),
params: &[],
constraints: &[],
signature,
Expand All @@ -350,7 +321,7 @@ impl<'a> Context<'a> {
// TODO: We should support aliases with different types and with parameters
let signature = this.make_term_apply(model::CORE_TYPE, &[]);
let symbol = this.bump.alloc(table::Symbol {
name: &alias.name,
name: model::SymbolName::new(alias.name.clone()),
params: &[],
constraints: &[],
signature,
Expand Down Expand Up @@ -446,7 +417,8 @@ impl<'a> Context<'a> {
}

OpType::OpaqueOp(op) => {
let node = self.make_named_global_ref(op.extension(), op.op_name());
let name = CoreSymbol::Qualified(op.extension().clone(), op.op_name().clone());
let node = self.resolve_symbol(&name);
let params = self
.bump
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg)));
Expand Down Expand Up @@ -505,7 +477,11 @@ impl<'a> Context<'a> {

let poly_func_type = match opdef.signature_func() {
SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type,
_ => return self.make_named_global_ref(opdef.extension_id(), opdef.name()),
_ => {
let name =
CoreSymbol::Qualified(opdef.extension_id().clone(), opdef.name().clone());
return self.resolve_symbol(&name);
}
};

let key = (opdef.extension_id().clone(), opdef.name().clone());
Expand All @@ -519,7 +495,7 @@ impl<'a> Context<'a> {
};

let symbol = self.with_local_scope(node, |this| {
let name = this.make_qualified_name(opdef.extension_id(), opdef.name());
let name = CoreSymbol::Qualified(opdef.extension_id().clone(), opdef.name().clone());
this.export_poly_func_type(name, poly_func_type)
});

Expand Down Expand Up @@ -737,7 +713,7 @@ impl<'a> Context<'a> {
/// Exports a polymorphic function type.
pub fn export_poly_func_type<RV: MaybeRV>(
&mut self,
name: &'a str,
name: CoreSymbol,
t: &PolyFuncTypeBase<RV>,
) -> &'a table::Symbol<'a> {
let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump);
Expand All @@ -746,7 +722,7 @@ impl<'a> Context<'a> {
.expect("exporting poly func type outside of local scope");

for (i, param) in t.params().iter().enumerate() {
let name = self.bump.alloc_str(&i.to_string());
let name = model::VarName::new(i.to_smolstr());
let r#type = self.export_type_param(param, Some((scope, i as _)));
let param = table::Param { name, r#type };
params.push(param)
Expand All @@ -756,7 +732,7 @@ impl<'a> Context<'a> {
let body = self.export_func_type(t.body());

self.bump.alloc(table::Symbol {
name,
name: name.into(),
params: params.into_bump_slice(),
constraints,
signature: body,
Expand All @@ -771,7 +747,8 @@ impl<'a> Context<'a> {
match t {
TypeEnum::Extension(ext) => self.export_custom_type(ext),
TypeEnum::Alias(alias) => {
let symbol = self.resolve_symbol(self.bump.alloc_str(alias.name()));
let name = CoreSymbol::Local(alias.name().to_smolstr());
let symbol = self.resolve_symbol(&name);
self.make_term(table::Term::Apply(symbol, &[]))
}
TypeEnum::Function(func) => self.export_func_type(func),
Expand All @@ -791,7 +768,8 @@ impl<'a> Context<'a> {
}

pub fn export_custom_type(&mut self, t: &CustomType) -> table::TermId {
let symbol = self.make_named_global_ref(t.extension(), t.name());
let name = CoreSymbol::Qualified(t.extension().clone(), t.name().clone());
let symbol = self.resolve_symbol(&name);

let args = self
.bump
Expand Down Expand Up @@ -943,26 +921,19 @@ impl<'a> Context<'a> {
}

let contents = self.make_term(table::Term::List(contents.into_bump_slice()));

let symbol = self.resolve_symbol(ArrayValue::CTR_NAME);
let args = self.bump.alloc_slice_copy(&[len, element_type, contents]);
return self.make_term(table::Term::Apply(symbol, args));
return self
.make_term_apply(ArrayValue::CTR_NAME, &[len, element_type, contents]);
}

if let Some(v) = e.value().downcast_ref::<ConstInt>() {
let bitwidth = self.make_term(model::Literal::Nat(v.log_width() as u64).into());
let literal = self.make_term(model::Literal::Nat(v.value_u()).into());

let symbol = self.resolve_symbol(ConstInt::CTR_NAME);
let args = self.bump.alloc_slice_copy(&[bitwidth, literal]);
return self.make_term(table::Term::Apply(symbol, args));
return self.make_term_apply(ConstInt::CTR_NAME, &[bitwidth, literal]);
}

if let Some(v) = e.value().downcast_ref::<ConstF64>() {
let literal = self.make_term(model::Literal::Float(v.value().into()).into());
let symbol = self.resolve_symbol(ConstF64::CTR_NAME);
let args = self.bump.alloc_slice_copy(&[literal]);
return self.make_term(table::Term::Apply(symbol, args));
return self.make_term_apply(ConstF64::CTR_NAME, &[literal]);
}

let json = match e.value().downcast_ref::<CustomSerialized>() {
Expand All @@ -973,9 +944,7 @@ impl<'a> Context<'a> {

let json = self.make_term(model::Literal::Str(json.into()).into());
let runtime_type = self.export_type(&e.get_type());
let args = self.bump.alloc_slice_copy(&[runtime_type, json]);
let symbol = self.resolve_symbol(model::COMPAT_CONST_JSON);
self.make_term(table::Term::Apply(symbol, args))
return self.make_term_apply(model::COMPAT_CONST_JSON, &[runtime_type, json]);
}

Value::Function { hugr } => {
Expand Down Expand Up @@ -1030,29 +999,62 @@ impl<'a> Context<'a> {
self.make_term_apply(model::COMPAT_META_JSON, &[name, value])
}

fn resolve_symbol(&mut self, name: &'a str) -> table::NodeId {
fn resolve_symbol(&mut self, name: &CoreSymbol) -> table::NodeId {
let result = self.symbols.resolve(name);

match result {
Ok(node) => node,
Err(_) => *self.implicit_imports.entry(name).or_insert_with(|| {
self.module.insert_node(table::Node {
operation: table::Operation::Import { name },
..table::Node::default()
})
}),
if let Ok(node) = result {
return node;
}

// NOTE: We do not use the entry API here in order to avoid allocating
// a new `SymbolName` when the import already exists.
if let Some(node) = self.implicit_imports.get(name) {
return *node;
}

let node = self.module.insert_node(table::Node {
operation: table::Operation::Import(name.clone().into()),
..table::Node::default()
});
self.implicit_imports.insert(name.clone(), node);
node
}

fn make_term_apply(&mut self, name: &'a str, args: &[table::TermId]) -> table::TermId {
let symbol = self.resolve_symbol(name);
fn make_term_apply(&mut self, name: &'static str, args: &[table::TermId]) -> table::TermId {
let name = self
.static_symbols
.entry(name)
.or_insert_with(|| {
let (ext, name) = name.rsplit_once(".").unwrap();
let ext = IdentList::new_unchecked(ext);
let name = name.to_smolstr();
CoreSymbol::Qualified(ext, name)
})
.clone();

let symbol = self.resolve_symbol(&name);
let args = self.bump.alloc_slice_copy(args);
self.make_term(table::Term::Apply(symbol, args))
}
}

type FxIndexSet<T> = indexmap::IndexSet<T, FxBuildHasher>;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum CoreSymbol {
Qualified(ExtensionId, SmolStr),
Local(SmolStr),
}

impl From<CoreSymbol> for model::SymbolName {
fn from(value: CoreSymbol) -> Self {
Self::new(match value {
CoreSymbol::Qualified(ext, name) => format_smolstr!("{}.{}", ext, name),
CoreSymbol::Local(name) => name,
})
}
}

/// Data structure for translating the edges between ports in the `Hugr` graph
/// into the hypergraph representation used by `hugr_model`.
struct Links {
Expand Down
Loading
Loading