codegen/expr: function codegen and refactoring

This commit is contained in:
pca006132 2021-08-25 15:29:58 +08:00
parent 93270d7227
commit 173102fc56
6 changed files with 170 additions and 64 deletions

View File

@ -1,10 +1,10 @@
use std::{collections::HashMap, convert::TryInto, iter::once};
use super::{get_llvm_type, CodeGenContext};
use crate::{
codegen::{get_llvm_type, CodeGenContext, CodeGenTask},
symbol_resolver::SymbolValue,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
};
use inkwell::{
types::{BasicType, BasicTypeEnum},
@ -31,7 +31,12 @@ pub fn assert_pointer_val<'ctx>(val: BasicValueEnum<'ctx>) -> PointerValue<'ctx>
}
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
fn get_subst_key(
&mut self,
obj: Option<Type>,
fun: &FunSignature,
filter: Option<&Vec<u32>>,
) -> String {
let mut vars = obj
.map(|ty| {
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
@ -42,7 +47,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
})
.unwrap_or_default();
vars.extend(fun.vars.iter());
let sorted = vars.keys().sorted();
let sorted =
vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted();
sorted
.map(|id| {
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
@ -101,42 +107,129 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
ret: Type,
) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None);
let top_level_defs = self.top_level.definitions.read();
let definition = top_level_defs.get(fun.1 .0).unwrap();
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
// TODO: codegen for function that are not yet generated
unimplemented!()
});
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) {
self.ctx.void_type().fn_type(&params, false)
let symbol =
if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
instance_to_symbol.get(&key).cloned()
} else {
unreachable!()
}
.unwrap_or_else(|| {
if let TopLevelDef::Function {
name,
instance_to_symbol,
instance_to_stmt,
var_id,
resolver,
..
} = &mut *definition.write()
{
instance_to_symbol.get(&key).cloned().unwrap_or_else(|| {
let symbol = format!("{}_{}", name, instance_to_symbol.len());
instance_to_symbol.insert(key, symbol.clone());
let key = self.get_subst_key(obj.map(|a| a.0), fun.0, Some(var_id));
let instance = instance_to_stmt.get(&key).unwrap();
let unifiers = self.top_level.unifiers.read();
let (unifier, primitives) = &unifiers[instance.unifier_id];
let mut unifier = Unifier::from_shared_unifier(&unifier);
let mut type_cache = [
(self.primitives.int32, primitives.int32),
(self.primitives.int64, primitives.int64),
(self.primitives.float, primitives.float),
(self.primitives.bool, primitives.bool),
(self.primitives.none, primitives.none),
]
.iter()
.map(|(a, b)| {
(self.unifier.get_representative(*a), unifier.get_representative(*b))
})
.collect();
let subst = fun
.0
.vars
.iter()
.map(|(id, ty)| {
(
*instance.subst.get(id).unwrap(),
unifier.copy_from(&mut self.unifier, *ty, &mut type_cache),
)
})
.collect();
let signature = FunSignature {
args: fun
.0
.args
.iter()
.map(|arg| FuncArg {
name: arg.name.clone(),
ty: unifier.copy_from(
&mut self.unifier,
arg.ty,
&mut type_cache,
),
default_value: arg.default_value.clone(),
})
.collect(),
ret: unifier.copy_from(&mut self.unifier, fun.0.ret, &mut type_cache),
vars: fun
.0
.vars
.iter()
.map(|(id, ty)| {
(
*id,
unifier.copy_from(&mut self.unifier, *ty, &mut type_cache),
)
})
.collect(),
};
let unifier = (unifier.get_shared_unifier(), *primitives);
let task = CodeGenTask {
symbol_name: symbol.clone(),
body: instance.body.clone(),
resolver: resolver.as_ref().unwrap().clone(),
calls: instance.calls.clone(),
subst,
signature,
unifier,
};
self.registry.add_task(task);
symbol
})
} else {
self.get_llvm_type(ret).fn_type(&params, false)
};
self.module.add_function(symbol, fun_ty, None)
unreachable!()
}
});
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in params.into_iter() {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
for k in keys.into_iter() {
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
}
// reorder the parameters
let params =
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left()
} else {
unreachable!()
};
val
let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| {
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
let fun_ty = if self.unifier.unioned(fun.0.ret, self.primitives.none) {
self.ctx.void_type().fn_type(&params, false)
} else {
self.get_llvm_type(fun.0.ret).fn_type(&params, false)
};
self.module.add_function(&symbol, fun_ty, None)
});
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in params.into_iter() {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
for k in keys.into_iter() {
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
}
// reorder the parameters
let params = fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left()
}
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
@ -516,9 +609,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
ExprKind::Call { func, args, keywords } => {
if let ExprKind::Name { id, .. } = &func.as_ref().node {
// TODO: handle primitive casts
// TODO: handle primitive casts and function pointers
let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier");
let ret = expr.custom.unwrap();
let mut params =
args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec();
let kw_iter = keywords.iter().map(|kw| {
@ -532,8 +624,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.unifier
.get_call_signature(*self.calls.get(&expr.location.into()).unwrap())
.unwrap();
return self.gen_call(None, (&signature, fun), params, ret);
return self.gen_call(None, (&signature, fun), params);
} else {
// TODO: method
unimplemented!()
}
}

View File

@ -3,7 +3,7 @@ use crate::{
toplevel::{TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FunSignature, Type, TypeEnum, Unifier},
typedef::{CallId, FunSignature, SharedUnifier, Type, TypeEnum, Unifier},
},
};
use crossbeam::channel::{unbounded, Receiver, Sender};
@ -38,11 +38,12 @@ pub struct CodeGenContext<'ctx, 'a> {
pub module: Module<'ctx>,
pub top_level: &'a TopLevelContext,
pub unifier: Unifier,
pub resolver: Arc<dyn SymbolResolver>,
pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore,
pub calls: HashMap<CodeLocation, CallId>,
pub registry: &'a WorkerRegistry,
// stores the alloca for variables
pub init_bb: BasicBlock<'ctx>,
// where continue and break should go to respectively
@ -166,7 +167,7 @@ impl WorkerRegistry {
let mut module = context.create_module(&module_name);
while let Some(task) = self.receiver.recv().unwrap() {
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
let result = gen_func(&context, self, builder, module, task, top_level_ctx.clone());
builder = result.0;
module = result.1;
*self.task_count.lock() -= 1;
@ -188,8 +189,8 @@ pub struct CodeGenTask {
pub signature: FunSignature,
pub body: Vec<Stmt<Option<Type>>>,
pub calls: HashMap<CodeLocation, CallId>,
pub unifier_index: usize,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub unifier: (SharedUnifier, PrimitiveStore),
pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
}
fn get_llvm_type<'ctx>(
@ -244,6 +245,7 @@ fn get_llvm_type<'ctx>(
pub fn gen_func<'ctx>(
context: &'ctx Context,
registry: &WorkerRegistry,
builder: Builder<'ctx>,
module: Module<'ctx>,
task: CodeGenTask,
@ -251,9 +253,8 @@ pub fn gen_func<'ctx>(
) -> (Builder<'ctx>, Module<'ctx>) {
// unwrap_or(0) is for unit tests without using rayon
let (mut unifier, primitives) = {
let unifiers = top_level_ctx.unifiers.read();
let (unifier, primitives) = &unifiers[task.unifier_index];
(Unifier::from_shared_unifier(unifier), *primitives)
let (unifier, primitives) = task.unifier;
(Unifier::from_shared_unifier(&unifier), primitives)
};
for (a, b) in task.subst.iter() {
@ -327,6 +328,7 @@ pub fn gen_func<'ctx>(
top_level: top_level_ctx.as_ref(),
calls: task.calls,
loop_bb: None,
registry,
var_assignment,
type_cache,
primitives,

View File

@ -4,9 +4,12 @@ use std::{collections::HashMap, collections::HashSet, sync::Arc};
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier};
use crate::symbol_resolver::SymbolResolver;
use crate::{
symbol_resolver::SymbolResolver,
typecheck::{type_inferencer::CodeLocation, typedef::CallId},
};
use itertools::{izip, Itertools};
use parking_lot::{Mutex, RwLock};
use parking_lot::RwLock;
use rustpython_parser::ast::{self, Stmt};
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
@ -15,6 +18,13 @@ pub struct DefinitionId(pub usize);
mod type_annotation;
use type_annotation::*;
pub struct FunInstance {
pub body: Vec<Stmt<Option<Type>>>,
pub calls: HashMap<CodeLocation, CallId>,
pub subst: HashMap<u32, Type>,
pub unifier_id: usize,
}
pub enum TopLevelDef {
Class {
// name for error messages and symbols
@ -33,13 +43,15 @@ pub enum TopLevelDef {
// ancestor classes, including itself.
ancestors: Vec<TypeAnnotation>,
// symbol resolver of the module defined the class, none if it is built-in type
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
},
Function {
// prefix for symbol, should be unique globally, and not ending with numbers
name: String,
// function signature.
signature: Type,
// instantiated type variable IDs
var_id: Vec<u32>,
/// Function instance to symbol mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class.
@ -49,11 +61,10 @@ pub enum TopLevelDef {
/// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type
/// variables.
/// Value: AST annotated with types together with a unification table index. Could contain
/// rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, (Stmt<Option<Type>>, usize)>,
instance_to_stmt: HashMap<String, FunInstance>,
// symbol resolver of the module defined the class
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
},
Initializer {
class_id: DefinitionId,
@ -171,7 +182,7 @@ impl TopLevelComposer {
/// when first regitering, the type_vars, fields, methods, ancestors are invalid
pub fn make_top_level_class_def(
index: usize,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
name: &str,
) -> TopLevelDef {
TopLevelDef::Class {
@ -189,11 +200,12 @@ impl TopLevelComposer {
pub fn make_top_level_function_def(
name: String,
ty: Type,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
) -> TopLevelDef {
TopLevelDef::Function {
name,
signature: ty,
var_id: Default::default(),
instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(),
resolver,
@ -214,7 +226,7 @@ impl TopLevelComposer {
pub fn register_top_level(
&mut self,
ast: ast::Stmt<()>,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
) -> Result<(String, DefinitionId), String> {
let mut defined_class_name: HashSet<String> = HashSet::new();
let mut defined_class_method_name: HashSet<String> = HashSet::new();
@ -363,7 +375,7 @@ impl TopLevelComposer {
continue;
}
};
let class_resolver = class_resolver.as_ref().unwrap().lock();
let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.deref();
let mut is_generic = false;
@ -467,7 +479,7 @@ impl TopLevelComposer {
continue;
}
};
let class_resolver = class_resolver.as_ref().unwrap().lock();
let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.deref();
let mut has_base = false;
@ -563,7 +575,7 @@ impl TopLevelComposer {
if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node {
let resolver = resolver.as_ref();
let resolver = resolver.unwrap();
let resolver = resolver.deref().lock();
let resolver = resolver.deref();
let function_resolver = resolver.deref();
// occured type vars should not be handled separately
@ -708,8 +720,7 @@ impl TopLevelComposer {
unreachable!("here must be class def ast");
};
let class_resolver = class_resolver.as_ref().unwrap();
let mut class_resolver = class_resolver.lock();
let class_resolver = class_resolver.deref_mut();
let class_resolver = class_resolver;
for b in class_body_ast {
if let ast::StmtKind::FunctionDef { args, returns, name, body, .. } = &b.node {

View File

@ -21,7 +21,7 @@ pub enum TypeAnnotation {
}
pub fn parse_ast_to_type_annotation_kinds<T>(
resolver: &dyn SymbolResolver,
resolver: &Box<dyn SymbolResolver + Send + Sync>,
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,

View File

@ -38,7 +38,7 @@ pub struct PrimitiveStore {
}
pub struct FunctionData {
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
pub return_type: Option<Type>,
pub bound_variables: Vec<Type>,
}

View File

@ -125,7 +125,7 @@ impl Unifier {
ty: Type,
type_cache: &mut HashMap<Type, Type>,
) -> Type {
let representative = self.get_representative(ty);
let representative = unifier.get_representative(ty);
type_cache.get(&representative).cloned().unwrap_or_else(|| {
// put in a placeholder first to handle possible recursive type
let placeholder = self.get_fresh_var().0;
@ -183,7 +183,7 @@ impl Unifier {
TypeEnum::TVirtual { ty: self.copy_from(unifier, *ty, type_cache) }
}
};
let ty = unifier.add_ty(ty);
let ty = self.add_ty(ty);
self.unify(placeholder, ty).unwrap();
type_cache.insert(representative, ty);
ty