forked from M-Labs/nac3
codegen for function call
This commit is contained in:
parent
34d3317ea0
commit
7a38ab3119
|
@ -1,12 +1,41 @@
|
||||||
use std::{convert::TryInto, iter::once};
|
use std::{convert::TryInto, iter::once};
|
||||||
|
|
||||||
use crate::top_level::{CodeGenContext, TopLevelDef};
|
use crate::{
|
||||||
use crate::typecheck::typedef::{Type, TypeEnum};
|
top_level::DefinitionId,
|
||||||
use inkwell::{types::BasicType, values::BasicValueEnum};
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
|
top_level::{CodeGenContext, TopLevelDef},
|
||||||
|
typecheck::typedef::FunSignature,
|
||||||
|
};
|
||||||
|
use inkwell::{
|
||||||
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
values::BasicValueEnum,
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
use itertools::{chain, izip, zip, Itertools};
|
use itertools::{chain, izip, zip, Itertools};
|
||||||
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
|
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
|
||||||
|
|
||||||
impl<'ctx> CodeGenContext<'ctx> {
|
impl<'ctx> CodeGenContext<'ctx> {
|
||||||
|
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
|
||||||
|
let mut vars = obj
|
||||||
|
.map(|ty| {
|
||||||
|
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
|
||||||
|
params.clone()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
vars.extend(fun.vars.iter());
|
||||||
|
let sorted = vars.keys().sorted();
|
||||||
|
sorted
|
||||||
|
.map(|id| {
|
||||||
|
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
|
||||||
|
})
|
||||||
|
.join(", ")
|
||||||
|
}
|
||||||
|
|
||||||
fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize {
|
fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize {
|
||||||
let obj_id = match &*self.unifier.get_ty(ty) {
|
let obj_id = match &*self.unifier.get_ty(ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } => *obj_id,
|
TypeEnum::TObj { obj_id, .. } => *obj_id,
|
||||||
|
@ -22,17 +51,88 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
index
|
index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> {
|
||||||
|
use TypeEnum::*;
|
||||||
|
// we assume the type cache should already contain primitive types,
|
||||||
|
// and they should be passed by value instead of passing as pointer.
|
||||||
|
self.type_cache.get(&ty).cloned().unwrap_or_else(|| match &*self.unifier.get_ty(ty) {
|
||||||
|
TObj { obj_id, fields, .. } => {
|
||||||
|
// a struct with fields in the order of declaration
|
||||||
|
let defs = self.top_level.definitions.read();
|
||||||
|
let definition = defs.get(obj_id.0).unwrap();
|
||||||
|
let ty = if let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read()
|
||||||
|
{
|
||||||
|
let fields = fields.borrow();
|
||||||
|
let fields =
|
||||||
|
fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec();
|
||||||
|
self.ctx
|
||||||
|
.struct_type(&fields, false)
|
||||||
|
.ptr_type(AddressSpace::Generic)
|
||||||
|
.into()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
ty
|
||||||
|
}
|
||||||
|
TTuple { ty } => {
|
||||||
|
// a struct with fields in the order present in the tuple
|
||||||
|
let fields = ty.iter().map(|ty| self.get_llvm_type(*ty)).collect_vec();
|
||||||
|
self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||||
|
}
|
||||||
|
TList { ty } => {
|
||||||
|
// a struct with an integer and a pointer to an array
|
||||||
|
let element_type = self.get_llvm_type(*ty);
|
||||||
|
let fields = [
|
||||||
|
self.ctx.i32_type().into(),
|
||||||
|
element_type.ptr_type(AddressSpace::Generic).into(),
|
||||||
|
];
|
||||||
|
self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_call(
|
||||||
|
&mut self,
|
||||||
|
obj: Option<(Type, BasicValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
params: &[BasicValueEnum<'ctx>],
|
||||||
|
ret: Type,
|
||||||
|
) -> Option<BasicValueEnum<'ctx>> {
|
||||||
|
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
|
||||||
|
let defs = self.top_level.definitions.read();
|
||||||
|
let definition = defs.get(fun.1.0).unwrap();
|
||||||
|
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
|
||||||
|
// TODO: codegen for function that are not yet generated
|
||||||
|
let symbol = instance_to_symbol.get(&key).unwrap();
|
||||||
|
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
|
||||||
|
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
|
||||||
|
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) {
|
||||||
|
self.ctx.void_type().fn_type(¶ms, false)
|
||||||
|
} else {
|
||||||
|
self.get_llvm_type(ret).fn_type(¶ms, false)
|
||||||
|
};
|
||||||
|
self.module.add_function(symbol, fun_ty, None)
|
||||||
|
});
|
||||||
|
// TODO: deal with default parameters and reordering based on keys
|
||||||
|
self.builder.build_call(fun_val, params, "call").try_as_basic_value().left()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
val
|
||||||
|
}
|
||||||
|
|
||||||
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
|
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
|
||||||
match value {
|
match value {
|
||||||
Constant::Bool(v) => {
|
Constant::Bool(v) => {
|
||||||
assert!(self.unifier.unioned(ty, self.top_level.primitives.bool));
|
assert!(self.unifier.unioned(ty, self.primitives.bool));
|
||||||
let ty = self.ctx.bool_type();
|
let ty = self.ctx.bool_type();
|
||||||
ty.const_int(if *v { 1 } else { 0 }, false).into()
|
ty.const_int(if *v { 1 } else { 0 }, false).into()
|
||||||
}
|
}
|
||||||
Constant::Int(v) => {
|
Constant::Int(v) => {
|
||||||
let ty = if self.unifier.unioned(ty, self.top_level.primitives.int32) {
|
let ty = if self.unifier.unioned(ty, self.primitives.int32) {
|
||||||
self.ctx.i32_type()
|
self.ctx.i32_type()
|
||||||
} else if self.unifier.unioned(ty, self.top_level.primitives.int64) {
|
} else if self.unifier.unioned(ty, self.primitives.int64) {
|
||||||
self.ctx.i64_type()
|
self.ctx.i64_type()
|
||||||
} else {
|
} else {
|
||||||
unreachable!();
|
unreachable!();
|
||||||
|
@ -40,7 +140,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
ty.const_int(v.try_into().unwrap(), false).into()
|
ty.const_int(v.try_into().unwrap(), false).into()
|
||||||
}
|
}
|
||||||
Constant::Float(v) => {
|
Constant::Float(v) => {
|
||||||
assert!(self.unifier.unioned(ty, self.top_level.primitives.float));
|
assert!(self.unifier.unioned(ty, self.primitives.float));
|
||||||
let ty = self.ctx.f64_type();
|
let ty = self.ctx.f64_type();
|
||||||
ty.const_float(*v).into()
|
ty.const_float(*v).into()
|
||||||
}
|
}
|
||||||
|
@ -134,7 +234,6 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
|
|
||||||
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> {
|
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> {
|
||||||
let zero = self.ctx.i32_type().const_int(0, false);
|
let zero = self.ctx.i32_type().const_int(0, false);
|
||||||
let primitives = &self.top_level.primitives;
|
|
||||||
match &expr.node {
|
match &expr.node {
|
||||||
ExprKind::Constant { value, .. } => {
|
ExprKind::Constant { value, .. } => {
|
||||||
let ty = expr.custom.unwrap();
|
let ty = expr.custom.unwrap();
|
||||||
|
@ -146,25 +245,36 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
}
|
}
|
||||||
ExprKind::List { elts, .. } => {
|
ExprKind::List { elts, .. } => {
|
||||||
// this shall be optimized later for constant primitive lists...
|
// this shall be optimized later for constant primitive lists...
|
||||||
|
// we should use memcpy for that instead of generating thousands of stores
|
||||||
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
||||||
let ty = if elements.is_empty() {
|
let ty = if elements.is_empty() {
|
||||||
self.ctx.i32_type().into()
|
self.ctx.i32_type().into()
|
||||||
} else {
|
} else {
|
||||||
elements[0].get_type()
|
elements[0].get_type()
|
||||||
};
|
};
|
||||||
// this length includes the leading length element
|
let arr_ptr = self.builder.build_array_alloca(
|
||||||
|
ty,
|
||||||
|
self.ctx.i32_type().const_int(elements.len() as u64, false),
|
||||||
|
"tmparr",
|
||||||
|
);
|
||||||
let arr_ty = self.ctx.struct_type(
|
let arr_ty = self.ctx.struct_type(
|
||||||
&[self.ctx.i32_type().into(), ty.array_type(elements.len() as u32).into()],
|
&[
|
||||||
|
self.ctx.i32_type().into(),
|
||||||
|
ty.ptr_type(AddressSpace::Generic).into(),
|
||||||
|
],
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
let arr_ptr = self.builder.build_alloca(arr_ty, "tmparr");
|
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
|
||||||
unsafe {
|
unsafe {
|
||||||
let len_ptr = arr_ptr
|
|
||||||
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(0u64, false)]);
|
|
||||||
self.builder.build_store(
|
self.builder.build_store(
|
||||||
len_ptr,
|
arr_str_ptr.const_in_bounds_gep(&[zero, zero]),
|
||||||
self.ctx.i32_type().const_int(elements.len() as u64, false),
|
self.ctx.i32_type().const_int(elements.len() as u64, false),
|
||||||
);
|
);
|
||||||
|
self.builder.build_store(
|
||||||
|
arr_str_ptr
|
||||||
|
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(1, false)]),
|
||||||
|
arr_ptr,
|
||||||
|
);
|
||||||
let arr_offset = self.ctx.i32_type().const_int(1, false);
|
let arr_offset = self.ctx.i32_type().const_int(1, false);
|
||||||
for (i, v) in elements.iter().enumerate() {
|
for (i, v) in elements.iter().enumerate() {
|
||||||
let ptr = self.builder.build_in_bounds_gep(
|
let ptr = self.builder.build_in_bounds_gep(
|
||||||
|
@ -175,7 +285,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
self.builder.build_store(ptr, *v);
|
self.builder.build_store(ptr, *v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
arr_ptr.into()
|
arr_str_ptr.into()
|
||||||
}
|
}
|
||||||
ExprKind::Tuple { elts, .. } => {
|
ExprKind::Tuple { elts, .. } => {
|
||||||
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
||||||
|
@ -266,9 +376,9 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
// when doing code generation for function instances
|
// when doing code generation for function instances
|
||||||
if ty1 != ty2 {
|
if ty1 != ty2 {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
} else if [primitives.int32, primitives.int64].contains(&ty1) {
|
} else if [self.primitives.int32, self.primitives.int64].contains(&ty1) {
|
||||||
self.gen_int_ops(op, left, right)
|
self.gen_int_ops(op, left, right)
|
||||||
} else if primitives.float == ty1 {
|
} else if self.primitives.float == ty1 {
|
||||||
self.gen_float_ops(op, left, right)
|
self.gen_float_ops(op, left, right)
|
||||||
} else {
|
} else {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
|
@ -277,7 +387,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
ExprKind::UnaryOp { op, operand } => {
|
ExprKind::UnaryOp { op, operand } => {
|
||||||
let ty = self.unifier.get_representative(operand.custom.unwrap());
|
let ty = self.unifier.get_representative(operand.custom.unwrap());
|
||||||
let val = self.gen_expr(operand);
|
let val = self.gen_expr(operand);
|
||||||
if ty == primitives.bool {
|
if ty == self.primitives.bool {
|
||||||
let val =
|
let val =
|
||||||
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
||||||
match op {
|
match op {
|
||||||
|
@ -286,7 +396,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
}
|
}
|
||||||
_ => val.into(),
|
_ => val.into(),
|
||||||
}
|
}
|
||||||
} else if [primitives.int32, primitives.int64].contains(&ty) {
|
} else if [self.primitives.int32, self.primitives.int64].contains(&ty) {
|
||||||
let val =
|
let val =
|
||||||
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
||||||
match op {
|
match op {
|
||||||
|
@ -303,7 +413,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
.into(),
|
.into(),
|
||||||
_ => val.into(),
|
_ => val.into(),
|
||||||
}
|
}
|
||||||
} else if ty == primitives.float {
|
} else if ty == self.primitives.float {
|
||||||
let val = if let BasicValueEnum::FloatValue(val) = val {
|
let val = if let BasicValueEnum::FloatValue(val) = val {
|
||||||
val
|
val
|
||||||
} else {
|
} else {
|
||||||
|
@ -334,7 +444,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
)
|
)
|
||||||
.fold(None, |prev, (lhs, rhs, op)| {
|
.fold(None, |prev, (lhs, rhs, op)| {
|
||||||
let ty = lhs.custom.unwrap();
|
let ty = lhs.custom.unwrap();
|
||||||
let current = if [primitives.int32, primitives.int64, primitives.bool]
|
let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
|
||||||
.contains(&ty)
|
.contains(&ty)
|
||||||
{
|
{
|
||||||
let (lhs, rhs) =
|
let (lhs, rhs) =
|
||||||
|
@ -355,7 +465,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
self.builder.build_int_compare(op, lhs, rhs, "cmp")
|
self.builder.build_int_compare(op, lhs, rhs, "cmp")
|
||||||
} else if ty == primitives.float {
|
} else if ty == self.primitives.float {
|
||||||
let (lhs, rhs) = if let (
|
let (lhs, rhs) = if let (
|
||||||
BasicValueEnum::FloatValue(lhs),
|
BasicValueEnum::FloatValue(lhs),
|
||||||
BasicValueEnum::FloatValue(rhs),
|
BasicValueEnum::FloatValue(rhs),
|
||||||
|
|
|
@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc};
|
||||||
use super::typecheck::type_inferencer::PrimitiveStore;
|
use super::typecheck::type_inferencer::PrimitiveStore;
|
||||||
use super::typecheck::typedef::{SharedUnifier, Type, Unifier};
|
use super::typecheck::typedef::{SharedUnifier, Type, Unifier};
|
||||||
use crate::symbol_resolver::SymbolResolver;
|
use crate::symbol_resolver::SymbolResolver;
|
||||||
use inkwell::{builder::Builder, context::Context, module::Module, values::PointerValue};
|
use inkwell::{builder::Builder, context::Context, module::Module, types::BasicTypeEnum, values::PointerValue};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use rustpython_parser::ast::Stmt;
|
use rustpython_parser::ast::Stmt;
|
||||||
|
|
||||||
|
@ -16,14 +16,17 @@ pub enum TopLevelDef {
|
||||||
object_id: DefinitionId,
|
object_id: DefinitionId,
|
||||||
// type variables bounded to the class.
|
// type variables bounded to the class.
|
||||||
type_vars: Vec<Type>,
|
type_vars: Vec<Type>,
|
||||||
// class fields and method signature.
|
// class fields
|
||||||
fields: Vec<(String, Type)>,
|
fields: Vec<(String, Type)>,
|
||||||
// class methods, pointing to the corresponding function definition.
|
// class methods, pointing to the corresponding function definition.
|
||||||
methods: Vec<(String, DefinitionId)>,
|
methods: Vec<(String, Type, DefinitionId)>,
|
||||||
// ancestor classes, including itself.
|
// ancestor classes, including itself.
|
||||||
ancestors: Vec<DefinitionId>,
|
ancestors: Vec<DefinitionId>,
|
||||||
},
|
},
|
||||||
Function {
|
Function {
|
||||||
|
// prefix for symbol, should be unique globally, and not ending with numbers
|
||||||
|
name: String,
|
||||||
|
// function signature.
|
||||||
signature: Type,
|
signature: Type,
|
||||||
/// Function instance to symbol mapping
|
/// Function instance to symbol mapping
|
||||||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||||
|
@ -48,7 +51,6 @@ pub struct CodeGenTask {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TopLevelContext {
|
pub struct TopLevelContext {
|
||||||
pub primitives: PrimitiveStore,
|
|
||||||
pub definitions: Arc<RwLock<Vec<RwLock<TopLevelDef>>>>,
|
pub definitions: Arc<RwLock<Vec<RwLock<TopLevelDef>>>>,
|
||||||
pub unifiers: Arc<RwLock<Vec<SharedUnifier>>>,
|
pub unifiers: Arc<RwLock<Vec<SharedUnifier>>>,
|
||||||
}
|
}
|
||||||
|
@ -61,4 +63,6 @@ pub struct CodeGenContext<'ctx> {
|
||||||
pub unifier: Unifier,
|
pub unifier: Unifier,
|
||||||
pub resolver: Box<dyn SymbolResolver>,
|
pub resolver: Box<dyn SymbolResolver>,
|
||||||
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
|
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
|
||||||
|
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||||
|
pub primitives: PrimitiveStore,
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,8 @@ pub struct Call {
|
||||||
pub struct FuncArg {
|
pub struct FuncArg {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub ty: Type,
|
pub ty: Type,
|
||||||
|
// TODO: change this to an optional value
|
||||||
|
// for primitive types
|
||||||
pub is_optional: bool,
|
pub is_optional: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
|
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
|
||||||
pub struct UnificationKey(usize);
|
pub struct UnificationKey(usize);
|
||||||
|
|
||||||
pub struct UnificationTable<V> {
|
pub struct UnificationTable<V> {
|
||||||
|
|
Loading…
Reference in New Issue