Skip to content

Commit

Permalink
fix cyclic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jay3332 committed Mar 17, 2024
1 parent 2eb2336 commit 880e1d1
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 76 deletions.
9 changes: 5 additions & 4 deletions hir/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ impl<'a> TypeChecker<'a> {
Expr::CallOp(op, target, args) => {
Self::lower_op(*op, (**target).clone(), args.first().cloned(), expr, ty);
}
Expr::CallFunc { func, .. } => {
let func = self.thir_mut().funcs.get(func).unwrap();
*ty = func.header.ret_ty.clone();
}
_ => (),
}
// Unify the new type with the old type
Expand Down Expand Up @@ -417,6 +421,7 @@ impl<'a> TypeChecker<'a> {
}
_ => (),
}
// debug substitutions
typed_expr.value_mut().1.apply(&table.substitutions);
}

Expand Down Expand Up @@ -484,9 +489,5 @@ impl<'a> TypeChecker<'a> {
pub fn check_module(&mut self, module_id: ModuleId, table: &mut UnificationTable) {
let scope_id = *self.thir().modules.get(&module_id).unwrap();
self.substitute_scope(module_id, scope_id, table);

for (i, subst) in table.substitutions.iter().enumerate() {
println!("${} => {}", i, subst);
}
}
}
51 changes: 29 additions & 22 deletions hir/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -929,11 +929,8 @@ impl TypeLowerer {
}

