From 1f5bea2448ef6919858b3ffbb73befb0a60eecac Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sat, 16 Oct 2021 22:17:36 +0800 Subject: [PATCH] nac3core/codegen: refactor according to #23 --- nac3artiq/src/lib.rs | 10 +- nac3core/src/codegen/expr.rs | 973 +++++++++++++++--------------- nac3core/src/codegen/generator.rs | 142 +++++ nac3core/src/codegen/mod.rs | 28 +- nac3core/src/codegen/stmt.rs | 443 +++++++------- nac3core/src/codegen/test.rs | 10 +- nac3standalone/src/main.rs | 80 ++- 7 files changed, 954 insertions(+), 732 deletions(-) create mode 100644 nac3core/src/codegen/generator.rs diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index bbd298c9..edead542 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -19,7 +19,7 @@ use rustpython_parser::{ use parking_lot::{Mutex, RwLock}; use nac3core::{ - codegen::{CodeGenTask, WithCall, WorkerRegistry}, + codegen::{CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry}, symbol_resolver::SymbolResolver, toplevel::{composer::TopLevelComposer, DefinitionId, GenCall, TopLevelContext, TopLevelDef}, typecheck::typedef::{FunSignature, FuncArg}, @@ -423,11 +423,13 @@ impl Nac3 { .expect("couldn't write module to file"); }))); let thread_names: Vec = (0..4).map(|i| format!("module{}", i)).collect(); - let threads: Vec<_> = thread_names.iter().map(|s| s.as_str()).collect(); + let threads: Vec<_> = thread_names + .iter() + .map(|s| Box::new(DefaultCodeGenerator::new(s.to_string()))) + .collect(); py.allow_threads(|| { - let (registry, handles) = - WorkerRegistry::create_workers(&threads, top_level.clone(), f); + let (registry, handles) = WorkerRegistry::create_workers(threads, top_level.clone(), f); registry.add_task(task); registry.wait_tasks_complete(handles); }); diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 69a4d50a..7f319fd9 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -14,6 +14,8 @@ use inkwell::{ use itertools::{chain, izip, zip, Itertools}; use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator, StrRef}; +use super::CodeGenerator; + pub fn assert_int_val(val: BasicValueEnum<'_>) -> IntValue<'_> { if let BasicValueEnum::IntValue(v) = val { v @@ -102,172 +104,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { get_llvm_type(self.ctx, &mut self.unifier, self.top_level, &mut self.type_cache, ty) } - fn gen_call( - &mut self, - obj: Option<(Type, BasicValueEnum<'ctx>)>, - fun: (&FunSignature, DefinitionId), - params: Vec<(Option, BasicValueEnum<'ctx>)>, - ) -> Option> { - let definition = self.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); - let mut task = None; - let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None); - let symbol = { - // make sure this lock guard is dropped at the end of this scope... - let def = definition.read(); - match &*def { - TopLevelDef::Function { instance_to_symbol, codegen_callback, .. } => { - if let Some(callback) = codegen_callback { - return callback.run(self, obj, fun, params); - } - instance_to_symbol.get(&key).cloned() - } - TopLevelDef::Class { methods, .. } => { - // TODO: what about other fields that require alloca? - let mut fun_id = None; - for (name, _, id) in methods.iter() { - if name == &"__init__".into() { - fun_id = Some(*id); - } - } - let ty = self.get_llvm_type(fun.0.ret).into_pointer_type(); - let zelf_ty: BasicTypeEnum = ty.get_element_type().try_into().unwrap(); - let zelf = self.builder.build_alloca(zelf_ty, "alloca").into(); - // call `__init__` if there is one - if let Some(fun_id) = fun_id { - let mut sign = fun.0.clone(); - sign.ret = self.primitives.none; - self.gen_call(Some((fun.0.ret, zelf)), (&sign, fun_id), params); - } - return Some(zelf); - } - } - } - .unwrap_or_else(|| { - if let TopLevelDef::Function { - name, - instance_to_symbol, - instance_to_stmt, - var_id, - resolver, - .. - } = &mut *definition.write() - { - instance_to_symbol.get(&key).cloned().unwrap_or_else(|| { - let symbol = format!("{}.{}", name, instance_to_symbol.len()); - instance_to_symbol.insert(key, symbol.clone()); - let key = self.get_subst_key(obj.map(|a| a.0), fun.0, Some(var_id)); - let instance = instance_to_stmt.get(&key).unwrap(); - let unifiers = self.top_level.unifiers.read(); - let (unifier, primitives) = &unifiers[instance.unifier_id]; - let mut unifier = Unifier::from_shared_unifier(unifier); - - let mut type_cache = [ - (self.primitives.int32, primitives.int32), - (self.primitives.int64, primitives.int64), - (self.primitives.float, primitives.float), - (self.primitives.bool, primitives.bool), - (self.primitives.none, primitives.none), - ] - .iter() - .map(|(a, b)| { - (self.unifier.get_representative(*a), unifier.get_representative(*b)) - }) - .collect(); - - let subst = fun - .0 - .vars - .iter() - .map(|(id, ty)| { - ( - *instance.subst.get(id).unwrap(), - unifier.copy_from(&mut self.unifier, *ty, &mut type_cache), - ) - }) - .collect(); - - let mut signature = FunSignature { - args: fun - .0 - .args - .iter() - .map(|arg| FuncArg { - name: arg.name, - ty: unifier.copy_from(&mut self.unifier, arg.ty, &mut type_cache), - default_value: arg.default_value.clone(), - }) - .collect(), - ret: unifier.copy_from(&mut self.unifier, fun.0.ret, &mut type_cache), - vars: fun - .0 - .vars - .iter() - .map(|(id, ty)| { - (*id, unifier.copy_from(&mut self.unifier, *ty, &mut type_cache)) - }) - .collect(), - }; - - if let Some(obj) = &obj { - signature.args.insert( - 0, - FuncArg { name: "self".into(), ty: obj.0, default_value: None }, - ); - } - - let unifier = (unifier.get_shared_unifier(), *primitives); - - task = Some(CodeGenTask { - symbol_name: symbol.clone(), - body: instance.body.clone(), - resolver: resolver.as_ref().unwrap().clone(), - calls: instance.calls.clone(), - subst, - signature, - unifier, - }); - symbol - }) - } else { - unreachable!() - } - }); - - if let Some(task) = task { - self.registry.add_task(task); - } - - let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| { - let mut args = fun.0.args.clone(); - if let Some(obj) = &obj { - args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None }); - } - let params = args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); - let fun_ty = if self.unifier.unioned(fun.0.ret, self.primitives.none) { - self.ctx.void_type().fn_type(¶ms, false) - } else { - self.get_llvm_type(fun.0.ret).fn_type(¶ms, false) - }; - self.module.add_function(&symbol, fun_ty, None) - }); - let mut keys = fun.0.args.clone(); - let mut mapping = HashMap::new(); - for (key, value) in params.into_iter() { - mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); - } - // default value handling - for k in keys.into_iter() { - mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap())); - } - // reorder the parameters - let mut params = - fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); - if let Some(obj) = obj { - params.insert(0, obj.1); - } - self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left() - } - fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> { match value { Constant::Bool(v) => { @@ -378,216 +214,392 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { _ => unimplemented!(), } } +} - pub fn gen_expr(&mut self, expr: &Expr>) -> Option> { - let zero = self.ctx.i32_type().const_int(0, false); - Some(match &expr.node { - ExprKind::Constant { value, .. } => { - let ty = expr.custom.unwrap(); - self.gen_const(value, ty) - } - ExprKind::Name { id, .. } => { - let ptr = self.var_assignment.get(id); - if let Some(ptr) = ptr { - self.builder.build_load(*ptr, "load") - } else { - let resolver = self.resolver.clone(); - resolver.get_symbol_value(*id, self).unwrap() +pub fn gen_constructor<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + signature: &FunSignature, + def: &TopLevelDef, + params: Vec<(Option, BasicValueEnum<'ctx>)>, +) -> BasicValueEnum<'ctx> { + match def { + TopLevelDef::Class { methods, .. } => { + // TODO: what about other fields that require alloca? + let mut fun_id = None; + for (name, _, id) in methods.iter() { + if name == &"__init__".into() { + fun_id = Some(*id); } } - ExprKind::List { elts, .. } => { - // this shall be optimized later for constant primitive lists... - // we should use memcpy for that instead of generating thousands of stores - let elements = elts.iter().map(|x| self.gen_expr(x).unwrap()).collect_vec(); - let ty = if elements.is_empty() { - self.ctx.i32_type().into() - } else { - elements[0].get_type() - }; - let arr_ptr = self.builder.build_array_alloca( - ty, - self.ctx.i32_type().const_int(elements.len() as u64, false), - "tmparr", - ); - let arr_ty = self.ctx.struct_type( - &[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()], - false, - ); - let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr"); - unsafe { - let len_ptr = - self.builder.build_in_bounds_gep(arr_str_ptr, &[zero, zero], "len_ptr"); - self.builder.build_store( - len_ptr, - self.ctx.i32_type().const_int(elements.len() as u64, false), - ); - let ptr_to_arr = self.builder.build_in_bounds_gep( - arr_str_ptr, - &[zero, self.ctx.i32_type().const_int(1, false)], - "ptr_to_arr", - ); - self.builder.build_store(ptr_to_arr, arr_ptr); - let i32_type = self.ctx.i32_type(); - for (i, v) in elements.iter().enumerate() { - let elem_ptr = self.builder.build_in_bounds_gep( - arr_ptr, - &[i32_type.const_int(i as u64, false)], - "elem_ptr", - ); - self.builder.build_store(elem_ptr, *v); - } - } - arr_str_ptr.into() + let ty = ctx.get_llvm_type(signature.ret).into_pointer_type(); + let zelf_ty: BasicTypeEnum = ty.get_element_type().try_into().unwrap(); + let zelf = ctx.builder.build_alloca(zelf_ty, "alloca").into(); + // call `__init__` if there is one + if let Some(fun_id) = fun_id { + let mut sign = signature.clone(); + sign.ret = ctx.primitives.none; + generator.gen_call(ctx, Some((signature.ret, zelf)), (&sign, fun_id), params); } - ExprKind::Tuple { elts, .. } => { - let element_val = elts.iter().map(|x| self.gen_expr(x).unwrap()).collect_vec(); - let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); - let tuple_ty = self.ctx.struct_type(&element_ty, false); - let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple"); - for (i, v) in element_val.into_iter().enumerate() { - unsafe { - let ptr = self.builder.build_in_bounds_gep( - tuple_ptr, - &[zero, self.ctx.i32_type().const_int(i as u64, false)], - "ptr", - ); - self.builder.build_store(ptr, v); - } - } - tuple_ptr.into() - } - ExprKind::Attribute { value, attr, .. } => { - // note that we would handle class methods directly in calls - let index = self.get_attr_index(value.custom.unwrap(), *attr); - let val = self.gen_expr(value).unwrap(); - let ptr = assert_pointer_val(val); - unsafe { - let ptr = self.builder.build_in_bounds_gep( - ptr, - &[zero, self.ctx.i32_type().const_int(index as u64, false)], - "attr", - ); - self.builder.build_load(ptr, "field") - } - } - ExprKind::BoolOp { op, values } => { - // requires conditional branches for short-circuiting... - let left = assert_int_val(self.gen_expr(&values[0]).unwrap()); - let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); - let a_bb = self.ctx.append_basic_block(current, "a"); - let b_bb = self.ctx.append_basic_block(current, "b"); - let cont_bb = self.ctx.append_basic_block(current, "cont"); - self.builder.build_conditional_branch(left, a_bb, b_bb); - let (a, b) = match op { - Boolop::Or => { - self.builder.position_at_end(a_bb); - let a = self.ctx.bool_type().const_int(1, false); - self.builder.build_unconditional_branch(cont_bb); - self.builder.position_at_end(b_bb); - let b = assert_int_val(self.gen_expr(&values[1]).unwrap()); - self.builder.build_unconditional_branch(cont_bb); - (a, b) - } - Boolop::And => { - self.builder.position_at_end(a_bb); - let a = assert_int_val(self.gen_expr(&values[1]).unwrap()); - self.builder.build_unconditional_branch(cont_bb); - self.builder.position_at_end(b_bb); - let b = self.ctx.bool_type().const_int(0, false); - self.builder.build_unconditional_branch(cont_bb); - (a, b) - } - }; - self.builder.position_at_end(cont_bb); - let phi = self.builder.build_phi(self.ctx.bool_type(), "phi"); - phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); - phi.as_basic_value() - } - ExprKind::BinOp { op, left, right } => { - let ty1 = self.unifier.get_representative(left.custom.unwrap()); - let ty2 = self.unifier.get_representative(right.custom.unwrap()); - let left = self.gen_expr(left).unwrap(); - let right = self.gen_expr(right).unwrap(); + zelf + } + _ => unreachable!(), + } +} - // we can directly compare the types, because we've got their representatives - // which would be unchanged until further unification, which we would never do - // when doing code generation for function instances - if ty1 == ty2 && [self.primitives.int32, self.primitives.int64].contains(&ty1) { - self.gen_int_ops(op, left, right) - } else if ty1 == ty2 && self.primitives.float == ty1 { - self.gen_float_ops(op, left, right) - } else { - unimplemented!() +pub fn gen_func_instance<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, BasicValueEnum<'ctx>)>, + fun: (&FunSignature, &mut TopLevelDef, String), +) -> String { + if let ( + sign, + TopLevelDef::Function { + name, instance_to_symbol, instance_to_stmt, var_id, resolver, .. + }, + key, + ) = fun + { + instance_to_symbol.get(&key).cloned().unwrap_or_else(|| { + let symbol = format!("{}.{}", name, instance_to_symbol.len()); + instance_to_symbol.insert(key, symbol.clone()); + let key = ctx.get_subst_key(obj.map(|a| a.0), sign, Some(var_id)); + let instance = instance_to_stmt.get(&key).unwrap(); + let unifiers = ctx.top_level.unifiers.read(); + let (unifier, primitives) = &unifiers[instance.unifier_id]; + let mut unifier = Unifier::from_shared_unifier(unifier); + + let mut type_cache = [ + (ctx.primitives.int32, primitives.int32), + (ctx.primitives.int64, primitives.int64), + (ctx.primitives.float, primitives.float), + (ctx.primitives.bool, primitives.bool), + (ctx.primitives.none, primitives.none), + ] + .iter() + .map(|(a, b)| (ctx.unifier.get_representative(*a), unifier.get_representative(*b))) + .collect(); + + let subst = sign + .vars + .iter() + .map(|(id, ty)| { + ( + *instance.subst.get(id).unwrap(), + unifier.copy_from(&mut ctx.unifier, *ty, &mut type_cache), + ) + }) + .collect(); + + let mut signature = FunSignature { + args: sign + .args + .iter() + .map(|arg| FuncArg { + name: arg.name, + ty: unifier.copy_from(&mut ctx.unifier, arg.ty, &mut type_cache), + default_value: arg.default_value.clone(), + }) + .collect(), + ret: unifier.copy_from(&mut ctx.unifier, sign.ret, &mut type_cache), + vars: sign + .vars + .iter() + .map(|(id, ty)| { + (*id, unifier.copy_from(&mut ctx.unifier, *ty, &mut type_cache)) + }) + .collect(), + }; + + if let Some(obj) = &obj { + signature + .args + .insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None }); + } + + let unifier = (unifier.get_shared_unifier(), *primitives); + + ctx.registry.add_task(CodeGenTask { + symbol_name: symbol.clone(), + body: instance.body.clone(), + resolver: resolver.as_ref().unwrap().clone(), + calls: instance.calls.clone(), + subst, + signature, + unifier, + }); + symbol + }) + } else { + unreachable!() + } +} + +pub fn gen_call<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, BasicValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + params: Vec<(Option, BasicValueEnum<'ctx>)>, +) -> Option> { + let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); + let key = ctx.get_subst_key(obj.map(|a| a.0), fun.0, None); + let symbol = { + // make sure this lock guard is dropped at the end of this scope... + let def = definition.read(); + match &*def { + TopLevelDef::Function { instance_to_symbol, codegen_callback, .. } => { + if let Some(callback) = codegen_callback { + return callback.run(ctx, obj, fun, params); + } + instance_to_symbol.get(&key).cloned() + } + TopLevelDef::Class { .. } => { + return Some(generator.gen_constructor(ctx, fun.0, &*def, params)) + } + } + } + .unwrap_or_else(|| { + generator.gen_func_instance(ctx, obj, (fun.0, &mut *definition.write(), key)) + }); + let fun_val = ctx.module.get_function(&symbol).unwrap_or_else(|| { + let mut args = fun.0.args.clone(); + if let Some(obj) = &obj { + args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None }); + } + let params = args.iter().map(|arg| ctx.get_llvm_type(arg.ty)).collect_vec(); + let fun_ty = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { + ctx.ctx.void_type().fn_type(¶ms, false) + } else { + ctx.get_llvm_type(fun.0.ret).fn_type(¶ms, false) + }; + ctx.module.add_function(&symbol, fun_ty, None) + }); + let mut keys = fun.0.args.clone(); + let mut mapping = HashMap::new(); + for (key, value) in params.into_iter() { + mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + } + // default value handling + for k in keys.into_iter() { + mapping.insert(k.name, ctx.gen_symbol_val(&k.default_value.unwrap())); + } + // reorder the parameters + let mut params = fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); + if let Some(obj) = obj { + params.insert(0, obj.1); + } + ctx.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left() +} + +pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + expr: &Expr>, +) -> Option> { + let zero = ctx.ctx.i32_type().const_int(0, false); + Some(match &expr.node { + ExprKind::Constant { value, .. } => { + let ty = expr.custom.unwrap(); + ctx.gen_const(value, ty) + } + ExprKind::Name { id, .. } => { + let ptr = ctx.var_assignment.get(id); + if let Some(ptr) = ptr { + ctx.builder.build_load(*ptr, "load") + } else { + let resolver = ctx.resolver.clone(); + resolver.get_symbol_value(*id, ctx).unwrap() + } + } + ExprKind::List { elts, .. } => { + // this shall be optimized later for constant primitive lists... + // we should use memcpy for that instead of generating thousands of stores + let elements = elts.iter().map(|x| generator.gen_expr(ctx, x).unwrap()).collect_vec(); + let ty = if elements.is_empty() { + ctx.ctx.i32_type().into() + } else { + elements[0].get_type() + }; + let arr_ptr = ctx.builder.build_array_alloca( + ty, + ctx.ctx.i32_type().const_int(elements.len() as u64, false), + "tmparr", + ); + let arr_ty = ctx.ctx.struct_type( + &[ctx.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()], + false, + ); + let arr_str_ptr = ctx.builder.build_alloca(arr_ty, "tmparrstr"); + unsafe { + let len_ptr = + ctx.builder.build_in_bounds_gep(arr_str_ptr, &[zero, zero], "len_ptr"); + ctx.builder.build_store( + len_ptr, + ctx.ctx.i32_type().const_int(elements.len() as u64, false), + ); + let ptr_to_arr = ctx.builder.build_in_bounds_gep( + arr_str_ptr, + &[zero, ctx.ctx.i32_type().const_int(1, false)], + "ptr_to_arr", + ); + ctx.builder.build_store(ptr_to_arr, arr_ptr); + let i32_type = ctx.ctx.i32_type(); + for (i, v) in elements.iter().enumerate() { + let elem_ptr = ctx.builder.build_in_bounds_gep( + arr_ptr, + &[i32_type.const_int(i as u64, false)], + "elem_ptr", + ); + ctx.builder.build_store(elem_ptr, *v); } } - ExprKind::UnaryOp { op, operand } => { - let ty = self.unifier.get_representative(operand.custom.unwrap()); - let val = self.gen_expr(operand).unwrap(); - if ty == self.primitives.bool { - let val = assert_int_val(val); - match op { - ast::Unaryop::Invert | ast::Unaryop::Not => { - self.builder.build_not(val, "not").into() - } - _ => val.into(), - } - } else if [self.primitives.int32, self.primitives.int64].contains(&ty) { - let val = assert_int_val(val); - match op { - ast::Unaryop::USub => self.builder.build_int_neg(val, "neg").into(), - ast::Unaryop::Invert => self.builder.build_not(val, "not").into(), - ast::Unaryop::Not => self - .builder - .build_int_compare( - inkwell::IntPredicate::EQ, - val, - val.get_type().const_zero(), - "not", - ) - .into(), - _ => val.into(), - } - } else if ty == self.primitives.float { - let val = if let BasicValueEnum::FloatValue(val) = val { - val - } else { - unreachable!() - }; - match op { - ast::Unaryop::USub => self.builder.build_float_neg(val, "neg").into(), - ast::Unaryop::Not => self - .builder - .build_float_compare( - inkwell::FloatPredicate::OEQ, - val, - val.get_type().const_zero(), - "not", - ) - .into(), - _ => val.into(), - } - } else { - unimplemented!() + arr_str_ptr.into() + } + ExprKind::Tuple { elts, .. } => { + let element_val = + elts.iter().map(|x| generator.gen_expr(ctx, x).unwrap()).collect_vec(); + let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); + let tuple_ty = ctx.ctx.struct_type(&element_ty, false); + let tuple_ptr = ctx.builder.build_alloca(tuple_ty, "tuple"); + for (i, v) in element_val.into_iter().enumerate() { + unsafe { + let ptr = ctx.builder.build_in_bounds_gep( + tuple_ptr, + &[zero, ctx.ctx.i32_type().const_int(i as u64, false)], + "ptr", + ); + ctx.builder.build_store(ptr, v); } } - ExprKind::Compare { left, ops, comparators } => { - izip!( - chain(once(left.as_ref()), comparators.iter()), - comparators.iter(), - ops.iter(), - ) + tuple_ptr.into() + } + ExprKind::Attribute { value, attr, .. } => { + // note that we would handle class methods directly in calls + let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + let val = generator.gen_expr(ctx, value).unwrap(); + let ptr = assert_pointer_val(val); + unsafe { + let ptr = ctx.builder.build_in_bounds_gep( + ptr, + &[zero, ctx.ctx.i32_type().const_int(index as u64, false)], + "attr", + ); + ctx.builder.build_load(ptr, "field") + } + } + ExprKind::BoolOp { op, values } => { + // requires conditional branches for short-circuiting... + let left = assert_int_val(generator.gen_expr(ctx, &values[0]).unwrap()); + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let a_bb = ctx.ctx.append_basic_block(current, "a"); + let b_bb = ctx.ctx.append_basic_block(current, "b"); + let cont_bb = ctx.ctx.append_basic_block(current, "cont"); + ctx.builder.build_conditional_branch(left, a_bb, b_bb); + let (a, b) = match op { + Boolop::Or => { + ctx.builder.position_at_end(a_bb); + let a = ctx.ctx.bool_type().const_int(1, false); + ctx.builder.build_unconditional_branch(cont_bb); + ctx.builder.position_at_end(b_bb); + let b = assert_int_val(generator.gen_expr(ctx, &values[1]).unwrap()); + ctx.builder.build_unconditional_branch(cont_bb); + (a, b) + } + Boolop::And => { + ctx.builder.position_at_end(a_bb); + let a = assert_int_val(generator.gen_expr(ctx, &values[1]).unwrap()); + ctx.builder.build_unconditional_branch(cont_bb); + ctx.builder.position_at_end(b_bb); + let b = ctx.ctx.bool_type().const_int(0, false); + ctx.builder.build_unconditional_branch(cont_bb); + (a, b) + } + }; + ctx.builder.position_at_end(cont_bb); + let phi = ctx.builder.build_phi(ctx.ctx.bool_type(), "phi"); + phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); + phi.as_basic_value() + } + ExprKind::BinOp { op, left, right } => { + let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); + let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); + let left = generator.gen_expr(ctx, left).unwrap(); + let right = generator.gen_expr(ctx, right).unwrap(); + + // we can directly compare the types, because we've got their representatives + // which would be unchanged until further unification, which we would never do + // when doing code generation for function instances + if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { + ctx.gen_int_ops(op, left, right) + } else if ty1 == ty2 && ctx.primitives.float == ty1 { + ctx.gen_float_ops(op, left, right) + } else { + unimplemented!() + } + } + ExprKind::UnaryOp { op, operand } => { + let ty = ctx.unifier.get_representative(operand.custom.unwrap()); + let val = generator.gen_expr(ctx, operand).unwrap(); + if ty == ctx.primitives.bool { + let val = assert_int_val(val); + match op { + ast::Unaryop::Invert | ast::Unaryop::Not => { + ctx.builder.build_not(val, "not").into() + } + _ => val.into(), + } + } else if [ctx.primitives.int32, ctx.primitives.int64].contains(&ty) { + let val = assert_int_val(val); + match op { + ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").into(), + ast::Unaryop::Invert => ctx.builder.build_not(val, "not").into(), + ast::Unaryop::Not => ctx + .builder + .build_int_compare( + inkwell::IntPredicate::EQ, + val, + val.get_type().const_zero(), + "not", + ) + .into(), + _ => val.into(), + } + } else if ty == ctx.primitives.float { + let val = + if let BasicValueEnum::FloatValue(val) = val { val } else { unreachable!() }; + match op { + ast::Unaryop::USub => ctx.builder.build_float_neg(val, "neg").into(), + ast::Unaryop::Not => ctx + .builder + .build_float_compare( + inkwell::FloatPredicate::OEQ, + val, + val.get_type().const_zero(), + "not", + ) + .into(), + _ => val.into(), + } + } else { + unimplemented!() + } + } + ExprKind::Compare { left, ops, comparators } => { + izip!(chain(once(left.as_ref()), comparators.iter()), comparators.iter(), ops.iter(),) .fold(None, |prev, (lhs, rhs, op)| { - let ty = self.unifier.get_representative(lhs.custom.unwrap()); + let ty = ctx.unifier.get_representative(lhs.custom.unwrap()); let current = - if [self.primitives.int32, self.primitives.int64, self.primitives.bool] + if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.bool] .contains(&ty) { let (lhs, rhs) = if let ( BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs), - ) = - (self.gen_expr(lhs).unwrap(), self.gen_expr(rhs).unwrap()) - { + ) = ( + generator.gen_expr(ctx, lhs).unwrap(), + generator.gen_expr(ctx, rhs).unwrap(), + ) { (lhs, rhs) } else { unreachable!() @@ -601,14 +613,15 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ast::Cmpop::GtE => inkwell::IntPredicate::SGE, _ => unreachable!(), }; - self.builder.build_int_compare(op, lhs, rhs, "cmp") - } else if ty == self.primitives.float { + ctx.builder.build_int_compare(op, lhs, rhs, "cmp") + } else if ty == ctx.primitives.float { let (lhs, rhs) = if let ( BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs), - ) = - (self.gen_expr(lhs).unwrap(), self.gen_expr(rhs).unwrap()) - { + ) = ( + generator.gen_expr(ctx, lhs).unwrap(), + generator.gen_expr(ctx, rhs).unwrap(), + ) { (lhs, rhs) } else { unreachable!() @@ -622,128 +635,130 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ast::Cmpop::GtE => inkwell::FloatPredicate::OGE, _ => unreachable!(), }; - self.builder.build_float_compare(op, lhs, rhs, "cmp") + ctx.builder.build_float_compare(op, lhs, rhs, "cmp") } else { unimplemented!() }; - prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current)) + prev.map(|v| ctx.builder.build_and(v, current, "cmp")).or(Some(current)) }) .unwrap() .into() // as there should be at least 1 element, it should never be none - } - ExprKind::IfExp { test, body, orelse } => { - let test = assert_int_val(self.gen_expr(test).unwrap()); - let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); - let then_bb = self.ctx.append_basic_block(current, "then"); - let else_bb = self.ctx.append_basic_block(current, "else"); - let cont_bb = self.ctx.append_basic_block(current, "cont"); - self.builder.build_conditional_branch(test, then_bb, else_bb); - self.builder.position_at_end(then_bb); - let a = self.gen_expr(body).unwrap(); - self.builder.build_unconditional_branch(cont_bb); - self.builder.position_at_end(else_bb); - let b = self.gen_expr(orelse).unwrap(); - self.builder.build_unconditional_branch(cont_bb); - self.builder.position_at_end(cont_bb); - let phi = self.builder.build_phi(a.get_type(), "ifexpr"); - phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]); - phi.as_basic_value() - } - ExprKind::Call { func, args, keywords } => { - let mut params = - args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); - let kw_iter = keywords.iter().map(|kw| { - (Some(*kw.node.arg.as_ref().unwrap()), self.gen_expr(&kw.node.value).unwrap()) - }); - params.extend(kw_iter); - let call = self.calls.get(&expr.location.into()); - let signature = match call { - Some(call) => self.unifier.get_call_signature(*call).unwrap(), - None => { - let ty = func.custom.unwrap(); - if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { - sign.borrow().clone() - } else { - unreachable!() - } - } - }; - match &func.as_ref().node { - ExprKind::Name { id, .. } => { - // TODO: handle primitive casts and function pointers - let fun = - self.resolver.get_identifier_def(*id).expect("Unknown identifier"); - return self.gen_call(None, (&signature, fun), params); - } - ExprKind::Attribute { value, attr, .. } => { - let val = self.gen_expr(value).unwrap(); - let id = if let TypeEnum::TObj { obj_id, .. } = - &*self.unifier.get_ty(value.custom.unwrap()) - { - *obj_id - } else { - unreachable!() - }; - let fun_id = { - let defs = self.top_level.definitions.read(); - let obj_def = defs.get(id.0).unwrap().read(); - if let TopLevelDef::Class { methods, .. } = &*obj_def { - let mut fun_id = None; - for (name, _, id) in methods.iter() { - if name == attr { - fun_id = Some(*id); - } - } - fun_id.unwrap() - } else { - unreachable!() - } - }; - return self.gen_call( - Some((value.custom.unwrap(), val)), - (&signature, fun_id), - params, - ); - } - _ => unimplemented!(), - } - } - ExprKind::Subscript { value, slice, .. } => { - if let TypeEnum::TList { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { - if let ExprKind::Slice { .. } = slice.node { - unimplemented!() + } + ExprKind::IfExp { test, body, orelse } => { + let test = assert_int_val(generator.gen_expr(ctx, test).unwrap()); + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let then_bb = ctx.ctx.append_basic_block(current, "then"); + let else_bb = ctx.ctx.append_basic_block(current, "else"); + let cont_bb = ctx.ctx.append_basic_block(current, "cont"); + ctx.builder.build_conditional_branch(test, then_bb, else_bb); + ctx.builder.position_at_end(then_bb); + let a = generator.gen_expr(ctx, body).unwrap(); + ctx.builder.build_unconditional_branch(cont_bb); + ctx.builder.position_at_end(else_bb); + let b = generator.gen_expr(ctx, orelse).unwrap(); + ctx.builder.build_unconditional_branch(cont_bb); + ctx.builder.position_at_end(cont_bb); + let phi = ctx.builder.build_phi(a.get_type(), "ifexpr"); + phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]); + phi.as_basic_value() + } + ExprKind::Call { func, args, keywords } => { + let mut params = + args.iter().map(|arg| (None, generator.gen_expr(ctx, arg).unwrap())).collect_vec(); + let kw_iter = keywords.iter().map(|kw| { + ( + Some(*kw.node.arg.as_ref().unwrap()), + generator.gen_expr(ctx, &kw.node.value).unwrap(), + ) + }); + params.extend(kw_iter); + let call = ctx.calls.get(&expr.location.into()); + let signature = match call { + Some(call) => ctx.unifier.get_call_signature(*call).unwrap(), + None => { + let ty = func.custom.unwrap(); + if let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) { + sign.borrow().clone() } else { - // TODO: bound check - let i32_type = self.ctx.i32_type(); - let v = assert_pointer_val(self.gen_expr(value).unwrap()); - let index = assert_int_val(self.gen_expr(slice).unwrap()); - unsafe { - let ptr_to_arr = self.builder.build_in_bounds_gep( - v, - &[i32_type.const_zero(), i32_type.const_int(1, false)], - "ptr_to_arr", - ); - let arr_ptr = - assert_pointer_val(self.builder.build_load(ptr_to_arr, "loadptr")); - let ptr = self.builder.build_gep(arr_ptr, &[index], "loadarrgep"); - self.builder.build_load(ptr, "loadarr") - } - } - } else { - let i32_type = self.ctx.i32_type(); - let v = assert_pointer_val(self.gen_expr(value).unwrap()); - let index = assert_int_val(self.gen_expr(slice).unwrap()); - unsafe { - let ptr_to_elem = self.builder.build_in_bounds_gep( - v, - &[i32_type.const_zero(), index], - "ptr_to_elem", - ); - self.builder.build_load(ptr_to_elem, "loadelem") + unreachable!() } } + }; + match &func.as_ref().node { + ExprKind::Name { id, .. } => { + // TODO: handle primitive casts and function pointers + let fun = ctx.resolver.get_identifier_def(*id).expect("Unknown identifier"); + return generator.gen_call(ctx, None, (&signature, fun), params); + } + ExprKind::Attribute { value, attr, .. } => { + let val = generator.gen_expr(ctx, value).unwrap(); + let id = if let TypeEnum::TObj { obj_id, .. } = + &*ctx.unifier.get_ty(value.custom.unwrap()) + { + *obj_id + } else { + unreachable!() + }; + let fun_id = { + let defs = ctx.top_level.definitions.read(); + let obj_def = defs.get(id.0).unwrap().read(); + if let TopLevelDef::Class { methods, .. } = &*obj_def { + let mut fun_id = None; + for (name, _, id) in methods.iter() { + if name == attr { + fun_id = Some(*id); + } + } + fun_id.unwrap() + } else { + unreachable!() + } + }; + return generator.gen_call( + ctx, + Some((value.custom.unwrap(), val)), + (&signature, fun_id), + params, + ); + } + _ => unimplemented!(), } - _ => unimplemented!(), - }) - } + } + ExprKind::Subscript { value, slice, .. } => { + if let TypeEnum::TList { .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { + if let ExprKind::Slice { .. } = slice.node { + unimplemented!() + } else { + // TODO: bound check + let i32_type = ctx.ctx.i32_type(); + let v = assert_pointer_val(generator.gen_expr(ctx, value).unwrap()); + let index = assert_int_val(generator.gen_expr(ctx, slice).unwrap()); + unsafe { + let ptr_to_arr = ctx.builder.build_in_bounds_gep( + v, + &[i32_type.const_zero(), i32_type.const_int(1, false)], + "ptr_to_arr", + ); + let arr_ptr = + assert_pointer_val(ctx.builder.build_load(ptr_to_arr, "loadptr")); + let ptr = ctx.builder.build_gep(arr_ptr, &[index], "loadarrgep"); + ctx.builder.build_load(ptr, "loadarr") + } + } + } else { + let i32_type = ctx.ctx.i32_type(); + let v = assert_pointer_val(generator.gen_expr(ctx, value).unwrap()); + let index = assert_int_val(generator.gen_expr(ctx, slice).unwrap()); + unsafe { + let ptr_to_elem = ctx.builder.build_in_bounds_gep( + v, + &[i32_type.const_zero(), index], + "ptr_to_elem", + ); + ctx.builder.build_load(ptr_to_elem, "loadelem") + } + } + } + _ => unimplemented!(), + }) } diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs new file mode 100644 index 00000000..023c7476 --- /dev/null +++ b/nac3core/src/codegen/generator.rs @@ -0,0 +1,142 @@ +use crate::{ + codegen::{expr::*, stmt::*, CodeGenContext}, + toplevel::{DefinitionId, TopLevelDef}, + typecheck::typedef::{FunSignature, Type}, +}; +use inkwell::values::{BasicValueEnum, PointerValue}; +use rustpython_parser::ast::{Expr, Stmt, StrRef}; + +pub trait CodeGenerator { + /// Return the module name for the code generator. + fn get_name(&self) -> &str; + + /// Generate function call and returns the function return value. + /// - obj: Optional object for method call. + /// - fun: Function signature and definition ID. + /// - params: Function parameters. Note that this does not include the object even if the + /// function is a class method. + fn gen_call<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, BasicValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + params: Vec<(Option, BasicValueEnum<'ctx>)>, + ) -> Option> { + gen_call(self, ctx, obj, fun, params) + } + + /// Generate object constructor and returns the constructed object. + /// - signature: Function signature of the contructor. + /// - def: Class definition for the constructor class. + /// - params: Function parameters. + fn gen_constructor<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + signature: &FunSignature, + def: &TopLevelDef, + params: Vec<(Option, BasicValueEnum<'ctx>)>, + ) -> BasicValueEnum<'ctx> { + gen_constructor(self, ctx, signature, def, params) + } + + /// Generate a function instance. + /// - obj: Optional object for method call. + /// - fun: Function signature, definition ID and the substitution key. + /// - params: Function parameters. Note that this does not include the object even if the + /// function is a class method. + /// Note that this function should check if the function is generated in another thread (due to + /// possible race condition), see the default implementation for an example. + fn gen_func_instance<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, BasicValueEnum<'ctx>)>, + fun: (&FunSignature, &mut TopLevelDef, String), + ) -> String { + gen_func_instance(ctx, obj, fun) + } + + /// Generate the code for an expression. + fn gen_expr<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + expr: &Expr>, + ) -> Option> { + gen_expr(self, ctx, expr) + } + + /// Allocate memory for a variable and return a pointer pointing to it. + /// The default implementation places the allocations at the start of the function. + fn gen_var_alloc<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + ty: Type, + ) -> PointerValue<'ctx> { + gen_var(ctx, ty) + } + + /// Return a pointer pointing to the target of the expression. + fn gen_store_target<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + pattern: &Expr>, + ) -> PointerValue<'ctx> { + gen_store_target(self, ctx, pattern) + } + + /// Generate code for an assignment expression. + fn gen_assign<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + target: &Expr>, + value: BasicValueEnum<'ctx>, + ) { + gen_assign(self, ctx, target, value) + } + + /// Generate code for a while expression. + /// Return true if the while loop must early return + fn gen_while<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> bool { + gen_while(self, ctx, stmt); + false + } + + /// Generate code for an if expression. + /// Return true if the statement must early return + fn gen_if<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> bool { + gen_if(self, ctx, stmt) + } + + /// Generate code for a statement + /// Return true if the statement must early return + fn gen_stmt<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> bool { + gen_stmt(self, ctx, stmt) + } +} + +pub struct DefaultCodeGenerator { + name: String, +} + +impl DefaultCodeGenerator { + pub fn new(name: String) -> DefaultCodeGenerator { + DefaultCodeGenerator { name } + } +} + +impl CodeGenerator for DefaultCodeGenerator { + fn get_name(&self) -> &str { + &self.name + } +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 81bf3e83..04039028 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -27,11 +27,14 @@ use std::sync::{ use std::thread; mod expr; +mod generator; mod stmt; #[cfg(test)] mod test; +pub use generator::{CodeGenerator, DefaultCodeGenerator}; + pub struct CodeGenContext<'ctx, 'a> { pub ctx: &'ctx Context, pub builder: Builder<'ctx>, @@ -77,8 +80,8 @@ pub struct WorkerRegistry { } impl WorkerRegistry { - pub fn create_workers( - names: &[&str], + pub fn create_workers( + generators: Vec>, top_level_ctx: Arc, f: Arc, ) -> (Arc, Vec>) { @@ -89,21 +92,20 @@ impl WorkerRegistry { let registry = Arc::new(WorkerRegistry { sender: Arc::new(sender), receiver: Arc::new(receiver), - thread_count: names.len(), + thread_count: generators.len(), panicked: AtomicBool::new(false), task_count, wait_condvar, }); let mut handles = Vec::new(); - for name in names.iter() { + for mut generator in generators.into_iter() { let top_level_ctx = top_level_ctx.clone(); let registry = registry.clone(); let registry2 = registry.clone(); - let name = name.to_string(); let f = f.clone(); let handle = thread::spawn(move || { - registry.worker_thread(name, top_level_ctx, f); + registry.worker_thread(generator.as_mut(), top_level_ctx, f); }); let handle = thread::spawn(move || { if let Err(e) = handle.join() { @@ -156,18 +158,19 @@ impl WorkerRegistry { self.sender.send(Some(task)).unwrap(); } - fn worker_thread( + fn worker_thread( &self, - module_name: String, + generator: &mut G, top_level_ctx: Arc, f: Arc, ) { let context = Context::create(); let mut builder = context.create_builder(); - let mut module = context.create_module(&module_name); + let mut module = context.create_module(generator.get_name()); while let Some(task) = self.receiver.recv().unwrap() { - let result = gen_func(&context, self, builder, module, task, top_level_ctx.clone()); + let result = + gen_func(&context, generator, self, builder, module, task, top_level_ctx.clone()); builder = result.0; module = result.1; *self.task_count.lock() -= 1; @@ -243,8 +246,9 @@ fn get_llvm_type<'ctx>( }) } -pub fn gen_func<'ctx>( +pub fn gen_func<'ctx, G: CodeGenerator + ?Sized>( context: &'ctx Context, + generator: &mut G, registry: &WorkerRegistry, builder: Builder<'ctx>, module: Module<'ctx>, @@ -351,7 +355,7 @@ pub fn gen_func<'ctx>( let mut returned = false; for stmt in task.body.iter() { - returned = code_gen_context.gen_stmt(stmt); + returned = generator.gen_stmt(&mut code_gen_context, stmt); if returned { break; } diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index b40d8d1c..806472b9 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1,221 +1,250 @@ use super::{ expr::{assert_int_val, assert_pointer_val}, - CodeGenContext, + CodeGenContext, CodeGenerator, }; use crate::typecheck::typedef::Type; use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind}; -impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { - fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> { - // put the alloca in init block - let current = self.builder.get_insert_block().unwrap(); - // position before the last branching instruction... - self.builder.position_before(&self.init_bb.get_last_instruction().unwrap()); - let ty = self.get_llvm_type(ty); - let ptr = self.builder.build_alloca(ty, "tmp"); - self.builder.position_at_end(current); - ptr - } +pub fn gen_var<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type) -> PointerValue<'ctx> { + // put the alloca in init block + let current = ctx.builder.get_insert_block().unwrap(); + // position before the last branching instruction... + ctx.builder.position_before(&ctx.init_bb.get_last_instruction().unwrap()); + let ty = ctx.get_llvm_type(ty); + let ptr = ctx.builder.build_alloca(ty, "tmp"); + ctx.builder.position_at_end(current); + ptr +} - fn parse_pattern(&mut self, pattern: &Expr>) -> PointerValue<'ctx> { - // very similar to gen_expr, but we don't do an extra load at the end - // and we flatten nested tuples - match &pattern.node { - ExprKind::Name { id, .. } => { - self.var_assignment.get(id).cloned().unwrap_or_else(|| { - let ptr = self.gen_var(pattern.custom.unwrap()); - self.var_assignment.insert(*id, ptr); - ptr - }) - } - ExprKind::Attribute { value, attr, .. } => { - let index = self.get_attr_index(value.custom.unwrap(), *attr); - let val = self.gen_expr(value).unwrap(); - let ptr = if let BasicValueEnum::PointerValue(v) = val { - v - } else { - unreachable!(); - }; - unsafe { - self.builder.build_in_bounds_gep( - ptr, - &[ - self.ctx.i32_type().const_zero(), - self.ctx.i32_type().const_int(index as u64, false), - ], - "attr", - ) - } - } - ExprKind::Subscript { value, slice, .. } => { - let i32_type = self.ctx.i32_type(); - let v = assert_pointer_val(self.gen_expr(value).unwrap()); - let index = assert_int_val(self.gen_expr(slice).unwrap()); - unsafe { - let ptr_to_arr = self.builder.build_in_bounds_gep( - v, - &[i32_type.const_zero(), i32_type.const_int(1, false)], - "ptr_to_arr", - ); - let arr_ptr = - assert_pointer_val(self.builder.build_load(ptr_to_arr, "loadptr")); - self.builder.build_gep(arr_ptr, &[index], "loadarrgep") - } - } - _ => unreachable!(), - } - } - - fn gen_assignment(&mut self, target: &Expr>, value: BasicValueEnum<'ctx>) { - let i32_type = self.ctx.i32_type(); - if let ExprKind::Tuple { elts, .. } = &target.node { - if let BasicValueEnum::PointerValue(ptr) = value { - for (i, elt) in elts.iter().enumerate() { - unsafe { - let t = self.builder.build_in_bounds_gep( - ptr, - &[i32_type.const_zero(), i32_type.const_int(i as u64, false)], - "elem", - ); - let v = self.builder.build_load(t, "tmpload"); - self.gen_assignment(elt, v); - } - } +pub fn gen_store_target<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + pattern: &Expr>, +) -> PointerValue<'ctx> { + // very similar to gen_expr, but we don't do an extra load at the end + // and we flatten nested tuples + match &pattern.node { + ExprKind::Name { id, .. } => ctx.var_assignment.get(id).cloned().unwrap_or_else(|| { + let ptr = generator.gen_var_alloc(ctx, pattern.custom.unwrap()); + ctx.var_assignment.insert(*id, ptr); + ptr + }), + ExprKind::Attribute { value, attr, .. } => { + let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + let val = generator.gen_expr(ctx, value).unwrap(); + let ptr = if let BasicValueEnum::PointerValue(v) = val { + v } else { - unreachable!() + unreachable!(); + }; + unsafe { + ctx.builder.build_in_bounds_gep( + ptr, + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(index as u64, false), + ], + "attr", + ) } - } else { - let ptr = self.parse_pattern(target); - self.builder.build_store(ptr, value); } - } - - // return true if it contains terminator - pub fn gen_stmt(&mut self, stmt: &Stmt>) -> bool { - match &stmt.node { - StmtKind::Pass => {}, - StmtKind::Expr { value } => { - self.gen_expr(value); + ExprKind::Subscript { value, slice, .. } => { + let i32_type = ctx.ctx.i32_type(); + let v = assert_pointer_val(generator.gen_expr(ctx, value).unwrap()); + let index = assert_int_val(generator.gen_expr(ctx, slice).unwrap()); + unsafe { + let ptr_to_arr = ctx.builder.build_in_bounds_gep( + v, + &[i32_type.const_zero(), i32_type.const_int(1, false)], + "ptr_to_arr", + ); + let arr_ptr = assert_pointer_val(ctx.builder.build_load(ptr_to_arr, "loadptr")); + ctx.builder.build_gep(arr_ptr, &[index], "loadarrgep") } - StmtKind::Return { value } => { - let value = value.as_ref().map(|v| self.gen_expr(v).unwrap()); - let value = value.as_ref().map(|v| v as &dyn BasicValue); - self.builder.build_return(value); - return true; - } - StmtKind::AnnAssign { target, value, .. } => { - if let Some(value) = value { - let value = self.gen_expr(value).unwrap(); - self.gen_assignment(target, value); - } - } - StmtKind::Assign { targets, value, .. } => { - let value = self.gen_expr(value).unwrap(); - for target in targets.iter() { - self.gen_assignment(target, value); - } - } - StmtKind::Continue => { - self.builder.build_unconditional_branch(self.loop_bb.unwrap().0); - return true; - } - StmtKind::Break => { - self.builder.build_unconditional_branch(self.loop_bb.unwrap().1); - return true; - } - StmtKind::If { test, body, orelse } => { - let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); - let test_bb = self.ctx.append_basic_block(current, "test"); - let body_bb = self.ctx.append_basic_block(current, "body"); - let mut cont_bb = None; // self.ctx.append_basic_block(current, "cont"); - // if there is no orelse, we just go to cont_bb - let orelse_bb = if orelse.is_empty() { - cont_bb = Some(self.ctx.append_basic_block(current, "cont")); - cont_bb.unwrap() - } else { - self.ctx.append_basic_block(current, "orelse") - }; - self.builder.build_unconditional_branch(test_bb); - self.builder.position_at_end(test_bb); - let test = self.gen_expr(test).unwrap(); - if let BasicValueEnum::IntValue(test) = test { - self.builder.build_conditional_branch(test, body_bb, orelse_bb); - } else { - unreachable!() - }; - self.builder.position_at_end(body_bb); - let mut exited = false; - for stmt in body.iter() { - exited = self.gen_stmt(stmt); - if exited { - break; - } - } - if !exited { - if cont_bb.is_none() { - cont_bb = Some(self.ctx.append_basic_block(current, "cont")); - } - self.builder.build_unconditional_branch(cont_bb.unwrap()); - } - if !orelse.is_empty() { - exited = false; - self.builder.position_at_end(orelse_bb); - for stmt in orelse.iter() { - exited = self.gen_stmt(stmt); - if exited { - break; - } - } - if !exited { - if cont_bb.is_none() { - cont_bb = Some(self.ctx.append_basic_block(current, "cont")); - } - self.builder.build_unconditional_branch(cont_bb.unwrap()); - } - } - if let Some(cont_bb) = cont_bb { - self.builder.position_at_end(cont_bb); - } - } - StmtKind::While { test, body, orelse } => { - let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); - let test_bb = self.ctx.append_basic_block(current, "test"); - let body_bb = self.ctx.append_basic_block(current, "body"); - let cont_bb = self.ctx.append_basic_block(current, "cont"); - // if there is no orelse, we just go to cont_bb - let orelse_bb = if orelse.is_empty() { - cont_bb - } else { - self.ctx.append_basic_block(current, "orelse") - }; - // store loop bb information and restore it later - let loop_bb = self.loop_bb.replace((test_bb, cont_bb)); - self.builder.build_unconditional_branch(test_bb); - self.builder.position_at_end(test_bb); - let test = self.gen_expr(test).unwrap(); - if let BasicValueEnum::IntValue(test) = test { - self.builder.build_conditional_branch(test, body_bb, orelse_bb); - } else { - unreachable!() - }; - self.builder.position_at_end(body_bb); - for stmt in body.iter() { - self.gen_stmt(stmt); - } - self.builder.build_unconditional_branch(test_bb); - if !orelse.is_empty() { - self.builder.position_at_end(orelse_bb); - for stmt in orelse.iter() { - self.gen_stmt(stmt); - } - self.builder.build_unconditional_branch(cont_bb); - } - self.builder.position_at_end(cont_bb); - self.loop_bb = loop_bb; - } - _ => unimplemented!("{:?}", stmt), - }; - false + } + _ => unreachable!(), } } + +pub fn gen_assign<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + target: &Expr>, + value: BasicValueEnum<'ctx>, +) { + let i32_type = ctx.ctx.i32_type(); + if let ExprKind::Tuple { elts, .. } = &target.node { + if let BasicValueEnum::PointerValue(ptr) = value { + for (i, elt) in elts.iter().enumerate() { + unsafe { + let t = ctx.builder.build_in_bounds_gep( + ptr, + &[i32_type.const_zero(), i32_type.const_int(i as u64, false)], + "elem", + ); + let v = ctx.builder.build_load(t, "tmpload"); + generator.gen_assign(ctx, elt, v); + } + } + } else { + unreachable!() + } + } else { + let ptr = generator.gen_store_target(ctx, target); + ctx.builder.build_store(ptr, value); + } +} + +pub fn gen_while<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, +) { + if let StmtKind::While { test, body, orelse } = &stmt.node { + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let test_bb = ctx.ctx.append_basic_block(current, "test"); + let body_bb = ctx.ctx.append_basic_block(current, "body"); + let cont_bb = ctx.ctx.append_basic_block(current, "cont"); + // if there is no orelse, we just go to cont_bb + let orelse_bb = + if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "orelse") }; + // store loop bb information and restore it later + let loop_bb = ctx.loop_bb.replace((test_bb, cont_bb)); + ctx.builder.build_unconditional_branch(test_bb); + ctx.builder.position_at_end(test_bb); + let test = generator.gen_expr(ctx, test).unwrap(); + if let BasicValueEnum::IntValue(test) = test { + ctx.builder.build_conditional_branch(test, body_bb, orelse_bb); + } else { + unreachable!() + }; + ctx.builder.position_at_end(body_bb); + for stmt in body.iter() { + generator.gen_stmt(ctx, stmt); + } + ctx.builder.build_unconditional_branch(test_bb); + if !orelse.is_empty() { + ctx.builder.position_at_end(orelse_bb); + for stmt in orelse.iter() { + generator.gen_stmt(ctx, stmt); + } + ctx.builder.build_unconditional_branch(cont_bb); + } + ctx.builder.position_at_end(cont_bb); + ctx.loop_bb = loop_bb; + } else { + unreachable!() + } +} + +pub fn gen_if<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, +) -> bool { + if let StmtKind::If { test, body, orelse } = &stmt.node { + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let test_bb = ctx.ctx.append_basic_block(current, "test"); + let body_bb = ctx.ctx.append_basic_block(current, "body"); + let mut cont_bb = None; + // if there is no orelse, we just go to cont_bb + let orelse_bb = if orelse.is_empty() { + cont_bb = Some(ctx.ctx.append_basic_block(current, "cont")); + cont_bb.unwrap() + } else { + ctx.ctx.append_basic_block(current, "orelse") + }; + ctx.builder.build_unconditional_branch(test_bb); + ctx.builder.position_at_end(test_bb); + let test = generator.gen_expr(ctx, test).unwrap(); + if let BasicValueEnum::IntValue(test) = test { + ctx.builder.build_conditional_branch(test, body_bb, orelse_bb); + } else { + unreachable!() + }; + ctx.builder.position_at_end(body_bb); + let mut exited = false; + for stmt in body.iter() { + exited = generator.gen_stmt(ctx, stmt); + if exited { + break; + } + } + if !exited { + if cont_bb.is_none() { + cont_bb = Some(ctx.ctx.append_basic_block(current, "cont")); + } + ctx.builder.build_unconditional_branch(cont_bb.unwrap()); + } + let then_exited = exited; + let else_exited = if !orelse.is_empty() { + exited = false; + ctx.builder.position_at_end(orelse_bb); + for stmt in orelse.iter() { + exited = generator.gen_stmt(ctx, stmt); + if exited { + break; + } + } + if !exited { + if cont_bb.is_none() { + cont_bb = Some(ctx.ctx.append_basic_block(current, "cont")); + } + ctx.builder.build_unconditional_branch(cont_bb.unwrap()); + } + exited + } else { + false + }; + if let Some(cont_bb) = cont_bb { + ctx.builder.position_at_end(cont_bb); + } + then_exited && else_exited + } else { + unreachable!() + } +} + +pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, +) -> bool { + match &stmt.node { + StmtKind::Pass => {} + StmtKind::Expr { value } => { + generator.gen_expr(ctx, value); + } + StmtKind::Return { value } => { + let value = value.as_ref().map(|v| generator.gen_expr(ctx, v).unwrap()); + let value = value.as_ref().map(|v| v as &dyn BasicValue); + ctx.builder.build_return(value); + return true; + } + StmtKind::AnnAssign { target, value, .. } => { + if let Some(value) = value { + let value = generator.gen_expr(ctx, value).unwrap(); + generator.gen_assign(ctx, target, value); + } + } + StmtKind::Assign { targets, value, .. } => { + let value = generator.gen_expr(ctx, value).unwrap(); + for target in targets.iter() { + generator.gen_assign(ctx, target, value); + } + } + StmtKind::Continue => { + ctx.builder.build_unconditional_branch(ctx.loop_bb.unwrap().0); + return true; + } + StmtKind::Break => { + ctx.builder.build_unconditional_branch(ctx.loop_bb.unwrap().1); + return true; + } + StmtKind::If { .. } => return generator.gen_if(ctx, stmt), + StmtKind::While { .. } => return generator.gen_while(ctx, stmt), + _ => unimplemented!() + }; + false +} diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 068d6790..9682c69d 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -1,5 +1,5 @@ use crate::{ - codegen::{CodeGenTask, WithCall, WorkerRegistry, CodeGenContext}, + codegen::{CodeGenTask, WithCall, WorkerRegistry, CodeGenContext, DefaultCodeGenerator}, location::Location, symbol_resolver::SymbolResolver, toplevel::{ @@ -72,7 +72,7 @@ fn test_primitives() { class_names: Default::default(), }) as Arc; - let threads = ["test"]; + let threads = vec![DefaultCodeGenerator::new("test".into()).into()]; let signature = FunSignature { args: vec![ FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }, @@ -186,7 +186,7 @@ fn test_primitives() { .trim(); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); }))); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); + let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, f); registry.add_task(task); registry.wait_tasks_complete(handles); } @@ -245,7 +245,7 @@ fn test_simple_call() { unreachable!() } - let threads = ["test"]; + let threads = vec![DefaultCodeGenerator::new("test".into()).into()]; let mut function_data = FunctionData { resolver: resolver.clone(), bound_variables: Vec::new(), @@ -351,7 +351,7 @@ fn test_simple_call() { .trim(); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); }))); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); + let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, f); registry.add_task(task); registry.wait_tasks_complete(handles); } diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 65740ba0..83839edc 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -1,12 +1,16 @@ -use std::env; -use std::fs; -use inkwell::{OptimizationLevel, passes::{PassManager, PassManagerBuilder}, targets::*}; +use inkwell::{ + passes::{PassManager, PassManagerBuilder}, + targets::*, + OptimizationLevel, +}; use nac3core::typecheck::type_inferencer::PrimitiveStore; use rustpython_parser::parser; +use std::env; +use std::fs; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; use nac3core::{ - codegen::{CodeGenTask, WithCall, WorkerRegistry}, + codegen::{CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry}, symbol_resolver::SymbolResolver, toplevel::{composer::TopLevelComposer, TopLevelDef}, typecheck::typedef::FunSignature, @@ -17,7 +21,10 @@ use basic_symbol_resolver::*; fn main() { let demo_name = env::args().nth(1).unwrap(); - let threads: u32 = env::args().nth(2).map(|s| str::parse(&s).unwrap()).unwrap_or(1); + let threads: u32 = env::args() + .nth(2) + .map(|s| str::parse(&s).unwrap()) + .unwrap_or(1); let start = SystemTime::now(); @@ -29,7 +36,7 @@ fn main() { println!("Cannot open input file: {}", err); return; } - }; + }; let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; let (mut composer, builtins_def, builtins_ty) = TopLevelComposer::new(vec![]); @@ -38,23 +45,27 @@ fn main() { id_to_type: builtins_ty.into(), id_to_def: builtins_def.into(), class_names: Default::default(), - }.into(); - let resolver = Arc::new( - Resolver(internal_resolver.clone()) - ) as Arc; + } + .into(); + let resolver = + Arc::new(Resolver(internal_resolver.clone())) as Arc; let setup_time = SystemTime::now(); - println!("setup time: {}ms", setup_time.duration_since(start).unwrap().as_millis()); + println!( + "setup time: {}ms", + setup_time.duration_since(start).unwrap().as_millis() + ); let parser_result = parser::parse_program(&program).unwrap(); let parse_time = SystemTime::now(); - println!("parse time: {}ms", parse_time.duration_since(setup_time).unwrap().as_millis()); + println!( + "parse time: {}ms", + parse_time.duration_since(setup_time).unwrap().as_millis() + ); for stmt in parser_result.into_iter() { - let (name, def_id, ty) = composer.register_top_level( - stmt, - Some(resolver.clone()), - "__main__".into(), - ).unwrap(); + let (name, def_id, ty) = composer + .register_top_level(stmt, Some(resolver.clone()), "__main__".into()) + .unwrap(); internal_resolver.add_id_def(name, def_id); if let Some(ty) = ty { @@ -64,7 +75,13 @@ fn main() { composer.start_analysis(true).unwrap(); let analysis_time = SystemTime::now(); - println!("analysis time: {}ms", analysis_time.duration_since(parse_time).unwrap().as_millis()); + println!( + "analysis time: {}ms", + analysis_time + .duration_since(parse_time) + .unwrap() + .as_millis() + ); let top_level = Arc::new(composer.make_top_level_context()); @@ -119,19 +136,32 @@ fn main() { ) .expect("couldn't create target machine"); target_machine - .write_to_file(module, FileType::Object, Path::new(&format!("{}.o", module.get_name().to_str().unwrap()))) + .write_to_file( + module, + FileType::Object, + Path::new(&format!("{}.o", module.get_name().to_str().unwrap())), + ) .expect("couldn't write module to file"); // println!("IR:\n{}", module.print_to_string().to_str().unwrap()); - }))); - let threads: Vec = (0..threads).map(|i| format!("module{}", i)).collect(); - let threads: Vec<_> = threads.iter().map(|s| s.as_str()).collect(); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); + let threads = (0..threads) + .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{}", i)))) + .collect(); + let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, f); registry.add_task(task); registry.wait_tasks_complete(handles); let final_time = SystemTime::now(); - println!("codegen time (including LLVM): {}ms", final_time.duration_since(analysis_time).unwrap().as_millis()); - println!("total time: {}ms", final_time.duration_since(start).unwrap().as_millis()); + println!( + "codegen time (including LLVM): {}ms", + final_time + .duration_since(analysis_time) + .unwrap() + .as_millis() + ); + println!( + "total time: {}ms", + final_time.duration_since(start).unwrap().as_millis() + ); }