1
0
forked from M-Labs/nac3

codegen for simple function call, and various fixes

This commit is contained in:
pca006132 2021-08-19 15:30:15 +08:00
parent f205a8282a
commit 3279f7a776
7 changed files with 132 additions and 37 deletions

View File

@ -88,7 +88,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
) -> Option<BasicValueEnum<'ctx>> { ) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0); let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
let defs = self.top_level.definitions.read(); let defs = self.top_level.definitions.read();
let definition = defs.get(fun.1 .0).unwrap(); let definition = defs.get(fun.1.0).unwrap();
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| { let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
// TODO: codegen for function that are not yet generated // TODO: codegen for function that are not yet generated
@ -232,9 +232,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
} }
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> { pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> Option<BasicValueEnum<'ctx>> {
let zero = self.ctx.i32_type().const_int(0, false); let zero = self.ctx.i32_type().const_int(0, false);
match &expr.node { Some(match &expr.node {
ExprKind::Constant { value, .. } => { ExprKind::Constant { value, .. } => {
let ty = expr.custom.unwrap(); let ty = expr.custom.unwrap();
self.gen_const(value, ty) self.gen_const(value, ty)
@ -254,7 +254,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ExprKind::List { elts, .. } => { ExprKind::List { elts, .. } => {
// this shall be optimized later for constant primitive lists... // this shall be optimized later for constant primitive lists...
// we should use memcpy for that instead of generating thousands of stores // we should use memcpy for that instead of generating thousands of stores
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec(); let elements = elts.iter().map(|x| self.gen_expr(x).unwrap()).collect_vec();
let ty = if elements.is_empty() { let ty = if elements.is_empty() {
self.ctx.i32_type().into() self.ctx.i32_type().into()
} else { } else {
@ -293,7 +293,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
arr_str_ptr.into() arr_str_ptr.into()
} }
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec(); let element_val = elts.iter().map(|x| self.gen_expr(x).unwrap()).collect_vec();
let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec();
let tuple_ty = self.ctx.struct_type(&element_ty, false); let tuple_ty = self.ctx.struct_type(&element_ty, false);
let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple"); let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple");
@ -311,7 +311,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
// note that we would handle class methods directly in calls // note that we would handle class methods directly in calls
let index = self.get_attr_index(value.custom.unwrap(), attr); let index = self.get_attr_index(value.custom.unwrap(), attr);
let val = self.gen_expr(value); let val = self.gen_expr(value).unwrap();
let ptr = if let BasicValueEnum::PointerValue(v) = val { let ptr = if let BasicValueEnum::PointerValue(v) = val {
v v
} else { } else {
@ -327,11 +327,12 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
ExprKind::BoolOp { op, values } => { ExprKind::BoolOp { op, values } => {
// requires conditional branches for short-circuiting... // requires conditional branches for short-circuiting...
let left = if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]) { let left =
left if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]).unwrap() {
} else { left
unreachable!() } else {
}; unreachable!()
};
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let a_bb = self.ctx.append_basic_block(current, "a"); let a_bb = self.ctx.append_basic_block(current, "a");
let b_bb = self.ctx.append_basic_block(current, "b"); let b_bb = self.ctx.append_basic_block(current, "b");
@ -343,7 +344,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let a = self.ctx.bool_type().const_int(1, false); let a = self.ctx.bool_type().const_int(1, false);
self.builder.build_unconditional_branch(cont_bb); self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(b_bb); self.builder.position_at_end(b_bb);
let b = if let BasicValueEnum::IntValue(b) = self.gen_expr(&values[1]) { let b = if let BasicValueEnum::IntValue(b) =
self.gen_expr(&values[1]).unwrap()
{
b b
} else { } else {
unreachable!() unreachable!()
@ -353,7 +356,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
Boolop::And => { Boolop::And => {
self.builder.position_at_end(a_bb); self.builder.position_at_end(a_bb);
let a = if let BasicValueEnum::IntValue(a) = self.gen_expr(&values[1]) { let a = if let BasicValueEnum::IntValue(a) =
self.gen_expr(&values[1]).unwrap()
{
a a
} else { } else {
unreachable!() unreachable!()
@ -373,8 +378,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ExprKind::BinOp { op, left, right } => { ExprKind::BinOp { op, left, right } => {
let ty1 = self.unifier.get_representative(left.custom.unwrap()); let ty1 = self.unifier.get_representative(left.custom.unwrap());
let ty2 = self.unifier.get_representative(right.custom.unwrap()); let ty2 = self.unifier.get_representative(right.custom.unwrap());
let left = self.gen_expr(left); let left = self.gen_expr(left).unwrap();
let right = self.gen_expr(right); let right = self.gen_expr(right).unwrap();
// we can directly compare the types, because we've got their representatives // we can directly compare the types, because we've got their representatives
// which would be unchanged until further unification, which we would never do // which would be unchanged until further unification, which we would never do
@ -389,7 +394,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
ExprKind::UnaryOp { op, operand } => { ExprKind::UnaryOp { op, operand } => {
let ty = self.unifier.get_representative(operand.custom.unwrap()); let ty = self.unifier.get_representative(operand.custom.unwrap());
let val = self.gen_expr(operand); let val = self.gen_expr(operand).unwrap();
if ty == self.primitives.bool { if ty == self.primitives.bool {
let val = let val =
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() }; if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
@ -454,7 +459,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let (lhs, rhs) = if let ( let (lhs, rhs) = if let (
BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(lhs),
BasicValueEnum::IntValue(rhs), BasicValueEnum::IntValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs)) ) =
(self.gen_expr(lhs).unwrap(), self.gen_expr(rhs).unwrap())
{ {
(lhs, rhs) (lhs, rhs)
} else { } else {
@ -474,7 +480,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let (lhs, rhs) = if let ( let (lhs, rhs) = if let (
BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(lhs),
BasicValueEnum::FloatValue(rhs), BasicValueEnum::FloatValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs)) ) =
(self.gen_expr(lhs).unwrap(), self.gen_expr(rhs).unwrap())
{ {
(lhs, rhs) (lhs, rhs)
} else { } else {
@ -499,7 +506,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.into() // as there should be at least 1 element, it should never be none .into() // as there should be at least 1 element, it should never be none
} }
ExprKind::IfExp { test, body, orelse } => { ExprKind::IfExp { test, body, orelse } => {
let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test) { let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test).unwrap() {
test test
} else { } else {
unreachable!() unreachable!()
@ -511,17 +518,40 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let cont_bb = self.ctx.append_basic_block(current, "cont"); let cont_bb = self.ctx.append_basic_block(current, "cont");
self.builder.build_conditional_branch(test, then_bb, else_bb); self.builder.build_conditional_branch(test, then_bb, else_bb);
self.builder.position_at_end(then_bb); self.builder.position_at_end(then_bb);
let a = self.gen_expr(body); let a = self.gen_expr(body).unwrap();
self.builder.build_unconditional_branch(cont_bb); self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(else_bb); self.builder.position_at_end(else_bb);
let b = self.gen_expr(orelse); let b = self.gen_expr(orelse).unwrap();
self.builder.build_unconditional_branch(cont_bb); self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb); self.builder.position_at_end(cont_bb);
let phi = self.builder.build_phi(a.get_type(), "ifexpr"); let phi = self.builder.build_phi(a.get_type(), "ifexpr");
phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]); phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]);
phi.as_basic_value() phi.as_basic_value()
} }
ExprKind::Call { func, args, keywords } => {
if let ExprKind::Name { id, .. } = &func.as_ref().node {
// TODO: handle primitive casts
let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier");
let ret = expr.custom.unwrap();
let mut params =
args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec();
let kw_iter = keywords.iter().map(|kw| {
(
Some(kw.node.arg.as_ref().unwrap().clone()),
self.gen_expr(&kw.node.value).unwrap(),
)
});
params.extend(kw_iter);
let signature = self
.unifier
.get_call_signature(*self.calls.get(&expr.location.into()).unwrap())
.unwrap();
return self.gen_call(None, (&signature, fun), params, ret);
} else {
unimplemented!()
}
}
_ => unimplemented!(), _ => unimplemented!(),
} })
} }
} }

View File

@ -2,8 +2,8 @@ use crate::{
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
top_level::{TopLevelContext, TopLevelDef}, top_level::{TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{FunSignature, Type, TypeEnum, Unifier}, typedef::{CallId, FunSignature, Type, TypeEnum, Unifier},
}, },
}; };
use crossbeam::channel::{unbounded, Receiver, Sender}; use crossbeam::channel::{unbounded, Receiver, Sender};
@ -42,6 +42,7 @@ pub struct CodeGenContext<'ctx, 'a> {
pub var_assignment: HashMap<String, PointerValue<'ctx>>, pub var_assignment: HashMap<String, PointerValue<'ctx>>,
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>, pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub calls: HashMap<CodeLocation, CallId>,
// stores the alloca for variables // stores the alloca for variables
pub init_bb: BasicBlock<'ctx>, pub init_bb: BasicBlock<'ctx>,
// where continue and break should go to respectively // where continue and break should go to respectively
@ -186,6 +187,7 @@ pub struct CodeGenTask {
pub symbol_name: String, pub symbol_name: String,
pub signature: FunSignature, pub signature: FunSignature,
pub body: Vec<Stmt<Option<Type>>>, pub body: Vec<Stmt<Option<Type>>>,
pub calls: HashMap<CodeLocation, CallId>,
pub unifier_index: usize, pub unifier_index: usize,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>, pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
} }
@ -323,6 +325,7 @@ pub fn gen_func<'ctx>(
ctx: &context, ctx: &context,
resolver: task.resolver, resolver: task.resolver,
top_level: top_level_ctx.as_ref(), top_level: top_level_ctx.as_ref(),
calls: task.calls,
loop_bb: None, loop_bb: None,
var_assignment, var_assignment,
type_cache, type_cache,

View File

@ -28,7 +28,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
let index = self.get_attr_index(value.custom.unwrap(), attr); let index = self.get_attr_index(value.custom.unwrap(), attr);
let val = self.gen_expr(value); let val = self.gen_expr(value).unwrap();
let ptr = if let BasicValueEnum::PointerValue(v) = val { let ptr = if let BasicValueEnum::PointerValue(v) = val {
v v
} else { } else {
@ -68,33 +68,82 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
} }
pub fn gen_stmt(&mut self, stmt: &Stmt<Option<Type>>) { // return true if it contains terminator
pub fn gen_stmt(&mut self, stmt: &Stmt<Option<Type>>) -> bool {
match &stmt.node { match &stmt.node {
StmtKind::Expr { value } => { StmtKind::Expr { value } => {
self.gen_expr(&value); self.gen_expr(&value);
} }
StmtKind::Return { value } => { StmtKind::Return { value } => {
let value = value.as_ref().map(|v| self.gen_expr(&v)); let value = value.as_ref().map(|v| self.gen_expr(&v).unwrap());
let value = value.as_ref().map(|v| v as &dyn BasicValue); let value = value.as_ref().map(|v| v as &dyn BasicValue);
self.builder.build_return(value); self.builder.build_return(value);
return true;
} }
StmtKind::AnnAssign { target, value, .. } => { StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value { if let Some(value) = value {
let value = self.gen_expr(&value); let value = self.gen_expr(&value).unwrap();
self.gen_assignment(target, value); self.gen_assignment(target, value);
} }
} }
StmtKind::Assign { targets, value, .. } => { StmtKind::Assign { targets, value, .. } => {
let value = self.gen_expr(&value); let value = self.gen_expr(&value).unwrap();
for target in targets.iter() { for target in targets.iter() {
self.gen_assignment(target, value); self.gen_assignment(target, value);
} }
} }
StmtKind::Continue => { StmtKind::Continue => {
self.builder.build_unconditional_branch(self.loop_bb.unwrap().0); self.builder.build_unconditional_branch(self.loop_bb.unwrap().0);
return true;
} }
StmtKind::Break => { StmtKind::Break => {
self.builder.build_unconditional_branch(self.loop_bb.unwrap().1); self.builder.build_unconditional_branch(self.loop_bb.unwrap().1);
return true;
}
StmtKind::If { test, body, orelse } => {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let test_bb = self.ctx.append_basic_block(current, "test");
let body_bb = self.ctx.append_basic_block(current, "body");
let cont_bb = self.ctx.append_basic_block(current, "cont");
// if there is no orelse, we just go to cont_bb
let orelse_bb = if orelse.is_empty() {
cont_bb
} else {
self.ctx.append_basic_block(current, "orelse")
};
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb);
let test = self.gen_expr(test).unwrap();
if let BasicValueEnum::IntValue(test) = test {
self.builder.build_conditional_branch(test, body_bb, orelse_bb);
} else {
unreachable!()
};
self.builder.position_at_end(body_bb);
let mut exited = false;
for stmt in body.iter() {
exited = self.gen_stmt(stmt);
if exited {
break;
}
}
if !exited {
self.builder.build_unconditional_branch(cont_bb);
}
if !orelse.is_empty() {
exited = false;
self.builder.position_at_end(orelse_bb);
for stmt in orelse.iter() {
exited = self.gen_stmt(stmt);
if exited {
break;
}
}
if !exited {
self.builder.build_unconditional_branch(cont_bb);
}
}
self.builder.position_at_end(cont_bb);
} }
StmtKind::While { test, body, orelse } => { StmtKind::While { test, body, orelse } => {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
@ -111,7 +160,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let loop_bb = self.loop_bb.replace((test_bb, cont_bb)); let loop_bb = self.loop_bb.replace((test_bb, cont_bb));
self.builder.build_unconditional_branch(test_bb); self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb); self.builder.position_at_end(test_bb);
let test = self.gen_expr(test); let test = self.gen_expr(test).unwrap();
if let BasicValueEnum::IntValue(test) = test { if let BasicValueEnum::IntValue(test) = test {
self.builder.build_conditional_branch(test, body_bb, orelse_bb); self.builder.build_conditional_branch(test, body_bb, orelse_bb);
} else { } else {
@ -132,7 +181,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
self.builder.position_at_end(cont_bb); self.builder.position_at_end(cont_bb);
self.loop_bb = loop_bb; self.loop_bb = loop_bb;
} }
_ => unimplemented!(), _ => unimplemented!("{:?}", stmt),
} };
false
} }
} }

View File

@ -179,6 +179,7 @@ fn test_primitives() {
body: statements, body: statements,
unifier_index: 0, unifier_index: 0,
resolver: env.function_data.resolver.clone(), resolver: env.function_data.resolver.clone(),
calls: Default::default(),
signature, signature,
}; };

