1
0
forked from M-Labs/nac3

codegen/expr: function codegen and refactoring

This commit is contained in:
pca006132 2021-08-25 15:29:58 +08:00
parent 93270d7227
commit 173102fc56
6 changed files with 170 additions and 64 deletions

View File

@ -1,10 +1,10 @@
use std::{collections::HashMap, convert::TryInto, iter::once}; use std::{collections::HashMap, convert::TryInto, iter::once};
use super::{get_llvm_type, CodeGenContext};
use crate::{ use crate::{
codegen::{get_llvm_type, CodeGenContext, CodeGenTask},
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
}; };
use inkwell::{ use inkwell::{
types::{BasicType, BasicTypeEnum}, types::{BasicType, BasicTypeEnum},
@ -31,7 +31,12 @@ pub fn assert_pointer_val<'ctx>(val: BasicValueEnum<'ctx>) -> PointerValue<'ctx>
} }
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String { fn get_subst_key(
&mut self,
obj: Option<Type>,
fun: &FunSignature,
filter: Option<&Vec<u32>>,
) -> String {
let mut vars = obj let mut vars = obj
.map(|ty| { .map(|ty| {
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) { if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
@ -42,7 +47,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}) })
.unwrap_or_default(); .unwrap_or_default();
vars.extend(fun.vars.iter()); vars.extend(fun.vars.iter());
let sorted = vars.keys().sorted(); let sorted =
vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted();
sorted sorted
.map(|id| { .map(|id| {
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string()) self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
@ -101,24 +107,116 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
obj: Option<(Type, BasicValueEnum<'ctx>)>, obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>, params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
ret: Type,
) -> Option<BasicValueEnum<'ctx>> { ) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0); let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None);
let top_level_defs = self.top_level.definitions.read(); let top_level_defs = self.top_level.definitions.read();
let definition = top_level_defs.get(fun.1 .0).unwrap(); let definition = top_level_defs.get(fun.1 .0).unwrap();
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { let symbol =
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| { if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
// TODO: codegen for function that are not yet generated instance_to_symbol.get(&key).cloned()
unimplemented!() } else {
unreachable!()
}
.unwrap_or_else(|| {
if let TopLevelDef::Function {
name,
instance_to_symbol,
instance_to_stmt,
var_id,
resolver,
..
} = &mut *definition.write()
{
instance_to_symbol.get(&key).cloned().unwrap_or_else(|| {
let symbol = format!("{}_{}", name, instance_to_symbol.len());
instance_to_symbol.insert(key, symbol.clone());
let key = self.get_subst_key(obj.map(|a| a.0), fun.0, Some(var_id));
let instance = instance_to_stmt.get(&key).unwrap();
let unifiers = self.top_level.unifiers.read();
let (unifier, primitives) = &unifiers[instance.unifier_id];
let mut unifier = Unifier::from_shared_unifier(&unifier);
let mut type_cache = [
(self.primitives.int32, primitives.int32),
(self.primitives.int64, primitives.int64),
(self.primitives.float, primitives.float),
(self.primitives.bool, primitives.bool),
(self.primitives.none, primitives.none),
]
.iter()
.map(|(a, b)| {
(self.unifier.get_representative(*a), unifier.get_representative(*b))
})
.collect();
let subst = fun
.0
.vars
.iter()
.map(|(id, ty)| {
(
*instance.subst.get(id).unwrap(),
unifier.copy_from(&mut self.unifier, *ty, &mut type_cache),
)
})
.collect();
let signature = FunSignature {
args: fun
.0
.args
.iter()
.map(|arg| FuncArg {
name: arg.name.clone(),
ty: unifier.copy_from(
&mut self.unifier,
arg.ty,
&mut type_cache,
),
default_value: arg.default_value.clone(),
})
.collect(),
ret: unifier.copy_from(&mut self.unifier, fun.0.ret, &mut type_cache),
vars: fun
.0
.vars
.iter()
.map(|(id, ty)| {
(
*id,
unifier.copy_from(&mut self.unifier, *ty, &mut type_cache),
)
})
.collect(),
};
let unifier = (unifier.get_shared_unifier(), *primitives);
let task = CodeGenTask {
symbol_name: symbol.clone(),
body: instance.body.clone(),
resolver: resolver.as_ref().unwrap().clone(),
calls: instance.calls.clone(),
subst,
signature,
unifier,
};
self.registry.add_task(task);
symbol
})
} else {
unreachable!()
}
}); });
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| {
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) { let fun_ty = if self.unifier.unioned(fun.0.ret, self.primitives.none) {
self.ctx.void_type().fn_type(&params, false) self.ctx.void_type().fn_type(&params, false)
} else { } else {
self.get_llvm_type(ret).fn_type(&params, false) self.get_llvm_type(fun.0.ret).fn_type(&params, false)
}; };
self.module.add_function(symbol, fun_ty, None) self.module.add_function(&symbol, fun_ty, None)
}); });
let mut keys = fun.0.args.clone(); let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
@ -130,13 +228,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap())); mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
} }
// reorder the parameters // reorder the parameters
let params = let params = fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left() self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left()
} else {
unreachable!()
};
val
} }
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> { fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
@ -516,9 +609,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
ExprKind::Call { func, args, keywords } => { ExprKind::Call { func, args, keywords } => {
if let ExprKind::Name { id, .. } = &func.as_ref().node { if let ExprKind::Name { id, .. } = &func.as_ref().node {
// TODO: handle primitive casts // TODO: handle primitive casts and function pointers
let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier"); let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier");
let ret = expr.custom.unwrap();
let mut params = let mut params =
args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec();
let kw_iter = keywords.iter().map(|kw| { let kw_iter = keywords.iter().map(|kw| {
@ -532,8 +624,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.unifier .unifier
.get_call_signature(*self.calls.get(&expr.location.into()).unwrap()) .get_call_signature(*self.calls.get(&expr.location.into()).unwrap())
.unwrap(); .unwrap();
return self.gen_call(None, (&signature, fun), params, ret); return self.gen_call(None, (&signature, fun), params);
} else { } else {
// TODO: method
unimplemented!() unimplemented!()
} }
} }

View File

@ -3,7 +3,7 @@ use crate::{
toplevel::{TopLevelContext, TopLevelDef}, toplevel::{TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FunSignature, Type, TypeEnum, Unifier}, typedef::{CallId, FunSignature, SharedUnifier, Type, TypeEnum, Unifier},
}, },
}; };
use crossbeam::channel::{unbounded, Receiver, Sender}; use crossbeam::channel::{unbounded, Receiver, Sender};
@ -38,11 +38,12 @@ pub struct CodeGenContext<'ctx, 'a> {
pub module: Module<'ctx>, pub module: Module<'ctx>,
pub top_level: &'a TopLevelContext, pub top_level: &'a TopLevelContext,
pub unifier: Unifier, pub unifier: Unifier,
pub resolver: Arc<dyn SymbolResolver>, pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
pub var_assignment: HashMap<String, PointerValue<'ctx>>, pub var_assignment: HashMap<String, PointerValue<'ctx>>,
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>, pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub calls: HashMap<CodeLocation, CallId>, pub calls: HashMap<CodeLocation, CallId>,
pub registry: &'a WorkerRegistry,
// stores the alloca for variables // stores the alloca for variables
pub init_bb: BasicBlock<'ctx>, pub init_bb: BasicBlock<'ctx>,
// where continue and break should go to respectively // where continue and break should go to respectively
@ -166,7 +167,7 @@ impl WorkerRegistry {
let mut module = context.create_module(&module_name); let mut module = context.create_module(&module_name);
while let Some(task) = self.receiver.recv().unwrap() { while let Some(task) = self.receiver.recv().unwrap() {
let result = gen_func(&context, builder, module, task, top_level_ctx.clone()); let result = gen_func(&context, self, builder, module, task, top_level_ctx.clone());
builder = result.0; builder = result.0;
module = result.1; module = result.1;
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
@ -188,8 +189,8 @@ pub struct CodeGenTask {
pub signature: FunSignature, pub signature: FunSignature,
pub body: Vec<Stmt<Option<Type>>>, pub body: Vec<Stmt<Option<Type>>>,
pub calls: HashMap<CodeLocation, CallId>, pub calls: HashMap<CodeLocation, CallId>,
pub unifier_index: usize, pub unifier: (SharedUnifier, PrimitiveStore),
pub resolver: Arc<dyn SymbolResolver + Send + Sync>, pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
} }
fn get_llvm_type<'ctx>( fn get_llvm_type<'ctx>(
@ -244,6 +245,7 @@ fn get_llvm_type<'ctx>(
pub fn gen_func<'ctx>( pub fn gen_func<'ctx>(
context: &'ctx Context, context: &'ctx Context,
registry: &WorkerRegistry,
builder: Builder<'ctx>, builder: Builder<'ctx>,
module: Module<'ctx>, module: Module<'ctx>,
task: CodeGenTask, task: CodeGenTask,
@ -251,9 +253,8 @@ pub fn gen_func<'ctx>(
) -> (Builder<'ctx>, Module<'ctx>) { ) -> (Builder<'ctx>, Module<'ctx>) {
// unwrap_or(0) is for unit tests without using rayon // unwrap_or(0) is for unit tests without using rayon
let (mut unifier, primitives) = { let (mut unifier, primitives) = {
let unifiers = top_level_ctx.unifiers.read(); let (unifier, primitives) = task.unifier;
let (unifier, primitives) = &unifiers[task.unifier_index]; (Unifier::from_shared_unifier(&unifier), primitives)
(Unifier::from_shared_unifier(unifier), *primitives)
}; };
for (a, b) in task.subst.iter() { for (a, b) in task.subst.iter() {
@ -327,6 +328,7 @@ pub fn gen_func<'ctx>(
top_level: top_level_ctx.as_ref(), top_level: top_level_ctx.as_ref(),
calls: task.calls, calls: task.calls,
loop_bb: None, loop_bb: None,
registry,
var_assignment, var_assignment,
type_cache, type_cache,
primitives, primitives,

View File

@ -4,9 +4,12 @@ use std::{collections::HashMap, collections::HashSet, sync::Arc};
use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier}; use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier};
use crate::symbol_resolver::SymbolResolver; use crate::{
symbol_resolver::SymbolResolver,
typecheck::{type_inferencer::CodeLocation, typedef::CallId},
};
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use parking_lot::{Mutex, RwLock}; use parking_lot::RwLock;
use rustpython_parser::ast::{self, Stmt}; use rustpython_parser::ast::{self, Stmt};
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
@ -15,6 +18,13 @@ pub struct DefinitionId(pub usize);
mod type_annotation; mod type_annotation;
use type_annotation::*; use type_annotation::*;
pub struct FunInstance {
pub body: Vec<Stmt<Option<Type>>>,
pub calls: HashMap<CodeLocation, CallId>,
pub subst: HashMap<u32, Type>,
pub unifier_id: usize,
}
pub enum TopLevelDef { pub enum TopLevelDef {
Class { Class {
// name for error messages and symbols // name for error messages and symbols
@ -33,13 +43,15 @@ pub enum TopLevelDef {
// ancestor classes, including itself. // ancestor classes, including itself.
ancestors: Vec<TypeAnnotation>, ancestors: Vec<TypeAnnotation>,
// symbol resolver of the module defined the class, none if it is built-in type // symbol resolver of the module defined the class, none if it is built-in type
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
}, },
Function { Function {
// prefix for symbol, should be unique globally, and not ending with numbers // prefix for symbol, should be unique globally, and not ending with numbers
name: String, name: String,
// function signature. // function signature.
signature: Type, signature: Type,
// instantiated type variable IDs
var_id: Vec<u32>,
/// Function instance to symbol mapping /// Function instance to symbol mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending /// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. /// order, including type variables associated with the class.
@ -49,11 +61,10 @@ pub enum TopLevelDef {
/// Key: string representation of type variable values, sorted by variable ID in ascending /// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type /// order, including type variables associated with the class. Excluding rigid type
/// variables. /// variables.
/// Value: AST annotated with types together with a unification table index. Could contain
/// rigid type variables that would be substituted when the function is instantiated. /// rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, (Stmt<Option<Type>>, usize)>, instance_to_stmt: HashMap<String, FunInstance>,
// symbol resolver of the module defined the class // symbol resolver of the module defined the class
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
}, },
Initializer { Initializer {
class_id: DefinitionId, class_id: DefinitionId,
@ -171,7 +182,7 @@ impl TopLevelComposer {
/// when first regitering, the type_vars, fields, methods, ancestors are invalid /// when first regitering, the type_vars, fields, methods, ancestors are invalid
pub fn make_top_level_class_def( pub fn make_top_level_class_def(
index: usize, index: usize,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
name: &str, name: &str,
) -> TopLevelDef { ) -> TopLevelDef {
TopLevelDef::Class { TopLevelDef::Class {
@ -189,11 +200,12 @@ impl TopLevelComposer {
pub fn make_top_level_function_def( pub fn make_top_level_function_def(
name: String, name: String,
ty: Type, ty: Type,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
) -> TopLevelDef { ) -> TopLevelDef {
TopLevelDef::Function { TopLevelDef::Function {
name, name,
signature: ty, signature: ty,
var_id: Default::default(),
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver, resolver,
@ -214,7 +226,7 @@ impl TopLevelComposer {
pub fn register_top_level( pub fn register_top_level(
&mut self, &mut self,
ast: ast::Stmt<()>, ast: ast::Stmt<()>,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
) -> Result<(String, DefinitionId), String> { ) -> Result<(String, DefinitionId), String> {
let mut defined_class_name: HashSet<String> = HashSet::new(); let mut defined_class_name: HashSet<String> = HashSet::new();
let mut defined_class_method_name: HashSet<String> = HashSet::new(); let mut defined_class_method_name: HashSet<String> = HashSet::new();
@ -363,7 +375,7 @@ impl TopLevelComposer {
continue; continue;
} }
}; };
let class_resolver = class_resolver.as_ref().unwrap().lock(); let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.deref(); let class_resolver = class_resolver.deref();
let mut is_generic = false; let mut is_generic = false;
@ -467,7 +479,7 @@ impl TopLevelComposer {
continue; continue;
} }
}; };
let class_resolver = class_resolver.as_ref().unwrap().lock(); let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.deref(); let class_resolver = class_resolver.deref();
let mut has_base = false; let mut has_base = false;
@ -563,7 +575,7 @@ impl TopLevelComposer {
if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node { if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node {
let resolver = resolver.as_ref(); let resolver = resolver.as_ref();
let resolver = resolver.unwrap(); let resolver = resolver.unwrap();
let resolver = resolver.deref().lock(); let resolver = resolver.deref();
let function_resolver = resolver.deref(); let function_resolver = resolver.deref();
// occured type vars should not be handled separately // occured type vars should not be handled separately
@ -708,8 +720,7 @@ impl TopLevelComposer {
unreachable!("here must be class def ast"); unreachable!("here must be class def ast");
}; };
let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref().unwrap();
let mut class_resolver = class_resolver.lock(); let class_resolver = class_resolver;
let class_resolver = class_resolver.deref_mut();
for b in class_body_ast { for b in class_body_ast {
if let ast::StmtKind::FunctionDef { args, returns, name, body, .. } = &b.node { if let ast::StmtKind::FunctionDef { args, returns, name, body, .. } = &b.node {

View File

@ -21,7 +21,7 @@ pub enum TypeAnnotation {
} }
pub fn parse_ast_to_type_annotation_kinds<T>( pub fn parse_ast_to_type_annotation_kinds<T>(
resolver: &dyn SymbolResolver, resolver: &Box<dyn SymbolResolver + Send + Sync>,
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,

View File

@ -38,7 +38,7 @@ pub struct PrimitiveStore {
} }
pub struct FunctionData { pub struct FunctionData {
pub resolver: Arc<dyn SymbolResolver + Send + Sync>, pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
pub return_type: Option<Type>, pub return_type: Option<Type>,
pub bound_variables: Vec<Type>, pub bound_variables: Vec<Type>,
} }

View File

@ -125,7 +125,7 @@ impl Unifier {
ty: Type, ty: Type,
type_cache: &mut HashMap<Type, Type>, type_cache: &mut HashMap<Type, Type>,
) -> Type { ) -> Type {
let representative = self.get_representative(ty); let representative = unifier.get_representative(ty);
type_cache.get(&representative).cloned().unwrap_or_else(|| { type_cache.get(&representative).cloned().unwrap_or_else(|| {
// put in a placeholder first to handle possible recursive type // put in a placeholder first to handle possible recursive type
let placeholder = self.get_fresh_var().0; let placeholder = self.get_fresh_var().0;
@ -183,7 +183,7 @@ impl Unifier {
TypeEnum::TVirtual { ty: self.copy_from(unifier, *ty, type_cache) } TypeEnum::TVirtual { ty: self.copy_from(unifier, *ty, type_cache) }
} }
}; };
let ty = unifier.add_ty(ty); let ty = self.add_ty(ty);
self.unify(placeholder, ty).unwrap(); self.unify(placeholder, ty).unwrap();
type_cache.insert(representative, ty); type_cache.insert(representative, ty);
ty ty