forked from M-Labs/nac3
1
0
Fork 0

codegen for simple function call, and various fixes

This commit is contained in:
pca006132 2021-08-19 15:30:15 +08:00
parent f205a8282a
commit 3279f7a776
7 changed files with 132 additions and 37 deletions

View File

@ -88,7 +88,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
let defs = self.top_level.definitions.read();
let definition = defs.get(fun.1 .0).unwrap();
let definition = defs.get(fun.1.0).unwrap();
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
// TODO: codegen for function that are not yet generated
@ -232,9 +232,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
}
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> {
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> Option<BasicValueEnum<'ctx>> {
let zero = self.ctx.i32_type().const_int(0, false);
match &expr.node {
Some(match &expr.node {
ExprKind::Constant { value, .. } => {
let ty = expr.custom.unwrap();
self.gen_const(value, ty)
@ -254,7 +254,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ExprKind::List { elts, .. } => {
// this shall be optimized later for constant primitive lists...
// we should use memcpy for that instead of generating thousands of stores
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
let elements = elts.iter().map(|x| self.gen_expr(x).unwrap()).collect_vec();
let ty = if elements.is_empty() {
self.ctx.i32_type().into()
} else {
@ -293,7 +293,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
arr_str_ptr.into()
}
ExprKind::Tuple { elts, .. } => {
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
let element_val = elts.iter().map(|x| self.gen_expr(x).unwrap()).collect_vec();
let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec();
let tuple_ty = self.ctx.struct_type(&element_ty, false);
let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple");
@ -311,7 +311,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ExprKind::Attribute { value, attr, .. } => {
// note that we would handle class methods directly in calls
let index = self.get_attr_index(value.custom.unwrap(), attr);
let val = self.gen_expr(value);
let val = self.gen_expr(value).unwrap();
let ptr = if let BasicValueEnum::PointerValue(v) = val {
v
} else {
@ -327,7 +327,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
ExprKind::BoolOp { op, values } => {
// requires conditional branches for short-circuiting...
let left = if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]) {
let left =
if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]).unwrap() {
left
} else {
unreachable!()
@ -343,7 +344,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let a = self.ctx.bool_type().const_int(1, false);
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(b_bb);
let b = if let BasicValueEnum::IntValue(b) = self.gen_expr(&values[1]) {
let b = if let BasicValueEnum::IntValue(b) =
self.gen_expr(&values[1]).unwrap()
{
b
} else {
unreachable!()
@ -353,7 +356,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
Boolop::And => {
self.builder.position_at_end(a_bb);
let a = if let BasicValueEnum::IntValue(a) = self.gen_expr(&values[1]) {
let a = if let BasicValueEnum::IntValue(a) =
self.gen_expr(&values[1]).unwrap()
{
a
} else {
unreachable!()
@ -373,8 +378,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ExprKind::BinOp { op, left, right } => {
let ty1 = self.unifier.get_representative(left.custom.unwrap());
let ty2 = self.unifier.get_representative(right.custom.unwrap());
let left = self.gen_expr(left);
let right = self.gen_expr(right);
let left = self.gen_expr(left).unwrap();
let right = self.gen_expr(right).unwrap();
// we can directly compare the types, because we've got their representatives
// which would be unchanged until further unification, which we would never do
@ -389,7 +394,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
ExprKind::UnaryOp { op, operand } => {
let ty = self.unifier.get_representative(operand.custom.unwrap());
let val = self.gen_expr(operand);
let val = self.gen_expr(operand).unwrap();
if ty == self.primitives.bool {
let val =
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
@ -454,7 +459,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let (lhs, rhs) = if let (
BasicValueEnum::IntValue(lhs),
BasicValueEnum::IntValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
) =
(self.gen_expr(lhs).unwrap(), self.gen_expr(rhs).unwrap())
{
(lhs, rhs)
} else {
@ -474,7 +480,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let (lhs, rhs) = if let (
BasicValueEnum::FloatValue(lhs),
BasicValueEnum::FloatValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
) =
(self.gen_expr(lhs).unwrap(), self.gen_expr(rhs).unwrap())
{
(lhs, rhs)
} else {
@ -499,7 +506,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.into() // as there should be at least 1 element, it should never be none
}
ExprKind::IfExp { test, body, orelse } => {
let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test) {
let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test).unwrap() {
test
} else {
unreachable!()
@ -511,17 +518,40 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let cont_bb = self.ctx.append_basic_block(current, "cont");
self.builder.build_conditional_branch(test, then_bb, else_bb);
self.builder.position_at_end(then_bb);
let a = self.gen_expr(body);
let a = self.gen_expr(body).unwrap();
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(else_bb);
let b = self.gen_expr(orelse);
let b = self.gen_expr(orelse).unwrap();
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
let phi = self.builder.build_phi(a.get_type(), "ifexpr");
phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]);
phi.as_basic_value()
}
_ => unimplemented!(),
ExprKind::Call { func, args, keywords } => {
if let ExprKind::Name { id, .. } = &func.as_ref().node {
// TODO: handle primitive casts
let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier");
let ret = expr.custom.unwrap();
let mut params =
args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec();
let kw_iter = keywords.iter().map(|kw| {
(
Some(kw.node.arg.as_ref().unwrap().clone()),
self.gen_expr(&kw.node.value).unwrap(),
)
});
params.extend(kw_iter);
let signature = self
.unifier
.get_call_signature(*self.calls.get(&expr.location.into()).unwrap())
.unwrap();
return self.gen_call(None, (&signature, fun), params, ret);
} else {
unimplemented!()
}
}
_ => unimplemented!(),
})
}
}

View File

@ -2,8 +2,8 @@ use crate::{
symbol_resolver::SymbolResolver,
top_level::{TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{FunSignature, Type, TypeEnum, Unifier},
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FunSignature, Type, TypeEnum, Unifier},
},
};
use crossbeam::channel::{unbounded, Receiver, Sender};
@ -42,6 +42,7 @@ pub struct CodeGenContext<'ctx, 'a> {
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore,
pub calls: HashMap<CodeLocation, CallId>,
// stores the alloca for variables
pub init_bb: BasicBlock<'ctx>,
// where continue and break should go to respectively
@ -186,6 +187,7 @@ pub struct CodeGenTask {
pub symbol_name: String,
pub signature: FunSignature,
pub body: Vec<Stmt<Option<Type>>>,
pub calls: HashMap<CodeLocation, CallId>,
pub unifier_index: usize,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
}
@ -323,6 +325,7 @@ pub fn gen_func<'ctx>(
ctx: &context,
resolver: task.resolver,
top_level: top_level_ctx.as_ref(),
calls: task.calls,
loop_bb: None,
var_assignment,
type_cache,

View File

@ -28,7 +28,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
ExprKind::Attribute { value, attr, .. } => {
let index = self.get_attr_index(value.custom.unwrap(), attr);
let val = self.gen_expr(value);
let val = self.gen_expr(value).unwrap();
let ptr = if let BasicValueEnum::PointerValue(v) = val {
v
} else {
@ -68,33 +68,82 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
}
pub fn gen_stmt(&mut self, stmt: &Stmt<Option<Type>>) {
// return true if it contains terminator
pub fn gen_stmt(&mut self, stmt: &Stmt<Option<Type>>) -> bool {
match &stmt.node {
StmtKind::Expr { value } => {
self.gen_expr(&value);
}
StmtKind::Return { value } => {
let value = value.as_ref().map(|v| self.gen_expr(&v));
let value = value.as_ref().map(|v| self.gen_expr(&v).unwrap());
let value = value.as_ref().map(|v| v as &dyn BasicValue);
self.builder.build_return(value);
return true;
}
StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value {
let value = self.gen_expr(&value);
let value = self.gen_expr(&value).unwrap();
self.gen_assignment(target, value);
}
}
StmtKind::Assign { targets, value, .. } => {
let value = self.gen_expr(&value);
let value = self.gen_expr(&value).unwrap();
for target in targets.iter() {
self.gen_assignment(target, value);
}
}
StmtKind::Continue => {
self.builder.build_unconditional_branch(self.loop_bb.unwrap().0);
return true;
}
StmtKind::Break => {
self.builder.build_unconditional_branch(self.loop_bb.unwrap().1);
return true;
}
StmtKind::If { test, body, orelse } => {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let test_bb = self.ctx.append_basic_block(current, "test");
let body_bb = self.ctx.append_basic_block(current, "body");
let cont_bb = self.ctx.append_basic_block(current, "cont");
// if there is no orelse, we just go to cont_bb
let orelse_bb = if orelse.is_empty() {
cont_bb
} else {
self.ctx.append_basic_block(current, "orelse")
};
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb);
let test = self.gen_expr(test).unwrap();
if let BasicValueEnum::IntValue(test) = test {
self.builder.build_conditional_branch(test, body_bb, orelse_bb);
} else {
unreachable!()
};
self.builder.position_at_end(body_bb);
let mut exited = false;
for stmt in body.iter() {
exited = self.gen_stmt(stmt);
if exited {
break;
}
}
if !exited {
self.builder.build_unconditional_branch(cont_bb);
}
if !orelse.is_empty() {
exited = false;
self.builder.position_at_end(orelse_bb);
for stmt in orelse.iter() {
exited = self.gen_stmt(stmt);
if exited {
break;
}
}
if !exited {
self.builder.build_unconditional_branch(cont_bb);
}
}
self.builder.position_at_end(cont_bb);
}
StmtKind::While { test, body, orelse } => {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
@ -111,7 +160,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let loop_bb = self.loop_bb.replace((test_bb, cont_bb));
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb);
let test = self.gen_expr(test);
let test = self.gen_expr(test).unwrap();
if let BasicValueEnum::IntValue(test) = test {
self.builder.build_conditional_branch(test, body_bb, orelse_bb);
} else {
@ -132,7 +181,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
self.builder.position_at_end(cont_bb);
self.loop_bb = loop_bb;
}
_ => unimplemented!(),
}
_ => unimplemented!("{:?}", stmt),
};
false
}
}

View File

@ -179,6 +179,7 @@ fn test_primitives() {
body: statements,
unifier_index: 0,
resolver: env.function_data.resolver.clone(),
calls: Default::default(),
signature,
};

View File

@ -1,8 +1,8 @@
#![warn(clippy::all)]
#![allow(dead_code)]
mod codegen;
mod location;
mod symbol_resolver;
mod top_level;
mod typecheck;
pub mod codegen;
pub mod location;
pub mod symbol_resolver;
pub mod top_level;
pub mod typecheck;

View File

@ -97,6 +97,7 @@ impl TypeEnum {
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
#[derive(Clone)]
pub struct Unifier {
unification_table: UnificationTable<Rc<TypeEnum>>,
calls: Vec<Rc<Call>>,
@ -153,6 +154,15 @@ impl Unifier {
id
}
pub fn get_call_signature(&mut self, id: CallId) -> Option<FunSignature> {
let fun = self.calls.get(id.0).unwrap().fun.borrow().unwrap();
if let TypeEnum::TFunc(sign) = &*self.get_ty(fun) {
Some(sign.borrow().clone())
} else {
None
}
}
pub fn get_representative(&mut self, ty: Type) -> Type {
self.unification_table.get_representative(ty)
}

View File

@ -3,6 +3,7 @@ use std::rc::Rc;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct UnificationKey(usize);
#[derive(Clone)]
pub struct UnificationTable<V> {
parents: Vec<usize>,
ranks: Vec<u32>,