diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index fa60f704d..7e6230dc2 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -21,6 +21,7 @@ extension_inference = [] declarative = ["serde_yaml"] model_unstable = ["hugr-model"] zstd = ["dep:zstd"] +default = ["model_unstable"] [lib] bench = false diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index ecf5abd70..a77b3a0f9 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -22,11 +22,11 @@ use crate::{ use fxhash::{FxBuildHasher, FxHashMap}; use hugr_model::v0::{ self as model, - bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}, + bumpalo::{collections::Vec as BumpVec, Bump}, table, }; use petgraph::unionfind::UnionFind; -use std::fmt::Write; +use smol_str::{format_smolstr, SmolStr, ToSmolStr}; /// Export a [`Package`] to its representation in the model. pub fn export_package<'a>(package: &'a Package, bump: &'a Bump) -> table::Package<'a> { @@ -77,18 +77,18 @@ struct Context<'a> { links: Links, /// The symbol table tracking symbols that are currently in scope. - symbols: model::scope::SymbolTable<'a>, + symbols: model::scope::SymbolTable, /// Mapping from implicit imports to their node ids. - implicit_imports: FxHashMap<&'a str, table::NodeId>, + implicit_imports: FxHashMap, /// Map from node ids in the [`Hugr`] to the corresponding node ids in the model. node_to_id: FxHashMap, /// Mapping from node ids in the [`Hugr`] to the corresponding model nodes. id_to_node: FxHashMap, - // TODO: Once this module matures, we should consider adding an auxiliary structure - // that ensures that the `node_to_id` and `id_to_node` maps stay in sync. + + static_symbols: FxHashMap<&'static str, CoreSymbol>, } impl<'a> Context<'a> { @@ -110,6 +110,7 @@ impl<'a> Context<'a> { implicit_imports: FxHashMap::default(), node_to_id: FxHashMap::default(), id_to_node: FxHashMap::default(), + static_symbols: FxHashMap::default(), } } @@ -183,26 +184,6 @@ impl<'a> Context<'a> { .or_insert_with(|| self.module.insert_term(term)) } - pub fn make_qualified_name( - &mut self, - extension: &ExtensionId, - name: impl AsRef, - ) -> &'a str { - let capacity = extension.len() + name.as_ref().len() + 1; - let mut output = BumpString::with_capacity_in(capacity, self.bump); - let _ = write!(&mut output, "{}.{}", extension, name.as_ref()); - output.into_bump_str() - } - - pub fn make_named_global_ref( - &mut self, - extension: &IdentList, - name: impl AsRef, - ) -> table::NodeId { - let symbol = self.make_qualified_name(extension, name); - self.resolve_symbol(symbol) - } - /// Get the node that declares or defines the function associated with the given /// node via the static input. Returns `None` if the node is not connected to a function. fn connected_function(&self, node: Node) -> Option { @@ -215,16 +196,6 @@ impl<'a> Context<'a> { } } - /// Get the name of a function definition or declaration node. Returns `None` if not - /// one of those operations. - fn get_func_name(&self, func_node: Node) -> Option<&'a str> { - match self.hugr.get_optype(func_node) { - OpType::FuncDecl(func_decl) => Some(&func_decl.name), - OpType::FuncDefn(func_defn) => Some(&func_defn.name), - _ => None, - } - } - fn with_local_scope(&mut self, node: table::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { let prev_local_scope = self.local_scope.replace(node); let prev_local_constraints = std::mem::take(&mut self.local_constraints); @@ -253,16 +224,16 @@ impl<'a> Context<'a> { // We record the name of the symbol defined by the node, if any. let symbol = match optype { - OpType::FuncDefn(func_defn) => Some(func_defn.name.as_str()), - OpType::FuncDecl(func_decl) => Some(func_decl.name.as_str()), - OpType::AliasDecl(alias_decl) => Some(alias_decl.name.as_str()), - OpType::AliasDefn(alias_defn) => Some(alias_defn.name.as_str()), + OpType::FuncDefn(func_defn) => Some(func_defn.name.to_smolstr()), + OpType::FuncDecl(func_decl) => Some(func_decl.name.to_smolstr()), + OpType::AliasDecl(alias_decl) => Some(alias_decl.name.clone()), + OpType::AliasDefn(alias_defn) => Some(alias_defn.name.clone()), _ => None, }; if let Some(symbol) = symbol { self.symbols - .insert(symbol, node_id) + .insert(CoreSymbol::Local(symbol), node_id) .expect("duplicate symbol"); } @@ -319,7 +290,7 @@ impl<'a> Context<'a> { } OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { - let name = this.get_func_name(node).unwrap(); + let name = this.symbols.symbol_name(node_id).unwrap().clone(); let symbol = this.export_poly_func_type(name, &func.signature); regions = this .bump @@ -328,7 +299,7 @@ impl<'a> Context<'a> { }), OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { - let name = this.get_func_name(node).unwrap(); + let name = this.symbols.symbol_name(node_id).unwrap().clone(); let symbol = this.export_poly_func_type(name, &func.signature); table::Operation::DeclareFunc(symbol) }), @@ -337,7 +308,7 @@ impl<'a> Context<'a> { // TODO: We should support aliases with different types and with parameters let signature = this.make_term_apply(model::CORE_TYPE, &[]); let symbol = this.bump.alloc(table::Symbol { - name: &alias.name, + name: model::SymbolName::new(alias.name.clone()), params: &[], constraints: &[], signature, @@ -350,7 +321,7 @@ impl<'a> Context<'a> { // TODO: We should support aliases with different types and with parameters let signature = this.make_term_apply(model::CORE_TYPE, &[]); let symbol = this.bump.alloc(table::Symbol { - name: &alias.name, + name: model::SymbolName::new(alias.name.clone()), params: &[], constraints: &[], signature, @@ -446,7 +417,8 @@ impl<'a> Context<'a> { } OpType::OpaqueOp(op) => { - let node = self.make_named_global_ref(op.extension(), op.op_name()); + let name = CoreSymbol::Qualified(op.extension().clone(), op.op_name().clone()); + let node = self.resolve_symbol(&name); let params = self .bump .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); @@ -505,7 +477,11 @@ impl<'a> Context<'a> { let poly_func_type = match opdef.signature_func() { SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type, - _ => return self.make_named_global_ref(opdef.extension_id(), opdef.name()), + _ => { + let name = + CoreSymbol::Qualified(opdef.extension_id().clone(), opdef.name().clone()); + return self.resolve_symbol(&name); + } }; let key = (opdef.extension_id().clone(), opdef.name().clone()); @@ -519,7 +495,7 @@ impl<'a> Context<'a> { }; let symbol = self.with_local_scope(node, |this| { - let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); + let name = CoreSymbol::Qualified(opdef.extension_id().clone(), opdef.name().clone()); this.export_poly_func_type(name, poly_func_type) }); @@ -737,7 +713,7 @@ impl<'a> Context<'a> { /// Exports a polymorphic function type. pub fn export_poly_func_type( &mut self, - name: &'a str, + name: CoreSymbol, t: &PolyFuncTypeBase, ) -> &'a table::Symbol<'a> { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); @@ -746,7 +722,7 @@ impl<'a> Context<'a> { .expect("exporting poly func type outside of local scope"); for (i, param) in t.params().iter().enumerate() { - let name = self.bump.alloc_str(&i.to_string()); + let name = model::VarName::new(i.to_smolstr()); let r#type = self.export_type_param(param, Some((scope, i as _))); let param = table::Param { name, r#type }; params.push(param) @@ -756,7 +732,7 @@ impl<'a> Context<'a> { let body = self.export_func_type(t.body()); self.bump.alloc(table::Symbol { - name, + name: name.into(), params: params.into_bump_slice(), constraints, signature: body, @@ -771,7 +747,8 @@ impl<'a> Context<'a> { match t { TypeEnum::Extension(ext) => self.export_custom_type(ext), TypeEnum::Alias(alias) => { - let symbol = self.resolve_symbol(self.bump.alloc_str(alias.name())); + let name = CoreSymbol::Local(alias.name().to_smolstr()); + let symbol = self.resolve_symbol(&name); self.make_term(table::Term::Apply(symbol, &[])) } TypeEnum::Function(func) => self.export_func_type(func), @@ -791,7 +768,8 @@ impl<'a> Context<'a> { } pub fn export_custom_type(&mut self, t: &CustomType) -> table::TermId { - let symbol = self.make_named_global_ref(t.extension(), t.name()); + let name = CoreSymbol::Qualified(t.extension().clone(), t.name().clone()); + let symbol = self.resolve_symbol(&name); let args = self .bump @@ -943,26 +921,19 @@ impl<'a> Context<'a> { } let contents = self.make_term(table::Term::List(contents.into_bump_slice())); - - let symbol = self.resolve_symbol(ArrayValue::CTR_NAME); - let args = self.bump.alloc_slice_copy(&[len, element_type, contents]); - return self.make_term(table::Term::Apply(symbol, args)); + return self + .make_term_apply(ArrayValue::CTR_NAME, &[len, element_type, contents]); } if let Some(v) = e.value().downcast_ref::() { let bitwidth = self.make_term(model::Literal::Nat(v.log_width() as u64).into()); let literal = self.make_term(model::Literal::Nat(v.value_u()).into()); - - let symbol = self.resolve_symbol(ConstInt::CTR_NAME); - let args = self.bump.alloc_slice_copy(&[bitwidth, literal]); - return self.make_term(table::Term::Apply(symbol, args)); + return self.make_term_apply(ConstInt::CTR_NAME, &[bitwidth, literal]); } if let Some(v) = e.value().downcast_ref::() { let literal = self.make_term(model::Literal::Float(v.value().into()).into()); - let symbol = self.resolve_symbol(ConstF64::CTR_NAME); - let args = self.bump.alloc_slice_copy(&[literal]); - return self.make_term(table::Term::Apply(symbol, args)); + return self.make_term_apply(ConstF64::CTR_NAME, &[literal]); } let json = match e.value().downcast_ref::() { @@ -973,9 +944,7 @@ impl<'a> Context<'a> { let json = self.make_term(model::Literal::Str(json.into()).into()); let runtime_type = self.export_type(&e.get_type()); - let args = self.bump.alloc_slice_copy(&[runtime_type, json]); - let symbol = self.resolve_symbol(model::COMPAT_CONST_JSON); - self.make_term(table::Term::Apply(symbol, args)) + return self.make_term_apply(model::COMPAT_CONST_JSON, &[runtime_type, json]); } Value::Function { hugr } => { @@ -1030,22 +999,40 @@ impl<'a> Context<'a> { self.make_term_apply(model::COMPAT_META_JSON, &[name, value]) } - fn resolve_symbol(&mut self, name: &'a str) -> table::NodeId { + fn resolve_symbol(&mut self, name: &CoreSymbol) -> table::NodeId { let result = self.symbols.resolve(name); - match result { - Ok(node) => node, - Err(_) => *self.implicit_imports.entry(name).or_insert_with(|| { - self.module.insert_node(table::Node { - operation: table::Operation::Import { name }, - ..table::Node::default() - }) - }), + if let Ok(node) = result { + return node; + } + + // NOTE: We do not use the entry API here in order to avoid allocating + // a new `SymbolName` when the import already exists. + if let Some(node) = self.implicit_imports.get(name) { + return *node; } + + let node = self.module.insert_node(table::Node { + operation: table::Operation::Import(name.clone().into()), + ..table::Node::default() + }); + self.implicit_imports.insert(name.clone(), node); + node } - fn make_term_apply(&mut self, name: &'a str, args: &[table::TermId]) -> table::TermId { - let symbol = self.resolve_symbol(name); + fn make_term_apply(&mut self, name: &'static str, args: &[table::TermId]) -> table::TermId { + let name = self + .static_symbols + .entry(name) + .or_insert_with(|| { + let (ext, name) = name.rsplit_once(".").unwrap(); + let ext = IdentList::new_unchecked(ext); + let name = name.to_smolstr(); + CoreSymbol::Qualified(ext, name) + }) + .clone(); + + let symbol = self.resolve_symbol(&name); let args = self.bump.alloc_slice_copy(args); self.make_term(table::Term::Apply(symbol, args)) } @@ -1053,6 +1040,21 @@ impl<'a> Context<'a> { type FxIndexSet = indexmap::IndexSet; +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum CoreSymbol { + Qualified(ExtensionId, SmolStr), + Local(SmolStr), +} + +impl From for model::SymbolName { + fn from(value: CoreSymbol) -> Self { + Self::new(match value { + CoreSymbol::Qualified(ext, name) => format_smolstr!("{}.{}", ext, name), + CoreSymbol::Local(name) => name, + }) + } +} + /// Data structure for translating the edges between ports in the `Hugr` graph /// into the hypergraph representation used by `hugr_model`. struct Links { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index edf06b8c6..db065ab32 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use crate::{ - extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError}, + extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef}, hugr::HugrMut, ops::{ constant::{CustomConst, CustomSerialized, OpaqueValue}, @@ -24,7 +24,7 @@ use crate::{ PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, }, - Direction, Hugr, HugrView, Node, Port, + Direction, Extension, Hugr, HugrView, Node, Port, }; use fxhash::FxHashMap; use hugr_model::v0 as model; @@ -111,8 +111,9 @@ pub fn import_hugr( extensions, nodes: FxHashMap::default(), local_vars: FxHashMap::default(), - custom_name_cache: FxHashMap::default(), region_scope: table::RegionId::default(), + cache_typedef: FxHashMap::default(), + cache_custom_name: FxHashMap::default(), }; ctx.import_root()?; @@ -145,7 +146,8 @@ struct Context<'a> { local_vars: FxHashMap, - custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, + cache_custom_name: FxHashMap, + cache_typedef: FxHashMap, &'a TypeDef)>, region_scope: table::RegionId, } @@ -284,11 +286,14 @@ impl<'a> Context<'a> { Ok(()) } - fn get_symbol_name(&self, node_id: table::NodeId) -> Result<&'a str, ImportError> { + fn get_symbol_name( + &self, + node_id: table::NodeId, + ) -> Result<&'a model::SymbolName, ImportError> { let node_data = self.get_node(node_id)?; let name = node_data .operation - .symbol() + .symbol_name() .ok_or(table::ModelError::InvalidSymbol(node_id))?; Ok(name) } @@ -303,7 +308,7 @@ impl<'a> Context<'a> { _ => return Err(table::ModelError::UnexpectedOperation(func_node).into()), }; - self.import_poly_func_type(func_node, *symbol, |_, signature| Ok(signature)) + self.import_poly_func_type(func_node, symbol, |_, signature| Ok(signature)) } /// Import the root region of the module. @@ -359,7 +364,7 @@ impl<'a> Context<'a> { } table::Operation::DefineFunc(symbol) => { - self.import_poly_func_type(node_id, *symbol, |ctx, signature| { + self.import_poly_func_type(node_id, symbol, |ctx, signature| { let optype = OpType::FuncDefn(FuncDefn { name: symbol.name.to_string(), signature, @@ -378,7 +383,7 @@ impl<'a> Context<'a> { } table::Operation::DeclareFunc(symbol) => { - self.import_poly_func_type(node_id, *symbol, |ctx, signature| { + self.import_poly_func_type(node_id, symbol, |ctx, signature| { let optype = OpType::FuncDecl(FuncDecl { name: symbol.name.to_string(), signature, @@ -507,12 +512,11 @@ impl<'a> Context<'a> { let table::Term::Apply(node, params) = self.get_term(operation)? else { return Err(table::ModelError::TypeError(operation).into()); }; - let name = self.get_symbol_name(*node)?; let args = params .iter() .map(|param| self.import_type_arg(*param)) .collect::, _>>()?; - let (extension, name) = self.import_custom_name(name)?; + let (extension, name) = self.import_custom_name(*node)?; let signature = self.get_node_signature(node_id)?; // TODO: Currently we do not have the description or any other metadata for @@ -893,7 +897,7 @@ impl<'a> Context<'a> { fn import_poly_func_type( &mut self, node: table::NodeId, - symbol: table::Symbol<'a>, + symbol: &'a table::Symbol<'a>, in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, ) -> Result { let mut imported_params = Vec::with_capacity(symbol.params.len()); @@ -1159,25 +1163,32 @@ impl<'a> Context<'a> { .map(|arg| self.import_type_arg(*arg)) .collect::, _>>()?; - let name = self.get_symbol_name(*symbol)?; - let (extension, id) = self.import_custom_name(name)?; - - let extension_ref = - self.extensions - .get(&extension) - .ok_or_else(|| ImportError::Extension { - missing_ext: extension.clone(), - available: self.extensions.ids().cloned().collect(), + let (extension_ref, ext_type) = match self.cache_typedef.get(symbol) { + Some(ext_type) => *ext_type, + None => { + let (extension, id) = self.import_custom_name(*symbol)?; + + let extension_ref = self.extensions.get(&extension).ok_or_else(|| { + ImportError::Extension { + missing_ext: extension.clone(), + available: self.extensions.ids().cloned().collect(), + } })?; - let ext_type = - extension_ref - .get_type(&id) - .ok_or_else(|| ImportError::ExtensionType { - ext: extension.clone(), - name: id.clone(), + let typedef = extension_ref.get_type(&id).ok_or_else(|| { + ImportError::ExtensionType { + ext: extension.clone(), + name: id.clone(), + } })?; + self.cache_typedef.insert(*symbol, (extension_ref, typedef)); + (extension_ref, typedef) + } + }; + + let id = ext_type.name().clone(); + let extension = extension_ref.name().clone(); let bound = ext_type.bound(&args); Ok(TypeBase::new_extension(CustomType::new( @@ -1339,23 +1350,23 @@ impl<'a> Context<'a> { fn import_custom_name( &mut self, - symbol: &'a str, + node: table::NodeId, ) -> Result<(ExtensionId, SmolStr), ImportError> { - use std::collections::hash_map::Entry; - match self.custom_name_cache.entry(symbol) { - Entry::Occupied(occupied_entry) => Ok(occupied_entry.get().clone()), - Entry::Vacant(vacant_entry) => { - let qualified_name = ExtensionId::new(symbol) - .map_err(|_| table::ModelError::MalformedName(symbol.to_smolstr()))?; - - let (extension, id) = qualified_name - .split_last() - .ok_or_else(|| table::ModelError::MalformedName(symbol.to_smolstr()))?; - - vacant_entry.insert((extension.clone(), id.clone())); - Ok((extension, id)) - } + if let Some(ext_id) = self.cache_custom_name.get(&node) { + return Ok(ext_id.clone()); } + + let symbol = self.get_symbol_name(node)?; + + let qualified_name = ExtensionId::new(symbol.as_ref()) + .map_err(|_| table::ModelError::MalformedName(symbol.as_ref().to_smolstr()))?; + + let ext_id = qualified_name + .split_last() + .ok_or_else(|| table::ModelError::MalformedName(symbol.as_ref().to_smolstr()))?; + + self.cache_custom_name.insert(node, ext_id.clone()); + Ok(ext_id) } fn import_json_meta( @@ -1527,7 +1538,7 @@ impl<'a> Context<'a> { return Ok(None); }; - if name != self.get_symbol_name(*symbol)? { + if name != self.get_symbol_name(*symbol)?.as_ref() { return Ok(None); } diff --git a/hugr-model/src/v0/ast/resolve.rs b/hugr-model/src/v0/ast/resolve.rs index 2f8a5ba6e..ed82f7b3d 100644 --- a/hugr-model/src/v0/ast/resolve.rs +++ b/hugr-model/src/v0/ast/resolve.rs @@ -17,7 +17,7 @@ pub struct Context<'a> { bump: &'a Bump, vars: VarTable<'a>, links: LinkTable<&'a str>, - symbols: SymbolTable<'a>, + symbols: SymbolTable, imports: FxHashMap, terms: FxHashMap, TermId>, } @@ -125,7 +125,7 @@ impl<'a> Context<'a> { for (id, node) in zip_eq(ids, nodes) { if let Some(symbol_name) = node.operation.symbol_name() { self.symbols - .insert(symbol_name.as_ref(), *id) + .insert(symbol_name.clone(), *id) .map_err(|_| ResolveError::DuplicateSymbol(symbol_name.clone()))?; } } @@ -182,9 +182,7 @@ impl<'a> Context<'a> { let symbol = self.resolve_symbol(symbol)?; table::Operation::DeclareOperation(symbol) } - Operation::Import(symbol_name) => table::Operation::Import { - name: symbol_name.as_ref(), - }, + Operation::Import(symbol_name) => table::Operation::Import(symbol_name.clone()), Operation::Custom(term) => { let term = self.resolve_term(term)?; table::Operation::Custom(term) @@ -288,7 +286,7 @@ impl<'a> Context<'a> { } fn resolve_symbol(&mut self, symbol: &'a Symbol) -> BuildResult<&'a table::Symbol<'a>> { - let name = symbol.name.as_ref(); + let name = symbol.name.clone(); let params = self.resolve_params(&symbol.params)?; let constraints = self.resolve_terms(&symbol.constraints)?; let signature = self.resolve_term(&symbol.signature)?; @@ -306,7 +304,7 @@ impl<'a> Context<'a> { /// This incrementally inserts the names of the parameters into the current /// variable scope, so that any parameter is in scope for each of its /// succeeding parameters. - fn resolve_params(&mut self, params: &'a [Param]) -> BuildResult<&'a [table::Param<'a>]> { + fn resolve_params(&mut self, params: &'a [Param]) -> BuildResult<&'a [table::Param]> { try_alloc_slice( self.bump, params.iter().map(|param| self.resolve_param(param)), @@ -317,8 +315,8 @@ impl<'a> Context<'a> { /// /// This inserts the name of the parameter into the current variable scope, /// making the parameter accessible as a variable. - fn resolve_param(&mut self, param: &'a Param) -> BuildResult> { - let name = param.name.as_ref(); + fn resolve_param(&mut self, param: &'a Param) -> BuildResult { + let name = param.name.clone(); let r#type = self.resolve_term(¶m.r#type)?; self.vars @@ -347,9 +345,7 @@ impl<'a> Context<'a> { *self.imports.entry(symbol_name.clone()).or_insert_with(|| { self.module.insert_node(table::Node { - operation: table::Operation::Import { - name: symbol_name.as_ref(), - }, + operation: table::Operation::Import(symbol_name.clone()), ..Default::default() }) }) diff --git a/hugr-model/src/v0/ast/view.rs b/hugr-model/src/v0/ast/view.rs index 8feb15853..74fbfecfc 100644 --- a/hugr-model/src/v0/ast/view.rs +++ b/hugr-model/src/v0/ast/view.rs @@ -26,35 +26,35 @@ impl<'a> View<'a, NodeId> for Node { fn view(module: &'a table::Module<'a>, id: NodeId) -> Option { let node = module.get_node(id)?; - let operation = match node.operation { + let operation = match &node.operation { table::Operation::Invalid => Operation::Invalid, table::Operation::Dfg => Operation::Dfg, table::Operation::Cfg => Operation::Cfg, table::Operation::Block => Operation::Block, table::Operation::DefineFunc(symbol) => { - Operation::DefineFunc(Box::new(module.view(*symbol)?)) + Operation::DefineFunc(Box::new(module.view(symbol)?)) } table::Operation::DeclareFunc(symbol) => { - Operation::DeclareFunc(Box::new(module.view(*symbol)?)) + Operation::DeclareFunc(Box::new(module.view(symbol)?)) } table::Operation::Custom(operation) => Operation::Custom(module.view(operation)?), table::Operation::DefineAlias(symbol, value) => { - let symbol = Box::new(module.view(*symbol)?); + let symbol = Box::new(module.view(symbol)?); let value = module.view(value)?; Operation::DefineAlias(symbol, value) } table::Operation::DeclareAlias(symbol) => { - Operation::DeclareAlias(Box::new(module.view(*symbol)?)) + Operation::DeclareAlias(Box::new(module.view(symbol)?)) } table::Operation::DeclareConstructor(symbol) => { - Operation::DeclareConstructor(Box::new(module.view(*symbol)?)) + Operation::DeclareConstructor(Box::new(module.view(symbol)?)) } table::Operation::DeclareOperation(symbol) => { - Operation::DeclareOperation(Box::new(module.view(*symbol)?)) + Operation::DeclareOperation(Box::new(module.view(symbol)?)) } table::Operation::TailLoop => Operation::TailLoop, table::Operation::Conditional => Operation::Conditional, - table::Operation::Import { name } => Operation::Import(SymbolName::new(name)), + table::Operation::Import(name) => Operation::Import(name.clone()), }; let meta = module.view(node.meta)?; @@ -89,9 +89,9 @@ impl<'a> View<'a, table::SeqPart> for SeqPart { } } -impl<'a> View<'a, table::Symbol<'a>> for Symbol { - fn view(module: &'a table::Module<'a>, id: table::Symbol<'a>) -> Option { - let name = SymbolName::new(id.name); +impl<'a> View<'a, &'a table::Symbol<'a>> for Symbol { + fn view(module: &'a table::Module<'a>, id: &'a table::Symbol<'a>) -> Option { + let name = id.name.clone(); let params = module.view(id.params)?; let constraints = module.view(id.constraints)?; let signature = module.view(id.signature)?; @@ -104,9 +104,9 @@ impl<'a> View<'a, table::Symbol<'a>> for Symbol { } } -impl<'a> View<'a, table::Param<'a>> for Param { - fn view(module: &'a table::Module<'a>, param: table::Param<'a>) -> Option { - let name = VarName::new(param.name); +impl<'a> View<'a, &'a table::Param> for Param { + fn view(module: &'a table::Module<'a>, param: &'a table::Param) -> Option { + let name = param.name.clone(); let r#type = module.view(param.r#type)?; Some(Param { name, r#type }) } @@ -147,14 +147,14 @@ impl<'a> View<'a, VarId> for VarName { }; let param = &symbol.params[id.1 as usize]; - Some(Self(param.name.into())) + Some(param.name.clone()) } } impl<'a> View<'a, NodeId> for SymbolName { fn view(module: &'a table::Module<'a>, id: NodeId) -> Option { let node = module.get_node(id)?; - let name = node.operation.symbol()?; - Some(Self(name.into())) + let name = node.operation.symbol_name()?; + Some(name.clone()) } } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index ed0de96d7..dc177be6e 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -128,7 +128,7 @@ fn read_operation<'a>( Which::Block(()) => table::Operation::Block, Which::FuncDefn(reader) => { let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); + let name = model::SymbolName::new(reader.get_name()?.to_str()?); let params = read_list!(bump, reader.get_params()?, read_param); let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); let signature = table::TermId(reader.get_signature()); @@ -142,7 +142,7 @@ fn read_operation<'a>( } Which::FuncDecl(reader) => { let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); + let name = model::SymbolName::new(reader.get_name()?.to_str()?); let params = read_list!(bump, reader.get_params()?, read_param); let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); let signature = table::TermId(reader.get_signature()); @@ -157,7 +157,7 @@ fn read_operation<'a>( Which::AliasDefn(reader) => { let symbol = reader.get_symbol()?; let value = table::TermId(reader.get_value()); - let name = bump.alloc_str(symbol.get_name()?.to_str()?); + let name = model::SymbolName::new(symbol.get_name()?.to_str()?); let params = read_list!(bump, symbol.get_params()?, read_param); let signature = table::TermId(symbol.get_signature()); let symbol = bump.alloc(table::Symbol { @@ -170,7 +170,7 @@ fn read_operation<'a>( } Which::AliasDecl(reader) => { let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); + let name = model::SymbolName::new(reader.get_name()?.to_str()?); let params = read_list!(bump, reader.get_params()?, read_param); let signature = table::TermId(reader.get_signature()); let symbol = bump.alloc(table::Symbol { @@ -183,7 +183,7 @@ fn read_operation<'a>( } Which::ConstructorDecl(reader) => { let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); + let name = model::SymbolName::new(reader.get_name()?.to_str()?); let params = read_list!(bump, reader.get_params()?, read_param); let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); let signature = table::TermId(reader.get_signature()); @@ -197,7 +197,7 @@ fn read_operation<'a>( } Which::OperationDecl(reader) => { let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); + let name = model::SymbolName::new(reader.get_name()?.to_str()?); let params = read_list!(bump, reader.get_params()?, read_param); let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); let signature = table::TermId(reader.get_signature()); @@ -212,9 +212,10 @@ fn read_operation<'a>( Which::Custom(operation) => table::Operation::Custom(table::TermId(operation)), Which::TailLoop(()) => table::Operation::TailLoop, Which::Conditional(()) => table::Operation::Conditional, - Which::Import(name) => table::Operation::Import { - name: bump.alloc_str(name?.to_str()?), - }, + Which::Import(name) => { + let name = model::SymbolName::new(name?.to_str()?); + table::Operation::Import(name) + } }) } @@ -304,11 +305,8 @@ fn read_seq_part( }) } -fn read_param<'a>( - bump: &'a Bump, - reader: hugr_capnp::param::Reader, -) -> ReadResult> { - let name = bump.alloc_str(reader.get_name()?.to_str()?); +fn read_param<'a>(_: &'a Bump, reader: hugr_capnp::param::Reader) -> ReadResult { + let name = model::VarName::new(reader.get_name()?.to_str()?); let r#type = table::TermId(reader.get_type()); Ok(table::Param { name, r#type }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index 1217e9766..c8619ec1d 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -99,8 +99,8 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &tabl write_symbol(builder, symbol); } - table::Operation::Import { name } => { - builder.set_import(*name); + table::Operation::Import(name) => { + builder.set_import(name); } table::Operation::Invalid => builder.set_invalid(()), @@ -108,14 +108,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &tabl } fn write_symbol(mut builder: hugr_capnp::symbol::Builder, symbol: &table::Symbol) { - builder.set_name(symbol.name); + builder.set_name(&symbol.name); write_list!(builder, init_params, write_param, symbol.params); let _ = builder.set_constraints(table::TermId::unwrap_slice(symbol.constraints)); builder.set_signature(symbol.signature.0); } fn write_param(mut builder: hugr_capnp::param::Builder, param: &table::Param) { - builder.set_name(param.name); + builder.set_name(¶m.name); builder.set_type(param.r#type.0); } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 9aba742e0..badd772f4 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -88,6 +88,7 @@ use pyo3::types::PyAnyMethods as _; #[cfg(feature = "pyo3")] use pyo3::PyTypeInfo as _; use smol_str::SmolStr; +use std::borrow::Borrow; use std::sync::Arc; use table::LinkIndex; @@ -375,6 +376,12 @@ impl AsRef for VarName { } } +impl Borrow for VarName { + fn borrow(&self) -> &str { + self.as_ref() + } +} + #[cfg(feature = "pyo3")] impl<'py> pyo3::FromPyObject<'py> for VarName { fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { @@ -411,6 +418,12 @@ impl AsRef for SymbolName { } } +impl Borrow for SymbolName { + fn borrow(&self) -> &str { + self.as_ref() + } +} + #[cfg(feature = "pyo3")] impl<'py> pyo3::FromPyObject<'py> for SymbolName { fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { @@ -442,6 +455,12 @@ impl AsRef for LinkName { } } +impl Borrow for LinkName { + fn borrow(&self) -> &str { + self.as_ref() + } +} + #[cfg(feature = "pyo3")] impl<'py> pyo3::FromPyObject<'py> for LinkName { fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { diff --git a/hugr-model/src/v0/scope/symbol.rs b/hugr-model/src/v0/scope/symbol.rs index 20c99f763..e5492e77c 100644 --- a/hugr-model/src/v0/scope/symbol.rs +++ b/hugr-model/src/v0/scope/symbol.rs @@ -1,10 +1,14 @@ -use std::{borrow::Cow, hash::BuildHasherDefault}; - use fxhash::FxHasher; +use indexmap::Equivalent; use indexmap::IndexMap; +use std::hash::BuildHasherDefault; +use std::hash::Hash; use thiserror::Error; -use crate::v0::table::{NodeId, RegionId}; +use crate::v0::{ + table::{NodeId, RegionId}, + SymbolName, +}; type FxIndexMap = IndexMap>; @@ -39,17 +43,24 @@ type FxIndexMap = IndexMap>; /// assert!(symbols.is_visible(NodeId(0))); /// assert!(!symbols.is_visible(NodeId(1))); /// ``` -#[derive(Debug, Clone, Default)] -pub struct SymbolTable<'a> { - symbols: FxIndexMap<&'a str, BindingIndex>, +#[derive(Debug, Clone)] +pub struct SymbolTable { + symbols: FxIndexMap, bindings: FxIndexMap, scopes: FxIndexMap, } -impl<'a> SymbolTable<'a> { +impl SymbolTable +where + K: Clone + Eq + Hash, +{ /// Create a new symbol table. pub fn new() -> Self { - Self::default() + Self { + symbols: Default::default(), + bindings: Default::default(), + scopes: Default::default(), + } } /// Enter a new scope for the given region. @@ -92,23 +103,32 @@ impl<'a> SymbolTable<'a> { /// # Panics /// /// Panics if there is no current scope. - pub fn insert(&mut self, name: &'a str, node: NodeId) -> Result<(), DuplicateSymbolError> { + pub fn insert(&mut self, name: K, node: NodeId) -> Result<(), DuplicateSymbolError> { let scope_depth = self.scopes.len() as u16 - 1; - let (symbol_index, shadowed) = self.symbols.insert_full(name, self.bindings.len()); - if let Some(shadowed) = shadowed { - let (shadowed_node, shadowed_binding) = self.bindings.get_index(shadowed).unwrap(); - if shadowed_binding.scope_depth == scope_depth { - self.symbols.insert(name, shadowed); - return Err(DuplicateSymbolError(name.into(), node, *shadowed_node)); + let (symbol_index, shadows) = match self.symbols.entry(name) { + indexmap::map::Entry::Occupied(entry) => { + let (shadowed_node, shadowed_binding) = + self.bindings.get_index(*entry.get()).unwrap(); + + if shadowed_binding.scope_depth == scope_depth { + return Err(DuplicateSymbolError(node, *shadowed_node)); + } + + (entry.index(), Some(*entry.get())) } - } + indexmap::map::Entry::Vacant(entry) => { + let index = entry.index(); + entry.insert(self.bindings.len()); + (index, None) + } + }; self.bindings.insert( node, Binding { scope_depth, - shadows: shadowed, + shadows, symbol_index, }, ); @@ -116,6 +136,13 @@ impl<'a> SymbolTable<'a> { Ok(()) } + /// Get the name of the symbol defined by the given node. + pub fn symbol_name(&self, node: NodeId) -> Option<&K> { + let binding = self.bindings.get(&node)?; + let (name, _) = self.symbols.get_index(binding.symbol_index)?; + Some(name) + } + /// Check whether a symbol is currently visible in the current scope. pub fn is_visible(&self, node: NodeId) -> bool { let Some(binding) = self.bindings.get(&node) else { @@ -127,11 +154,11 @@ impl<'a> SymbolTable<'a> { } /// Tries to resolve a symbol name in the current scope. - pub fn resolve(&self, name: &'a str) -> Result { - let index = *self - .symbols - .get(name) - .ok_or(UnknownSymbolError(name.into()))?; + pub fn resolve(&self, name: &Q) -> Result + where + Q: ?Sized + Hash + Equivalent, + { + let index = *self.symbols.get(name).ok_or(UnknownSymbolError)?; // NOTE: The unwrap is safe because the `symbols` map // points to valid indices in the `bindings` map. @@ -159,6 +186,15 @@ impl<'a> SymbolTable<'a> { } } +impl Default for SymbolTable +where + K: Clone + Eq + Hash, +{ + fn default() -> Self { + Self::new() + } +} + #[derive(Debug, Clone, Copy)] struct Binding { /// The depth of the scope in which this binding is defined. @@ -189,10 +225,10 @@ pub type ScopeDepth = u16; /// Error that occurs when trying to resolve an unknown symbol. #[derive(Debug, Clone, Error)] -#[error("symbol name `{0}` not found in this scope")] -pub struct UnknownSymbolError<'a>(pub Cow<'a, str>); +#[error("symbol name not found in this scope")] +pub struct UnknownSymbolError; /// Error that occurs when trying to introduce a symbol that is already defined in the current scope. #[derive(Debug, Clone, Error)] -#[error("symbol `{0}` is already defined in this scope")] -pub struct DuplicateSymbolError<'a>(pub Cow<'a, str>, pub NodeId, pub NodeId); +#[error("symbol is already defined in this scope")] +pub struct DuplicateSymbolError(pub NodeId, pub NodeId); diff --git a/hugr-model/src/v0/table/mod.rs b/hugr-model/src/v0/table/mod.rs index 756a52c1e..f49e63f44 100644 --- a/hugr-model/src/v0/table/mod.rs +++ b/hugr-model/src/v0/table/mod.rs @@ -29,7 +29,7 @@ use smol_str::SmolStr; use thiserror::Error; mod view; -use super::{ast, Literal, RegionKind}; +use super::{ast, Literal, RegionKind, SymbolName, VarName}; pub use view::View; /// A package consisting of a sequence of [`Module`]s. @@ -230,23 +230,20 @@ pub enum Operation<'a> { DeclareOperation(&'a Symbol<'a>), /// Import a symbol. - Import { - /// The name of the symbol to be imported. - name: &'a str, - }, + Import(SymbolName), } impl<'a> Operation<'a> { - /// Returns the symbol introduced by the operation, if any. - pub fn symbol(&self) -> Option<&'a str> { + /// Returns the name of the symbol introduced by the operation, if any. + pub fn symbol_name(&self) -> Option<&SymbolName> { match self { - Operation::DefineFunc(symbol) => Some(symbol.name), - Operation::DeclareFunc(symbol) => Some(symbol.name), - Operation::DefineAlias(symbol, _) => Some(symbol.name), - Operation::DeclareAlias(symbol) => Some(symbol.name), - Operation::DeclareConstructor(symbol) => Some(symbol.name), - Operation::DeclareOperation(symbol) => Some(symbol.name), - Operation::Import { name } => Some(name), + Operation::DefineFunc(symbol) => Some(&symbol.name), + Operation::DeclareFunc(symbol) => Some(&symbol.name), + Operation::DefineAlias(symbol, _) => Some(&symbol.name), + Operation::DeclareAlias(symbol) => Some(&symbol.name), + Operation::DeclareConstructor(symbol) => Some(&symbol.name), + Operation::DeclareOperation(symbol) => Some(&symbol.name), + Operation::Import(name) => Some(name), _ => None, } } @@ -289,12 +286,12 @@ pub struct RegionScope { /// See [`ast::Symbol`] for the AST representation. /// /// [`ast::Symbol`]: crate::v0::ast::Symbol -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Symbol<'a> { /// The name of the symbol. - pub name: &'a str, + pub name: SymbolName, /// The static parameters. - pub params: &'a [Param<'a>], + pub params: &'a [Param], /// The constraints on the static parameters. pub constraints: &'a [TermId], /// The signature of the symbol. @@ -374,10 +371,10 @@ pub enum SeqPart { /// See [`ast::Param`] for the AST representation. /// /// [`ast::Param`]: crate::v0::ast::Param -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Param<'a> { +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Param { /// The name of the parameter. - pub name: &'a str, + pub name: VarName, /// The type of the parameter. pub r#type: TermId, }