Skip to content

Commit

Permalink
rework how lookup ids are stored
Browse files Browse the repository at this point in the history
  • Loading branch information
jay3332 committed Mar 17, 2024
1 parent 880e1d1 commit bfa326e
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 77 deletions.
22 changes: 13 additions & 9 deletions hir/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
BinaryIntIntrinsic, BoolIntrinsic, Constraint, Expr, IntIntrinsic, LocalEnv, Relation, Ty,
TypedExpr, UnaryIntIntrinsic, UnificationTable,
},
Hir, IntSign, Lookup, ModuleId, Node, Op, Pattern, PrimitiveTy, Scope, ScopeId,
Hir, IntSign, ItemKind, ModuleId, Node, Op, Pattern, PrimitiveTy, Scope, ScopeId,
};
use common::span::{Spanned, SpannedExt};

Expand Down Expand Up @@ -334,7 +334,7 @@ impl<'a> TypeChecker<'a> {
}
Expr::CallFunc { func, .. } => {
let func = self.thir_mut().funcs.get(func).unwrap();
*ty = func.header.ret_ty.clone();
table.unify_constraint(Constraint(ty.clone(), func.header.ret_ty.clone()));
}
_ => (),
}
Expand Down Expand Up @@ -421,7 +421,6 @@ impl<'a> TypeChecker<'a> {
}
_ => (),
}
// debug substitutions
typed_expr.value_mut().1.apply(&table.substitutions);
}

