Skip to content

Commit

Permalink
lower function args
Browse files Browse the repository at this point in the history
  • Loading branch information
jay3332 committed Mar 16, 2024
1 parent a925e6b commit 2eb2336
Show file tree
Hide file tree
Showing 11 changed files with 431 additions and 253 deletions.
76 changes: 49 additions & 27 deletions codegen/src/aot.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use common::span::Spanned;
use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::{
attributes::Attribute,
basic_block::BasicBlock,
builder::Builder,
context::Context,
Expand All @@ -11,8 +11,8 @@ use inkwell::{
IntPredicate,
};
use mir::{
BlockId, Constant, Expr, Func, IntIntrinsic, IntSign, IntWidth, LocalEnv, LocalId, Node,
PrimitiveTy, Ty, UnaryIntIntrinsic,
BlockId, Constant, Expr, Func, Ident, IntIntrinsic, IntSign, IntWidth, LocalEnv, LocalId,
LookupId, Node, PrimitiveTy, Ty, UnaryIntIntrinsic,
};
use std::{collections::HashMap, mem::MaybeUninit, ops::Not};

Expand All @@ -28,11 +28,12 @@ pub struct Compiler<'a, 'ctx> {
pub builder: &'a Builder<'ctx>,
pub fpm: &'a PassManager<FunctionValue<'ctx>>,
pub module: &'a Module<'ctx>,
pub func: &'a Func,

lowering: MaybeUninit<Func>,
fn_value: MaybeUninit<FunctionValue<'ctx>>,
functions: HashMap<LookupId, FunctionValue<'ctx>>,
locals: HashMap<LocalId, Option<Local<'ctx>>>,
blocks: HashMap<BlockId, BasicBlock<'ctx>>,
fn_value: MaybeUninit<FunctionValue<'ctx>>,
increment: usize,
}

Expand All @@ -44,6 +45,11 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
self.context.custom_width_int_type(width as usize as _)
}

#[inline]
fn lowering_mut(&mut self) -> &mut Func {
unsafe { self.lowering.assume_init_mut() }
}

#[inline]
const fn fn_value(&self) -> FunctionValue<'ctx> {
unsafe { self.fn_value.assume_init() }
Expand Down Expand Up @@ -178,19 +184,17 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
BasicValueEnum::IntValue(bool_value)
}
Expr::Call(func, args) => {
let f = self.module.get_function(&func.to_string()).unwrap();
let args = args
.into_iter()
.map(|arg| self.lower_expr(arg).unwrap().into())
.collect::<Vec<_>>();

self.builder
.build_call(f, &args, &self.next_increment())
.build_call(self.functions[&func], &args, &self.next_increment())
.try_as_basic_value()
.left()
.unwrap()
}
_ => todo!(),
})
}

Expand Down Expand Up @@ -249,12 +253,12 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {

/// Lowers a block given its ID.
pub fn lower_block(&mut self, block_id: BlockId) {
let block = self.func.blocks.get(&block_id).unwrap();
let block = self.lowering_mut().blocks.remove(&block_id).unwrap();
self.builder
.position_at_end(*self.blocks.get(&block_id).unwrap());

for node in block {
self.lower_node(node.clone())
self.lower_node(node)
}
}

Expand Down Expand Up @@ -294,10 +298,9 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
}