/// Runs type inference over a function.
///
/// TODO: run type inference over parameters
pub fn lower_func(&mut self, func: Func) -> Result<Func<InferMetadata>> {
let header = func.header;
let mut header = FuncHeader::<InferMetadata> {
pub fn lower_func_header(&mut self, header: FuncHeader) -> Result<FuncHeader<InferMetadata>> {
Ok(FuncHeader {
name: header.name,
ty_params: header
.ty_params
Expand Down Expand Up @@ -965,10 +962,12 @@ impl TypeLowerer {
.collect::<Result<Vec<_>>>()?,
ret_ty: self.lower_hir_ty(header.ret_ty),
ret_ty_span: header.ret_ty_span,
};
})
}

pub fn lower_func_scope(&mut self, func: &Func<InferMetadata>) -> Result<Ty> {
let mut bindings = Vec::new();
for param in &header.params {
for param in &func.header.params {
if let Err(why) = flatten_param(&param.pat, param.ty.clone(), &mut bindings) {
self.err_nonfatal(why);
}
Expand All @@ -977,17 +976,11 @@ impl TypeLowerer {
func.body,
ScopeKind::Func,
bindings,
header.ty_params.clone(),
func.header.ty_params.clone(),
true,
Some((header.ret_ty.clone(), header.ret_ty_span)),
Some((func.header.ret_ty.clone(), func.header.ret_ty_span)),
)?;
header.ret_ty = self.resolution_lookup[&func.body].0.clone();

Ok(Func {
vis: func.vis,
header,
body: func.body,
})
Ok(self.resolution_lookup[&func.body].0.clone())
}

/// Runs type inference over a scope.
Expand Down Expand Up @@ -1017,17 +1010,31 @@ impl TypeLowerer {
resolution,
);

let mut items = HashMap::new();
let mut lowering = Vec::with_capacity(scope.items.len());
let mut items = HashMap::with_capacity(scope.items.len());
for (name, lookup @ Lookup(_, id)) in scope.items.extract_if(|_, l| l.0 == ItemKind::Func) {
let func = self.hir.funcs.remove(&id).expect("func not found");
let func = self.lower_func(func)?;
let header = self.lower_func_header(func.header)?;
// register the function in the scope
self.scope_mut()
.funcs
.insert(name, (id, func.header.clone()));
self.thir.funcs.insert(id, func);
self.scope_mut().funcs.insert(name, (id, header.clone()));
let func = Func {
vis: func.vis,
header,
body: func.body,
};
self.thir.funcs.insert(id, func.clone());
lowering.push((id, func));
items.insert(name, lookup);
}
for (id, func) in lowering {
let ty = self.lower_func_scope(&func)?;
let old = &mut self.thir.funcs.get_mut(&id).unwrap().header.ret_ty;
self.table
.constraints
.push_back(Constraint(old.clone(), ty.clone()));
let _ = self.table.unify_all();
*old = ty;
}

let mut exit_action = None;
let full_span = scope.children.span();
Expand Down
1 change: 1 addition & 0 deletions hir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#![feature(hash_extract_if)]
#![feature(let_chains)]
#![feature(more_qualified_paths)]
#![feature(map_try_insert)]

pub mod check;
pub mod error;
Expand Down
4 changes: 2 additions & 2 deletions hir/src/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ impl AstLowerer {
let sty = self.lower_struct_def_into_ty(module, sct.clone(), scope)?;

// Update type parameters with their bounds
if let Some(ty_def) = self.hir.types.get_mut(&scope.lookup_id_or_panic(item_id))
{
let ty_def = self.hir.types.get_mut(&scope.lookup_id_or_panic(item_id));
if let Some(ty_def) = ty_def {
ty_def.ty_params = sty.ty_params.clone();
}
self.propagate_nonfatal(self.assert_item_unique(scope, &item_id, sct_name));
Expand Down
29 changes: 19 additions & 10 deletions hir/src/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,44 +541,53 @@ impl Ty {
}
}

pub fn apply(&mut self, substitutions: &VecDeque<Ty>) {
// check cycles with Vec over HashSet despite .contains checks since cycles are usually very
// small and the overhead of hashing is not worth it
#[inline]
fn apply_cyclic(&mut self, substitutions: &VecDeque<Ty>, cycle: &mut Vec<usize>) {
match self {
Self::Unknown(i) => {
cycle.push(*i);
match substitutions.get(*i) {
// If this substitution is a reference to another unknown, we need to apply that
// substitution as well.
Some(Self::Unknown(j)) if *j != *i => {
// substitution as well
Some(Self::Unknown(j)) if !cycle.contains(j) => {
*self = Self::Unknown(*j);
self.apply(substitutions);
self.apply_cyclic(substitutions, cycle);
}
Some(ty) => *self = ty.clone(),
None => (),
}
}
Self::Tuple(tys) => {
for ty in tys {
ty.apply(substitutions);
ty.apply_cyclic(substitutions, cycle);
}
}
Self::Array(ty, len) => {
ty.apply(substitutions);
len.as_mut().map(|len| len.apply(substitutions));
ty.apply_cyclic(substitutions, cycle);
len.as_mut()
.map(|len| len.apply_cyclic(substitutions, cycle));
}
Self::Struct(_, tys) => {
for ty in tys {
ty.apply(substitutions);
ty.apply_cyclic(substitutions, cycle);
}
}
Self::Func(params, ret) => {
for ty in params {
ty.apply(substitutions);
ty.apply_cyclic(substitutions, cycle);
}
ret.apply(substitutions);
ret.apply_cyclic(substitutions, cycle);
}
_ => {}
}
}

pub fn apply(&mut self, substitutions: &VecDeque<Ty>) {
self.apply_cyclic(substitutions, &mut Vec::new());
}

pub fn has_any_unknown(&self) -> bool {
match self {
Self::Unknown(_) => true,
Expand Down
88 changes: 50 additions & 38 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

match nodes {
Ok(nodes) => {
println!("{nodes:#?}");
for node in &nodes {
println!("{node}");
// for line in node.to_string().lines() {
// writeln!(file, "// {line}")?;
// }
}
// println!("{nodes:#?}");
// for node in &nodes {
// println!("{node}");
// // for line in node.to_string().lines() {
// // writeln!(file, "// {line}")?;
// // }
// }

let mut lowerer = AstLowerer::new(nodes);

Expand All @@ -60,22 +60,24 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
) {
Ok(_) => {
full += start.elapsed();
println!(
"=== [ HIR ({:?} to lower) ] ===\n\n{}",
start.elapsed(),
lowerer.hir
);
// println!(
// "=== [ HIR ({:?} to lower) ] ===\n\n{}",
// start.elapsed(),
// lowerer.hir
// );
println!("hir: {:?}", start.elapsed());

let start = std::time::Instant::now();
let mut ty_lowerer = TypeLowerer::new(lowerer.hir.clone());
match ty_lowerer.lower_module(ModuleId::from(Src::None)) {
Ok(_) => {
full += start.elapsed();
println!(
"=== [ THIR ({:?} to type) ] ===\n\n{}",
start.elapsed(),
ty_lowerer.thir
);
// println!(
// "=== [ THIR ({:?} to type) ] ===\n\n{}",
// start.elapsed(),
// ty_lowerer.thir
// );
println!("thir: {:?}", start.elapsed());

let start = std::time::Instant::now();
let mut typeck = TypeChecker::from_lowerer(&mut ty_lowerer);
Expand All @@ -84,11 +86,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
typeck.check_module(ModuleId::from(Src::None), &mut table);

full += start.elapsed();
println!(
"=== [ THIR ({:?} to check) ] ===\n\n{}",
start.elapsed(),
typeck.lower.thir
);
// println!(
// "=== [ THIR ({:?} to check) ] ===\n\n{}",
// start.elapsed(),
// typeck.lower.thir
// );
println!("typeck: {:?}", start.elapsed());
for error in typeck.lower.errors.drain(..) {
dwriter.write_diagnostic(
&mut std::io::stdout(),
Expand All @@ -102,11 +105,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
mir_lowerer.lower_module(ModuleId::from(Src::None));

full += start.elapsed();
println!(
"=== [ MIR ({:?} to lower) ] ===\n\n{}",
start.elapsed(),
mir_lowerer.mir
);
// println!(
// "=== [ MIR ({:?} to lower) ] ===\n\n{}",
// start.elapsed(),
// mir_lowerer.mir
// );
println!("mir: {:?}", start.elapsed());
for error in mir_lowerer.errors.drain(..) {
dwriter.write_diagnostic(
&mut std::io::stdout(),
Expand All @@ -120,8 +124,16 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let module = compile_llvm(&ctx, mir_lowerer.mir.functions);

full += start.elapsed();
println!("=== [ LLVM IR ({:?} to compile) ] ===", start.elapsed());
println!("{}", module.to_string());
// println!("=== [ LLVM IR ({:?} to compile) ] ===", start.elapsed());
// println!("{}", module.to_string());
println!("llvm: {:?}", start.elapsed());
println!("total cmptime: {full:?}");
type F = unsafe extern "C" fn() -> i32;
let engine = module
.create_jit_execution_engine(OptimizationLevel::Aggressive)?;
let f = unsafe { engine.get_function::<F>("test")? };
println!("evaluating test()...");
println!("-> {}", unsafe { f.call() });

module.write_bitcode_to_path(&*PathBuf::from("out.bc"));

Expand All @@ -147,15 +159,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let assembly = buffer.as_slice().to_vec();

full += start.elapsed();
println!(
"=== {} ASM ({:?} to compile, {:?} total) ===",
TargetMachine::get_default_triple()
.as_str()
.to_string_lossy(),
start.elapsed(),
full
);
println!("{}", String::from_utf8_lossy(&assembly));
// println!(
// "=== {} ASM ({:?} to compile, {:?} total) ===",
// TargetMachine::get_default_triple()
// .as_str()
// .to_string_lossy(),
// start.elapsed(),
// full
// );
// println!("{}", String::from_utf8_lossy(&assembly));
machine.write_to_file(
&module,
FileType::Object,
Expand Down

0 comments on commit 880e1d1

Please sign in to comment.