From 173102fc56884b477eb495643ad2f6a137495262 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Wed, 25 Aug 2021 15:29:58 +0800 Subject: [PATCH] codegen/expr: function codegen and refactoring --- nac3core/src/codegen/expr.rs | 169 ++++++++++++++---- nac3core/src/codegen/mod.rs | 18 +- nac3core/src/toplevel/mod.rs | 39 ++-- nac3core/src/toplevel/type_annotation.rs | 2 +- nac3core/src/typecheck/type_inferencer/mod.rs | 2 +- nac3core/src/typecheck/typedef/mod.rs | 4 +- 6 files changed, 170 insertions(+), 64 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index a267e86f..036aee97 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1,10 +1,10 @@ use std::{collections::HashMap, convert::TryInto, iter::once}; -use super::{get_llvm_type, CodeGenContext}; use crate::{ + codegen::{get_llvm_type, CodeGenContext, CodeGenTask}, symbol_resolver::SymbolValue, toplevel::{DefinitionId, TopLevelDef}, - typecheck::typedef::{FunSignature, Type, TypeEnum}, + typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, }; use inkwell::{ types::{BasicType, BasicTypeEnum}, @@ -31,7 +31,12 @@ pub fn assert_pointer_val<'ctx>(val: BasicValueEnum<'ctx>) -> PointerValue<'ctx> } impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { - fn get_subst_key(&mut self, obj: Option, fun: &FunSignature) -> String { + fn get_subst_key( + &mut self, + obj: Option, + fun: &FunSignature, + filter: Option<&Vec>, + ) -> String { let mut vars = obj .map(|ty| { if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) { @@ -42,7 +47,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { }) .unwrap_or_default(); vars.extend(fun.vars.iter()); - let sorted = vars.keys().sorted(); + let sorted = + vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted(); sorted .map(|id| { self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string()) @@ -101,42 +107,129 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { obj: Option<(Type, BasicValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), params: Vec<(Option, BasicValueEnum<'ctx>)>, - ret: Type, ) -> Option> { - let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0); + let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None); let top_level_defs = self.top_level.definitions.read(); let definition = top_level_defs.get(fun.1 .0).unwrap(); - let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { - let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| { - // TODO: codegen for function that are not yet generated - unimplemented!() - }); - let fun_val = self.module.get_function(symbol).unwrap_or_else(|| { - let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); - let fun_ty = if self.unifier.unioned(ret, self.primitives.none) { - self.ctx.void_type().fn_type(¶ms, false) + let symbol = + if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { + instance_to_symbol.get(&key).cloned() + } else { + unreachable!() + } + .unwrap_or_else(|| { + if let TopLevelDef::Function { + name, + instance_to_symbol, + instance_to_stmt, + var_id, + resolver, + .. + } = &mut *definition.write() + { + instance_to_symbol.get(&key).cloned().unwrap_or_else(|| { + let symbol = format!("{}_{}", name, instance_to_symbol.len()); + instance_to_symbol.insert(key, symbol.clone()); + let key = self.get_subst_key(obj.map(|a| a.0), fun.0, Some(var_id)); + let instance = instance_to_stmt.get(&key).unwrap(); + let unifiers = self.top_level.unifiers.read(); + let (unifier, primitives) = &unifiers[instance.unifier_id]; + let mut unifier = Unifier::from_shared_unifier(&unifier); + + let mut type_cache = [ + (self.primitives.int32, primitives.int32), + (self.primitives.int64, primitives.int64), + (self.primitives.float, primitives.float), + (self.primitives.bool, primitives.bool), + (self.primitives.none, primitives.none), + ] + .iter() + .map(|(a, b)| { + (self.unifier.get_representative(*a), unifier.get_representative(*b)) + }) + .collect(); + + let subst = fun + .0 + .vars + .iter() + .map(|(id, ty)| { + ( + *instance.subst.get(id).unwrap(), + unifier.copy_from(&mut self.unifier, *ty, &mut type_cache), + ) + }) + .collect(); + + let signature = FunSignature { + args: fun + .0 + .args + .iter() + .map(|arg| FuncArg { + name: arg.name.clone(), + ty: unifier.copy_from( + &mut self.unifier, + arg.ty, + &mut type_cache, + ), + default_value: arg.default_value.clone(), + }) + .collect(), + ret: unifier.copy_from(&mut self.unifier, fun.0.ret, &mut type_cache), + vars: fun + .0 + .vars + .iter() + .map(|(id, ty)| { + ( + *id, + unifier.copy_from(&mut self.unifier, *ty, &mut type_cache), + ) + }) + .collect(), + }; + + let unifier = (unifier.get_shared_unifier(), *primitives); + + let task = CodeGenTask { + symbol_name: symbol.clone(), + body: instance.body.clone(), + resolver: resolver.as_ref().unwrap().clone(), + calls: instance.calls.clone(), + subst, + signature, + unifier, + }; + self.registry.add_task(task); + symbol + }) } else { - self.get_llvm_type(ret).fn_type(¶ms, false) - }; - self.module.add_function(symbol, fun_ty, None) + unreachable!() + } }); - let mut keys = fun.0.args.clone(); - let mut mapping = HashMap::new(); - for (key, value) in params.into_iter() { - mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); - } - // default value handling - for k in keys.into_iter() { - mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap())); - } - // reorder the parameters - let params = - fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); - self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left() - } else { - unreachable!() - }; - val + + let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| { + let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); + let fun_ty = if self.unifier.unioned(fun.0.ret, self.primitives.none) { + self.ctx.void_type().fn_type(¶ms, false) + } else { + self.get_llvm_type(fun.0.ret).fn_type(¶ms, false) + }; + self.module.add_function(&symbol, fun_ty, None) + }); + let mut keys = fun.0.args.clone(); + let mut mapping = HashMap::new(); + for (key, value) in params.into_iter() { + mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + } + // default value handling + for k in keys.into_iter() { + mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap())); + } + // reorder the parameters + let params = fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); + self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left() } fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> { @@ -516,9 +609,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } ExprKind::Call { func, args, keywords } => { if let ExprKind::Name { id, .. } = &func.as_ref().node { - // TODO: handle primitive casts + // TODO: handle primitive casts and function pointers let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier"); - let ret = expr.custom.unwrap(); let mut params = args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); let kw_iter = keywords.iter().map(|kw| { @@ -532,8 +624,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { .unifier .get_call_signature(*self.calls.get(&expr.location.into()).unwrap()) .unwrap(); - return self.gen_call(None, (&signature, fun), params, ret); + return self.gen_call(None, (&signature, fun), params); } else { + // TODO: method unimplemented!() } } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 09082b1d..e8e8e8ba 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -3,7 +3,7 @@ use crate::{ toplevel::{TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, - typedef::{CallId, FunSignature, Type, TypeEnum, Unifier}, + typedef::{CallId, FunSignature, SharedUnifier, Type, TypeEnum, Unifier}, }, }; use crossbeam::channel::{unbounded, Receiver, Sender}; @@ -38,11 +38,12 @@ pub struct CodeGenContext<'ctx, 'a> { pub module: Module<'ctx>, pub top_level: &'a TopLevelContext, pub unifier: Unifier, - pub resolver: Arc, + pub resolver: Arc>, pub var_assignment: HashMap>, pub type_cache: HashMap>, pub primitives: PrimitiveStore, pub calls: HashMap, + pub registry: &'a WorkerRegistry, // stores the alloca for variables pub init_bb: BasicBlock<'ctx>, // where continue and break should go to respectively @@ -166,7 +167,7 @@ impl WorkerRegistry { let mut module = context.create_module(&module_name); while let Some(task) = self.receiver.recv().unwrap() { - let result = gen_func(&context, builder, module, task, top_level_ctx.clone()); + let result = gen_func(&context, self, builder, module, task, top_level_ctx.clone()); builder = result.0; module = result.1; *self.task_count.lock() -= 1; @@ -188,8 +189,8 @@ pub struct CodeGenTask { pub signature: FunSignature, pub body: Vec>>, pub calls: HashMap, - pub unifier_index: usize, - pub resolver: Arc, + pub unifier: (SharedUnifier, PrimitiveStore), + pub resolver: Arc>, } fn get_llvm_type<'ctx>( @@ -244,6 +245,7 @@ fn get_llvm_type<'ctx>( pub fn gen_func<'ctx>( context: &'ctx Context, + registry: &WorkerRegistry, builder: Builder<'ctx>, module: Module<'ctx>, task: CodeGenTask, @@ -251,9 +253,8 @@ pub fn gen_func<'ctx>( ) -> (Builder<'ctx>, Module<'ctx>) { // unwrap_or(0) is for unit tests without using rayon let (mut unifier, primitives) = { - let unifiers = top_level_ctx.unifiers.read(); - let (unifier, primitives) = &unifiers[task.unifier_index]; - (Unifier::from_shared_unifier(unifier), *primitives) + let (unifier, primitives) = task.unifier; + (Unifier::from_shared_unifier(&unifier), primitives) }; for (a, b) in task.subst.iter() { @@ -327,6 +328,7 @@ pub fn gen_func<'ctx>( top_level: top_level_ctx.as_ref(), calls: task.calls, loop_bb: None, + registry, var_assignment, type_cache, primitives, diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index f30aaab6..3f1be59d 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -4,9 +4,12 @@ use std::{collections::HashMap, collections::HashSet, sync::Arc}; use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier}; -use crate::symbol_resolver::SymbolResolver; +use crate::{ + symbol_resolver::SymbolResolver, + typecheck::{type_inferencer::CodeLocation, typedef::CallId}, +}; use itertools::{izip, Itertools}; -use parking_lot::{Mutex, RwLock}; +use parking_lot::RwLock; use rustpython_parser::ast::{self, Stmt}; #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] @@ -15,6 +18,13 @@ pub struct DefinitionId(pub usize); mod type_annotation; use type_annotation::*; +pub struct FunInstance { + pub body: Vec>>, + pub calls: HashMap, + pub subst: HashMap, + pub unifier_id: usize, +} + pub enum TopLevelDef { Class { // name for error messages and symbols @@ -33,13 +43,15 @@ pub enum TopLevelDef { // ancestor classes, including itself. ancestors: Vec, // symbol resolver of the module defined the class, none if it is built-in type - resolver: Option>>, + resolver: Option>>, }, Function { // prefix for symbol, should be unique globally, and not ending with numbers name: String, // function signature. signature: Type, + // instantiated type variable IDs + var_id: Vec, /// Function instance to symbol mapping /// Key: string representation of type variable values, sorted by variable ID in ascending /// order, including type variables associated with the class. @@ -49,11 +61,10 @@ pub enum TopLevelDef { /// Key: string representation of type variable values, sorted by variable ID in ascending /// order, including type variables associated with the class. Excluding rigid type /// variables. - /// Value: AST annotated with types together with a unification table index. Could contain /// rigid type variables that would be substituted when the function is instantiated. - instance_to_stmt: HashMap>, usize)>, + instance_to_stmt: HashMap, // symbol resolver of the module defined the class - resolver: Option>>, + resolver: Option>>, }, Initializer { class_id: DefinitionId, @@ -171,7 +182,7 @@ impl TopLevelComposer { /// when first regitering, the type_vars, fields, methods, ancestors are invalid pub fn make_top_level_class_def( index: usize, - resolver: Option>>, + resolver: Option>>, name: &str, ) -> TopLevelDef { TopLevelDef::Class { @@ -189,11 +200,12 @@ impl TopLevelComposer { pub fn make_top_level_function_def( name: String, ty: Type, - resolver: Option>>, + resolver: Option>>, ) -> TopLevelDef { TopLevelDef::Function { name, signature: ty, + var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver, @@ -214,7 +226,7 @@ impl TopLevelComposer { pub fn register_top_level( &mut self, ast: ast::Stmt<()>, - resolver: Option>>, + resolver: Option>>, ) -> Result<(String, DefinitionId), String> { let mut defined_class_name: HashSet = HashSet::new(); let mut defined_class_method_name: HashSet = HashSet::new(); @@ -363,7 +375,7 @@ impl TopLevelComposer { continue; } }; - let class_resolver = class_resolver.as_ref().unwrap().lock(); + let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.deref(); let mut is_generic = false; @@ -467,7 +479,7 @@ impl TopLevelComposer { continue; } }; - let class_resolver = class_resolver.as_ref().unwrap().lock(); + let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.deref(); let mut has_base = false; @@ -563,7 +575,7 @@ impl TopLevelComposer { if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node { let resolver = resolver.as_ref(); let resolver = resolver.unwrap(); - let resolver = resolver.deref().lock(); + let resolver = resolver.deref(); let function_resolver = resolver.deref(); // occured type vars should not be handled separately @@ -708,8 +720,7 @@ impl TopLevelComposer { unreachable!("here must be class def ast"); }; let class_resolver = class_resolver.as_ref().unwrap(); - let mut class_resolver = class_resolver.lock(); - let class_resolver = class_resolver.deref_mut(); + let class_resolver = class_resolver; for b in class_body_ast { if let ast::StmtKind::FunctionDef { args, returns, name, body, .. } = &b.node { diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index d6405da2..e0f04d1b 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -21,7 +21,7 @@ pub enum TypeAnnotation { } pub fn parse_ast_to_type_annotation_kinds( - resolver: &dyn SymbolResolver, + resolver: &Box, top_level_defs: &[Arc>], unifier: &mut Unifier, primitives: &PrimitiveStore, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index e81a7473..898d5f3e 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -38,7 +38,7 @@ pub struct PrimitiveStore { } pub struct FunctionData { - pub resolver: Arc, + pub resolver: Arc>, pub return_type: Option, pub bound_variables: Vec, } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index f4acb808..aea8642c 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -125,7 +125,7 @@ impl Unifier { ty: Type, type_cache: &mut HashMap, ) -> Type { - let representative = self.get_representative(ty); + let representative = unifier.get_representative(ty); type_cache.get(&representative).cloned().unwrap_or_else(|| { // put in a placeholder first to handle possible recursive type let placeholder = self.get_fresh_var().0; @@ -183,7 +183,7 @@ impl Unifier { TypeEnum::TVirtual { ty: self.copy_from(unifier, *ty, type_cache) } } }; - let ty = unifier.add_ty(ty); + let ty = self.add_ty(ty); self.unify(placeholder, ty).unwrap(); type_cache.insert(representative, ty); ty