/// Compiles the specified function into an LLVM `FunctionValue`.
fn compile_fn(&mut self) {
let (names, param_tys) = self
.func
/// Registers the specified function into an LLVM `FunctionValue`.
fn register_fn(&mut self, id: LookupId, func: &Func) -> Vec<Ident> {
let (names, param_tys) = func
.params
.iter()
.filter_map(|(name, ty)| {
Expand All @@ -307,18 +310,29 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
})
.unzip::<_, _, Vec<_>, Vec<_>>();

let fn_ty = match self.func.ret_ty.is_zst() {
let fn_ty = match func.ret_ty.is_zst() {
true => self.context.void_type().fn_type(&param_tys, false),
false => self.lower_ty(&self.func.ret_ty).fn_type(&param_tys, false),
false => self.lower_ty(&func.ret_ty).fn_type(&param_tys, false),
};

// TODO: qualified name
let name = self.func.name.to_string();
self.fn_value
.write(self.module.add_function(&name, fn_ty, None));
let name = func.name.to_string();
let fn_value = self.module.add_function(&name, fn_ty, None);
self.functions.insert(id, fn_value);
names
}

/// Compiles the body of the given function.
fn compile_fn(&mut self, fn_value: FunctionValue<'ctx>, func: Func, names: Vec<Ident>) {
let block_ids = func.blocks.keys().copied().collect::<Vec<_>>();
self.lowering = MaybeUninit::new(func);
self.fn_value.write(fn_value);
self.locals.clear();
self.blocks.clear();
self.increment = 0;

// Create blocks
for id in self.func.blocks.keys() {
for id in &block_ids {
let bb = self
.context
.append_basic_block(self.fn_value(), &id.to_string());
Expand All @@ -343,8 +357,9 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
}

// Compile body
self.func.blocks.keys().for_each(|&id| self.lower_block(id));
block_ids.into_iter().for_each(|id| self.lower_block(id));
self.fn_value().print_to_string();
unsafe { self.lowering.assume_init_drop() };

// Verify and run optimizations
if self.fn_value().verify(true) {
Expand All @@ -362,20 +377,27 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
builder: &'a Builder<'ctx>,
pass_manager: &'a PassManager<FunctionValue<'ctx>>,
module: &'a Module<'ctx>,
func: &'a Func,
) -> FunctionValue<'ctx> {
functions: HashMap<LookupId, Func>,
) {
let mut compiler = Self {
context,
builder,
fpm: pass_manager,
module,
func,
functions: HashMap::with_capacity(functions.len()),
lowering: MaybeUninit::uninit(),
fn_value: MaybeUninit::uninit(),
locals: HashMap::new(),
blocks: HashMap::with_capacity(func.blocks.len()),
blocks: HashMap::new(),
increment: 0,
};
compiler.compile_fn();
compiler.fn_value()

let mut names = Vec::with_capacity(functions.len());
for (id, func) in &functions {
names.push(compiler.register_fn(*id, func));
}
for ((id, func), names) in functions.into_iter().zip(names) {
compiler.compile_fn(compiler.functions[&id], func, names);
}
}
}
19 changes: 5 additions & 14 deletions codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@ pub use inkwell::{
};

use inkwell::passes::PassManager;
use mir::{Mir, ModuleId};
use mir::{Func, LookupId};
use std::collections::HashMap;

pub fn compile_llvm<'ctx>(
context: &'ctx Context,
mir: &Mir,
module_id: ModuleId, /*, options: CompileOptions*/
) -> Module<'ctx> {
let module = context.create_module(&module_id.to_string());
pub fn compile_llvm(context: &Context, functions: HashMap<LookupId, Func>) -> Module {
let module = context.create_module("root");
let builder = context.create_builder();

// Create FPM
Expand All @@ -33,12 +30,6 @@ pub fn compile_llvm<'ctx>(
fpm.add_reassociate_pass();
fpm.initialize();

for func in mir
.functions
.iter()
.filter_map(|(id, func)| id.0.eq(&module_id).then_some(func))
{
aot::Compiler::compile(&context, &builder, &fpm, &module, func);
}
aot::Compiler::compile(&context, &builder, &fpm, &module, functions);
module
}
9 changes: 6 additions & 3 deletions hir/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
BinaryIntIntrinsic, BoolIntrinsic, Constraint, Expr, IntIntrinsic, LocalEnv, Relation, Ty,
TypedExpr, UnaryIntIntrinsic, UnificationTable,
},
Hir, IntSign, ModuleId, Node, Op, Pattern, PrimitiveTy, Scope, ScopeId,
Hir, IntSign, Lookup, ModuleId, Node, Op, Pattern, PrimitiveTy, Scope, ScopeId,
};
use common::span::{Spanned, SpannedExt};

Expand Down Expand Up @@ -447,8 +447,11 @@ impl<'a> TypeChecker<'a> {
.expect("scope not found");

// Substitute over all functions in the scope
for (_, func) in &mut scope.funcs {
self.substitute_scope(module, func.body, table);
for (_, &Lookup(_, id)) in &scope.items {
let scope = self.thir_mut().funcs[&id].body;
self.substitute_scope(module, scope, table);

let func = self.thir_mut().funcs.get_mut(&id).unwrap();
func.header.ret_ty.apply(&table.substitutions);
}

Expand Down
Loading

0 comments on commit 2eb2336

Please sign in to comment.