diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 81a2161b..9329ba6f 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -88,7 +88,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ) -> Option> { let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0); let defs = self.top_level.definitions.read(); - let definition = defs.get(fun.1 .0).unwrap(); + let definition = defs.get(fun.1.0).unwrap(); let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| { // TODO: codegen for function that are not yet generated @@ -232,9 +232,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } } - pub fn gen_expr(&mut self, expr: &Expr>) -> BasicValueEnum<'ctx> { + pub fn gen_expr(&mut self, expr: &Expr>) -> Option> { let zero = self.ctx.i32_type().const_int(0, false); - match &expr.node { + Some(match &expr.node { ExprKind::Constant { value, .. } => { let ty = expr.custom.unwrap(); self.gen_const(value, ty) @@ -254,7 +254,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ExprKind::List { elts, .. } => { // this shall be optimized later for constant primitive lists... // we should use memcpy for that instead of generating thousands of stores - let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec(); + let elements = elts.iter().map(|x| self.gen_expr(x).unwrap()).collect_vec(); let ty = if elements.is_empty() { self.ctx.i32_type().into() } else { @@ -293,7 +293,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { arr_str_ptr.into() } ExprKind::Tuple { elts, .. } => { - let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec(); + let element_val = elts.iter().map(|x| self.gen_expr(x).unwrap()).collect_vec(); let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); let tuple_ty = self.ctx.struct_type(&element_ty, false); let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple"); @@ -311,7 +311,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ExprKind::Attribute { value, attr, .. } => { // note that we would handle class methods directly in calls let index = self.get_attr_index(value.custom.unwrap(), attr); - let val = self.gen_expr(value); + let val = self.gen_expr(value).unwrap(); let ptr = if let BasicValueEnum::PointerValue(v) = val { v } else { @@ -327,11 +327,12 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } ExprKind::BoolOp { op, values } => { // requires conditional branches for short-circuiting... - let left = if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]) { - left - } else { - unreachable!() - }; + let left = + if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]).unwrap() { + left + } else { + unreachable!() + }; let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let a_bb = self.ctx.append_basic_block(current, "a"); let b_bb = self.ctx.append_basic_block(current, "b"); @@ -343,7 +344,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let a = self.ctx.bool_type().const_int(1, false); self.builder.build_unconditional_branch(cont_bb); self.builder.position_at_end(b_bb); - let b = if let BasicValueEnum::IntValue(b) = self.gen_expr(&values[1]) { + let b = if let BasicValueEnum::IntValue(b) = + self.gen_expr(&values[1]).unwrap() + { b } else { unreachable!() @@ -353,7 +356,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } Boolop::And => { self.builder.position_at_end(a_bb); - let a = if let BasicValueEnum::IntValue(a) = self.gen_expr(&values[1]) { + let a = if let BasicValueEnum::IntValue(a) = + self.gen_expr(&values[1]).unwrap() + { a } else { unreachable!() @@ -373,8 +378,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ExprKind::BinOp { op, left, right } => { let ty1 = self.unifier.get_representative(left.custom.unwrap()); let ty2 = self.unifier.get_representative(right.custom.unwrap()); - let left = self.gen_expr(left); - let right = self.gen_expr(right); + let left = self.gen_expr(left).unwrap(); + let right = self.gen_expr(right).unwrap(); // we can directly compare the types, because we've got their representatives // which would be unchanged until further unification, which we would never do @@ -389,7 +394,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } ExprKind::UnaryOp { op, operand } => { let ty = self.unifier.get_representative(operand.custom.unwrap()); - let val = self.gen_expr(operand); + let val = self.gen_expr(operand).unwrap(); if ty == self.primitives.bool { let val = if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() }; @@ -454,7 +459,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let (lhs, rhs) = if let ( BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs), - ) = (self.gen_expr(lhs), self.gen_expr(rhs)) + ) = + (self.gen_expr(lhs).unwrap(), self.gen_expr(rhs).unwrap()) { (lhs, rhs) } else { @@ -474,7 +480,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let (lhs, rhs) = if let ( BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs), - ) = (self.gen_expr(lhs), self.gen_expr(rhs)) + ) = + (self.gen_expr(lhs).unwrap(), self.gen_expr(rhs).unwrap()) { (lhs, rhs) } else { @@ -499,7 +506,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { .into() // as there should be at least 1 element, it should never be none } ExprKind::IfExp { test, body, orelse } => { - let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test) { + let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test).unwrap() { test } else { unreachable!() @@ -511,17 +518,40 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let cont_bb = self.ctx.append_basic_block(current, "cont"); self.builder.build_conditional_branch(test, then_bb, else_bb); self.builder.position_at_end(then_bb); - let a = self.gen_expr(body); + let a = self.gen_expr(body).unwrap(); self.builder.build_unconditional_branch(cont_bb); self.builder.position_at_end(else_bb); - let b = self.gen_expr(orelse); + let b = self.gen_expr(orelse).unwrap(); self.builder.build_unconditional_branch(cont_bb); self.builder.position_at_end(cont_bb); let phi = self.builder.build_phi(a.get_type(), "ifexpr"); phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]); phi.as_basic_value() } + ExprKind::Call { func, args, keywords } => { + if let ExprKind::Name { id, .. } = &func.as_ref().node { + // TODO: handle primitive casts + let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier"); + let ret = expr.custom.unwrap(); + let mut params = + args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); + let kw_iter = keywords.iter().map(|kw| { + ( + Some(kw.node.arg.as_ref().unwrap().clone()), + self.gen_expr(&kw.node.value).unwrap(), + ) + }); + params.extend(kw_iter); + let signature = self + .unifier + .get_call_signature(*self.calls.get(&expr.location.into()).unwrap()) + .unwrap(); + return self.gen_call(None, (&signature, fun), params, ret); + } else { + unimplemented!() + } + } _ => unimplemented!(), - } + }) } } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 4df3ee99..374ea417 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -2,8 +2,8 @@ use crate::{ symbol_resolver::SymbolResolver, top_level::{TopLevelContext, TopLevelDef}, typecheck::{ - type_inferencer::PrimitiveStore, - typedef::{FunSignature, Type, TypeEnum, Unifier}, + type_inferencer::{CodeLocation, PrimitiveStore}, + typedef::{CallId, FunSignature, Type, TypeEnum, Unifier}, }, }; use crossbeam::channel::{unbounded, Receiver, Sender}; @@ -42,6 +42,7 @@ pub struct CodeGenContext<'ctx, 'a> { pub var_assignment: HashMap>, pub type_cache: HashMap>, pub primitives: PrimitiveStore, + pub calls: HashMap, // stores the alloca for variables pub init_bb: BasicBlock<'ctx>, // where continue and break should go to respectively @@ -186,6 +187,7 @@ pub struct CodeGenTask { pub symbol_name: String, pub signature: FunSignature, pub body: Vec>>, + pub calls: HashMap, pub unifier_index: usize, pub resolver: Arc, } @@ -323,6 +325,7 @@ pub fn gen_func<'ctx>( ctx: &context, resolver: task.resolver, top_level: top_level_ctx.as_ref(), + calls: task.calls, loop_bb: None, var_assignment, type_cache, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 9f468609..3aa95c13 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -28,7 +28,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } ExprKind::Attribute { value, attr, .. } => { let index = self.get_attr_index(value.custom.unwrap(), attr); - let val = self.gen_expr(value); + let val = self.gen_expr(value).unwrap(); let ptr = if let BasicValueEnum::PointerValue(v) = val { v } else { @@ -68,33 +68,82 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } } - pub fn gen_stmt(&mut self, stmt: &Stmt>) { + // return true if it contains terminator + pub fn gen_stmt(&mut self, stmt: &Stmt>) -> bool { match &stmt.node { StmtKind::Expr { value } => { self.gen_expr(&value); } StmtKind::Return { value } => { - let value = value.as_ref().map(|v| self.gen_expr(&v)); + let value = value.as_ref().map(|v| self.gen_expr(&v).unwrap()); let value = value.as_ref().map(|v| v as &dyn BasicValue); self.builder.build_return(value); + return true; } StmtKind::AnnAssign { target, value, .. } => { if let Some(value) = value { - let value = self.gen_expr(&value); + let value = self.gen_expr(&value).unwrap(); self.gen_assignment(target, value); } } StmtKind::Assign { targets, value, .. } => { - let value = self.gen_expr(&value); + let value = self.gen_expr(&value).unwrap(); for target in targets.iter() { self.gen_assignment(target, value); } } StmtKind::Continue => { self.builder.build_unconditional_branch(self.loop_bb.unwrap().0); + return true; } StmtKind::Break => { self.builder.build_unconditional_branch(self.loop_bb.unwrap().1); + return true; + } + StmtKind::If { test, body, orelse } => { + let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); + let test_bb = self.ctx.append_basic_block(current, "test"); + let body_bb = self.ctx.append_basic_block(current, "body"); + let cont_bb = self.ctx.append_basic_block(current, "cont"); + // if there is no orelse, we just go to cont_bb + let orelse_bb = if orelse.is_empty() { + cont_bb + } else { + self.ctx.append_basic_block(current, "orelse") + }; + self.builder.build_unconditional_branch(test_bb); + self.builder.position_at_end(test_bb); + let test = self.gen_expr(test).unwrap(); + if let BasicValueEnum::IntValue(test) = test { + self.builder.build_conditional_branch(test, body_bb, orelse_bb); + } else { + unreachable!() + }; + self.builder.position_at_end(body_bb); + let mut exited = false; + for stmt in body.iter() { + exited = self.gen_stmt(stmt); + if exited { + break; + } + } + if !exited { + self.builder.build_unconditional_branch(cont_bb); + } + if !orelse.is_empty() { + exited = false; + self.builder.position_at_end(orelse_bb); + for stmt in orelse.iter() { + exited = self.gen_stmt(stmt); + if exited { + break; + } + } + if !exited { + self.builder.build_unconditional_branch(cont_bb); + } + } + self.builder.position_at_end(cont_bb); } StmtKind::While { test, body, orelse } => { let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); @@ -111,7 +160,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let loop_bb = self.loop_bb.replace((test_bb, cont_bb)); self.builder.build_unconditional_branch(test_bb); self.builder.position_at_end(test_bb); - let test = self.gen_expr(test); + let test = self.gen_expr(test).unwrap(); if let BasicValueEnum::IntValue(test) = test { self.builder.build_conditional_branch(test, body_bb, orelse_bb); } else { @@ -132,7 +181,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { self.builder.position_at_end(cont_bb); self.loop_bb = loop_bb; } - _ => unimplemented!(), - } + _ => unimplemented!("{:?}", stmt), + }; + false } } diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 72ea9a75..2b11d195 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -179,6 +179,7 @@ fn test_primitives() { body: statements, unifier_index: 0, resolver: env.function_data.resolver.clone(), + calls: Default::default(), signature, }; diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index 6d7de6f5..ad1725b4 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -1,8 +1,8 @@ #![warn(clippy::all)] #![allow(dead_code)] -mod codegen; -mod location; -mod symbol_resolver; -mod top_level; -mod typecheck; +pub mod codegen; +pub mod location; +pub mod symbol_resolver; +pub mod top_level; +pub mod typecheck; diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 621ea65b..5ca9a9fd 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -97,6 +97,7 @@ impl TypeEnum { pub type SharedUnifier = Arc, u32, Vec)>>; +#[derive(Clone)] pub struct Unifier { unification_table: UnificationTable>, calls: Vec>, @@ -153,6 +154,15 @@ impl Unifier { id } + pub fn get_call_signature(&mut self, id: CallId) -> Option { + let fun = self.calls.get(id.0).unwrap().fun.borrow().unwrap(); + if let TypeEnum::TFunc(sign) = &*self.get_ty(fun) { + Some(sign.borrow().clone()) + } else { + None + } + } + pub fn get_representative(&mut self, ty: Type) -> Type { self.unification_table.get_representative(ty) } diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs index 7475afce..19f836d8 100644 --- a/nac3core/src/typecheck/unification_table.rs +++ b/nac3core/src/typecheck/unification_table.rs @@ -3,6 +3,7 @@ use std::rc::Rc; #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] pub struct UnificationKey(usize); +#[derive(Clone)] pub struct UnificationTable { parents: Vec, ranks: Vec,