hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
4 changed files with 101 additions and 6 deletions
Showing only changes of commit d8c713ce3d - Show all commits

View File

@ -33,7 +33,7 @@ impl<'ctx> CodeGenContext<'ctx> {
.join(", ") .join(", ")
} }
fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize { pub fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize {
let obj_id = match &*self.unifier.get_ty(ty) { let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id, TypeEnum::TObj { obj_id, .. } => *obj_id,
// we cannot have other types, virtual type should be handled by function calls // we cannot have other types, virtual type should be handled by function calls
@ -48,7 +48,7 @@ impl<'ctx> CodeGenContext<'ctx> {
index index
} }
fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> { pub fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> {
use TypeEnum::*; use TypeEnum::*;
// we assume the type cache should already contain primitive types, // we assume the type cache should already contain primitive types,
// and they should be passed by value instead of passing as pointer. // and they should be passed by value instead of passing as pointer.
@ -275,7 +275,15 @@ impl<'ctx> CodeGenContext<'ctx> {
} }
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
let ptr = self.var_assignment.get(id).unwrap(); let ptr = self.var_assignment.get(id).unwrap();
self.builder.build_load(*ptr, "load") let primitives = &self.primitives;
// we should only dereference primitive types
if [primitives.int32, primitives.int64, primitives.float, primitives.bool]
.contains(&self.unifier.get_representative(expr.custom.unwrap()))
{
self.builder.build_load(*ptr, "load")
} else {
(*ptr).into()
}
} }
ExprKind::List { elts, .. } => { ExprKind::List { elts, .. } => {
// this shall be optimized later for constant primitive lists... // this shall be optimized later for constant primitive lists...
@ -472,7 +480,7 @@ impl<'ctx> CodeGenContext<'ctx> {
ops.iter(), ops.iter(),
) )
.fold(None, |prev, (lhs, rhs, op)| { .fold(None, |prev, (lhs, rhs, op)| {
let ty = lhs.custom.unwrap(); let ty = self.unifier.get_representative(lhs.custom.unwrap());
let current = let current =
if [self.primitives.int32, self.primitives.int64, self.primitives.bool] if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
.contains(&ty) .contains(&ty)

View File

@ -1 +0,0 @@

View File

@ -1,2 +1,2 @@
mod expr; mod expr;
mod helper; mod stmt;

View File

@ -0,0 +1,88 @@
use std::convert::TryInto;
use crate::{top_level::CodeGenContext, typecheck::typedef::Type};
use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, PointerValue},
};
use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind};
impl<'ctx> CodeGenContext<'ctx> {
fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> {
let ty = self.get_llvm_type(ty);
let ty = if let BasicTypeEnum::PointerType(ty) = ty {
ty.get_element_type().try_into().unwrap()
} else {
ty
};
self.builder.build_alloca(ty, "tmp")
}
fn parse_pattern(&mut self, pattern: &Expr<Option<Type>>) -> PointerValue<'ctx> {
// very similar to gen_expr, but we don't do an extra load at the end
// and we flatten nested tuples
match &pattern.node {
ExprKind::Name { id, .. } => {
self.var_assignment.get(id).cloned().unwrap_or_else(|| {
let ptr = self.gen_var(pattern.custom.unwrap());
self.var_assignment.insert(id.clone(), ptr);
ptr
})
}
ExprKind::Attribute { value, attr, .. } => {
let index = self.get_attr_index(value.custom.unwrap(), attr);
let val = self.gen_expr(value);
let ptr = if let BasicValueEnum::PointerValue(v) = val {
v
} else {
unreachable!();
};
unsafe {
ptr.const_in_bounds_gep(&[
self.ctx.i32_type().const_zero(),
self.ctx.i32_type().const_int(index as u64, false),
])
}
}
ExprKind::Subscript { .. } => unimplemented!(),
_ => unreachable!(),
}
}
fn gen_assignment(&mut self, target: &Expr<Option<Type>>, value: BasicValueEnum<'ctx>) {
if let ExprKind::Tuple { elts, .. } = &target.node {
if let BasicValueEnum::PointerValue(ptr) = value {
for (i, elt) in elts.iter().enumerate() {
unsafe {
let t = ptr.const_in_bounds_gep(&[
self.ctx.i32_type().const_zero(),
self.ctx.i32_type().const_int(i as u64, false),
]);
let v = self.builder.build_load(t, "tmpload");
self.gen_assignment(elt, v);
}
}
} else {
unreachable!()
}
} else {
let ptr = self.parse_pattern(target);
self.builder.build_store(ptr, value);
}
}
pub fn gen_stmt(&mut self, stmt: &Stmt<Option<Type>>) {
match &stmt.node {
StmtKind::Expr { value } => {
self.gen_expr(&value);
}
StmtKind::Assign { targets, value, .. } => {
let value = self.gen_expr(&value);
for target in targets.iter() {
self.gen_assignment(target, value);
}
}
_ => unimplemented!(),
}
}
}