diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8e629f9b0..26ec835c4 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1,12 +1,41 @@ use std::{convert::TryInto, iter::once}; -use crate::top_level::{CodeGenContext, TopLevelDef}; -use crate::typecheck::typedef::{Type, TypeEnum}; -use inkwell::{types::BasicType, values::BasicValueEnum}; +use crate::{ + top_level::DefinitionId, + typecheck::typedef::{Type, TypeEnum}, +}; +use crate::{ + top_level::{CodeGenContext, TopLevelDef}, + typecheck::typedef::FunSignature, +}; +use inkwell::{ + types::{BasicType, BasicTypeEnum}, + values::BasicValueEnum, + AddressSpace, +}; use itertools::{chain, izip, zip, Itertools}; use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator}; impl<'ctx> CodeGenContext<'ctx> { + fn get_subst_key(&mut self, obj: Option, fun: &FunSignature) -> String { + let mut vars = obj + .map(|ty| { + if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) { + params.clone() + } else { + unreachable!() + } + }) + .unwrap_or_default(); + vars.extend(fun.vars.iter()); + let sorted = vars.keys().sorted(); + sorted + .map(|id| { + self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string()) + }) + .join(", ") + } + fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, @@ -22,17 +51,88 @@ impl<'ctx> CodeGenContext<'ctx> { index } + fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> { + use TypeEnum::*; + // we assume the type cache should already contain primitive types, + // and they should be passed by value instead of passing as pointer. + self.type_cache.get(&ty).cloned().unwrap_or_else(|| match &*self.unifier.get_ty(ty) { + TObj { obj_id, fields, .. } => { + // a struct with fields in the order of declaration + let defs = self.top_level.definitions.read(); + let definition = defs.get(obj_id.0).unwrap(); + let ty = if let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read() + { + let fields = fields.borrow(); + let fields = + fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec(); + self.ctx + .struct_type(&fields, false) + .ptr_type(AddressSpace::Generic) + .into() + } else { + unreachable!() + }; + ty + } + TTuple { ty } => { + // a struct with fields in the order present in the tuple + let fields = ty.iter().map(|ty| self.get_llvm_type(*ty)).collect_vec(); + self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into() + } + TList { ty } => { + // a struct with an integer and a pointer to an array + let element_type = self.get_llvm_type(*ty); + let fields = [ + self.ctx.i32_type().into(), + element_type.ptr_type(AddressSpace::Generic).into(), + ]; + self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into() + } + _ => unreachable!(), + }) + } + + fn gen_call( + &mut self, + obj: Option<(Type, BasicValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + params: &[BasicValueEnum<'ctx>], + ret: Type, + ) -> Option> { + let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0); + let defs = self.top_level.definitions.read(); + let definition = defs.get(fun.1.0).unwrap(); + let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { + // TODO: codegen for function that are not yet generated + let symbol = instance_to_symbol.get(&key).unwrap(); + let fun_val = self.module.get_function(symbol).unwrap_or_else(|| { + let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); + let fun_ty = if self.unifier.unioned(ret, self.primitives.none) { + self.ctx.void_type().fn_type(¶ms, false) + } else { + self.get_llvm_type(ret).fn_type(¶ms, false) + }; + self.module.add_function(symbol, fun_ty, None) + }); + // TODO: deal with default parameters and reordering based on keys + self.builder.build_call(fun_val, params, "call").try_as_basic_value().left() + } else { + unreachable!() + }; + val + } + fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> { match value { Constant::Bool(v) => { - assert!(self.unifier.unioned(ty, self.top_level.primitives.bool)); + assert!(self.unifier.unioned(ty, self.primitives.bool)); let ty = self.ctx.bool_type(); ty.const_int(if *v { 1 } else { 0 }, false).into() } Constant::Int(v) => { - let ty = if self.unifier.unioned(ty, self.top_level.primitives.int32) { + let ty = if self.unifier.unioned(ty, self.primitives.int32) { self.ctx.i32_type() - } else if self.unifier.unioned(ty, self.top_level.primitives.int64) { + } else if self.unifier.unioned(ty, self.primitives.int64) { self.ctx.i64_type() } else { unreachable!(); @@ -40,7 +140,7 @@ impl<'ctx> CodeGenContext<'ctx> { ty.const_int(v.try_into().unwrap(), false).into() } Constant::Float(v) => { - assert!(self.unifier.unioned(ty, self.top_level.primitives.float)); + assert!(self.unifier.unioned(ty, self.primitives.float)); let ty = self.ctx.f64_type(); ty.const_float(*v).into() } @@ -134,7 +234,6 @@ impl<'ctx> CodeGenContext<'ctx> { pub fn gen_expr(&mut self, expr: &Expr>) -> BasicValueEnum<'ctx> { let zero = self.ctx.i32_type().const_int(0, false); - let primitives = &self.top_level.primitives; match &expr.node { ExprKind::Constant { value, .. } => { let ty = expr.custom.unwrap(); @@ -146,25 +245,36 @@ impl<'ctx> CodeGenContext<'ctx> { } ExprKind::List { elts, .. } => { // this shall be optimized later for constant primitive lists... + // we should use memcpy for that instead of generating thousands of stores let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec(); let ty = if elements.is_empty() { self.ctx.i32_type().into() } else { elements[0].get_type() }; - // this length includes the leading length element + let arr_ptr = self.builder.build_array_alloca( + ty, + self.ctx.i32_type().const_int(elements.len() as u64, false), + "tmparr", + ); let arr_ty = self.ctx.struct_type( - &[self.ctx.i32_type().into(), ty.array_type(elements.len() as u32).into()], + &[ + self.ctx.i32_type().into(), + ty.ptr_type(AddressSpace::Generic).into(), + ], false, ); - let arr_ptr = self.builder.build_alloca(arr_ty, "tmparr"); + let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr"); unsafe { - let len_ptr = arr_ptr - .const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(0u64, false)]); self.builder.build_store( - len_ptr, + arr_str_ptr.const_in_bounds_gep(&[zero, zero]), self.ctx.i32_type().const_int(elements.len() as u64, false), ); + self.builder.build_store( + arr_str_ptr + .const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(1, false)]), + arr_ptr, + ); let arr_offset = self.ctx.i32_type().const_int(1, false); for (i, v) in elements.iter().enumerate() { let ptr = self.builder.build_in_bounds_gep( @@ -175,7 +285,7 @@ impl<'ctx> CodeGenContext<'ctx> { self.builder.build_store(ptr, *v); } } - arr_ptr.into() + arr_str_ptr.into() } ExprKind::Tuple { elts, .. } => { let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec(); @@ -266,9 +376,9 @@ impl<'ctx> CodeGenContext<'ctx> { // when doing code generation for function instances if ty1 != ty2 { unimplemented!() - } else if [primitives.int32, primitives.int64].contains(&ty1) { + } else if [self.primitives.int32, self.primitives.int64].contains(&ty1) { self.gen_int_ops(op, left, right) - } else if primitives.float == ty1 { + } else if self.primitives.float == ty1 { self.gen_float_ops(op, left, right) } else { unimplemented!() @@ -277,7 +387,7 @@ impl<'ctx> CodeGenContext<'ctx> { ExprKind::UnaryOp { op, operand } => { let ty = self.unifier.get_representative(operand.custom.unwrap()); let val = self.gen_expr(operand); - if ty == primitives.bool { + if ty == self.primitives.bool { let val = if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() }; match op { @@ -286,7 +396,7 @@ impl<'ctx> CodeGenContext<'ctx> { } _ => val.into(), } - } else if [primitives.int32, primitives.int64].contains(&ty) { + } else if [self.primitives.int32, self.primitives.int64].contains(&ty) { let val = if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() }; match op { @@ -303,7 +413,7 @@ impl<'ctx> CodeGenContext<'ctx> { .into(), _ => val.into(), } - } else if ty == primitives.float { + } else if ty == self.primitives.float { let val = if let BasicValueEnum::FloatValue(val) = val { val } else { @@ -334,7 +444,7 @@ impl<'ctx> CodeGenContext<'ctx> { ) .fold(None, |prev, (lhs, rhs, op)| { let ty = lhs.custom.unwrap(); - let current = if [primitives.int32, primitives.int64, primitives.bool] + let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool] .contains(&ty) { let (lhs, rhs) = @@ -355,7 +465,7 @@ impl<'ctx> CodeGenContext<'ctx> { _ => unreachable!(), }; self.builder.build_int_compare(op, lhs, rhs, "cmp") - } else if ty == primitives.float { + } else if ty == self.primitives.float { let (lhs, rhs) = if let ( BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs), diff --git a/nac3core/src/top_level.rs b/nac3core/src/top_level.rs index ef2651cd6..e205af288 100644 --- a/nac3core/src/top_level.rs +++ b/nac3core/src/top_level.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::typedef::{SharedUnifier, Type, Unifier}; use crate::symbol_resolver::SymbolResolver; -use inkwell::{builder::Builder, context::Context, module::Module, values::PointerValue}; +use inkwell::{builder::Builder, context::Context, module::Module, types::BasicTypeEnum, values::PointerValue}; use parking_lot::RwLock; use rustpython_parser::ast::Stmt; @@ -16,14 +16,17 @@ pub enum TopLevelDef { object_id: DefinitionId, // type variables bounded to the class. type_vars: Vec, - // class fields and method signature. + // class fields fields: Vec<(String, Type)>, // class methods, pointing to the corresponding function definition. - methods: Vec<(String, DefinitionId)>, + methods: Vec<(String, Type, DefinitionId)>, // ancestor classes, including itself. ancestors: Vec, }, Function { + // prefix for symbol, should be unique globally, and not ending with numbers + name: String, + // function signature. signature: Type, /// Function instance to symbol mapping /// Key: string representation of type variable values, sorted by variable ID in ascending @@ -48,7 +51,6 @@ pub struct CodeGenTask { } pub struct TopLevelContext { - pub primitives: PrimitiveStore, pub definitions: Arc>>>, pub unifiers: Arc>>, } @@ -61,4 +63,6 @@ pub struct CodeGenContext<'ctx> { pub unifier: Unifier, pub resolver: Box, pub var_assignment: HashMap>, + pub type_cache: HashMap>, + pub primitives: PrimitiveStore, } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index daf7fcfe7..c880b6c2d 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -30,6 +30,8 @@ pub struct Call { pub struct FuncArg { pub name: String, pub ty: Type, + // TODO: change this to an optional value + // for primitive types pub is_optional: bool, } diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs index 7a95a2a20..7475afcef 100644 --- a/nac3core/src/typecheck/unification_table.rs +++ b/nac3core/src/typecheck/unification_table.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] pub struct UnificationKey(usize); pub struct UnificationTable {