forked from M-Labs/nac3
codegen/expr: function codegen and refactoring
This commit is contained in:
parent
93270d7227
commit
173102fc56
|
@ -1,10 +1,10 @@
|
|||
use std::{collections::HashMap, convert::TryInto, iter::once};
|
||||
|
||||
use super::{get_llvm_type, CodeGenContext};
|
||||
use crate::{
|
||||
codegen::{get_llvm_type, CodeGenContext, CodeGenTask},
|
||||
symbol_resolver::SymbolValue,
|
||||
toplevel::{DefinitionId, TopLevelDef},
|
||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
};
|
||||
use inkwell::{
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
|
@ -31,7 +31,12 @@ pub fn assert_pointer_val<'ctx>(val: BasicValueEnum<'ctx>) -> PointerValue<'ctx>
|
|||
}
|
||||
|
||||
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
|
||||
fn get_subst_key(
|
||||
&mut self,
|
||||
obj: Option<Type>,
|
||||
fun: &FunSignature,
|
||||
filter: Option<&Vec<u32>>,
|
||||
) -> String {
|
||||
let mut vars = obj
|
||||
.map(|ty| {
|
||||
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
|
||||
|
@ -42,7 +47,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
})
|
||||
.unwrap_or_default();
|
||||
vars.extend(fun.vars.iter());
|
||||
let sorted = vars.keys().sorted();
|
||||
let sorted =
|
||||
vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted();
|
||||
sorted
|
||||
.map(|id| {
|
||||
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
|
||||
|
@ -101,24 +107,116 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
obj: Option<(Type, BasicValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
|
||||
ret: Type,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
|
||||
let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None);
|
||||
let top_level_defs = self.top_level.definitions.read();
|
||||
let definition = top_level_defs.get(fun.1 .0).unwrap();
|
||||
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
|
||||
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
|
||||
// TODO: codegen for function that are not yet generated
|
||||
unimplemented!()
|
||||
let symbol =
|
||||
if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
|
||||
instance_to_symbol.get(&key).cloned()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
.unwrap_or_else(|| {
|
||||
if let TopLevelDef::Function {
|
||||
name,
|
||||
instance_to_symbol,
|
||||
instance_to_stmt,
|
||||
var_id,
|
||||
resolver,
|
||||
..
|
||||
} = &mut *definition.write()
|
||||
{
|
||||
instance_to_symbol.get(&key).cloned().unwrap_or_else(|| {
|
||||
let symbol = format!("{}_{}", name, instance_to_symbol.len());
|
||||
instance_to_symbol.insert(key, symbol.clone());
|
||||
let key = self.get_subst_key(obj.map(|a| a.0), fun.0, Some(var_id));
|
||||
let instance = instance_to_stmt.get(&key).unwrap();
|
||||
let unifiers = self.top_level.unifiers.read();
|
||||
let (unifier, primitives) = &unifiers[instance.unifier_id];
|
||||
let mut unifier = Unifier::from_shared_unifier(&unifier);
|
||||
|
||||
let mut type_cache = [
|
||||
(self.primitives.int32, primitives.int32),
|
||||
(self.primitives.int64, primitives.int64),
|
||||
(self.primitives.float, primitives.float),
|
||||
(self.primitives.bool, primitives.bool),
|
||||
(self.primitives.none, primitives.none),
|
||||
]
|
||||
.iter()
|
||||
.map(|(a, b)| {
|
||||
(self.unifier.get_representative(*a), unifier.get_representative(*b))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let subst = fun
|
||||
.0
|
||||
.vars
|
||||
.iter()
|
||||
.map(|(id, ty)| {
|
||||
(
|
||||
*instance.subst.get(id).unwrap(),
|
||||
unifier.copy_from(&mut self.unifier, *ty, &mut type_cache),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let signature = FunSignature {
|
||||
args: fun
|
||||
.0
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| FuncArg {
|
||||
name: arg.name.clone(),
|
||||
ty: unifier.copy_from(
|
||||
&mut self.unifier,
|
||||
arg.ty,
|
||||
&mut type_cache,
|
||||
),
|
||||
default_value: arg.default_value.clone(),
|
||||
})
|
||||
.collect(),
|
||||
ret: unifier.copy_from(&mut self.unifier, fun.0.ret, &mut type_cache),
|
||||
vars: fun
|
||||
.0
|
||||
.vars
|
||||
.iter()
|
||||
.map(|(id, ty)| {
|
||||
(
|
||||
*id,
|
||||
unifier.copy_from(&mut self.unifier, *ty, &mut type_cache),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let unifier = (unifier.get_shared_unifier(), *primitives);
|
||||
|
||||
let task = CodeGenTask {
|
||||
symbol_name: symbol.clone(),
|
||||
body: instance.body.clone(),
|
||||
resolver: resolver.as_ref().unwrap().clone(),
|
||||
calls: instance.calls.clone(),
|
||||
subst,
|
||||
signature,
|
||||
unifier,
|
||||
};
|
||||
self.registry.add_task(task);
|
||||
symbol
|
||||
})
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
});
|
||||
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
|
||||
|
||||
let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| {
|
||||
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
|
||||
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) {
|
||||
let fun_ty = if self.unifier.unioned(fun.0.ret, self.primitives.none) {
|
||||
self.ctx.void_type().fn_type(¶ms, false)
|
||||
} else {
|
||||
self.get_llvm_type(ret).fn_type(¶ms, false)
|
||||
self.get_llvm_type(fun.0.ret).fn_type(¶ms, false)
|
||||
};
|
||||
self.module.add_function(symbol, fun_ty, None)
|
||||
self.module.add_function(&symbol, fun_ty, None)
|
||||
});
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
|
@ -130,13 +228,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
|
||||
}
|
||||
// reorder the parameters
|
||||
let params =
|
||||
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
|
||||
let params = fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
|
||||
self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left()
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
val
|
||||
}
|
||||
|
||||
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
|
||||
|
@ -516,9 +609,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
}
|
||||
ExprKind::Call { func, args, keywords } => {
|
||||
if let ExprKind::Name { id, .. } = &func.as_ref().node {
|
||||
// TODO: handle primitive casts
|
||||
// TODO: handle primitive casts and function pointers
|
||||
let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier");
|
||||
let ret = expr.custom.unwrap();
|
||||
let mut params =
|
||||
args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec();
|
||||
let kw_iter = keywords.iter().map(|kw| {
|
||||
|
@ -532,8 +624,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
.unifier
|
||||
.get_call_signature(*self.calls.get(&expr.location.into()).unwrap())
|
||||
.unwrap();
|
||||
return self.gen_call(None, (&signature, fun), params, ret);
|
||||
return self.gen_call(None, (&signature, fun), params);
|
||||
} else {
|
||||
// TODO: method
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::{
|
|||
toplevel::{TopLevelContext, TopLevelDef},
|
||||
typecheck::{
|
||||
type_inferencer::{CodeLocation, PrimitiveStore},
|
||||
typedef::{CallId, FunSignature, Type, TypeEnum, Unifier},
|
||||
typedef::{CallId, FunSignature, SharedUnifier, Type, TypeEnum, Unifier},
|
||||
},
|
||||
};
|
||||
use crossbeam::channel::{unbounded, Receiver, Sender};
|
||||
|
@ -38,11 +38,12 @@ pub struct CodeGenContext<'ctx, 'a> {
|
|||
pub module: Module<'ctx>,
|
||||
pub top_level: &'a TopLevelContext,
|
||||
pub unifier: Unifier,
|
||||
pub resolver: Arc<dyn SymbolResolver>,
|
||||
pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
|
||||
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
|
||||
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||
pub primitives: PrimitiveStore,
|
||||
pub calls: HashMap<CodeLocation, CallId>,
|
||||
pub registry: &'a WorkerRegistry,
|
||||
// stores the alloca for variables
|
||||
pub init_bb: BasicBlock<'ctx>,
|
||||
// where continue and break should go to respectively
|
||||
|
@ -166,7 +167,7 @@ impl WorkerRegistry {
|
|||
let mut module = context.create_module(&module_name);
|
||||
|
||||
while let Some(task) = self.receiver.recv().unwrap() {
|
||||
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
|
||||
let result = gen_func(&context, self, builder, module, task, top_level_ctx.clone());
|
||||
builder = result.0;
|
||||
module = result.1;
|
||||
*self.task_count.lock() -= 1;
|
||||
|
@ -188,8 +189,8 @@ pub struct CodeGenTask {
|
|||
pub signature: FunSignature,
|
||||
pub body: Vec<Stmt<Option<Type>>>,
|
||||
pub calls: HashMap<CodeLocation, CallId>,
|
||||
pub unifier_index: usize,
|
||||
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||
pub unifier: (SharedUnifier, PrimitiveStore),
|
||||
pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
|
||||
}
|
||||
|
||||
fn get_llvm_type<'ctx>(
|
||||
|
@ -244,6 +245,7 @@ fn get_llvm_type<'ctx>(
|
|||
|
||||
pub fn gen_func<'ctx>(
|
||||
context: &'ctx Context,
|
||||
registry: &WorkerRegistry,
|
||||
builder: Builder<'ctx>,
|
||||
module: Module<'ctx>,
|
||||
task: CodeGenTask,
|
||||
|
@ -251,9 +253,8 @@ pub fn gen_func<'ctx>(
|
|||
) -> (Builder<'ctx>, Module<'ctx>) {
|
||||
// unwrap_or(0) is for unit tests without using rayon
|
||||
let (mut unifier, primitives) = {
|
||||
let unifiers = top_level_ctx.unifiers.read();
|
||||
let (unifier, primitives) = &unifiers[task.unifier_index];
|
||||
(Unifier::from_shared_unifier(unifier), *primitives)
|
||||
let (unifier, primitives) = task.unifier;
|
||||
(Unifier::from_shared_unifier(&unifier), primitives)
|
||||
};
|
||||
|
||||
for (a, b) in task.subst.iter() {
|
||||
|
@ -327,6 +328,7 @@ pub fn gen_func<'ctx>(
|
|||
top_level: top_level_ctx.as_ref(),
|
||||
calls: task.calls,
|
||||
loop_bb: None,
|
||||
registry,
|
||||
var_assignment,
|
||||
type_cache,
|
||||
primitives,
|
||||
|
|
|
@ -4,9 +4,12 @@ use std::{collections::HashMap, collections::HashSet, sync::Arc};
|
|||
|
||||
use super::typecheck::type_inferencer::PrimitiveStore;
|
||||
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier};
|
||||
use crate::symbol_resolver::SymbolResolver;
|
||||
use crate::{
|
||||
symbol_resolver::SymbolResolver,
|
||||
typecheck::{type_inferencer::CodeLocation, typedef::CallId},
|
||||
};
|
||||
use itertools::{izip, Itertools};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use parking_lot::RwLock;
|
||||
use rustpython_parser::ast::{self, Stmt};
|
||||
|
||||
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
|
||||
|
@ -15,6 +18,13 @@ pub struct DefinitionId(pub usize);
|
|||
mod type_annotation;
|
||||
use type_annotation::*;
|
||||
|
||||
pub struct FunInstance {
|
||||
pub body: Vec<Stmt<Option<Type>>>,
|
||||
pub calls: HashMap<CodeLocation, CallId>,
|
||||
pub subst: HashMap<u32, Type>,
|
||||
pub unifier_id: usize,
|
||||
}
|
||||
|
||||
pub enum TopLevelDef {
|
||||
Class {
|
||||
// name for error messages and symbols
|
||||
|
@ -33,13 +43,15 @@ pub enum TopLevelDef {
|
|||
// ancestor classes, including itself.
|
||||
ancestors: Vec<TypeAnnotation>,
|
||||
// symbol resolver of the module defined the class, none if it is built-in type
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
||||
},
|
||||
Function {
|
||||
// prefix for symbol, should be unique globally, and not ending with numbers
|
||||
name: String,
|
||||
// function signature.
|
||||
signature: Type,
|
||||
// instantiated type variable IDs
|
||||
var_id: Vec<u32>,
|
||||
/// Function instance to symbol mapping
|
||||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class.
|
||||
|
@ -49,11 +61,10 @@ pub enum TopLevelDef {
|
|||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class. Excluding rigid type
|
||||
/// variables.
|
||||
/// Value: AST annotated with types together with a unification table index. Could contain
|
||||
/// rigid type variables that would be substituted when the function is instantiated.
|
||||
instance_to_stmt: HashMap<String, (Stmt<Option<Type>>, usize)>,
|
||||
instance_to_stmt: HashMap<String, FunInstance>,
|
||||
// symbol resolver of the module defined the class
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
||||
},
|
||||
Initializer {
|
||||
class_id: DefinitionId,
|
||||
|
@ -171,7 +182,7 @@ impl TopLevelComposer {
|
|||
/// when first regitering, the type_vars, fields, methods, ancestors are invalid
|
||||
pub fn make_top_level_class_def(
|
||||
index: usize,
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
||||
name: &str,
|
||||
) -> TopLevelDef {
|
||||
TopLevelDef::Class {
|
||||
|
@ -189,11 +200,12 @@ impl TopLevelComposer {
|
|||
pub fn make_top_level_function_def(
|
||||
name: String,
|
||||
ty: Type,
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
||||
) -> TopLevelDef {
|
||||
TopLevelDef::Function {
|
||||
name,
|
||||
signature: ty,
|
||||
var_id: Default::default(),
|
||||
instance_to_symbol: Default::default(),
|
||||
instance_to_stmt: Default::default(),
|
||||
resolver,
|
||||
|
@ -214,7 +226,7 @@ impl TopLevelComposer {
|
|||
pub fn register_top_level(
|
||||
&mut self,
|
||||
ast: ast::Stmt<()>,
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
||||
) -> Result<(String, DefinitionId), String> {
|
||||
let mut defined_class_name: HashSet<String> = HashSet::new();
|
||||
let mut defined_class_method_name: HashSet<String> = HashSet::new();
|
||||
|
@ -363,7 +375,7 @@ impl TopLevelComposer {
|
|||
continue;
|
||||
}
|
||||
};
|
||||
let class_resolver = class_resolver.as_ref().unwrap().lock();
|
||||
let class_resolver = class_resolver.as_ref().unwrap();
|
||||
let class_resolver = class_resolver.deref();
|
||||
|
||||
let mut is_generic = false;
|
||||
|
@ -467,7 +479,7 @@ impl TopLevelComposer {
|
|||
continue;
|
||||
}
|
||||
};
|
||||
let class_resolver = class_resolver.as_ref().unwrap().lock();
|
||||
let class_resolver = class_resolver.as_ref().unwrap();
|
||||
let class_resolver = class_resolver.deref();
|
||||
|
||||
let mut has_base = false;
|
||||
|
@ -563,7 +575,7 @@ impl TopLevelComposer {
|
|||
if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node {
|
||||
let resolver = resolver.as_ref();
|
||||
let resolver = resolver.unwrap();
|
||||
let resolver = resolver.deref().lock();
|
||||
let resolver = resolver.deref();
|
||||
let function_resolver = resolver.deref();
|
||||
|
||||
// occured type vars should not be handled separately
|
||||
|
@ -708,8 +720,7 @@ impl TopLevelComposer {
|
|||
unreachable!("here must be class def ast");
|
||||
};
|
||||
let class_resolver = class_resolver.as_ref().unwrap();
|
||||
let mut class_resolver = class_resolver.lock();
|
||||
let class_resolver = class_resolver.deref_mut();
|
||||
let class_resolver = class_resolver;
|
||||
|
||||
for b in class_body_ast {
|
||||
if let ast::StmtKind::FunctionDef { args, returns, name, body, .. } = &b.node {
|
||||
|
|
|
@ -21,7 +21,7 @@ pub enum TypeAnnotation {
|
|||
}
|
||||
|
||||
pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||
resolver: &dyn SymbolResolver,
|
||||
resolver: &Box<dyn SymbolResolver + Send + Sync>,
|
||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
|
|
|
@ -38,7 +38,7 @@ pub struct PrimitiveStore {
|
|||
}
|
||||
|
||||
pub struct FunctionData {
|
||||
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||
pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
|
||||
pub return_type: Option<Type>,
|
||||
pub bound_variables: Vec<Type>,
|
||||
}
|
||||
|
|
|
@ -125,7 +125,7 @@ impl Unifier {
|
|||
ty: Type,
|
||||
type_cache: &mut HashMap<Type, Type>,
|
||||
) -> Type {
|
||||
let representative = self.get_representative(ty);
|
||||
let representative = unifier.get_representative(ty);
|
||||
type_cache.get(&representative).cloned().unwrap_or_else(|| {
|
||||
// put in a placeholder first to handle possible recursive type
|
||||
let placeholder = self.get_fresh_var().0;
|
||||
|
@ -183,7 +183,7 @@ impl Unifier {
|
|||
TypeEnum::TVirtual { ty: self.copy_from(unifier, *ty, type_cache) }
|
||||
}
|
||||
};
|
||||
let ty = unifier.add_ty(ty);
|
||||
let ty = self.add_ty(ty);
|
||||
self.unify(placeholder, ty).unwrap();
|
||||
type_cache.insert(representative, ty);
|
||||
ty
|
||||
|
|
Loading…
Reference in New Issue