hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
4 changed files with 143 additions and 27 deletions
Showing only changes of commit 7a38ab3119 - Show all commits

View File

@ -1,12 +1,41 @@
use std::{convert::TryInto, iter::once};
use crate::top_level::{CodeGenContext, TopLevelDef};
use crate::typecheck::typedef::{Type, TypeEnum};
use inkwell::{types::BasicType, values::BasicValueEnum};
use crate::{
top_level::DefinitionId,
typecheck::typedef::{Type, TypeEnum},
};
use crate::{
top_level::{CodeGenContext, TopLevelDef},
typecheck::typedef::FunSignature,
};
use inkwell::{
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
};
use itertools::{chain, izip, zip, Itertools};
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
impl<'ctx> CodeGenContext<'ctx> {
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
let mut vars = obj
.map(|ty| {
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
params.clone()
} else {
unreachable!()
}
})
.unwrap_or_default();
vars.extend(fun.vars.iter());
let sorted = vars.keys().sorted();
sorted
.map(|id| {
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
})
.join(", ")
}
fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize {
let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id,
@ -22,17 +51,88 @@ impl<'ctx> CodeGenContext<'ctx> {
index
}
fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> {
use TypeEnum::*;
// we assume the type cache should already contain primitive types,
// and they should be passed by value instead of passing as pointer.
self.type_cache.get(&ty).cloned().unwrap_or_else(|| match &*self.unifier.get_ty(ty) {
TObj { obj_id, fields, .. } => {
// a struct with fields in the order of declaration
let defs = self.top_level.definitions.read();
let definition = defs.get(obj_id.0).unwrap();
let ty = if let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read()
{
let fields = fields.borrow();
let fields =
fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec();
self.ctx
.struct_type(&fields, false)
.ptr_type(AddressSpace::Generic)
.into()
} else {
unreachable!()
};
ty
}
TTuple { ty } => {
// a struct with fields in the order present in the tuple
let fields = ty.iter().map(|ty| self.get_llvm_type(*ty)).collect_vec();
self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
}
TList { ty } => {
// a struct with an integer and a pointer to an array
let element_type = self.get_llvm_type(*ty);
let fields = [
self.ctx.i32_type().into(),
element_type.ptr_type(AddressSpace::Generic).into(),
];
self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
}
_ => unreachable!(),
})
}
fn gen_call(
&mut self,
obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: &[BasicValueEnum<'ctx>],
ret: Type,
) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
let defs = self.top_level.definitions.read();
let definition = defs.get(fun.1.0).unwrap();
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
// TODO: codegen for function that are not yet generated
let symbol = instance_to_symbol.get(&key).unwrap();
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) {
self.ctx.void_type().fn_type(&params, false)
} else {
self.get_llvm_type(ret).fn_type(&params, false)
};
self.module.add_function(symbol, fun_ty, None)
});
// TODO: deal with default parameters and reordering based on keys
self.builder.build_call(fun_val, params, "call").try_as_basic_value().left()
} else {
unreachable!()
};
val
}
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
match value {
Constant::Bool(v) => {
assert!(self.unifier.unioned(ty, self.top_level.primitives.bool));
assert!(self.unifier.unioned(ty, self.primitives.bool));
let ty = self.ctx.bool_type();
ty.const_int(if *v { 1 } else { 0 }, false).into()
}
Constant::Int(v) => {
let ty = if self.unifier.unioned(ty, self.top_level.primitives.int32) {
let ty = if self.unifier.unioned(ty, self.primitives.int32) {
self.ctx.i32_type()
} else if self.unifier.unioned(ty, self.top_level.primitives.int64) {
} else if self.unifier.unioned(ty, self.primitives.int64) {
self.ctx.i64_type()
} else {
unreachable!();
@ -40,7 +140,7 @@ impl<'ctx> CodeGenContext<'ctx> {
ty.const_int(v.try_into().unwrap(), false).into()
}
Constant::Float(v) => {
assert!(self.unifier.unioned(ty, self.top_level.primitives.float));
assert!(self.unifier.unioned(ty, self.primitives.float));
let ty = self.ctx.f64_type();
ty.const_float(*v).into()
}
@ -134,7 +234,6 @@ impl<'ctx> CodeGenContext<'ctx> {
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> {
let zero = self.ctx.i32_type().const_int(0, false);
let primitives = &self.top_level.primitives;
match &expr.node {
ExprKind::Constant { value, .. } => {
let ty = expr.custom.unwrap();
@ -146,25 +245,36 @@ impl<'ctx> CodeGenContext<'ctx> {
}
ExprKind::List { elts, .. } => {
// this shall be optimized later for constant primitive lists...
// we should use memcpy for that instead of generating thousands of stores
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
let ty = if elements.is_empty() {
self.ctx.i32_type().into()
} else {
elements[0].get_type()
};
// this length includes the leading length element
let arr_ptr = self.builder.build_array_alloca(
ty,
self.ctx.i32_type().const_int(elements.len() as u64, false),
"tmparr",
);
let arr_ty = self.ctx.struct_type(
&[self.ctx.i32_type().into(), ty.array_type(elements.len() as u32).into()],
&[
self.ctx.i32_type().into(),
ty.ptr_type(AddressSpace::Generic).into(),
],
false,
);
let arr_ptr = self.builder.build_alloca(arr_ty, "tmparr");
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
unsafe {
let len_ptr = arr_ptr
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(0u64, false)]);
self.builder.build_store(
len_ptr,
arr_str_ptr.const_in_bounds_gep(&[zero, zero]),
self.ctx.i32_type().const_int(elements.len() as u64, false),
);
self.builder.build_store(
arr_str_ptr
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(1, false)]),
arr_ptr,
);
let arr_offset = self.ctx.i32_type().const_int(1, false);
for (i, v) in elements.iter().enumerate() {
let ptr = self.builder.build_in_bounds_gep(
@ -175,7 +285,7 @@ impl<'ctx> CodeGenContext<'ctx> {
self.builder.build_store(ptr, *v);
}
}
arr_ptr.into()
arr_str_ptr.into()
}
ExprKind::Tuple { elts, .. } => {
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
@ -266,9 +376,9 @@ impl<'ctx> CodeGenContext<'ctx> {
// when doing code generation for function instances
if ty1 != ty2 {
unimplemented!()
} else if [primitives.int32, primitives.int64].contains(&ty1) {
} else if [self.primitives.int32, self.primitives.int64].contains(&ty1) {
self.gen_int_ops(op, left, right)
} else if primitives.float == ty1 {
} else if self.primitives.float == ty1 {
self.gen_float_ops(op, left, right)
} else {
unimplemented!()
@ -277,7 +387,7 @@ impl<'ctx> CodeGenContext<'ctx> {
ExprKind::UnaryOp { op, operand } => {
let ty = self.unifier.get_representative(operand.custom.unwrap());
let val = self.gen_expr(operand);
if ty == primitives.bool {
if ty == self.primitives.bool {
let val =
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
match op {
@ -286,7 +396,7 @@ impl<'ctx> CodeGenContext<'ctx> {
}
_ => val.into(),
}
} else if [primitives.int32, primitives.int64].contains(&ty) {
} else if [self.primitives.int32, self.primitives.int64].contains(&ty) {
let val =
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
match op {
@ -303,7 +413,7 @@ impl<'ctx> CodeGenContext<'ctx> {
.into(),
_ => val.into(),
}
} else if ty == primitives.float {
} else if ty == self.primitives.float {
let val = if let BasicValueEnum::FloatValue(val) = val {
val
} else {
@ -334,7 +444,7 @@ impl<'ctx> CodeGenContext<'ctx> {
)
.fold(None, |prev, (lhs, rhs, op)| {
let ty = lhs.custom.unwrap();
let current = if [primitives.int32, primitives.int64, primitives.bool]
let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
.contains(&ty)
{
let (lhs, rhs) =
@ -355,7 +465,7 @@ impl<'ctx> CodeGenContext<'ctx> {
_ => unreachable!(),
};
self.builder.build_int_compare(op, lhs, rhs, "cmp")
} else if ty == primitives.float {
} else if ty == self.primitives.float {
let (lhs, rhs) = if let (
BasicValueEnum::FloatValue(lhs),
BasicValueEnum::FloatValue(rhs),

View File

@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc};
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{SharedUnifier, Type, Unifier};
use crate::symbol_resolver::SymbolResolver;
use inkwell::{builder::Builder, context::Context, module::Module, values::PointerValue};
use inkwell::{builder::Builder, context::Context, module::Module, types::BasicTypeEnum, values::PointerValue};
use parking_lot::RwLock;
use rustpython_parser::ast::Stmt;
@ -16,14 +16,17 @@ pub enum TopLevelDef {
object_id: DefinitionId,
// type variables bounded to the class.
type_vars: Vec<Type>,
// class fields and method signature.
// class fields
fields: Vec<(String, Type)>,
// class methods, pointing to the corresponding function definition.
methods: Vec<(String, DefinitionId)>,
methods: Vec<(String, Type, DefinitionId)>,
// ancestor classes, including itself.
ancestors: Vec<DefinitionId>,
},
Function {
// prefix for symbol, should be unique globally, and not ending with numbers
name: String,
// function signature.
signature: Type,
/// Function instance to symbol mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending
@ -48,7 +51,6 @@ pub struct CodeGenTask {
}
pub struct TopLevelContext {
pub primitives: PrimitiveStore,
pub definitions: Arc<RwLock<Vec<RwLock<TopLevelDef>>>>,
pub unifiers: Arc<RwLock<Vec<SharedUnifier>>>,
}
@ -61,4 +63,6 @@ pub struct CodeGenContext<'ctx> {
pub unifier: Unifier,
pub resolver: Box<dyn SymbolResolver>,
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore,
}

View File

@ -30,6 +30,8 @@ pub struct Call {
pub struct FuncArg {
pub name: String,
pub ty: Type,
// TODO: change this to an optional value
// for primitive types
pub is_optional: bool,
}

View File

@ -1,6 +1,6 @@
use std::rc::Rc;
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct UnificationKey(usize);
pub struct UnificationTable<V> {