Expand Down Expand Up @@ -451,13 +450,18 @@ impl<'a> TypeChecker<'a> {
.remove(&scope_id)
.expect("scope not found");

// Substitute over all functions in the scope
for (_, &Lookup(_, id)) in &scope.items {
let scope = self.thir_mut().funcs[&id].body;
self.substitute_scope(module, scope, table);
// Substitute over all items in the scope
for ((kind, _), id) in &scope.items {
match kind {
ItemKind::Func => {
let scope = self.thir_mut().funcs[&id].body;
self.substitute_scope(module, scope, table);

let func = self.thir_mut().funcs.get_mut(&id).unwrap();
func.header.ret_ty.apply(&table.substitutions);
let func = self.thir_mut().funcs.get_mut(&id).unwrap();
func.header.ret_ty.apply(&table.substitutions);
}
_ => (),
}
}

// Substitute over the scope
Expand Down
11 changes: 5 additions & 6 deletions hir/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use crate::{
},
warning::Warning,
Expr, FloatWidth, Func, FuncHeader, FuncParam, Hir, Ident, IntSign, IntWidth, ItemId, ItemKind,
Literal, LogicalOp, Lookup, LookupId, Metadata, ModuleId, Node, Pattern, PrimitiveTy, ScopeId,
TyParam,
Literal, LogicalOp, LookupId, Metadata, ModuleId, Node, Pattern, PrimitiveTy, ScopeId, TyParam,
};
use common::span::{Span, Spanned, SpannedExt};
use std::{borrow::Cow, collections::HashMap};
Expand Down Expand Up @@ -351,8 +350,8 @@ impl TypeLowerer {
if let Some(cnst) = self.hir.scopes.get(&scope.id).and_then(|scope| {
scope
.items
.get(&item)
.and_then(|Lookup(_, id)| self.hir.consts.get(id))
.get(&(ItemKind::Const, item))
.and_then(|id| self.hir.consts.get(id))
}) {
return Ok(Binding {
def_span: cnst.name.span(),
Expand Down Expand Up @@ -1012,7 +1011,7 @@ impl TypeLowerer {

let mut lowering = Vec::with_capacity(scope.items.len());
let mut items = HashMap::with_capacity(scope.items.len());
for (name, lookup @ Lookup(_, id)) in scope.items.extract_if(|_, l| l.0 == ItemKind::Func) {
for ((kind, name), id) in scope.items.extract_if(|k, _| k.0 == ItemKind::Func) {
let func = self.hir.funcs.remove(&id).expect("func not found");
let header = self.lower_func_header(func.header)?;
// register the function in the scope
Expand All @@ -1024,7 +1023,7 @@ impl TypeLowerer {
};
self.thir.funcs.insert(id, func.clone());
lowering.push((id, func));
items.insert(name, lookup);
items.insert((kind, name), id);
}
for (id, func) in lowering {
let ty = self.lower_func_scope(&func)?;
Expand Down
22 changes: 8 additions & 14 deletions hir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ pub enum Node<M: Metadata = LowerMetadata> {
ImplicitReturn(Spanned<M::Expr>),
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum ItemKind {
Func,
Alias,
Expand All @@ -269,9 +269,6 @@ pub enum ItemKind {
Type,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Lookup(pub ItemKind, pub LookupId);

#[derive(Clone, Debug)]
pub struct Scope<M: Metadata = LowerMetadata> {
/// The module in which this scope is defined.
Expand All @@ -283,7 +280,7 @@ pub struct Scope<M: Metadata = LowerMetadata> {
/// The children of this scope.
pub children: Spanned<Vec<Spanned<Node<M>>>>,
/// A lookup of all items in the scope.
pub items: HashMap<ItemId, Lookup>,
pub items: HashMap<(ItemKind, ItemId), LookupId>,
}

impl<M: Metadata> Scope<M> {
Expand All @@ -306,16 +303,13 @@ impl<M: Metadata> Scope<M> {
}
}

#[inline]
#[must_use]
pub fn lookup_id(&self, id: ItemId) -> Option<LookupId> {
self.items.get(&id).map(|lookup| lookup.1)
pub(crate) fn get_lookup(&self, kind: ItemKind, id: ItemId) -> Option<LookupId> {
self.items.get(&(kind, id)).copied()
}

#[inline]
#[must_use]
pub(crate) fn lookup_id_or_panic(&self, id: ItemId) -> LookupId {
self.items[&id].1
pub(crate) fn get_lookup_or_panic(&self, kind: ItemKind, id: ItemId) -> LookupId {
self.get_lookup(kind, id)
.expect(&format!("item {id} not found in scope"))
}
}

Expand Down Expand Up @@ -981,7 +975,7 @@ where
format!("@!{decorator}").write_indent(f)?;
}

for (item_id, Lookup(kind, lookup)) in &scope.items {
for ((kind, item_id), lookup) in &scope.items {
writeln!(f, "{item_id}:")?;
match kind {
ItemKind::Func => WithHir(&self.funcs[&lookup], self).write_indent(f)?,
Expand Down
63 changes: 33 additions & 30 deletions hir/src/lower.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
error::{Error, Result},
Const, Decorator, Expr, FieldVisibility, FloatWidth, Func, FuncHeader, FuncParam, Hir, Ident,
IntSign, IntWidth, ItemId, ItemKind, Literal, LogicalOp, Lookup, LookupId, ModuleId, Node, Op,
Pattern, PrimitiveTy, Scope, ScopeId, StructField, StructTy, Ty, TyDef, TyParam,
IntSign, IntWidth, ItemId, ItemKind, Literal, LogicalOp, LookupId, ModuleId, Node, Op, Pattern,
PrimitiveTy, Scope, ScopeId, StructField, StructTy, Ty, TyDef, TyParam,
};
use common::span::{Span, Spanned, SpannedExt};
use grammar::{
Expand Down Expand Up @@ -87,10 +87,10 @@ fn ty_params_into_unbounded_ty_param(ty_params: &[ast::TyParam]) -> Vec<TyParam>
}

macro_rules! insert_lookup {
($self:ident, $target:ident, $kind:ident, $e:expr) => {{
($self:ident, $target:ident, $e:expr) => {{
let id = $self.next_lookup_id();
$self.hir.$target.insert(id, $e);
Lookup(ItemKind::$kind, id)
id
}};
}

Expand Down Expand Up @@ -131,13 +131,13 @@ impl AstLowerer {
}

#[inline]
fn get_item_name(&self, Lookup(kind, lookup): &Lookup) -> Option<Spanned<Ident>> {
fn get_item_name(&self, kind: ItemKind, id: LookupId) -> Option<Spanned<Ident>> {
match kind {
ItemKind::Func => self.hir.funcs.get(lookup).map(|func| func.header.name),
ItemKind::Alias => self.hir.aliases.get(lookup).map(|alias| alias.name),
ItemKind::Const => self.hir.consts.get(lookup).map(|cnst| cnst.name),
ItemKind::Struct => self.hir.structs.get(lookup).map(|sty| sty.name),
ItemKind::Type => self.hir.types.get(lookup).map(|ty| ty.name),
ItemKind::Func => self.hir.funcs.get(&id).map(|func| func.header.name),
ItemKind::Alias => self.hir.aliases.get(&id).map(|alias| alias.name),
ItemKind::Const => self.hir.consts.get(&id).map(|cnst| cnst.name),
ItemKind::Struct => self.hir.structs.get(&id).map(|sty| sty.name),
ItemKind::Type => self.hir.types.get(&id).map(|ty| ty.name),
}
}

Expand All @@ -151,8 +151,9 @@ impl AstLowerer {
) -> Result<()> {
if let Some(occupied) = scope
.items
.get(item)
.and_then(|item| self.get_item_name(item))
.iter()
.find_map(|((kind, name), id)| (item == name).then_some((kind, id)))
.and_then(|(kind, id)| self.get_item_name(*kind, *id))
{
return Err(Error::NameConflict(occupied.span(), src));
}
Expand Down Expand Up @@ -208,9 +209,10 @@ impl AstLowerer {
// Do a pass over all types to identify them
for node in &nodes {
if let Some((item_id, ty_def)) = self.pass_over_ty_def(module, node)? {
scope
.items
.insert(item_id, insert_lookup!(self, types, Type, ty_def));
scope.items.insert(
(ItemKind::Type, item_id),
insert_lookup!(self, types, ty_def),
);
}
}

Expand All @@ -224,14 +226,18 @@ impl AstLowerer {
let sty = self.lower_struct_def_into_ty(module, sct.clone(), scope)?;

// Update type parameters with their bounds
let ty_def = self.hir.types.get_mut(&scope.lookup_id_or_panic(item_id));
let ty_def = self
.hir
.types
.get_mut(&scope.get_lookup_or_panic(ItemKind::Type, item_id));
if let Some(ty_def) = ty_def {
ty_def.ty_params = sty.ty_params.clone();
}
self.propagate_nonfatal(self.assert_item_unique(scope, &item_id, sct_name));
scope
.items
.insert(item_id, insert_lookup!(self, structs, Struct, sty));
scope.items.insert(
(ItemKind::Struct, item_id),
insert_lookup!(self, structs, sty),
);
}
_ => (),
}
Expand Down Expand Up @@ -273,7 +279,7 @@ impl AstLowerer {
let fields = self
.hir
.structs
.get(&scope.lookup_id_or_panic(pid))
.get(&scope.get_lookup_or_panic(ItemKind::Struct, pid))
.cloned()
.expect("struct not found, this is a bug")
.into_adhoc_struct_ty_with_applied_ty_params(Some(dest.span()), args)?
Expand All @@ -284,7 +290,7 @@ impl AstLowerer {
let sty = self
.hir
.structs
.get_mut(&scope.lookup_id_or_panic(*child))
.get_mut(&scope.get_lookup_or_panic(ItemKind::Struct, *child))
.expect("struct not found, this is a bug");

let mut fields = fields.clone();
Expand Down Expand Up @@ -327,7 +333,7 @@ impl AstLowerer {
self.propagate_nonfatal(self.assert_item_unique(scope, &item, name));
scope
.items
.insert(item, insert_lookup!(self, consts, Const, cnst));
.insert((ItemKind::Const, item), insert_lookup!(self, consts, cnst));
}
}
Ok(())
Expand Down Expand Up @@ -447,7 +453,7 @@ impl AstLowerer {
/// struct A<__0> { a: __0 }
/// ```
fn desugar_inferred_types_in_structs(&mut self, scope: &mut Scope) {
for Lookup(kind, id) in scope.items.values() {
for ((kind, _), id) in &scope.items {
if *kind != ItemKind::Struct {
continue;
}
Expand Down Expand Up @@ -543,15 +549,12 @@ impl AstLowerer {
};

let err = Error::TypeNotFound(full_span, Spanned(tail, span), mid);
let Lookup(kind, id) = ctx
let id = ctx
.scope
.items
.get(&ItemId(mid, ident))
.get(&(ItemKind::Type, ItemId(mid, ident)))
.ok_or(err.clone())?;

if *kind != ItemKind::Type {
return Err(err.clone());
}
let ty_def = self.hir.types.get(id).ok_or(err)?;

let ty_params = match application {
Expand Down Expand Up @@ -643,7 +646,7 @@ impl AstLowerer {
self.propagate_nonfatal(self.assert_item_unique(ctx.scope, &item, name));
scope
.items
.insert(item, insert_lookup!(self, funcs, Func, func));
.insert((ItemKind::Func, item), insert_lookup!(self, funcs, func));
}
}
Ok(())
Expand Down Expand Up @@ -1144,7 +1147,7 @@ impl AstLowerer {
) -> Result<Spanned<Expr>> {
let item_id = ItemId(ctx.module(), *ident.value());
Ok(
if let Some(Lookup(ItemKind::Const, id)) = ctx.scope.items.get(&item_id) {
if let Some(id) = ctx.scope.items.get(&(ItemKind::Const, item_id)) {
// TODO: true const-eval instead of inline (this will be replaced by `alias`)
if let Some(app) = app {
return Err(Error::ExplicitTypeArgumentsNotAllowed(app.span()));
Expand Down
5 changes: 2 additions & 3 deletions mir/src/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ use common::span::{Spanned, SpannedExt};
use hir::infer::flatten_param;
use hir::{
typed::{self, LocalEnv, Ty, TypedExpr},
Ident, IntSign, ItemId, ItemKind, Literal, Lookup, LookupId, ModuleId, Pattern, PrimitiveTy,
ScopeId,
Ident, IntSign, ItemId, ItemKind, Literal, LookupId, ModuleId, Pattern, PrimitiveTy, ScopeId,
};
use std::collections::HashMap;

Expand Down Expand Up @@ -326,7 +325,7 @@ impl Lowerer {

// First, lower all static items in the scope
let mut funcs = Vec::new();
for (item, Lookup(kind, id)) in scope.items {
for ((kind, item), id) in scope.items {
match kind {
ItemKind::Func => {
let func = self.thir.funcs.remove(&id).expect("no such func");
Expand Down
17 changes: 9 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
typeck.check_module(ModuleId::from(Src::None), &mut table);

full += start.elapsed();
// println!(
// "=== [ THIR ({:?} to check) ] ===\n\n{}",
// start.elapsed(),
// typeck.lower.thir
// );
println!(
"=== [ THIR ({:?} to check) ] ===\n\n{}",
start.elapsed(),
typeck.lower.thir
);
println!("typeck: {:?}", start.elapsed());
for error in typeck.lower.errors.drain(..) {
dwriter.write_diagnostic(
Expand Down Expand Up @@ -131,9 +131,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
type F = unsafe extern "C" fn() -> i32;
let engine = module
.create_jit_execution_engine(OptimizationLevel::Aggressive)?;
let f = unsafe { engine.get_function::<F>("test")? };
println!("evaluating test()...");
println!("-> {}", unsafe { f.call() });
if let Ok(f) = unsafe { engine.get_function::<F>("test") } {
println!("evaluating test()...");
println!("-> {}", unsafe { f.call() });
}

module.write_bitcode_to_path(&*PathBuf::from("out.bc"));

Expand Down
11 changes: 4 additions & 7 deletions test.trb
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
func a(x: int32, y: int32) -> int32 {
x + y
}

func b() {
a(6, 5) + 10
}
func double(x: int32) = 2 * x;
func triple(x: int32) = 3 * x;
func sum(a: int32, b: int32) = a + b;

func test() = double(2) + triple(3) + sum(2, 3) + 1;

0 comments on commit bfa326e

Please sign in to comment.