forked from M-Labs/nac3
codegen for function call
This commit is contained in:
parent
34d3317ea0
commit
7a38ab3119
@ -1,12 +1,41 @@
|
||||
use std::{convert::TryInto, iter::once};
|
||||
|
||||
use crate::top_level::{CodeGenContext, TopLevelDef};
|
||||
use crate::typecheck::typedef::{Type, TypeEnum};
|
||||
use inkwell::{types::BasicType, values::BasicValueEnum};
|
||||
use crate::{
|
||||
top_level::DefinitionId,
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
use crate::{
|
||||
top_level::{CodeGenContext, TopLevelDef},
|
||||
typecheck::typedef::FunSignature,
|
||||
};
|
||||
use inkwell::{
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::BasicValueEnum,
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::{chain, izip, zip, Itertools};
|
||||
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
|
||||
|
||||
impl<'ctx> CodeGenContext<'ctx> {
|
||||
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
|
||||
let mut vars = obj
|
||||
.map(|ty| {
|
||||
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
|
||||
params.clone()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
vars.extend(fun.vars.iter());
|
||||
let sorted = vars.keys().sorted();
|
||||
sorted
|
||||
.map(|id| {
|
||||
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
|
||||
})
|
||||
.join(", ")
|
||||
}
|
||||
|
||||
fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize {
|
||||
let obj_id = match &*self.unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { obj_id, .. } => *obj_id,
|
||||
@ -22,17 +51,88 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
index
|
||||
}
|
||||
|
||||
fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> {
|
||||
use TypeEnum::*;
|
||||
// we assume the type cache should already contain primitive types,
|
||||
// and they should be passed by value instead of passing as pointer.
|
||||
self.type_cache.get(&ty).cloned().unwrap_or_else(|| match &*self.unifier.get_ty(ty) {
|
||||
TObj { obj_id, fields, .. } => {
|
||||
// a struct with fields in the order of declaration
|
||||
let defs = self.top_level.definitions.read();
|
||||
let definition = defs.get(obj_id.0).unwrap();
|
||||
let ty = if let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read()
|
||||
{
|
||||
let fields = fields.borrow();
|
||||
let fields =
|
||||
fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec();
|
||||
self.ctx
|
||||
.struct_type(&fields, false)
|
||||
.ptr_type(AddressSpace::Generic)
|
||||
.into()
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
ty
|
||||
}
|
||||
TTuple { ty } => {
|
||||
// a struct with fields in the order present in the tuple
|
||||
let fields = ty.iter().map(|ty| self.get_llvm_type(*ty)).collect_vec();
|
||||
self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||
}
|
||||
TList { ty } => {
|
||||
// a struct with an integer and a pointer to an array
|
||||
let element_type = self.get_llvm_type(*ty);
|
||||
let fields = [
|
||||
self.ctx.i32_type().into(),
|
||||
element_type.ptr_type(AddressSpace::Generic).into(),
|
||||
];
|
||||
self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
})
|
||||
}
|
||||
|
||||
fn gen_call(
|
||||
&mut self,
|
||||
obj: Option<(Type, BasicValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
params: &[BasicValueEnum<'ctx>],
|
||||
ret: Type,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
|
||||
let defs = self.top_level.definitions.read();
|
||||
let definition = defs.get(fun.1.0).unwrap();
|
||||
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
|
||||
// TODO: codegen for function that are not yet generated
|
||||
let symbol = instance_to_symbol.get(&key).unwrap();
|
||||
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
|
||||
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
|
||||
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) {
|
||||
self.ctx.void_type().fn_type(¶ms, false)
|
||||
} else {
|
||||
self.get_llvm_type(ret).fn_type(¶ms, false)
|
||||
};
|
||||
self.module.add_function(symbol, fun_ty, None)
|
||||
});
|
||||
// TODO: deal with default parameters and reordering based on keys
|
||||
self.builder.build_call(fun_val, params, "call").try_as_basic_value().left()
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
val
|
||||
}
|
||||
|
||||
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
|
||||
match value {
|
||||
Constant::Bool(v) => {
|
||||
assert!(self.unifier.unioned(ty, self.top_level.primitives.bool));
|
||||
assert!(self.unifier.unioned(ty, self.primitives.bool));
|
||||
let ty = self.ctx.bool_type();
|
||||
ty.const_int(if *v { 1 } else { 0 }, false).into()
|
||||
}
|
||||
Constant::Int(v) => {
|
||||
let ty = if self.unifier.unioned(ty, self.top_level.primitives.int32) {
|
||||
let ty = if self.unifier.unioned(ty, self.primitives.int32) {
|
||||
self.ctx.i32_type()
|
||||
} else if self.unifier.unioned(ty, self.top_level.primitives.int64) {
|
||||
} else if self.unifier.unioned(ty, self.primitives.int64) {
|
||||
self.ctx.i64_type()
|
||||
} else {
|
||||
unreachable!();
|
||||
@ -40,7 +140,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
ty.const_int(v.try_into().unwrap(), false).into()
|
||||
}
|
||||
Constant::Float(v) => {
|
||||
assert!(self.unifier.unioned(ty, self.top_level.primitives.float));
|
||||
assert!(self.unifier.unioned(ty, self.primitives.float));
|
||||
let ty = self.ctx.f64_type();
|
||||
ty.const_float(*v).into()
|
||||
}
|
||||
@ -134,7 +234,6 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
|
||||
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> {
|
||||
let zero = self.ctx.i32_type().const_int(0, false);
|
||||
let primitives = &self.top_level.primitives;
|
||||
match &expr.node {
|
||||
ExprKind::Constant { value, .. } => {
|
||||
let ty = expr.custom.unwrap();
|
||||
@ -146,25 +245,36 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
}
|
||||
ExprKind::List { elts, .. } => {
|
||||
// this shall be optimized later for constant primitive lists...
|
||||
// we should use memcpy for that instead of generating thousands of stores
|
||||
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
||||
let ty = if elements.is_empty() {
|
||||
self.ctx.i32_type().into()
|
||||
} else {
|
||||
elements[0].get_type()
|
||||
};
|
||||
// this length includes the leading length element
|
||||
let arr_ptr = self.builder.build_array_alloca(
|
||||
ty,
|
||||
self.ctx.i32_type().const_int(elements.len() as u64, false),
|
||||
"tmparr",
|
||||
);
|
||||
let arr_ty = self.ctx.struct_type(
|
||||
&[self.ctx.i32_type().into(), ty.array_type(elements.len() as u32).into()],
|
||||
&[
|
||||
self.ctx.i32_type().into(),
|
||||
ty.ptr_type(AddressSpace::Generic).into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
let arr_ptr = self.builder.build_alloca(arr_ty, "tmparr");
|
||||
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
|
||||
unsafe {
|
||||
let len_ptr = arr_ptr
|
||||
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(0u64, false)]);
|
||||
self.builder.build_store(
|
||||
len_ptr,
|
||||
arr_str_ptr.const_in_bounds_gep(&[zero, zero]),
|
||||
self.ctx.i32_type().const_int(elements.len() as u64, false),
|
||||
);
|
||||
self.builder.build_store(
|
||||
arr_str_ptr
|
||||
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(1, false)]),
|
||||
arr_ptr,
|
||||
);
|
||||
let arr_offset = self.ctx.i32_type().const_int(1, false);
|
||||
for (i, v) in elements.iter().enumerate() {
|
||||
let ptr = self.builder.build_in_bounds_gep(
|
||||
@ -175,7 +285,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
self.builder.build_store(ptr, *v);
|
||||
}
|
||||
}
|
||||
arr_ptr.into()
|
||||
arr_str_ptr.into()
|
||||
}
|
||||
ExprKind::Tuple { elts, .. } => {
|
||||
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
||||
@ -266,9 +376,9 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
// when doing code generation for function instances
|
||||
if ty1 != ty2 {
|
||||
unimplemented!()
|
||||
} else if [primitives.int32, primitives.int64].contains(&ty1) {
|
||||
} else if [self.primitives.int32, self.primitives.int64].contains(&ty1) {
|
||||
self.gen_int_ops(op, left, right)
|
||||
} else if primitives.float == ty1 {
|
||||
} else if self.primitives.float == ty1 {
|
||||
self.gen_float_ops(op, left, right)
|
||||
} else {
|
||||
unimplemented!()
|
||||
@ -277,7 +387,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
ExprKind::UnaryOp { op, operand } => {
|
||||
let ty = self.unifier.get_representative(operand.custom.unwrap());
|
||||
let val = self.gen_expr(operand);
|
||||
if ty == primitives.bool {
|
||||
if ty == self.primitives.bool {
|
||||
let val =
|
||||
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
||||
match op {
|
||||
@ -286,7 +396,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
}
|
||||
_ => val.into(),
|
||||
}
|
||||
} else if [primitives.int32, primitives.int64].contains(&ty) {
|
||||
} else if [self.primitives.int32, self.primitives.int64].contains(&ty) {
|
||||
let val =
|
||||
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
||||
match op {
|
||||
@ -303,7 +413,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
.into(),
|
||||
_ => val.into(),
|
||||
}
|
||||
} else if ty == primitives.float {
|
||||
} else if ty == self.primitives.float {
|
||||
let val = if let BasicValueEnum::FloatValue(val) = val {
|
||||
val
|
||||
} else {
|
||||
@ -334,7 +444,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
)
|
||||
.fold(None, |prev, (lhs, rhs, op)| {
|
||||
let ty = lhs.custom.unwrap();
|
||||
let current = if [primitives.int32, primitives.int64, primitives.bool]
|
||||
let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
|
||||
.contains(&ty)
|
||||
{
|
||||
let (lhs, rhs) =
|
||||
@ -355,7 +465,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.builder.build_int_compare(op, lhs, rhs, "cmp")
|
||||
} else if ty == primitives.float {
|
||||
} else if ty == self.primitives.float {
|
||||
let (lhs, rhs) = if let (
|
||||
BasicValueEnum::FloatValue(lhs),
|
||||
BasicValueEnum::FloatValue(rhs),
|
||||
|
@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc};
|
||||
use super::typecheck::type_inferencer::PrimitiveStore;
|
||||
use super::typecheck::typedef::{SharedUnifier, Type, Unifier};
|
||||
use crate::symbol_resolver::SymbolResolver;
|
||||
use inkwell::{builder::Builder, context::Context, module::Module, values::PointerValue};
|
||||
use inkwell::{builder::Builder, context::Context, module::Module, types::BasicTypeEnum, values::PointerValue};
|
||||
use parking_lot::RwLock;
|
||||
use rustpython_parser::ast::Stmt;
|
||||
|
||||
@ -16,14 +16,17 @@ pub enum TopLevelDef {
|
||||
object_id: DefinitionId,
|
||||
// type variables bounded to the class.
|
||||
type_vars: Vec<Type>,
|
||||
// class fields and method signature.
|
||||
// class fields
|
||||
fields: Vec<(String, Type)>,
|
||||
// class methods, pointing to the corresponding function definition.
|
||||
methods: Vec<(String, DefinitionId)>,
|
||||
methods: Vec<(String, Type, DefinitionId)>,
|
||||
// ancestor classes, including itself.
|
||||
ancestors: Vec<DefinitionId>,
|
||||
},
|
||||
Function {
|
||||
// prefix for symbol, should be unique globally, and not ending with numbers
|
||||
name: String,
|
||||
// function signature.
|
||||
signature: Type,
|
||||
/// Function instance to symbol mapping
|
||||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||
@ -48,7 +51,6 @@ pub struct CodeGenTask {
|
||||
}
|
||||
|
||||
pub struct TopLevelContext {
|
||||
pub primitives: PrimitiveStore,
|
||||
pub definitions: Arc<RwLock<Vec<RwLock<TopLevelDef>>>>,
|
||||
pub unifiers: Arc<RwLock<Vec<SharedUnifier>>>,
|
||||
}
|
||||
@ -61,4 +63,6 @@ pub struct CodeGenContext<'ctx> {
|
||||
pub unifier: Unifier,
|
||||
pub resolver: Box<dyn SymbolResolver>,
|
||||
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
|
||||
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||
pub primitives: PrimitiveStore,
|
||||
}
|
||||
|
@ -30,6 +30,8 @@ pub struct Call {
|
||||
pub struct FuncArg {
|
||||
pub name: String,
|
||||
pub ty: Type,
|
||||
// TODO: change this to an optional value
|
||||
// for primitive types
|
||||
pub is_optional: bool,
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
|
||||
pub struct UnificationKey(usize);
|
||||
|
||||
pub struct UnificationTable<V> {
|
||||
|
Loading…
Reference in New Issue
Block a user