Skip to content

Commit

Permalink
Add tail call elimination pass
Browse files Browse the repository at this point in the history
  • Loading branch information
nameoverflow committed Sep 14, 2017
1 parent d510b21 commit b78f4c7
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 33 deletions.
11 changes: 9 additions & 2 deletions libllvm/src/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl LLVMFunctionPassManager {
transforms::scalar::LLVMAddMergedLoadStoreMotionPass(llfpm);
transforms::scalar::LLVMAddConstantPropagationPass(llfpm);
transforms::scalar::LLVMAddPromoteMemoryToRegisterPass(llfpm);
transforms::scalar::LLVMAddTailCallEliminationPass(llfpm);
LLVMInitializeFunctionPassManager(llfpm);
LLVMFunctionPassManager(llfpm)
}
Expand All @@ -207,6 +208,12 @@ impl LLVMType {
pub fn get_element(&self) -> Self {
unsafe { LLVMType(LLVMGetElementType(self.0.clone())) }
}

pub fn get_null_ptr(&self) -> LLVMValue {
unsafe {
LLVMValue::from_ref(LLVMConstNull(self.raw_ptr()))
}
}
}

impl LLVMValue {
Expand Down Expand Up @@ -247,8 +254,8 @@ impl LLVMFunction {
unsafe { LLVMBasicBlock::from_ref(LLVMGetEntryBasicBlock(self.raw_ptr())) }
}

pub fn verify(&self, action: LLVMVerifierFailureAction) {
unsafe { LLVMVerifyFunction(self.raw_ptr(), action) };
pub fn verify(&self, action: LLVMVerifierFailureAction) -> bool {
unsafe { LLVMVerifyFunction(self.raw_ptr(), action) == 0 }
}
}

Expand Down
44 changes: 21 additions & 23 deletions src/codegen/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,12 @@ impl<'i> LLVMEmit<'i> {
let fv_tys = fv_allocas.iter().map(|v| v.get_type().get_element()).collect();
let fv_ty_actual = self.context().get_struct_type(&fv_tys, false);
let fv_ptr_actual = self.builder().bit_cast(&p_fvs, &fv_ty_actual.get_ptr(0), "fv");
p_fvs.dump();
fv_ptr_actual.dump();

for (i, fv) in fv_allocas.into_iter().enumerate() {
let fv_id = formal_fvs[i].name();
let fv_name = self.interner.trace(fv_id);
let fv_val_ptr = self.builder().struct_field_ptr(&fv_ptr_actual, i, "tmp");
let fv_ty_actual = fv.get_type();
// println!("22222222222222FUCK::::::::::::::::::");
// let fv_val_ptr_cast = self.builder().bit_cast(&fv_val_ptr, &fv_ty_actual.get_ptr(0), "cast");
let fv_val = self.builder().load(&fv_val_ptr, fv_name);
self.builder().store(&fv_val, &fv);
symtbl.insert(fv_id, fv);
Expand All @@ -132,21 +129,27 @@ impl<'i> LLVMEmit<'i> {

self.builder().ret(&fun_body);

fun.verify(LLVMVerifierFailureAction::LLVMPrintMessageAction);
if self.funpass { self.generator.passer.run(&fun); }
if !fun.verify(LLVMVerifierFailureAction::LLVMPrintMessageAction) {
panic!();
} else {
if self.funpass { self.generator.passer.run(&fun); }
}
}

pub fn gen_main(&mut self, def: &FunDef, prelude: &VarEnv) {
let main_ty = self.generator.get_main_type();
let fun = self.module().add_function("main", &main_ty);

let block = self.context().append_basic_block(&fun, "entry");
self.builder().set_position_at_end(&block);
let zero = self.context().get_int32_const(0);

let mut symtbl = prelude.sub_env();
let res = self.gen_expr(def.body(), &mut symtbl);
let alloca = self.builder().alloca(&res.get_type(), "res");
self.builder().store(&res, &alloca);
self.gen_expr(def.body(), &mut symtbl);

let zero = self.context().get_int32_const(0);
self.builder().ret(&zero);


fun.verify(LLVMVerifierFailureAction::LLVMPrintMessageAction);
if self.funpass { self.generator.passer.run(&fun); }
}
Expand All @@ -163,7 +166,7 @@ impl<'i> LLVMEmit<'i> {
match symbols.lookup(&vn) {
Some(v) => self.builder().load(v, var_name),
_ => {
println!("cannot find variable {}", self.interner.trace(vn));
eprintln!("cannot find variable {}", self.interner.trace(vn));
unreachable!()
}
}
Expand Down Expand Up @@ -216,8 +219,8 @@ impl<'i> LLVMEmit<'i> {
self.builder().call(&fun, &mut argsv, "call")
}
ApplyDir(VarDecl(fun, ref fun_ty), ref args) => {
let empty_fv_ty = self.context().get_int8_type();
let empty_fv_ptr = self.builder().alloca(&empty_fv_ty, "fv");
let empty_fv_ty = self.context().get_int8_type().get_ptr(0);
let empty_fv_ptr = empty_fv_ty.get_null_ptr();
let mut argsv: Vec<_> =
args.iter().map(|arg| self.gen_expr(arg, symbols)).collect();
argsv.push(empty_fv_ptr);
Expand Down Expand Up @@ -271,7 +274,6 @@ impl<'i> LLVMEmit<'i> {
self.builder().store(&cls_cast, &cls_ptr);


cls_value.dump();
// set function entry
let cls_fun = self.builder().struct_field_ptr(&cls_value, 0, "cls.fn");
let fn_ent = {
Expand Down Expand Up @@ -306,29 +308,25 @@ impl<'i> LLVMEmit<'i> {

let blk = self.builder().get_insert_block();
let parent = blk.get_parent();
let then_blk = self.context().append_basic_block(&parent, "if.then");
let else_blk = self.context().append_basic_block(&parent, "if.else");
let cont_blk = self.context().append_basic_block(&parent, "if.cont");
let mut then_blk = self.context().append_basic_block(&parent, "if.then");
let mut else_blk = self.context().append_basic_block(&parent, "if.else");
let mut cont_blk = self.context().append_basic_block(&parent, "if.cont");

self.builder().cond_br(&cond, &then_blk, &else_blk);
self.builder().set_position_at_end(&then_blk);

self.builder().set_position_at_end(&then_blk);
let then = self.gen_expr(t, symbols);
self.builder().br(&cont_blk);

let then_end = self.builder().get_insert_block();

self.builder().set_position_at_end(&else_blk);

let els = self.gen_expr(f, symbols);
self.builder().br(&cont_blk);

let els_end = self.builder().get_insert_block();

self.builder().set_position_at_end(&cont_blk);

let ret_ty = then.get_type();
self.builder().phi_node(&ret_ty, &[(&then, &then_blk), (&els, &else_blk)], "if.res")
self.builder().phi_node(&ret_ty, &[(&then, &then_end), (&els, &els_end)], "if.res")
}
List(_) => unimplemented!(),
Unary(_, _) => unimplemented!(),
Expand Down
1 change: 0 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ fn compile(name: &str, src: &str) -> Result<LLVMCodegen, CompileError> {
emitter.gen_main(mf.deref(), &env);
}
for def in top.values() {
println!("{:#?}", def);
emitter.gen_top_level(def.deref(), &env);
}
Ok(emitter.generator)
Expand Down
19 changes: 12 additions & 7 deletions test/fibonacci.gs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
def fibonacci(n) =
if (n == 1 || n == 2)
1
else
fibonacci(n - 1) + fibonacci(n - 2)
def fib(n) =
let fib_tail = (n, p, c) ->
if (n == 0)
-1
else
if (n == 1)
p
else
fib_tail(n - 1, c, p + c)
in fib_tail(n, 0, 1)

def main() =
let fib = fibonacci(10) in
putNumber(fib)
let f = fib(10) in
putNumber(f)
Empty file added test/fizzbuzz.gs
Empty file.

0 comments on commit b78f4c7

Please sign in to comment.