View File

@ -1,8 +1,8 @@
#![warn(clippy::all)] #![warn(clippy::all)]
#![allow(dead_code)] #![allow(dead_code)]
mod codegen; pub mod codegen;
mod location; pub mod location;
mod symbol_resolver; pub mod symbol_resolver;
mod top_level; pub mod top_level;
mod typecheck; pub mod typecheck;

View File

@ -97,6 +97,7 @@ impl TypeEnum {
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>; pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
#[derive(Clone)]
pub struct Unifier { pub struct Unifier {
unification_table: UnificationTable<Rc<TypeEnum>>, unification_table: UnificationTable<Rc<TypeEnum>>,
calls: Vec<Rc<Call>>, calls: Vec<Rc<Call>>,
@ -153,6 +154,15 @@ impl Unifier {
id id
} }
pub fn get_call_signature(&mut self, id: CallId) -> Option<FunSignature> {
let fun = self.calls.get(id.0).unwrap().fun.borrow().unwrap();
if let TypeEnum::TFunc(sign) = &*self.get_ty(fun) {
Some(sign.borrow().clone())
} else {
None
}
}
pub fn get_representative(&mut self, ty: Type) -> Type { pub fn get_representative(&mut self, ty: Type) -> Type {
self.unification_table.get_representative(ty) self.unification_table.get_representative(ty)
} }

View File

@ -3,6 +3,7 @@ use std::rc::Rc;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct UnificationKey(usize); pub struct UnificationKey(usize);
#[derive(Clone)]
pub struct UnificationTable<V> { pub struct UnificationTable<V> {
parents: Vec<usize>, parents: Vec<usize>,
ranks: Vec<u32>, ranks: Vec<u32>,