diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index d21d3ba79..59bb04300 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -14,7 +14,7 @@ use inkwell::{ use itertools::{chain, izip, zip, Itertools}; use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator}; -impl<'ctx> CodeGenContext<'ctx> { +impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { fn get_subst_key(&mut self, obj: Option, fun: &FunSignature) -> String { let mut vars = obj .map(|ty| { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index c3fba2dbc..9ca1cb583 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -16,21 +16,23 @@ use inkwell::{ AddressSpace, }; use itertools::Itertools; -use rayon::current_thread_index; -use rustpython_parser::ast::{Stmt, StmtKind}; +use rustpython_parser::ast::Stmt; use std::collections::HashMap; use std::sync::Arc; mod expr; mod stmt; -pub struct CodeGenContext<'ctx> { +#[cfg(test)] +mod test; + +pub struct CodeGenContext<'ctx, 'a> { pub ctx: &'ctx Context, pub builder: Builder<'ctx>, pub module: Module<'ctx>, - pub top_level: &'ctx TopLevelContext, + pub top_level: &'a TopLevelContext, pub unifier: Unifier, - pub resolver: Box, + pub resolver: Arc, pub var_assignment: HashMap>, pub type_cache: HashMap>, pub primitives: PrimitiveStore, @@ -45,9 +47,9 @@ pub struct CodeGenTask { pub subst: Vec<(Type, Type)>, pub symbol_name: String, pub signature: FunSignature, - pub body: Stmt>, + pub body: Vec>>, pub unifier_index: usize, - pub resolver: Box, + pub resolver: Arc, } fn get_llvm_type<'ctx>( @@ -60,7 +62,7 @@ fn get_llvm_type<'ctx>( use TypeEnum::*; // we assume the type cache should already contain primitive types, // and they should be passed by value instead of passing as pointer. - type_cache.get(&ty).cloned().unwrap_or_else(|| match &*unifier.get_ty(ty) { + type_cache.get(&unifier.get_representative(ty)).cloned().unwrap_or_else(|| match &*unifier.get_ty(ty) { TObj { obj_id, fields, .. } => { // a struct with fields in the order of declaration let defs = top_level.definitions.read(); @@ -97,16 +99,13 @@ fn get_llvm_type<'ctx>( }) } -pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc) { +pub fn gen_func<'ctx>(context: &'ctx Context, builder: Builder<'ctx>, module: Module<'ctx>, task: CodeGenTask, top_level_ctx: Arc) -> Module<'ctx> { // unwrap_or(0) is for unit tests without using rayon - let thread_id = current_thread_index().unwrap_or(0); let (mut unifier, primitives) = { let unifiers = top_level_ctx.unifiers.read(); let (unifier, primitives) = &unifiers[task.unifier_index]; (Unifier::from_shared_unifier(unifier), *primitives) }; - let contexts = top_level_ctx.conetexts.read(); - let context = contexts[thread_id].lock(); for (a, b) in task.subst.iter() { // this should be unification between variables and concrete types @@ -124,10 +123,10 @@ pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc) { }; let mut type_cache: HashMap<_, _> = [ - (primitives.int32, context.i32_type().into()), - (primitives.int64, context.i64_type().into()), - (primitives.float, context.f64_type().into()), - (primitives.bool, context.bool_type().into()), + (unifier.get_representative(primitives.int32), context.i32_type().into()), + (unifier.get_representative(primitives.int64), context.i64_type().into()), + (unifier.get_representative(primitives.float), context.f64_type().into()), + (unifier.get_representative(primitives.bool), context.bool_type().into()), ] .iter() .cloned() @@ -155,8 +154,6 @@ pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc) { .fn_type(¶ms, false) }; - let builder = context.create_builder(); - let module = context.create_module(&task.symbol_name); let fn_val = module.add_function(&task.symbol_name, fn_type, None); let init_bb = context.append_basic_block(fn_val, "init"); builder.position_at_end(init_bb); @@ -189,9 +186,9 @@ pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc) { unifier, }; - if let StmtKind::FunctionDef { body, .. } = &task.body.node { - for stmt in body.iter() { - code_gen_context.gen_stmt(stmt); - } + for stmt in task.body.iter() { + code_gen_context.gen_stmt(stmt); } + + code_gen_context.module } diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index fa9727f80..9f4686097 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -3,7 +3,7 @@ use crate::typecheck::typedef::Type; use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind}; -impl<'ctx> CodeGenContext<'ctx> { +impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> { // put the alloca in init block let current = self.builder.get_insert_block().unwrap(); diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs new file mode 100644 index 000000000..c9401f099 --- /dev/null +++ b/nac3core/src/codegen/test.rs @@ -0,0 +1,246 @@ +use super::{gen_func, CodeGenTask}; +use crate::{ + location::Location, + symbol_resolver::{SymbolResolver, SymbolValue}, + top_level::{DefinitionId, TopLevelContext}, + typecheck::{ + magic_methods::set_primitives_magic_methods, + type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore}, + typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}, + }, +}; +use indoc::indoc; +use inkwell::context::Context; +use parking_lot::RwLock; +use rustpython_parser::{ast::fold::Fold, parser::parse_program}; +use std::collections::HashMap; +use std::sync::Arc; + +#[derive(Clone)] +struct Resolver { + id_to_type: HashMap, + id_to_def: HashMap, + class_names: HashMap, +} + +impl SymbolResolver for Resolver { + fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { + self.id_to_type.get(str).cloned() + } + + fn get_symbol_value(&self, _: &str) -> Option { + unimplemented!() + } + + fn get_symbol_location(&self, _: &str) -> Option { + unimplemented!() + } + + fn get_identifier_def(&self, id: &str) -> Option { + self.id_to_def.get(id).cloned() + } +} + +struct TestEnvironment { + pub unifier: Unifier, + pub function_data: FunctionData, + pub primitives: PrimitiveStore, + pub id_to_name: HashMap, + pub identifier_mapping: HashMap, + pub virtual_checks: Vec<(Type, Type)>, + pub calls: HashMap>, + pub top_level: TopLevelContext, +} + +impl TestEnvironment { + pub fn basic_test_env() -> TestEnvironment { + let mut unifier = Unifier::new(); + + let int32 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(0), + fields: HashMap::new().into(), + params: HashMap::new(), + }); + let int64 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(1), + fields: HashMap::new().into(), + params: HashMap::new(), + }); + let float = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(2), + fields: HashMap::new().into(), + params: HashMap::new(), + }); + let bool = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(3), + fields: HashMap::new().into(), + params: HashMap::new(), + }); + let none = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(4), + fields: HashMap::new().into(), + params: HashMap::new(), + }); + let primitives = PrimitiveStore { int32, int64, float, bool, none }; + set_primitives_magic_methods(&primitives, &mut unifier); + + let id_to_name = [ + (0, "int32".to_string()), + (1, "int64".to_string()), + (2, "float".to_string()), + (3, "bool".to_string()), + (4, "none".to_string()), + ] + .iter() + .cloned() + .collect(); + + let mut identifier_mapping = HashMap::new(); + identifier_mapping.insert("None".into(), none); + + let resolver = Arc::new(Resolver { + id_to_type: identifier_mapping.clone(), + id_to_def: Default::default(), + class_names: Default::default(), + }) as Arc; + + TestEnvironment { + unifier, + top_level: TopLevelContext { + definitions: Default::default(), + unifiers: Default::default(), + conetexts: Default::default(), + }, + function_data: FunctionData { + resolver, + bound_variables: Vec::new(), + return_type: Some(primitives.int32), + }, + primitives, + id_to_name, + identifier_mapping, + virtual_checks: Vec::new(), + calls: HashMap::new(), + } + } + + fn get_inferencer(&mut self) -> Inferencer { + Inferencer { + top_level: &self.top_level, + function_data: &mut self.function_data, + unifier: &mut self.unifier, + variable_mapping: Default::default(), + primitives: &mut self.primitives, + virtual_checks: &mut self.virtual_checks, + calls: &mut self.calls, + } + } +} + +#[test] +fn test_primitives() { + let mut env = TestEnvironment::basic_test_env(); + let context = Context::create(); + let module = context.create_module("test"); + let builder = context.create_builder(); + + let signature = FunSignature { + args: vec![ + FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None }, + FuncArg { name: "b".to_string(), ty: env.primitives.int32, default_value: None }, + ], + ret: env.primitives.int32, + vars: HashMap::new(), + }; + + let mut inferencer = env.get_inferencer(); + let source = indoc! { " + c = a + b + d = a if c == 1 else 0 + return d + "}; + let statements = parse_program(source).unwrap(); + + let statements = statements + .into_iter() + .map(|v| inferencer.fold_stmt(v)) + .collect::, _>>() + .unwrap(); + + let top_level = Arc::new(TopLevelContext { + definitions: Default::default(), + unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])), + conetexts: Default::default(), + }); + + let task = CodeGenTask { + subst: Default::default(), + symbol_name: "testing".to_string(), + body: statements, + unifier_index: 0, + resolver: env.function_data.resolver.clone(), + signature, + }; + + let module = gen_func(&context, builder, module, task, top_level); + // the following IR is equivalent to + // ``` + // ; ModuleID = 'test.ll' + // source_filename = "test" + // + // ; Function Attrs: norecurse nounwind readnone + // define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 { + // init: + // %add = add i32 %1, %0 + // %cmp = icmp eq i32 %add, 1 + // %ifexpr = select i1 %cmp, i32 %0, i32 0 + // ret i32 %ifexpr + // } + // + // attributes #0 = { norecurse nounwind readnone } + // ``` + // after O2 optimization + + let expected = indoc! {" + ; ModuleID = 'test' + source_filename = \"test\" + + define i32 @testing(i32 %0, i32 %1) { + init: + %a = alloca i32 + store i32 %0, i32* %a + %b = alloca i32 + store i32 %1, i32* %b + %tmp = alloca i32 + %tmp4 = alloca i32 + br label %body + + body: ; preds = %init + %load = load i32, i32* %a + %load1 = load i32, i32* %b + %add = add i32 %load, %load1 + store i32 %add, i32* %tmp + %load2 = load i32, i32* %tmp + %cmp = icmp eq i32 %load2, 1 + br i1 %cmp, label %then, label %else + + then: ; preds = %body + %load3 = load i32, i32* %a + br label %cont + + else: ; preds = %body + br label %cont + + cont: ; preds = %else, %then + %ifexpr = phi i32 [ %load3, %then ], [ 0, %else ] + store i32 %ifexpr, i32* %tmp4 + %load5 = load i32, i32* %tmp4 + ret i32 %load5 + } + "} + .trim(); + let ir = module.print_to_string().to_string(); + println!("src:\n{}", source); + println!("IR:\n{}", ir); + assert_eq!(expected, ir.trim()); +}