Merge pull request 'hm-inference' (#6) from hm-inference into master

Reviewed-on: M-Labs/nac3#6
This commit is contained in:
sb10q 2021-08-19 11:46:50 +08:00
commit f205a8282a
29 changed files with 5928 additions and 2948 deletions

668
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -15,20 +15,20 @@ caller to specify which methods should be compiled). After type checking, the
compiler would analyse the set of functions/classes that are used and perform
code generation.
Symbol resolver:
- Str -> Nac3Type
- Str -> Value
value could be integer values, boolean values, bytes (for memcpy), function ID
(full name + concrete type)
## Current Plan
1. Write out the syntax-directed type checking/inferencing rules. Fix the rule
for type variable instantiation.
2. Update the library dependencies and rewrite some of the type checking code.
3. Design the symbol resolver API.
4. Move tests from code to external files to cleanup the code.
Type checking:
- [x] Basic interface for symbol resolver.
- [x] Track location information in context object (for diagnostics).
- [ ] Refactor old expression and statement type inference code. (anto)
- [ ] Error diagnostics utilities. (pca)
- [ ] Move tests to external files, write scripts for testing. (pca)
- [ ] Implement function type checking (instantiate bounded type parameters),
loop unrolling, type inference for lists with virtual objects. (pca)

View File

@ -7,5 +7,13 @@ edition = "2018"
[dependencies]
num-bigint = "0.3"
num-traits = "0.2"
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] }
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" }
itertools = "0.10.1"
crossbeam = "0.8.1"
parking_lot = "0.11.1"
rayon = "1.5.1"
[dev-dependencies]
test-case = "1.2.0"
indoc = "1.0"

1
nac3core/rustfmt.toml Normal file
View File

@ -0,0 +1 @@
use_small_heuristics = "Max"

View File

@ -0,0 +1,527 @@
use std::{collections::HashMap, convert::TryInto, iter::once};
use super::{get_llvm_type, CodeGenContext};
use crate::{
symbol_resolver::SymbolValue,
top_level::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
use inkwell::{
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
};
use itertools::{chain, izip, zip, Itertools};
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
let mut vars = obj
.map(|ty| {
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
params.borrow().clone()
} else {
unreachable!()
}
})
.unwrap_or_default();
vars.extend(fun.vars.iter());
let sorted = vars.keys().sorted();
sorted
.map(|id| {
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
})
.join(", ")
}
pub fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize {
let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id,
// we cannot have other types, virtual type should be handled by function calls
_ => unreachable!(),
};
let def = &self.top_level.definitions.read()[obj_id.0];
let index = if let TopLevelDef::Class { fields, .. } = &*def.read() {
fields.iter().find_position(|x| x.0 == attr).unwrap().0
} else {
unreachable!()
};
index
}
fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> {
match val {
SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(),
SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(),
SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(),
SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(),
SymbolValue::Tuple(ls) => {
let vals = ls.iter().map(|v| self.gen_symbol_val(v)).collect_vec();
let fields = vals.iter().map(|v| v.get_type()).collect_vec();
let ty = self.ctx.struct_type(&fields, false);
let ptr = self.builder.build_alloca(ty, "tuple");
let zero = self.ctx.i32_type().const_zero();
unsafe {
for (i, val) in vals.into_iter().enumerate() {
let p = ptr.const_in_bounds_gep(&[
zero,
self.ctx.i32_type().const_int(i as u64, false),
]);
self.builder.build_store(p, val);
}
}
ptr.into()
}
}
}
pub fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> {
get_llvm_type(self.ctx, &mut self.unifier, self.top_level, &mut self.type_cache, ty)
}
fn gen_call(
&mut self,
obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
ret: Type,
) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
let defs = self.top_level.definitions.read();
let definition = defs.get(fun.1 .0).unwrap();
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
// TODO: codegen for function that are not yet generated
unimplemented!()
});
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) {
self.ctx.void_type().fn_type(&params, false)
} else {
self.get_llvm_type(ret).fn_type(&params, false)
};
self.module.add_function(symbol, fun_ty, None)
});
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in params.into_iter() {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
for k in keys.into_iter() {
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
}
// reorder the parameters
let params =
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left()
} else {
unreachable!()
};
val
}
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
match value {
Constant::Bool(v) => {
assert!(self.unifier.unioned(ty, self.primitives.bool));
let ty = self.ctx.bool_type();
ty.const_int(if *v { 1 } else { 0 }, false).into()
}
Constant::Int(v) => {
let ty = if self.unifier.unioned(ty, self.primitives.int32) {
self.ctx.i32_type()
} else if self.unifier.unioned(ty, self.primitives.int64) {
self.ctx.i64_type()
} else {
unreachable!();
};
ty.const_int(v.try_into().unwrap(), false).into()
}
Constant::Float(v) => {
assert!(self.unifier.unioned(ty, self.primitives.float));
let ty = self.ctx.f64_type();
ty.const_float(*v).into()
}
Constant::Tuple(v) => {
let ty = self.unifier.get_ty(ty);
let types =
if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() };
let values = zip(types.into_iter(), v.iter())
.map(|(ty, v)| self.gen_const(v, ty))
.collect_vec();
let types = values.iter().map(BasicValueEnum::get_type).collect_vec();
let ty = self.ctx.struct_type(&types, false);
ty.const_named_struct(&values).into()
}
_ => unreachable!(),
}
}
fn gen_int_ops(
&mut self,
op: &Operator,
lhs: BasicValueEnum<'ctx>,
rhs: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
let (lhs, rhs) =
if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) {
(lhs, rhs)
} else {
unreachable!()
};
match op {
Operator::Add => self.builder.build_int_add(lhs, rhs, "add").into(),
Operator::Sub => self.builder.build_int_sub(lhs, rhs, "sub").into(),
Operator::Mult => self.builder.build_int_mul(lhs, rhs, "mul").into(),
Operator::Div => {
let float = self.ctx.f64_type();
let left = self.builder.build_signed_int_to_float(lhs, float, "i2f");
let right = self.builder.build_signed_int_to_float(rhs, float, "i2f");
self.builder.build_float_div(left, right, "fdiv").into()
}
Operator::Mod => self.builder.build_int_signed_rem(lhs, rhs, "mod").into(),
Operator::BitOr => self.builder.build_or(lhs, rhs, "or").into(),
Operator::BitXor => self.builder.build_xor(lhs, rhs, "xor").into(),
Operator::BitAnd => self.builder.build_and(lhs, rhs, "and").into(),
Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(),
Operator::RShift => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(),
Operator::FloorDiv => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(),
// special implementation?
Operator::Pow => unimplemented!(),
Operator::MatMult => unreachable!(),
}
}
fn gen_float_ops(
&mut self,
op: &Operator,
lhs: BasicValueEnum<'ctx>,
rhs: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
let (lhs, rhs) = if let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) =
(lhs, rhs)
{
(lhs, rhs)
} else {
unreachable!()
};
match op {
Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").into(),
Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").into(),
Operator::Mult => self.builder.build_float_mul(lhs, rhs, "fmul").into(),
Operator::Div => self.builder.build_float_div(lhs, rhs, "fdiv").into(),
Operator::Mod => self.builder.build_float_rem(lhs, rhs, "fmod").into(),
Operator::FloorDiv => {
let div = self.builder.build_float_div(lhs, rhs, "fdiv");
let floor_intrinsic =
self.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let float = self.ctx.f64_type();
let fn_type = float.fn_type(&[float.into()], false);
self.module.add_function("llvm.floor.f64", fn_type, None)
});
self.builder
.build_call(floor_intrinsic, &[div.into()], "floor")
.try_as_basic_value()
.left()
.unwrap()
}
// special implementation?
_ => unimplemented!(),
}
}
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> {
let zero = self.ctx.i32_type().const_int(0, false);
match &expr.node {
ExprKind::Constant { value, .. } => {
let ty = expr.custom.unwrap();
self.gen_const(value, ty)
}
ExprKind::Name { id, .. } => {
let ptr = self.var_assignment.get(id).unwrap();
let primitives = &self.primitives;
// we should only dereference primitive types
if [primitives.int32, primitives.int64, primitives.float, primitives.bool]
.contains(&self.unifier.get_representative(expr.custom.unwrap()))
{
self.builder.build_load(*ptr, "load")
} else {
(*ptr).into()
}
}
ExprKind::List { elts, .. } => {
// this shall be optimized later for constant primitive lists...
// we should use memcpy for that instead of generating thousands of stores
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
let ty = if elements.is_empty() {
self.ctx.i32_type().into()
} else {
elements[0].get_type()
};
let arr_ptr = self.builder.build_array_alloca(
ty,
self.ctx.i32_type().const_int(elements.len() as u64, false),
"tmparr",
);
let arr_ty = self.ctx.struct_type(
&[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()],
false,
);
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
unsafe {
self.builder.build_store(
arr_str_ptr.const_in_bounds_gep(&[zero, zero]),
self.ctx.i32_type().const_int(elements.len() as u64, false),
);
self.builder.build_store(
arr_str_ptr
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(1, false)]),
arr_ptr,
);
let arr_offset = self.ctx.i32_type().const_int(1, false);
for (i, v) in elements.iter().enumerate() {
let ptr = self.builder.build_in_bounds_gep(
arr_ptr,
&[zero, arr_offset, self.ctx.i32_type().const_int(i as u64, false)],
"arr_element",
);
self.builder.build_store(ptr, *v);
}
}
arr_str_ptr.into()
}
ExprKind::Tuple { elts, .. } => {
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec();
let tuple_ty = self.ctx.struct_type(&element_ty, false);
let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple");
for (i, v) in element_val.into_iter().enumerate() {
unsafe {
let ptr = tuple_ptr.const_in_bounds_gep(&[
zero,
self.ctx.i32_type().const_int(i as u64, false),
]);
self.builder.build_store(ptr, v);
}
}
tuple_ptr.into()
}
ExprKind::Attribute { value, attr, .. } => {
// note that we would handle class methods directly in calls
let index = self.get_attr_index(value.custom.unwrap(), attr);
let val = self.gen_expr(value);
let ptr = if let BasicValueEnum::PointerValue(v) = val {
v
} else {
unreachable!();
};
unsafe {
let ptr = ptr.const_in_bounds_gep(&[
zero,
self.ctx.i32_type().const_int(index as u64, false),
]);
self.builder.build_load(ptr, "field")
}
}
ExprKind::BoolOp { op, values } => {
// requires conditional branches for short-circuiting...
let left = if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]) {
left
} else {
unreachable!()
};
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let a_bb = self.ctx.append_basic_block(current, "a");
let b_bb = self.ctx.append_basic_block(current, "b");
let cont_bb = self.ctx.append_basic_block(current, "cont");
self.builder.build_conditional_branch(left, a_bb, b_bb);
let (a, b) = match op {
Boolop::Or => {
self.builder.position_at_end(a_bb);
let a = self.ctx.bool_type().const_int(1, false);
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(b_bb);
let b = if let BasicValueEnum::IntValue(b) = self.gen_expr(&values[1]) {
b
} else {
unreachable!()
};
self.builder.build_unconditional_branch(cont_bb);
(a, b)
}
Boolop::And => {
self.builder.position_at_end(a_bb);
let a = if let BasicValueEnum::IntValue(a) = self.gen_expr(&values[1]) {
a
} else {
unreachable!()
};
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(b_bb);
let b = self.ctx.bool_type().const_int(0, false);
self.builder.build_unconditional_branch(cont_bb);
(a, b)
}
};
self.builder.position_at_end(cont_bb);
let phi = self.builder.build_phi(self.ctx.bool_type(), "phi");
phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]);
phi.as_basic_value()
}
ExprKind::BinOp { op, left, right } => {
let ty1 = self.unifier.get_representative(left.custom.unwrap());
let ty2 = self.unifier.get_representative(right.custom.unwrap());
let left = self.gen_expr(left);
let right = self.gen_expr(right);
// we can directly compare the types, because we've got their representatives
// which would be unchanged until further unification, which we would never do
// when doing code generation for function instances
if ty1 == ty2 && [self.primitives.int32, self.primitives.int64].contains(&ty1) {
self.gen_int_ops(op, left, right)
} else if ty1 == ty2 && self.primitives.float == ty1 {
self.gen_float_ops(op, left, right)
} else {
unimplemented!()
}
}
ExprKind::UnaryOp { op, operand } => {
let ty = self.unifier.get_representative(operand.custom.unwrap());
let val = self.gen_expr(operand);
if ty == self.primitives.bool {
let val =
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
match op {
ast::Unaryop::Invert | ast::Unaryop::Not => {
self.builder.build_not(val, "not").into()
}
_ => val.into(),
}
} else if [self.primitives.int32, self.primitives.int64].contains(&ty) {
let val =
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
match op {
ast::Unaryop::USub => self.builder.build_int_neg(val, "neg").into(),
ast::Unaryop::Invert => self.builder.build_not(val, "not").into(),
ast::Unaryop::Not => self
.builder
.build_int_compare(
inkwell::IntPredicate::EQ,
val,
val.get_type().const_zero(),
"not",
)
.into(),
_ => val.into(),
}
} else if ty == self.primitives.float {
let val = if let BasicValueEnum::FloatValue(val) = val {
val
} else {
unreachable!()
};
match op {
ast::Unaryop::USub => self.builder.build_float_neg(val, "neg").into(),
ast::Unaryop::Not => self
.builder
.build_float_compare(
inkwell::FloatPredicate::OEQ,
val,
val.get_type().const_zero(),
"not",
)
.into(),
_ => val.into(),
}
} else {
unimplemented!()
}
}
ExprKind::Compare { left, ops, comparators } => {
izip!(
chain(once(left.as_ref()), comparators.iter()),
comparators.iter(),
ops.iter(),
)
.fold(None, |prev, (lhs, rhs, op)| {
let ty = self.unifier.get_representative(lhs.custom.unwrap());
let current =
if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
.contains(&ty)
{
let (lhs, rhs) = if let (
BasicValueEnum::IntValue(lhs),
BasicValueEnum::IntValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
{
(lhs, rhs)
} else {
unreachable!()
};
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
_ => unreachable!(),
};
self.builder.build_int_compare(op, lhs, rhs, "cmp")
} else if ty == self.primitives.float {
let (lhs, rhs) = if let (
BasicValueEnum::FloatValue(lhs),
BasicValueEnum::FloatValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
{
(lhs, rhs)
} else {
unreachable!()
};
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(),
};
self.builder.build_float_compare(op, lhs, rhs, "cmp")
} else {
unimplemented!()
};
prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current))
})
.unwrap()
.into() // as there should be at least 1 element, it should never be none
}
ExprKind::IfExp { test, body, orelse } => {
let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test) {
test
} else {
unreachable!()
};
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_bb = self.ctx.append_basic_block(current, "then");
let else_bb = self.ctx.append_basic_block(current, "else");
let cont_bb = self.ctx.append_basic_block(current, "cont");
self.builder.build_conditional_branch(test, then_bb, else_bb);
self.builder.position_at_end(then_bb);
let a = self.gen_expr(body);
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(else_bb);
let b = self.gen_expr(orelse);
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
let phi = self.builder.build_phi(a.get_type(), "ifexpr");
phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]);
phi.as_basic_value()
}
_ => unimplemented!(),
}
}
}

343
nac3core/src/codegen/mod.rs Normal file
View File

@ -0,0 +1,343 @@
use crate::{
symbol_resolver::SymbolResolver,
top_level::{TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{FunSignature, Type, TypeEnum, Unifier},
},
};
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{
basic_block::BasicBlock,
builder::Builder,
context::Context,
module::Module,
types::{BasicType, BasicTypeEnum},
values::PointerValue,
AddressSpace,
};
use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use rustpython_parser::ast::Stmt;
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::thread;
mod expr;
mod stmt;
#[cfg(test)]
mod test;
pub struct CodeGenContext<'ctx, 'a> {
pub ctx: &'ctx Context,
pub builder: Builder<'ctx>,
pub module: Module<'ctx>,
pub top_level: &'a TopLevelContext,
pub unifier: Unifier,
pub resolver: Arc<dyn SymbolResolver>,
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore,
// stores the alloca for variables
pub init_bb: BasicBlock<'ctx>,
// where continue and break should go to respectively
// the first one is the test_bb, and the second one is bb after the loop
pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
}
type Fp = Box<dyn Fn(&Module) + Send + Sync>;
pub struct WithCall {
fp: Fp,
}
impl WithCall {
pub fn new(fp: Fp) -> WithCall {
WithCall { fp }
}
pub fn run<'ctx>(&self, m: &Module<'ctx>) {
(self.fp)(m)
}
}
pub struct WorkerRegistry {
sender: Arc<Sender<Option<CodeGenTask>>>,
receiver: Arc<Receiver<Option<CodeGenTask>>>,
panicked: AtomicBool,
task_count: Mutex<usize>,
thread_count: usize,
wait_condvar: Condvar,
}
impl WorkerRegistry {
pub fn create_workers(
names: &[&str],
top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>,
) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
let (sender, receiver) = unbounded();
let task_count = Mutex::new(0);
let wait_condvar = Condvar::new();
let registry = Arc::new(WorkerRegistry {
sender: Arc::new(sender),
receiver: Arc::new(receiver),
thread_count: names.len(),
panicked: AtomicBool::new(false),
task_count,
wait_condvar,
});
let mut handles = Vec::new();
for name in names.iter() {
let top_level_ctx = top_level_ctx.clone();
let registry = registry.clone();
let registry2 = registry.clone();
let name = name.to_string();
let f = f.clone();
let handle = thread::spawn(move || {
registry.worker_thread(name, top_level_ctx, f);
});
let handle = thread::spawn(move || {
if let Err(e) = handle.join() {
if let Some(e) = e.downcast_ref::<&'static str>() {
eprintln!("Got an error: {}", e);
} else {
eprintln!("Got an unknown error: {:?}", e);
}
registry2.panicked.store(true, Ordering::SeqCst);
registry2.wait_condvar.notify_all();
}
});
handles.push(handle);
}
(registry, handles)
}
pub fn wait_tasks_complete(&self, handles: Vec<thread::JoinHandle<()>>) {
{
let mut count = self.task_count.lock();
while *count != 0 {
if self.panicked.load(Ordering::SeqCst) {
break;
}
self.wait_condvar.wait(&mut count);
}
}
for _ in 0..self.thread_count {
self.sender.send(None).unwrap();
}
{
let mut count = self.task_count.lock();
while *count != self.thread_count {
if self.panicked.load(Ordering::SeqCst) {
break;
}
self.wait_condvar.wait(&mut count);
}
}
for handle in handles {
handle.join().unwrap();
}
if self.panicked.load(Ordering::SeqCst) {
panic!("tasks panicked");
}
}
pub fn add_task(&self, task: CodeGenTask) {
*self.task_count.lock() += 1;
self.sender.send(Some(task)).unwrap();
}
fn worker_thread(
&self,
module_name: String,
top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>,
) {
let context = Context::create();
let mut builder = context.create_builder();
let mut module = context.create_module(&module_name);
while let Some(task) = self.receiver.recv().unwrap() {
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
builder = result.0;
module = result.1;
*self.task_count.lock() -= 1;
self.wait_condvar.notify_all();
}
// do whatever...
let mut lock = self.task_count.lock();
module.verify().unwrap();
f.run(&module);
*lock += 1;
self.wait_condvar.notify_all();
}
}
pub struct CodeGenTask {
pub subst: Vec<(Type, Type)>,
pub symbol_name: String,
pub signature: FunSignature,
pub body: Vec<Stmt<Option<Type>>>,
pub unifier_index: usize,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
}
fn get_llvm_type<'ctx>(
ctx: &'ctx Context,
unifier: &mut Unifier,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
ty: Type,
) -> BasicTypeEnum<'ctx> {
use TypeEnum::*;
// we assume the type cache should already contain primitive types,
// and they should be passed by value instead of passing as pointer.
type_cache.get(&unifier.get_representative(ty)).cloned().unwrap_or_else(|| {
match &*unifier.get_ty(ty) {
TObj { obj_id, fields, .. } => {
// a struct with fields in the order of declaration
let defs = top_level.definitions.read();
let definition = defs.get(obj_id.0).unwrap();
let ty = if let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read()
{
let fields = fields.borrow();
let fields = fields_list
.iter()
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0]))
.collect_vec();
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
} else {
unreachable!()
};
ty
}
TTuple { ty } => {
// a struct with fields in the order present in the tuple
let fields = ty
.iter()
.map(|ty| get_llvm_type(ctx, unifier, top_level, type_cache, *ty))
.collect_vec();
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
}
TList { ty } => {
// a struct with an integer and a pointer to an array
let element_type = get_llvm_type(ctx, unifier, top_level, type_cache, *ty);
let fields =
[ctx.i32_type().into(), element_type.ptr_type(AddressSpace::Generic).into()];
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
}
TVirtual { .. } => unimplemented!(),
_ => unreachable!(),
}
})
}
pub fn gen_func<'ctx>(
context: &'ctx Context,
builder: Builder<'ctx>,
module: Module<'ctx>,
task: CodeGenTask,
top_level_ctx: Arc<TopLevelContext>,
) -> (Builder<'ctx>, Module<'ctx>) {
// unwrap_or(0) is for unit tests without using rayon
let (mut unifier, primitives) = {
let unifiers = top_level_ctx.unifiers.read();
let (unifier, primitives) = &unifiers[task.unifier_index];
(Unifier::from_shared_unifier(unifier), *primitives)
};
for (a, b) in task.subst.iter() {
// this should be unification between variables and concrete types
// and should not cause any problem...
unifier.unify(*a, *b).unwrap();
}
// rebuild primitive store with unique representatives
let primitives = PrimitiveStore {
int32: unifier.get_representative(primitives.int32),
int64: unifier.get_representative(primitives.int64),
float: unifier.get_representative(primitives.float),
bool: unifier.get_representative(primitives.bool),
none: unifier.get_representative(primitives.none),
};
let mut type_cache: HashMap<_, _> = [
(unifier.get_representative(primitives.int32), context.i32_type().into()),
(unifier.get_representative(primitives.int64), context.i64_type().into()),
(unifier.get_representative(primitives.float), context.f64_type().into()),
(unifier.get_representative(primitives.bool), context.bool_type().into()),
]
.iter()
.cloned()
.collect();
let params = task
.signature
.args
.iter()
.map(|arg| {
get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty)
})
.collect_vec();
let fn_type = if unifier.unioned(task.signature.ret, primitives.none) {
context.void_type().fn_type(&params, false)
} else {
get_llvm_type(
&context,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
task.signature.ret,
)
.fn_type(&params, false)
};
let fn_val = module.add_function(&task.symbol_name, fn_type, None);
let init_bb = context.append_basic_block(fn_val, "init");
builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body");
let mut var_assignment = HashMap::new();
for (n, arg) in task.signature.args.iter().enumerate() {
let param = fn_val.get_nth_param(n as u32).unwrap();
let alloca = builder.build_alloca(
get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty),
&arg.name,
);
builder.build_store(alloca, param);
var_assignment.insert(arg.name.clone(), alloca);
}
builder.build_unconditional_branch(body_bb);
builder.position_at_end(body_bb);
let mut code_gen_context = CodeGenContext {
ctx: &context,
resolver: task.resolver,
top_level: top_level_ctx.as_ref(),
loop_bb: None,
var_assignment,
type_cache,
primitives,
init_bb,
builder,
module,
unifier,
};
for stmt in task.body.iter() {
code_gen_context.gen_stmt(stmt);
}
let CodeGenContext { builder, module, .. } = code_gen_context;
(builder, module)
}

View File

@ -0,0 +1,138 @@
use super::CodeGenContext;
use crate::typecheck::typedef::Type;
use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind};
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> {
// put the alloca in init block
let current = self.builder.get_insert_block().unwrap();
// position before the last branching instruction...
self.builder.position_before(&self.init_bb.get_last_instruction().unwrap());
let ty = self.get_llvm_type(ty);
let ptr = self.builder.build_alloca(ty, "tmp");
self.builder.position_at_end(current);
ptr
}
fn parse_pattern(&mut self, pattern: &Expr<Option<Type>>) -> PointerValue<'ctx> {
// very similar to gen_expr, but we don't do an extra load at the end
// and we flatten nested tuples
match &pattern.node {
ExprKind::Name { id, .. } => {
self.var_assignment.get(id).cloned().unwrap_or_else(|| {
let ptr = self.gen_var(pattern.custom.unwrap());
self.var_assignment.insert(id.clone(), ptr);
ptr
})
}
ExprKind::Attribute { value, attr, .. } => {
let index = self.get_attr_index(value.custom.unwrap(), attr);
let val = self.gen_expr(value);
let ptr = if let BasicValueEnum::PointerValue(v) = val {
v
} else {
unreachable!();
};
unsafe {
ptr.const_in_bounds_gep(&[
self.ctx.i32_type().const_zero(),
self.ctx.i32_type().const_int(index as u64, false),
])
}
}
ExprKind::Subscript { .. } => unimplemented!(),
_ => unreachable!(),
}
}
fn gen_assignment(&mut self, target: &Expr<Option<Type>>, value: BasicValueEnum<'ctx>) {
if let ExprKind::Tuple { elts, .. } = &target.node {
if let BasicValueEnum::PointerValue(ptr) = value {
for (i, elt) in elts.iter().enumerate() {
unsafe {
let t = ptr.const_in_bounds_gep(&[
self.ctx.i32_type().const_zero(),
self.ctx.i32_type().const_int(i as u64, false),
]);
let v = self.builder.build_load(t, "tmpload");
self.gen_assignment(elt, v);
}
}
} else {
unreachable!()
}
} else {
let ptr = self.parse_pattern(target);
self.builder.build_store(ptr, value);
}
}
pub fn gen_stmt(&mut self, stmt: &Stmt<Option<Type>>) {
match &stmt.node {
StmtKind::Expr { value } => {
self.gen_expr(&value);
}
StmtKind::Return { value } => {
let value = value.as_ref().map(|v| self.gen_expr(&v));
let value = value.as_ref().map(|v| v as &dyn BasicValue);
self.builder.build_return(value);
}
StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value {
let value = self.gen_expr(&value);
self.gen_assignment(target, value);
}
}
StmtKind::Assign { targets, value, .. } => {
let value = self.gen_expr(&value);
for target in targets.iter() {
self.gen_assignment(target, value);
}
}
StmtKind::Continue => {
self.builder.build_unconditional_branch(self.loop_bb.unwrap().0);
}
StmtKind::Break => {
self.builder.build_unconditional_branch(self.loop_bb.unwrap().1);
}
StmtKind::While { test, body, orelse } => {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let test_bb = self.ctx.append_basic_block(current, "test");
let body_bb = self.ctx.append_basic_block(current, "body");
let cont_bb = self.ctx.append_basic_block(current, "cont");
// if there is no orelse, we just go to cont_bb
let orelse_bb = if orelse.is_empty() {
cont_bb
} else {
self.ctx.append_basic_block(current, "orelse")
};
// store loop bb information and restore it later
let loop_bb = self.loop_bb.replace((test_bb, cont_bb));
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb);
let test = self.gen_expr(test);
if let BasicValueEnum::IntValue(test) = test {
self.builder.build_conditional_branch(test, body_bb, orelse_bb);
} else {
unreachable!()
};
self.builder.position_at_end(body_bb);
for stmt in body.iter() {
self.gen_stmt(stmt);
}
self.builder.build_unconditional_branch(test_bb);
if !orelse.is_empty() {
self.builder.position_at_end(orelse_bb);
for stmt in orelse.iter() {
self.gen_stmt(stmt);
}
self.builder.build_unconditional_branch(cont_bb);
}
self.builder.position_at_end(cont_bb);
self.loop_bb = loop_bb;
}
_ => unimplemented!(),
}
}
}

View File

@ -0,0 +1,247 @@
use super::{CodeGenTask, WorkerRegistry};
use crate::{
codegen::WithCall,
location::Location,
symbol_resolver::{SymbolResolver, SymbolValue},
top_level::{DefinitionId, TopLevelContext},
typecheck::{
magic_methods::set_primitives_magic_methods,
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier},
},
};
use indoc::indoc;
use parking_lot::RwLock;
use rustpython_parser::{ast::fold::Fold, parser::parse_program};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
struct Resolver {
id_to_type: HashMap<String, Type>,
id_to_def: HashMap<String, DefinitionId>,
class_names: HashMap<String, Type>,
}
impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
self.id_to_type.get(str).cloned()
}
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
unimplemented!()
}
fn get_symbol_location(&self, _: &str) -> Option<Location> {
unimplemented!()
}
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
self.id_to_def.get(id).cloned()
}
}
struct TestEnvironment {
pub unifier: Unifier,
pub function_data: FunctionData,
pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext,
}
impl TestEnvironment {
pub fn basic_test_env() -> TestEnvironment {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let primitives = PrimitiveStore { int32, int64, float, bool, none };
set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name = [
(0, "int32".to_string()),
(1, "int64".to_string()),
(2, "float".to_string()),
(3, "bool".to_string()),
(4, "none".to_string()),
]
.iter()
.cloned()
.collect();
let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
unifier,
top_level: TopLevelContext {
definitions: Default::default(),
unifiers: Default::default(),
// conetexts: Default::default(),
},
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
},
primitives,
id_to_name,
identifier_mapping,
virtual_checks: Vec::new(),
calls: HashMap::new(),
}
}
fn get_inferencer(&mut self) -> Inferencer {
Inferencer {
top_level: &self.top_level,
function_data: &mut self.function_data,
unifier: &mut self.unifier,
variable_mapping: Default::default(),
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
}
}
}
#[test]
fn test_primitives() {
let mut env = TestEnvironment::basic_test_env();
let threads = ["test"];
let signature = FunSignature {
args: vec![
FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None },
FuncArg { name: "b".to_string(), ty: env.primitives.int32, default_value: None },
],
ret: env.primitives.int32,
vars: HashMap::new(),
};
let mut inferencer = env.get_inferencer();
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
let source = indoc! { "
c = a + b
d = a if c == 1 else 0
return d
"};
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let mut identifiers = vec!["a".to_string(), "b".to_string()];
inferencer.check_block(&statements, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Default::default(),
unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])),
// conetexts: Default::default(),
});
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "testing".to_string(),
body: statements,
unifier_index: 0,
resolver: env.function_data.resolver.clone(),
signature,
};
let f = Arc::new(WithCall::new(Box::new(|module| {
// the following IR is equivalent to
// ```
// ; ModuleID = 'test.ll'
// source_filename = "test"
//
// ; Function Attrs: norecurse nounwind readnone
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
// init:
// %add = add i32 %1, %0
// %cmp = icmp eq i32 %add, 1
// %ifexpr = select i1 %cmp, i32 %0, i32 0
// ret i32 %ifexpr
// }
//
// attributes #0 = { norecurse nounwind readnone }
// ```
// after O2 optimization
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
define i32 @testing(i32 %0, i32 %1) {
init:
%a = alloca i32, align 4
store i32 %0, i32* %a, align 4
%b = alloca i32, align 4
store i32 %1, i32* %b, align 4
%tmp = alloca i32, align 4
%tmp4 = alloca i32, align 4
br label %body
body: ; preds = %init
%load = load i32, i32* %a, align 4
%load1 = load i32, i32* %b, align 4
%add = add i32 %load, %load1
store i32 %add, i32* %tmp, align 4
%load2 = load i32, i32* %tmp, align 4
%cmp = icmp eq i32 %load2, 1
br i1 %cmp, label %then, label %else
then: ; preds = %body
%load3 = load i32, i32* %a, align 4
br label %cont
else: ; preds = %body
br label %cont
cont: ; preds = %else, %then
%ifexpr = phi i32 [ %load3, %then ], [ 0, %else ]
store i32 %ifexpr, i32* %tmp4, align 4
%load5 = load i32, i32* %tmp4, align 4
ret i32 %load5
}
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}

View File

@ -1,212 +0,0 @@
use super::TopLevelContext;
use crate::typedef::*;
use std::boxed::Box;
use std::collections::HashMap;
struct ContextStack<'a> {
/// stack level, starts from 0
level: u32,
/// stack of variable definitions containing (id, def, level) where `def` is the original
/// definition in `level-1`.
var_defs: Vec<(usize, VarDef<'a>, u32)>,
/// stack of symbol definitions containing (name, level) where `level` is the smallest level
/// where the name is assigned a value
sym_def: Vec<(&'a str, u32)>,
}
pub struct InferenceContext<'a> {
/// top level context
top_level: TopLevelContext<'a>,
/// list of primitive instances
primitives: Vec<Type>,
/// list of variable instances
variables: Vec<Type>,
/// identifier to (type, readable) mapping.
/// an identifier might be defined earlier but has no value (for some code path), thus not
/// readable.
sym_table: HashMap<&'a str, (Type, bool)>,
/// resolution function reference, that may resolve unbounded identifiers to some type
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
/// stack
stack: ContextStack<'a>,
}
// non-trivial implementations here
impl<'a> InferenceContext<'a> {
/// return a new `InferenceContext` from `TopLevelContext` and resolution function.
pub fn new(
top_level: TopLevelContext,
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
) -> InferenceContext {
let primitives = (0..top_level.primitive_defs.len())
.map(|v| TypeEnum::PrimitiveType(PrimitiveId(v)).into())
.collect();
let variables = (0..top_level.var_defs.len())
.map(|v| TypeEnum::TypeVariable(VariableId(v)).into())
.collect();
InferenceContext {
top_level,
primitives,
variables,
sym_table: HashMap::new(),
resolution_fn,
stack: ContextStack {
level: 0,
var_defs: Vec::new(),
sym_def: Vec::new(),
},
}
}
/// execute the function with new scope.
/// variable assignment would be limited within the scope (not readable outside), and type
/// variable type guard would be limited within the scope.
/// returns the list of variables assigned within the scope, and the result of the function
pub fn with_scope<F, R>(&mut self, f: F) -> (Vec<&'a str>, R)
where
F: FnOnce(&mut Self) -> R,
{
self.stack.level += 1;
let result = f(self);
self.stack.level -= 1;
while !self.stack.var_defs.is_empty() {
let (_, _, level) = self.stack.var_defs.last().unwrap();
if *level > self.stack.level {
let (id, def, _) = self.stack.var_defs.pop().unwrap();
self.top_level.var_defs[id] = def;
} else {
break;
}
}
let mut poped_names = Vec::new();
while !self.stack.sym_def.is_empty() {
let (_, level) = self.stack.sym_def.last().unwrap();
if *level > self.stack.level {
let (name, _) = self.stack.sym_def.pop().unwrap();
self.sym_table.remove(name).unwrap();
poped_names.push(name);
} else {
break;
}
}
(poped_names, result)
}
/// assign a type to an identifier.
/// may return error if the identifier was defined but with different type
pub fn assign(&mut self, name: &'a str, ty: Type) -> Result<Type, String> {
if let Some((t, x)) = self.sym_table.get_mut(name) {
if t == &ty {
if !*x {
self.stack.sym_def.push((name, self.stack.level));
}
*x = true;
Ok(ty)
} else {
Err("different types".into())
}
} else {
self.stack.sym_def.push((name, self.stack.level));
self.sym_table.insert(name, (ty.clone(), true));
Ok(ty)
}
}
/// check if an identifier is already defined
pub fn defined(&self, name: &str) -> bool {
self.sym_table.get(name).is_some()
}
/// get the type of an identifier
/// may return error if the identifier is not defined, and cannot be resolved with the
/// resolution function.
pub fn resolve(&mut self, name: &str) -> Result<Type, String> {
if let Some((t, x)) = self.sym_table.get(name) {
if *x {
Ok(t.clone())
} else {
Err("may not have value".into())
}
} else {
self.resolution_fn.as_mut()(name)
}
}
/// restrict the bound of a type variable by replacing its definition.
/// used for implementing type guard
pub fn restrict(&mut self, id: VariableId, mut def: VarDef<'a>) {
std::mem::swap(self.top_level.var_defs.get_mut(id.0).unwrap(), &mut def);
self.stack.var_defs.push((id.0, def, self.stack.level));
}
}
// trivial getters:
impl<'a> InferenceContext<'a> {
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
self.primitives.get(id.0).unwrap().clone()
}
pub fn get_variable(&self, id: VariableId) -> Type {
self.variables.get(id.0).unwrap().clone()
}
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
self.top_level.fn_table.get(name)
}
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
self.top_level.primitive_defs.get(id.0).unwrap()
}
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
self.top_level.class_defs.get(id.0).unwrap()
}
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
self.top_level.parametric_defs.get(id.0).unwrap()
}
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
self.top_level.var_defs.get(id.0).unwrap()
}
pub fn get_type(&self, name: &str) -> Option<Type> {
self.top_level.get_type(name)
}
}
impl TypeEnum {
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> TypeEnum {
match self {
TypeEnum::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(),
TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType(
*id,
params
.iter()
.map(|v| v.as_ref().subst(map).into())
.collect(),
),
_ => self.clone(),
}
}
pub fn get_subst(&self, ctx: &InferenceContext) -> HashMap<VariableId, Type> {
match self {
TypeEnum::ParametricType(id, params) => {
let vars = &ctx.get_parametric_def(*id).params;
vars.iter()
.zip(params)
.map(|(v, p)| (*v, p.as_ref().clone().into()))
.collect()
}
// if this proves to be slow, we can use option type
_ => HashMap::new(),
}
}
pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b InferenceContext) -> Option<&'b TypeDef> {
match self {
TypeEnum::PrimitiveType(id) => Some(ctx.get_primitive_def(*id)),
TypeEnum::ClassType(id) | TypeEnum::VirtualClassType(id) => {
Some(&ctx.get_class_def(*id).base)
}
TypeEnum::ParametricType(id, _) => Some(&ctx.get_parametric_def(*id).base),
_ => None,
}
}
}

View File

@ -1,4 +0,0 @@
mod inference_context;
mod top_level_context;
pub use inference_context::InferenceContext;
pub use top_level_context::TopLevelContext;

View File

@ -1,136 +0,0 @@
use crate::typedef::*;
use std::collections::HashMap;
use std::rc::Rc;
/// Structure for storing top-level type definitions.
/// Used for collecting type signature from source code.
/// Can be converted to `InferenceContext` for type inference in functions.
pub struct TopLevelContext<'a> {
/// List of primitive definitions.
pub(super) primitive_defs: Vec<TypeDef<'a>>,
/// List of class definitions.
pub(super) class_defs: Vec<ClassDef<'a>>,
/// List of parametric type definitions.
pub(super) parametric_defs: Vec<ParametricDef<'a>>,
/// List of type variable definitions.
pub(super) var_defs: Vec<VarDef<'a>>,
/// Function name to signature mapping.
pub(super) fn_table: HashMap<&'a str, FnDef>,
/// Type name to type mapping.
pub(super) sym_table: HashMap<&'a str, Type>,
primitives: Vec<Type>,
variables: Vec<Type>,
}
impl<'a> TopLevelContext<'a> {
pub fn new(primitive_defs: Vec<TypeDef<'a>>) -> TopLevelContext {
let mut sym_table = HashMap::new();
let mut primitives = Vec::new();
for (i, t) in primitive_defs.iter().enumerate() {
primitives.push(TypeEnum::PrimitiveType(PrimitiveId(i)).into());
sym_table.insert(t.name, TypeEnum::PrimitiveType(PrimitiveId(i)).into());
}
TopLevelContext {
primitive_defs,
class_defs: Vec::new(),
parametric_defs: Vec::new(),
var_defs: Vec::new(),
fn_table: HashMap::new(),
sym_table,
primitives,
variables: Vec::new(),
}
}
pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId {
self.sym_table.insert(
def.base.name,
TypeEnum::ClassType(ClassId(self.class_defs.len())).into(),
);
self.class_defs.push(def);
ClassId(self.class_defs.len() - 1)
}
pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId {
let params = def
.params
.iter()
.map(|&v| Rc::new(TypeEnum::TypeVariable(v)))
.collect();
self.sym_table.insert(
def.base.name,
TypeEnum::ParametricType(ParamId(self.parametric_defs.len()), params).into(),
);
self.parametric_defs.push(def);
ParamId(self.parametric_defs.len() - 1)
}
pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId {
self.sym_table.insert(
def.name,
TypeEnum::TypeVariable(VariableId(self.var_defs.len())).into(),
);
self.add_variable_private(def)
}
pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId {
self.var_defs.push(def);
self.variables
.push(TypeEnum::TypeVariable(VariableId(self.var_defs.len() - 1)).into());
VariableId(self.var_defs.len() - 1)
}
pub fn add_fn(&mut self, name: &'a str, def: FnDef) {
self.fn_table.insert(name, def);
}
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
self.fn_table.get(name)
}
pub fn get_primitive_def_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
self.primitive_defs.get_mut(id.0).unwrap()
}
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
self.primitive_defs.get(id.0).unwrap()
}
pub fn get_class_def_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> {
self.class_defs.get_mut(id.0).unwrap()
}
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
self.class_defs.get(id.0).unwrap()
}
pub fn get_parametric_def_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> {
self.parametric_defs.get_mut(id.0).unwrap()
}
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
self.parametric_defs.get(id.0).unwrap()
}
pub fn get_variable_def_mut(&mut self, id: VariableId) -> &mut VarDef<'a> {
self.var_defs.get_mut(id.0).unwrap()
}
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
self.var_defs.get(id.0).unwrap()
}
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
self.primitives.get(id.0).unwrap().clone()
}
pub fn get_variable(&self, id: VariableId) -> Type {
self.variables.get(id.0).unwrap().clone()
}
pub fn get_type(&self, name: &str) -> Option<Type> {
// TODO: handle parametric types
self.sym_table.get(name).cloned()
}
}

View File

@ -1,922 +0,0 @@
use crate::context::InferenceContext;
use crate::inference_core::resolve_call;
use crate::magic_methods::*;
use crate::primitives::*;
use crate::typedef::{Type, TypeEnum::*};
use rustpython_parser::ast::{
Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator,
UnaryOperator,
};
use std::convert::TryInto;
type ParserResult = Result<Option<Type>, String>;
pub fn infer_expr<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
expr: &'b Expression,
) -> ParserResult {
match &expr.node {
ExpressionType::Number { value } => infer_constant(ctx, value),
ExpressionType::Identifier { name } => infer_identifier(ctx, name),
ExpressionType::List { elements } => infer_list(ctx, elements),
ExpressionType::Tuple { elements } => infer_tuple(ctx, elements),
ExpressionType::Attribute { value, name } => infer_attribute(ctx, value, name),
ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, values),
ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, op, a, b),
ExpressionType::Unop { op, a } => infer_unary_ops(ctx, op, a),
ExpressionType::Compare { vals, ops } => infer_compare(ctx, vals, ops),
ExpressionType::Call {
args,
function,
keywords,
} => {
if !keywords.is_empty() {
Err("keyword is not supported".into())
} else {
infer_call(ctx, &args, &function)
}
}
ExpressionType::Subscript { a, b } => infer_subscript(ctx, a, b),
ExpressionType::IfExpression { test, body, orelse } => {
infer_if_expr(ctx, &test, &body, orelse)
}
ExpressionType::Comprehension { kind, generators } => match kind.as_ref() {
ComprehensionKind::List { element } => {
if generators.len() == 1 {
infer_list_comprehension(ctx, element, &generators[0])
} else {
Err("only 1 generator statement is supported".into())
}
}
_ => Err("only list comprehension is supported".into()),
},
ExpressionType::True | ExpressionType::False => Ok(Some(ctx.get_primitive(BOOL_TYPE))),
_ => Err("not supported".into()),
}
}
fn infer_constant(
ctx: &mut InferenceContext,
value: &rustpython_parser::ast::Number,
) -> ParserResult {
use rustpython_parser::ast::Number;
match value {
Number::Integer { value } => {
let int32: Result<i32, _> = value.try_into();
if int32.is_ok() {
Ok(Some(ctx.get_primitive(INT32_TYPE)))
} else {
Err("integer out of range".into())
}
}
Number::Float { .. } => Ok(Some(ctx.get_primitive(FLOAT_TYPE))),
_ => Err("not supported".into()),
}
}
fn infer_identifier(ctx: &mut InferenceContext, name: &str) -> ParserResult {
Ok(Some(ctx.resolve(name)?))
}
fn infer_list<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
elements: &'b [Expression],
) -> ParserResult {
if elements.is_empty() {
return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into()));
}
let mut types = elements.iter().map(|v| infer_expr(ctx, v));
let head = types.next().unwrap()?;
if head.is_none() {
return Err("list elements must have some type".into());
}
for v in types {
// TODO: try virtual type...
if v? != head {
return Err("inhomogeneous list is not allowed".into());
}
}
Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into()))
}
fn infer_tuple<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
elements: &'b [Expression],
) -> ParserResult {
let types: Result<Option<Vec<_>>, String> =
elements.iter().map(|v| infer_expr(ctx, v)).collect();
if let Some(t) = types? {
Ok(Some(ParametricType(TUPLE_TYPE, t).into()))
} else {
Err("tuple elements must have some type".into())
}
}
fn infer_attribute<'a>(
ctx: &mut InferenceContext<'a>,
value: &'a Expression,
name: &str,
) -> ParserResult {
let value = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?;
if let TypeVariable(_) = value.as_ref() {
return Err("no fields for type variable".into());
}
value
.get_base(ctx)
.and_then(|b| b.fields.get(name).cloned())
.map_or_else(|| Err("no such field".to_string()), |v| Ok(Some(v)))
}
fn infer_bool_ops<'a>(ctx: &mut InferenceContext<'a>, values: &'a [Expression]) -> ParserResult {
assert_eq!(values.len(), 2);
let left = infer_expr(ctx, &values[0])?.ok_or_else(|| "no value".to_string())?;
let right = infer_expr(ctx, &values[1])?.ok_or_else(|| "no value".to_string())?;
let b = ctx.get_primitive(BOOL_TYPE);
if left == b && right == b {
Ok(Some(b))
} else {
Err("bool operands must be bool".into())
}
}
fn infer_bin_ops<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
op: &Operator,
left: &'b Expression,
right: &'b Expression,
) -> ParserResult {
let left = infer_expr(ctx, left)?.ok_or_else(|| "no value".to_string())?;
let right = infer_expr(ctx, right)?.ok_or_else(|| "no value".to_string())?;
let fun = binop_name(op);
resolve_call(ctx, Some(left), fun, &[right])
}
fn infer_unary_ops<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
op: &UnaryOperator,
obj: &'b Expression,
) -> ParserResult {
let ty = infer_expr(ctx, obj)?.ok_or_else(|| "no value".to_string())?;
if let UnaryOperator::Not = op {
if ty == ctx.get_primitive(BOOL_TYPE) {
Ok(Some(ty))
} else {
Err("logical not must be applied to bool".into())
}
} else {
resolve_call(ctx, Some(ty), unaryop_name(op), &[])
}
}
fn infer_compare<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
vals: &'b [Expression],
ops: &'b [Comparison],
) -> ParserResult {
let types: Result<Option<Vec<_>>, _> = vals.iter().map(|v| infer_expr(ctx, v)).collect();
let types = types?;
if types.is_none() {
return Err("comparison operands must have type".into());
}
let types = types.unwrap();
let boolean = ctx.get_primitive(BOOL_TYPE);
let left = &types[..types.len() - 1];
let right = &types[1..];
for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) {
let fun = comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?;
let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?;
if ty.is_none() || ty.unwrap() != boolean {
return Err("comparison result must be boolean".into());
}
}
Ok(Some(boolean))
}
fn infer_call<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
args: &'b [Expression],
function: &'b Expression,
) -> ParserResult {
// TODO: special handling for int64 constant
let types: Result<Option<Vec<_>>, _> = args.iter().map(|v| infer_expr(ctx, v)).collect();
let types = types?;
if types.is_none() {
return Err("function params must have type".into());
}
let (obj, fun) = match &function.node {
ExpressionType::Identifier { name } => (None, name),
ExpressionType::Attribute { value, name } => (
Some(infer_expr(ctx, &value)?.ok_or_else(|| "no value".to_string())?),
name,
),
_ => return Err("not supported".into()),
};
resolve_call(ctx, obj, fun.as_str(), &types.unwrap())
}
fn infer_subscript<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
a: &'b Expression,
b: &'b Expression,
) -> ParserResult {
let a = infer_expr(ctx, a)?.ok_or_else(|| "no value".to_string())?;
let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() {
ls[0].clone()
} else {
return Err("subscript is not supported for types other than list".into());
};
match &b.node {
ExpressionType::Slice { elements } => {
let int32 = ctx.get_primitive(INT32_TYPE);
let types: Result<Option<Vec<_>>, _> = elements
.iter()
.map(|v| {
if let ExpressionType::None = v.node {
Ok(Some(int32.clone()))
} else {
infer_expr(ctx, v)
}
})
.collect();
let types = types?.ok_or_else(|| "slice must have type".to_string())?;
if types.iter().all(|v| v == &int32) {
Ok(Some(a))
} else {
Err("slice must be int32 type".into())
}
}
_ => {
let b = infer_expr(ctx, b)?.ok_or_else(|| "no value".to_string())?;
if b == ctx.get_primitive(INT32_TYPE) {
Ok(Some(t))
} else {
Err("index must be either slice or int32".into())
}
}
}
}
fn infer_if_expr<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
test: &'b Expression,
body: &'b Expression,
orelse: &'b Expression,
) -> ParserResult {
let test = infer_expr(ctx, test)?.ok_or_else(|| "no value".to_string())?;
if test != ctx.get_primitive(BOOL_TYPE) {
return Err("test should be bool".into());
}
let body = infer_expr(ctx, body)?;
let orelse = infer_expr(ctx, orelse)?;
if body.as_ref() == orelse.as_ref() {
Ok(body)
} else {
Err("divergent type".into())
}
}
fn infer_simple_binding<'a: 'b, 'b>(
ctx: &mut InferenceContext<'b>,
name: &'a Expression,
ty: Type,
) -> Result<(), String> {
match &name.node {
ExpressionType::Identifier { name } => {
if name == "_" {
Ok(())
} else if ctx.defined(name.as_str()) {
Err("duplicated naming".into())
} else {
ctx.assign(name.as_str(), ty)?;
Ok(())
}
}
ExpressionType::Tuple { elements } => {
if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() {
if elements.len() == ls.len() {
for (a, b) in elements.iter().zip(ls.iter()) {
infer_simple_binding(ctx, a, b.clone())?;
}
Ok(())
} else {
Err("different length".into())
}
} else {
Err("not supported".into())
}
}
_ => Err("not supported".into()),
}
}
fn infer_list_comprehension<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
element: &'b Expression,
comprehension: &'b Comprehension,
) -> ParserResult {
if comprehension.is_async {
return Err("async is not supported".into());
}
let iter = infer_expr(ctx, &comprehension.iter)?.ok_or_else(|| "no value".to_string())?;
if let ParametricType(LIST_TYPE, ls) = iter.as_ref() {
ctx.with_scope(|ctx| {
infer_simple_binding(ctx, &comprehension.target, ls[0].clone())?;
let boolean = ctx.get_primitive(BOOL_TYPE);
for test in comprehension.ifs.iter() {
let result =
infer_expr(ctx, test)?.ok_or_else(|| "no value in test".to_string())?;
if result != boolean {
return Err("test must be bool".into());
}
}
let result = infer_expr(ctx, element)?.ok_or_else(|| "no value")?;
Ok(Some(ParametricType(LIST_TYPE, vec![result]).into()))
})
.1
} else {
Err("iteration is supported for list only".into())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::context::*;
use crate::typedef::*;
use rustpython_parser::parser::parse_expression;
use std::collections::HashMap;
use std::rc::Rc;
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
}
#[test]
fn test_constants() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("123").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("2147483647").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("2147483648").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("integer out of range".into()));
//
// let ast = parse_expression("2147483648").unwrap();
// let result = infer_expr(&mut ctx, &ast);
// assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
// let ast = parse_expression("9223372036854775807").unwrap();
// let result = infer_expr(&mut ctx, &ast);
// assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
// let ast = parse_expression("9223372036854775808").unwrap();
// let result = infer_expr(&mut ctx, &ast);
// assert_eq!(result, Err("integer out of range".into()));
let ast = parse_expression("123.456").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(FLOAT_TYPE));
let ast = parse_expression("True").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
}
#[test]
fn test_identifier() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
let ast = parse_expression("abc").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("ab").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded identifier".into()));
}
#[test]
fn test_list() {
let mut ctx = basic_ctx();
ctx.add_fn(
"foo",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
// def is reserved...
ctx.assign("efg", ctx.get_primitive(INT32_TYPE)).unwrap();
ctx.assign("xyz", ctx.get_primitive(FLOAT_TYPE)).unwrap();
let ast = parse_expression("[]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![BotType.into()]).into()
);
let ast = parse_expression("[abc]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[abc, efg]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[abc, efg, xyz]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("inhomogeneous list is not allowed".into()));
let ast = parse_expression("[foo()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("list elements must have some type".into()));
}
#[test]
fn test_tuple() {
let mut ctx = basic_ctx();
ctx.add_fn(
"foo",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
ctx.assign("efg", ctx.get_primitive(FLOAT_TYPE)).unwrap();
let ast = parse_expression("(abc, efg)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(
TUPLE_TYPE,
vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)]
)
.into()
);
let ast = parse_expression("(abc, efg, foo())").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("tuple elements must have some type".into()));
}
#[test]
fn test_attribute() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let int32 = ctx.get_primitive(INT32_TYPE);
let float = ctx.get_primitive(FLOAT_TYPE);
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let foo_def = ctx.get_class_def_mut(foo);
foo_def.base.fields.insert("a", int32.clone());
foo_def.base.fields.insert("b", ClassType(foo).into());
foo_def.base.fields.insert("c", int32.clone());
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "Bar",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let bar_def = ctx.get_class_def_mut(bar);
bar_def.base.fields.insert("a", int32);
bar_def.base.fields.insert("b", ClassType(bar).into());
bar_def.base.fields.insert("c", float);
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![],
});
let v1 = ctx.add_variable(VarDef {
name: "v1",
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
.unwrap();
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
ctx.assign("bot", Rc::new(BotType)).unwrap();
let ast = parse_expression("foo.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("foo.d").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such field".into()));
let ast = parse_expression("foobar.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("v0.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no fields for type variable".into()));
let ast = parse_expression("v1.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no fields for type variable".into()));
let ast = parse_expression("none().a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("bot.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such field".into()));
}
#[test]
fn test_bool_ops() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("True and False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("True and none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("True and 123").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("bool operands must be bool".into()));
}
#[test]
fn test_bin_ops() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("1 + 2 + 3").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("a + a + a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
}
#[test]
fn test_unary_ops() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("-(123)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("-a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("not True").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("not (1)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("logical not must be applied to bool".into()));
}
#[test]
fn test_compare() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("a == a == a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("a == a == 1").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("True > False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such function".into()));
let ast = parse_expression("True in False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unsupported comparison".into()));
}
#[test]
fn test_call() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let foo_def = ctx.get_class_def_mut(foo);
foo_def.base.methods.insert(
"a",
FnDef {
args: vec![],
result: Some(Rc::new(ClassType(foo))),
},
);
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "Bar",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let bar_def = ctx.get_class_def_mut(bar);
bar_def.base.methods.insert(
"a",
FnDef {
args: vec![],
result: Some(Rc::new(ClassType(bar))),
},
);
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![],
});
let v1 = ctx.add_variable(VarDef {
name: "v1",
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
});
let v2 = ctx.add_variable(VarDef {
name: "v2",
bound: vec![
ClassType(foo).into(),
ClassType(bar).into(),
ctx.get_primitive(INT32_TYPE),
],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
.unwrap();
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
ctx.assign("v2", ctx.get_variable(v2)).unwrap();
ctx.assign("bot", Rc::new(BotType)).unwrap();
let ast = parse_expression("foo.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
let ast = parse_expression("v1.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("foobar.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
let ast = parse_expression("none().a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("bot.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("[][0].a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
}
#[test]
fn infer_subscript() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("[1, 2, 3][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("[[1]][0][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("[1, 2, 3][1:2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[1, 2, 3][1:2:2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[1, 2, 3][1:1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("slice must be int32 type".into()));
let ast = parse_expression("[1, 2, 3][1:none()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("slice must have type".into()));
let ast = parse_expression("[1, 2, 3][1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("index must be either slice or int32".into()));
let ast = parse_expression("[1, 2, 3][none()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("none()[1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("123[1]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("subscript is not supported for types other than list".into())
);
}
#[test]
fn test_if_expr() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("1 if True else 0").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("none() if True else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap(), None);
let ast = parse_expression("none() if 1 else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("test should be bool".into()));
let ast = parse_expression("1 if True else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("divergent type".into()));
}
#[test]
fn test_list_comp() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let int32 = ctx.get_primitive(INT32_TYPE);
let mut ctx = get_inference_context(ctx);
ctx.assign("z", int32.clone()).unwrap();
let ast = parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into()
);
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), int32);
let ast =
parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), int32);
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("test must be bool".into()));
let ast = parse_expression("[y for x in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded identifier".into()));
let ast = parse_expression("[none() for x in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("[z for z in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("duplicated naming".into()));
let ast = parse_expression("[x for x in [] for y in []]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("only 1 generator statement is supported".into())
);
}
}

View File

@ -1,525 +0,0 @@
use crate::context::InferenceContext;
use crate::typedef::{TypeEnum::*, *};
use std::collections::HashMap;
fn find_subst(
ctx: &InferenceContext,
valuation: &Option<(VariableId, Type)>,
sub: &mut HashMap<VariableId, Type>,
mut a: Type,
mut b: Type,
) -> Result<(), String> {
// TODO: fix error messages later
if let TypeVariable(id) = a.as_ref() {
if let Some((assumption_id, t)) = valuation {
if assumption_id == id {
a = t.clone();
}
}
}
let mut substituted = false;
if let TypeVariable(id) = b.as_ref() {
if let Some(c) = sub.get(&id) {
b = c.clone();
substituted = true;
}
}
match (a.as_ref(), b.as_ref()) {
(BotType, _) => Ok(()),
(TypeVariable(id_a), TypeVariable(id_b)) => {
if substituted {
return if id_a == id_b {
Ok(())
} else {
Err("different variables".to_string())
};
}
let v_a = ctx.get_variable_def(*id_a);
let v_b = ctx.get_variable_def(*id_b);
if !v_b.bound.is_empty() {
if v_a.bound.is_empty() {
return Err("unbounded a".to_string());
} else {
let diff: Vec<_> = v_a
.bound
.iter()
.filter(|x| !v_b.bound.contains(x))
.collect();
if !diff.is_empty() {
return Err("different domain".to_string());
}
}
}
sub.insert(*id_b, a.clone());
Ok(())
}
(TypeVariable(id_a), _) => {
let v_a = ctx.get_variable_def(*id_a);
if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() {
Ok(())
} else {
Err("different domain".to_string())
}
}
(_, TypeVariable(id_b)) => {
let v_b = ctx.get_variable_def(*id_b);
if v_b.bound.is_empty() || v_b.bound.contains(&a) {
sub.insert(*id_b, a.clone());
Ok(())
} else {
Err("different domain".to_string())
}
}
(_, VirtualClassType(id_b)) => {
let mut parents;
match a.as_ref() {
ClassType(id_a) => {
parents = [*id_a].to_vec();
}
VirtualClassType(id_a) => {
parents = [*id_a].to_vec();
}
_ => {
return Err("cannot substitute non-class type into virtual class".to_string());
}
};
while !parents.is_empty() {
if *id_b == parents[0] {
return Ok(());
}
let c = ctx.get_class_def(parents.remove(0));
parents.extend_from_slice(&c.parents);
}
Err("not subtype".to_string())
}
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
if id_a != id_b || param_a.len() != param_b.len() {
Err("different parametric types".to_string())
} else {
for (x, y) in param_a.iter().zip(param_b.iter()) {
find_subst(ctx, valuation, sub, x.clone(), y.clone())?;
}
Ok(())
}
}
(_, _) => {
if a == b {
Ok(())
} else {
Err("not equal".to_string())
}
}
}
}
fn resolve_call_rec(
ctx: &InferenceContext,
valuation: &Option<(VariableId, Type)>,
obj: Option<Type>,
func: &str,
args: &[Type],
) -> Result<Option<Type>, String> {
let mut subst = obj
.as_ref()
.map(|v| v.get_subst(ctx))
.unwrap_or_else(HashMap::new);
let fun = match &obj {
Some(obj) => {
let base = match obj.as_ref() {
PrimitiveType(id) => &ctx.get_primitive_def(*id),
ClassType(id) | VirtualClassType(id) => &ctx.get_class_def(*id).base,
ParametricType(id, _) => &ctx.get_parametric_def(*id).base,
_ => return Err("not supported".to_string()),
};
base.methods.get(func)
}
None => ctx.get_fn_def(func),
}
.ok_or_else(|| "no such function".to_string())?;
if args.len() != fun.args.len() {
return Err("incorrect parameter number".to_string());
}
for (a, b) in args.iter().zip(fun.args.iter()) {
find_subst(ctx, valuation, &mut subst, a.clone(), b.clone())?;
}
let result = fun.result.as_ref().map(|v| v.subst(&subst));
Ok(result.map(|result| {
if let SelfType = result {
obj.unwrap()
} else {
result.into()
}
}))
}
pub fn resolve_call(
ctx: &InferenceContext,
obj: Option<Type>,
func: &str,
args: &[Type],
) -> Result<Option<Type>, String> {
resolve_call_rec(ctx, &None, obj, func, args)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::TopLevelContext;
use crate::primitives::*;
use std::rc::Rc;
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
}
#[test]
fn test_simple_generic() {
let mut ctx = basic_ctx();
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)],
});
let v1 = ctx.get_variable(v1);
let v2 = ctx.add_variable(VarDef {
name: "V2",
bound: vec![
ctx.get_primitive(BOOL_TYPE),
ctx.get_primitive(INT32_TYPE),
ctx.get_primitive(FLOAT_TYPE),
],
});
let v2 = ctx.get_variable(v2);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(FLOAT_TYPE)]),
Ok(Some(ctx.get_primitive(INT32_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(INT32_TYPE)],),
Ok(Some(ctx.get_primitive(INT32_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "float", &[ctx.get_primitive(INT32_TYPE)]),
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "float", &[ctx.get_primitive(BOOL_TYPE)]),
Err("different domain".to_string())
);
assert_eq!(
resolve_call(&ctx, None, "float", &[]),
Err("incorrect parameter number".to_string())
);
assert_eq!(
resolve_call(&ctx, None, "float", &[v1]),
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "float", &[v2]),
Err("different domain".to_string())
);
}
#[test]
fn test_methods() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "V0",
bound: vec![],
});
let v0 = ctx.get_variable(v0);
let int32 = ctx.get_primitive(INT32_TYPE);
let int64 = ctx.get_primitive(INT64_TYPE);
let ctx = get_inference_context(ctx);
// simple cases
assert_eq!(
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
Ok(Some(int32.clone()))
);
assert_ne!(
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
Ok(Some(int64.clone()))
);
assert_eq!(
resolve_call(&ctx, Some(int32), "__add__", &[int64]),
Err("not equal".to_string())
);
// with type variables
assert_eq!(
resolve_call(&ctx, Some(v0.clone()), "__add__", &[v0.clone()]),
Err("not supported".into())
);
}
#[test]
fn test_multi_generic() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "V0",
bound: vec![],
});
let v0 = ctx.get_variable(v0);
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![],
});
let v1 = ctx.get_variable(v1);
let v2 = ctx.add_variable(VarDef {
name: "V2",
bound: vec![],
});
let v2 = ctx.get_variable(v2);
let v3 = ctx.add_variable(VarDef {
name: "V3",
bound: vec![],
});
let v3 = ctx.get_variable(v3);
ctx.add_fn(
"foo",
FnDef {
args: vec![v0.clone(), v0.clone(), v1.clone()],
result: Some(v0.clone()),
},
);
ctx.add_fn(
"foo1",
FnDef {
args: vec![ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1]).into()],
result: Some(v0),
},
);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]),
Err("different variables".to_string())
);
assert_eq!(
resolve_call(
&ctx,
None,
"foo1",
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()]
),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(
&ctx,
None,
"foo1",
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()]
),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(
&ctx,
None,
"foo1",
&[ParametricType(TUPLE_TYPE, vec![v2, v3.clone(), v3]).into()]
),
Err("different variables".to_string())
);
}
#[test]
fn test_class_generics() {
let mut ctx = basic_ctx();
let list = ctx.get_parametric_def_mut(LIST_TYPE);
let t = Rc::new(TypeVariable(list.params[0]));
list.base.methods.insert(
"head",
FnDef {
args: vec![],
result: Some(t.clone()),
},
);
list.base.methods.insert(
"append",
FnDef {
args: vec![t],
result: None,
},
);
let v0 = ctx.add_variable(VarDef {
name: "V0",
bound: vec![],
});
let v0 = ctx.get_variable(v0);
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![],
});
let v1 = ctx.get_variable(v1);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(
&ctx,
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
"head",
&[]
),
Ok(Some(v0.clone()))
);
assert_eq!(
resolve_call(
&ctx,
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
"append",
&[v0.clone()]
),
Ok(None)
);
assert_eq!(
resolve_call(
&ctx,
Some(ParametricType(LIST_TYPE, vec![v0]).into()),
"append",
&[v1]
),
Err("different variables".to_string())
);
}
#[test]
fn test_virtual_class() {
let mut ctx = basic_ctx();
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![],
});
let foo1 = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo1",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![foo],
});
let foo2 = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo2",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![foo1],
});
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "bar",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![],
});
ctx.add_fn(
"foo",
FnDef {
args: vec![VirtualClassType(foo).into()],
result: None,
},
);
ctx.add_fn(
"foo1",
FnDef {
args: vec![VirtualClassType(foo1).into()],
result: None,
},
);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]),
Err("not subtype".to_string())
);
assert_eq!(
resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]),
Err("not subtype".to_string())
);
// virtual class substitution
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]),
Err("not subtype".to_string())
);
}
}

View File

@ -1,593 +1,8 @@
#![warn(clippy::all)]
#![allow(clippy::clone_double_ref)]
#![allow(dead_code)]
extern crate num_bigint;
extern crate inkwell;
extern crate rustpython_parser;
pub mod expression_inference;
pub mod inference_core;
mod magic_methods;
pub mod primitives;
pub mod typedef;
pub mod context;
use std::error::Error;
use std::fmt;
use std::path::Path;
use std::collections::HashMap;
use num_traits::cast::ToPrimitive;
use rustpython_parser::ast;
use inkwell::OptimizationLevel;
use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::module::Module;
use inkwell::targets::*;
use inkwell::types;
use inkwell::types::BasicType;
use inkwell::values;
use inkwell::{IntPredicate, FloatPredicate};
use inkwell::basic_block;
use inkwell::passes;
#[derive(Debug)]
enum CompileErrorKind {
Unsupported(&'static str),
MissingTypeAnnotation,
UnknownTypeAnnotation,
IncompatibleTypes,
UnboundIdentifier,
BreakOutsideLoop,
Internal(&'static str)
}
impl fmt::Display for CompileErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CompileErrorKind::Unsupported(feature)
=> write!(f, "The following Python feature is not supported by NAC3: {}", feature),
CompileErrorKind::MissingTypeAnnotation
=> write!(f, "Missing type annotation"),
CompileErrorKind::UnknownTypeAnnotation
=> write!(f, "Unknown type annotation"),
CompileErrorKind::IncompatibleTypes
=> write!(f, "Incompatible types"),
CompileErrorKind::UnboundIdentifier
=> write!(f, "Unbound identifier"),
CompileErrorKind::BreakOutsideLoop
=> write!(f, "Break outside loop"),
CompileErrorKind::Internal(details)
=> write!(f, "Internal compiler error: {}", details),
}
}
}
#[derive(Debug)]
pub struct CompileError {
location: ast::Location,
kind: CompileErrorKind,
}
impl fmt::Display for CompileError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}, at {}", self.kind, self.location)
}
}
impl Error for CompileError {}
type CompileResult<T> = Result<T, CompileError>;
pub struct CodeGen<'ctx> {
context: &'ctx Context,
module: Module<'ctx>,
pass_manager: passes::PassManager<values::FunctionValue<'ctx>>,
builder: Builder<'ctx>,
current_source_location: ast::Location,
namespace: HashMap<String, values::PointerValue<'ctx>>,
break_bb: Option<basic_block::BasicBlock<'ctx>>,
}
impl<'ctx> CodeGen<'ctx> {
pub fn new(context: &'ctx Context) -> CodeGen<'ctx> {
let module = context.create_module("kernel");
let pass_manager = passes::PassManager::create(&module);
pass_manager.add_instruction_combining_pass();
pass_manager.add_reassociate_pass();
pass_manager.add_gvn_pass();
pass_manager.add_cfg_simplification_pass();
pass_manager.add_basic_alias_analysis_pass();
pass_manager.add_promote_memory_to_register_pass();
pass_manager.add_instruction_combining_pass();
pass_manager.add_reassociate_pass();
pass_manager.initialize();
let i32_type = context.i32_type();
let fn_type = i32_type.fn_type(&[i32_type.into()], false);
module.add_function("output", fn_type, None);
CodeGen {
context, module, pass_manager,
builder: context.create_builder(),
current_source_location: ast::Location::default(),
namespace: HashMap::new(),
break_bb: None,
}
}
fn set_source_location(&mut self, location: ast::Location) {
self.current_source_location = location;
}
fn compile_error(&self, kind: CompileErrorKind) -> CompileError {
CompileError {
location: self.current_source_location,
kind
}
}
fn get_basic_type(&self, name: &str) -> CompileResult<types::BasicTypeEnum<'ctx>> {
match name {
"bool" => Ok(self.context.bool_type().into()),
"int32" => Ok(self.context.i32_type().into()),
"int64" => Ok(self.context.i64_type().into()),
"float32" => Ok(self.context.f32_type().into()),
"float64" => Ok(self.context.f64_type().into()),
_ => Err(self.compile_error(CompileErrorKind::UnknownTypeAnnotation))
}
}
fn compile_function_def(
&mut self,
name: &str,
args: &ast::Parameters,
body: &ast::Suite,
decorator_list: &[ast::Expression],
returns: &Option<ast::Expression>,
is_async: bool,
) -> CompileResult<values::FunctionValue<'ctx>> {
if is_async {
return Err(self.compile_error(CompileErrorKind::Unsupported("async functions")))
}
for decorator in decorator_list.iter() {
self.set_source_location(decorator.location);
if let ast::ExpressionType::Identifier { name } = &decorator.node {
if name != "kernel" && name != "portable" {
return Err(self.compile_error(CompileErrorKind::Unsupported("custom decorators")))
}
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("decorator must be an identifier")))
}
}
let args_type = args.args.iter().map(|val| {
self.set_source_location(val.location);
if let Some(annotation) = &val.annotation {
if let ast::ExpressionType::Identifier { name } = &annotation.node {
Ok(self.get_basic_type(&name)?)
} else {
Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
}
} else {
Err(self.compile_error(CompileErrorKind::MissingTypeAnnotation))
}
}).collect::<CompileResult<Vec<types::BasicTypeEnum>>>()?;
let return_type = if let Some(returns) = returns {
self.set_source_location(returns.location);
if let ast::ExpressionType::Identifier { name } = &returns.node {
if name == "None" { None } else { Some(self.get_basic_type(name)?) }
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
}
} else {
None
};
let fn_type = match return_type {
Some(ty) => ty.fn_type(&args_type, false),
None => self.context.void_type().fn_type(&args_type, false)
};
let function = self.module.add_function(name, fn_type, None);
let basic_block = self.context.append_basic_block(function, "entry");
self.builder.position_at_end(basic_block);
for (n, arg) in args.args.iter().enumerate() {
let param = function.get_nth_param(n as u32).unwrap();
let alloca = self.builder.build_alloca(param.get_type(), &arg.arg);
self.builder.build_store(alloca, param);
self.namespace.insert(arg.arg.clone(), alloca);
}
self.compile_suite(body, return_type)?;
Ok(function)
}
fn compile_expression(
&mut self,
expression: &ast::Expression
) -> CompileResult<values::BasicValueEnum<'ctx>> {
self.set_source_location(expression.location);
match &expression.node {
ast::ExpressionType::True => Ok(self.context.bool_type().const_int(1, false).into()),
ast::ExpressionType::False => Ok(self.context.bool_type().const_int(0, false).into()),
ast::ExpressionType::Number { value: ast::Number::Integer { value } } => {
let mut bits = value.bits();
if value.sign() == num_bigint::Sign::Minus {
bits += 1;
}
match bits {
0..=32 => Ok(self.context.i32_type().const_int(value.to_i32().unwrap() as _, true).into()),
33..=64 => Ok(self.context.i64_type().const_int(value.to_i64().unwrap() as _, true).into()),
_ => Err(self.compile_error(CompileErrorKind::Unsupported("integers larger than 64 bits")))
}
},
ast::ExpressionType::Number { value: ast::Number::Float { value } } => {
Ok(self.context.f64_type().const_float(*value).into())
},
ast::ExpressionType::Identifier { name } => {
match self.namespace.get(name) {
Some(value) => Ok(self.builder.build_load(*value, name).into()),
None => Err(self.compile_error(CompileErrorKind::UnboundIdentifier))
}
},
ast::ExpressionType::Unop { op, a } => {
let a = self.compile_expression(&a)?;
match (op, a) {
(ast::UnaryOperator::Pos, values::BasicValueEnum::IntValue(a))
=> Ok(a.into()),
(ast::UnaryOperator::Pos, values::BasicValueEnum::FloatValue(a))
=> Ok(a.into()),
(ast::UnaryOperator::Neg, values::BasicValueEnum::IntValue(a))
=> Ok(self.builder.build_int_neg(a, "tmpneg").into()),
(ast::UnaryOperator::Neg, values::BasicValueEnum::FloatValue(a))
=> Ok(self.builder.build_float_neg(a, "tmpneg").into()),
(ast::UnaryOperator::Inv, values::BasicValueEnum::IntValue(a))
=> Ok(self.builder.build_not(a, "tmpnot").into()),
(ast::UnaryOperator::Not, values::BasicValueEnum::IntValue(a)) => {
// boolean "not"
if a.get_type().get_bit_width() != 1 {
Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation")))
} else {
Ok(self.builder.build_not(a, "tmpnot").into())
}
},
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation"))),
}
},
ast::ExpressionType::Binop { a, op, b } => {
let a = self.compile_expression(&a)?;
let b = self.compile_expression(&b)?;
if a.get_type() != b.get_type() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
use ast::Operator::*;
match (op, a, b) {
(Add, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_add(a, b, "tmpadd").into()),
(Sub, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_sub(a, b, "tmpsub").into()),
(Mult, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_mul(a, b, "tmpmul").into()),
(Add, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_add(a, b, "tmpadd").into()),
(Sub, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_sub(a, b, "tmpsub").into()),
(Mult, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_mul(a, b, "tmpmul").into()),
(Div, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_div(a, b, "tmpdiv").into()),
(FloorDiv, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_signed_div(a, b, "tmpdiv").into()),
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented binary operation"))),
}
},
ast::ExpressionType::Compare { vals, ops } => {
let mut vals = vals.iter();
let mut ops = ops.iter();
let mut result = None;
let mut a = self.compile_expression(vals.next().unwrap())?;
loop {
if let Some(op) = ops.next() {
let b = self.compile_expression(vals.next().unwrap())?;
if a.get_type() != b.get_type() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let this_result = match (a, b) {
(values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) => {
match op {
ast::Comparison::Equal
=> self.builder.build_int_compare(IntPredicate::EQ, a, b, "tmpeq"),
ast::Comparison::NotEqual
=> self.builder.build_int_compare(IntPredicate::NE, a, b, "tmpne"),
ast::Comparison::Less
=> self.builder.build_int_compare(IntPredicate::SLT, a, b, "tmpslt"),
ast::Comparison::LessOrEqual
=> self.builder.build_int_compare(IntPredicate::SLE, a, b, "tmpsle"),
ast::Comparison::Greater
=> self.builder.build_int_compare(IntPredicate::SGT, a, b, "tmpsgt"),
ast::Comparison::GreaterOrEqual
=> self.builder.build_int_compare(IntPredicate::SGE, a, b, "tmpsge"),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
}
},
(values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) => {
match op {
ast::Comparison::Equal
=> self.builder.build_float_compare(FloatPredicate::OEQ, a, b, "tmpoeq"),
ast::Comparison::NotEqual
=> self.builder.build_float_compare(FloatPredicate::UNE, a, b, "tmpune"),
ast::Comparison::Less
=> self.builder.build_float_compare(FloatPredicate::OLT, a, b, "tmpolt"),
ast::Comparison::LessOrEqual
=> self.builder.build_float_compare(FloatPredicate::OLE, a, b, "tmpole"),
ast::Comparison::Greater
=> self.builder.build_float_compare(FloatPredicate::OGT, a, b, "tmpogt"),
ast::Comparison::GreaterOrEqual
=> self.builder.build_float_compare(FloatPredicate::OGE, a, b, "tmpoge"),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
}
},
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("comparison of non-numerical types"))),
};
match result {
Some(last) => {
result = Some(self.builder.build_and(last, this_result, "tmpand"));
}
None => {
result = Some(this_result);
}
}
a = b;
} else {
return Ok(result.unwrap().into())
}
}
},
ast::ExpressionType::Call { function, args, keywords } => {
if !keywords.is_empty() {
return Err(self.compile_error(CompileErrorKind::Unsupported("keyword arguments")))
}
let args = args.iter().map(|val| self.compile_expression(val))
.collect::<CompileResult<Vec<values::BasicValueEnum>>>()?;
self.set_source_location(expression.location);
if let ast::ExpressionType::Identifier { name } = &function.node {
match (name.as_str(), args[0]) {
("int32", values::BasicValueEnum::IntValue(a)) => {
let nbits = a.get_type().get_bit_width();
if nbits < 32 {
Ok(self.builder.build_int_s_extend(a, self.context.i32_type(), "tmpsext").into())
} else if nbits > 32 {
Ok(self.builder.build_int_truncate(a, self.context.i32_type(), "tmptrunc").into())
} else {
Ok(a.into())
}
},
("int64", values::BasicValueEnum::IntValue(a)) => {
let nbits = a.get_type().get_bit_width();
if nbits < 64 {
Ok(self.builder.build_int_s_extend(a, self.context.i64_type(), "tmpsext").into())
} else {
Ok(a.into())
}
},
("int32", values::BasicValueEnum::FloatValue(a)) => {
Ok(self.builder.build_float_to_signed_int(a, self.context.i32_type(), "tmpfptosi").into())
},
("int64", values::BasicValueEnum::FloatValue(a)) => {
Ok(self.builder.build_float_to_signed_int(a, self.context.i64_type(), "tmpfptosi").into())
},
("float32", values::BasicValueEnum::IntValue(a)) => {
Ok(self.builder.build_signed_int_to_float(a, self.context.f32_type(), "tmpsitofp").into())
},
("float64", values::BasicValueEnum::IntValue(a)) => {
Ok(self.builder.build_signed_int_to_float(a, self.context.f64_type(), "tmpsitofp").into())
},
("float32", values::BasicValueEnum::FloatValue(a)) => {
if a.get_type() == self.context.f64_type() {
Ok(self.builder.build_float_trunc(a, self.context.f32_type(), "tmptrunc").into())
} else {
Ok(a.into())
}
},
("float64", values::BasicValueEnum::FloatValue(a)) => {
if a.get_type() == self.context.f32_type() {
Ok(self.builder.build_float_ext(a, self.context.f64_type(), "tmpext").into())
} else {
Ok(a.into())
}
},
("output", values::BasicValueEnum::IntValue(a)) => {
let fn_value = self.module.get_function("output").unwrap();
Ok(self.builder.build_call(fn_value, &[a.into()], "call")
.try_as_basic_value().left().unwrap())
},
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unrecognized call")))
}
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("function must be an identifier")))
}
},
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented expression"))),
}
}
fn compile_statement(
&mut self,
statement: &ast::Statement,
return_type: Option<types::BasicTypeEnum>
) -> CompileResult<()> {
self.set_source_location(statement.location);
use ast::StatementType::*;
match &statement.node {
Assign { targets, value } => {
let value = self.compile_expression(value)?;
for target in targets.iter() {
self.set_source_location(target.location);
if let ast::ExpressionType::Identifier { name } = &target.node {
let builder = &self.builder;
let target = self.namespace.entry(name.clone()).or_insert_with(
|| builder.build_alloca(value.get_type(), name));
if target.get_type() != value.get_type().ptr_type(inkwell::AddressSpace::Generic) {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
builder.build_store(*target, value);
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("assignment target must be an identifier")))
}
}
},
Expression { expression } => { self.compile_expression(expression)?; },
If { test, body, orelse } => {
let test = self.compile_expression(test)?;
if test.get_type() != self.context.bool_type().into() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_bb = self.context.append_basic_block(parent, "then");
let else_bb = self.context.append_basic_block(parent, "else");
let cont_bb = self.context.append_basic_block(parent, "ifcont");
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
self.builder.position_at_end(then_bb);
self.compile_suite(body, return_type)?;
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(else_bb);
if let Some(orelse) = orelse {
self.compile_suite(orelse, return_type)?;
}
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
},
While { test, body, orelse } => {
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let test_bb = self.context.append_basic_block(parent, "test");
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb);
let test = self.compile_expression(test)?;
if test.get_type() != self.context.bool_type().into() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let then_bb = self.context.append_basic_block(parent, "then");
let else_bb = self.context.append_basic_block(parent, "else");
let cont_bb = self.context.append_basic_block(parent, "ifcont");
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
self.break_bb = Some(cont_bb);
self.builder.position_at_end(then_bb);
self.compile_suite(body, return_type)?;
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(else_bb);
if let Some(orelse) = orelse {
self.compile_suite(orelse, return_type)?;
}
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
self.break_bb = None;
},
Break => {
if let Some(bb) = self.break_bb {
self.builder.build_unconditional_branch(bb);
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let unreachable_bb = self.context.append_basic_block(parent, "unreachable");
self.builder.position_at_end(unreachable_bb);
} else {
return Err(self.compile_error(CompileErrorKind::BreakOutsideLoop));
}
}
Return { value: Some(value) } => {
if let Some(return_type) = return_type {
let value = self.compile_expression(value)?;
if value.get_type() != return_type {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
self.builder.build_return(Some(&value));
} else {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
},
Return { value: None } => {
if !return_type.is_none() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
self.builder.build_return(None);
},
Pass => (),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special statement"))),
}
Ok(())
}
fn compile_suite(
&mut self,
suite: &ast::Suite,
return_type: Option<types::BasicTypeEnum>
) -> CompileResult<()> {
for statement in suite.iter() {
self.compile_statement(statement, return_type)?;
}
Ok(())
}
pub fn compile_toplevel(&mut self, statement: &ast::Statement) -> CompileResult<()> {
self.set_source_location(statement.location);
if let ast::StatementType::FunctionDef {
is_async,
name,
args,
body,
decorator_list,
returns,
} = &statement.node {
let function = self.compile_function_def(name, args, body, decorator_list, returns, *is_async)?;
self.pass_manager.run_on(&function);
Ok(())
} else {
Err(self.compile_error(CompileErrorKind::Internal("top-level is not a function definition")))
}
}
pub fn print_ir(&self) {
self.module.print_to_stderr();
}
pub fn output(&self, filename: &str) {
//let triple = TargetTriple::create("riscv32-none-linux-gnu");
let triple = TargetMachine::get_default_triple();
let target = Target::from_triple(&triple)
.expect("couldn't create target from target triple");
let target_machine = target
.create_target_machine(
&triple,
"",
"",
OptimizationLevel::Default,
RelocMode::Default,
CodeModel::Default,
)
.expect("couldn't create target machine");
target_machine
.write_to_file(&self.module, FileType::Object, Path::new(filename))
.expect("couldn't write module to file");
}
}
mod codegen;
mod location;
mod symbol_resolver;
mod top_level;
mod typecheck;

31
nac3core/src/location.rs Normal file
View File

@ -0,0 +1,31 @@
use rustpython_parser::ast;
use std::vec::Vec;
#[derive(Clone, Copy, PartialEq)]
pub struct FileID(u32);
#[derive(Clone, Copy, PartialEq)]
pub enum Location {
CodeRange(FileID, ast::Location),
Builtin,
}
pub struct FileRegistry {
files: Vec<String>,
}
impl FileRegistry {
pub fn new() -> FileRegistry {
FileRegistry { files: Vec::new() }
}
pub fn add_file(&mut self, path: &str) -> FileID {
let index = self.files.len() as u32;
self.files.push(path.to_owned());
FileID(index)
}
pub fn query_file(&self, id: FileID) -> &str {
&self.files[id.0 as usize]
}
}

View File

@ -1,58 +0,0 @@
use rustpython_parser::ast::{Comparison, Operator, UnaryOperator};
pub fn binop_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__add__",
Operator::Sub => "__sub__",
Operator::Div => "__truediv__",
Operator::Mod => "__mod__",
Operator::Mult => "__mul__",
Operator::Pow => "__pow__",
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
}
}
pub fn binop_assign_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__iadd__",
Operator::Sub => "__isub__",
Operator::Div => "__itruediv__",
Operator::Mod => "__imod__",
Operator::Mult => "__imul__",
Operator::Pow => "__ipow__",
Operator::BitOr => "__ior__",
Operator::BitXor => "__ixor__",
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
}
pub fn unaryop_name(op: &UnaryOperator) -> &'static str {
match op {
UnaryOperator::Pos => "__pos__",
UnaryOperator::Neg => "__neg__",
UnaryOperator::Not => "__not__",
UnaryOperator::Inv => "__inv__",
}
}
pub fn comparison_name(op: &Comparison) -> Option<&'static str> {
match op {
Comparison::Less => Some("__lt__"),
Comparison::LessOrEqual => Some("__le__"),
Comparison::Greater => Some("__gt__"),
Comparison::GreaterOrEqual => Some("__ge__"),
Comparison::Equal => Some("__eq__"),
Comparison::NotEqual => Some("__ne__"),
_ => None,
}
}

View File

@ -1,184 +0,0 @@
use super::typedef::{TypeEnum::*, *};
use crate::context::*;
use std::collections::HashMap;
pub const TUPLE_TYPE: ParamId = ParamId(0);
pub const LIST_TYPE: ParamId = ParamId(1);
pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0);
pub const INT32_TYPE: PrimitiveId = PrimitiveId(1);
pub const INT64_TYPE: PrimitiveId = PrimitiveId(2);
pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(3);
fn impl_math(def: &mut TypeDef, ty: &Type) {
let result = Some(ty.clone());
let fun = FnDef {
args: vec![ty.clone()],
result: result.clone(),
};
def.methods.insert("__add__", fun.clone());
def.methods.insert("__sub__", fun.clone());
def.methods.insert("__mul__", fun.clone());
def.methods.insert(
"__neg__",
FnDef {
args: vec![],
result,
},
);
def.methods.insert(
"__truediv__",
FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(FLOAT_TYPE).into()),
},
);
def.methods.insert("__floordiv__", fun.clone());
def.methods.insert("__mod__", fun.clone());
def.methods.insert("__pow__", fun);
}
fn impl_bits(def: &mut TypeDef, ty: &Type) {
let result = Some(ty.clone());
let fun = FnDef {
args: vec![PrimitiveType(INT32_TYPE).into()],
result,
};
def.methods.insert("__lshift__", fun.clone());
def.methods.insert("__rshift__", fun);
def.methods.insert(
"__xor__",
FnDef {
args: vec![ty.clone()],
result: Some(ty.clone()),
},
);
}
fn impl_eq(def: &mut TypeDef, ty: &Type) {
let fun = FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(BOOL_TYPE).into()),
};
def.methods.insert("__eq__", fun.clone());
def.methods.insert("__ne__", fun);
}
fn impl_order(def: &mut TypeDef, ty: &Type) {
let fun = FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(BOOL_TYPE).into()),
};
def.methods.insert("__lt__", fun.clone());
def.methods.insert("__gt__", fun.clone());
def.methods.insert("__le__", fun.clone());
def.methods.insert("__ge__", fun);
}
pub fn basic_ctx() -> TopLevelContext<'static> {
let primitives = [
TypeDef {
name: "bool",
fields: HashMap::new(),
methods: HashMap::new(),
},
TypeDef {
name: "int32",
fields: HashMap::new(),
methods: HashMap::new(),
},
TypeDef {
name: "int64",
fields: HashMap::new(),
methods: HashMap::new(),
},
TypeDef {
name: "float",
fields: HashMap::new(),
methods: HashMap::new(),
},
]
.to_vec();
let mut ctx = TopLevelContext::new(primitives);
let b = ctx.get_primitive(BOOL_TYPE);
let b_def = ctx.get_primitive_def_mut(BOOL_TYPE);
impl_eq(b_def, &b);
let int32 = ctx.get_primitive(INT32_TYPE);
let int32_def = ctx.get_primitive_def_mut(INT32_TYPE);
impl_math(int32_def, &int32);
impl_bits(int32_def, &int32);
impl_order(int32_def, &int32);
impl_eq(int32_def, &int32);
let int64 = ctx.get_primitive(INT64_TYPE);
let int64_def = ctx.get_primitive_def_mut(INT64_TYPE);
impl_math(int64_def, &int64);
impl_bits(int64_def, &int64);
impl_order(int64_def, &int64);
impl_eq(int64_def, &int64);
let float = ctx.get_primitive(FLOAT_TYPE);
let float_def = ctx.get_primitive_def_mut(FLOAT_TYPE);
impl_math(float_def, &float);
impl_order(float_def, &float);
impl_eq(float_def, &float);
let t = ctx.add_variable_private(VarDef {
name: "T",
bound: vec![],
});
ctx.add_parametric(ParametricDef {
base: TypeDef {
name: "tuple",
fields: HashMap::new(),
methods: HashMap::new(),
},
// we have nothing for tuple, so no param def
params: vec![],
});
ctx.add_parametric(ParametricDef {
base: TypeDef {
name: "list",
fields: HashMap::new(),
methods: HashMap::new(),
},
params: vec![t],
});
let i = ctx.add_variable_private(VarDef {
name: "I",
bound: vec![
PrimitiveType(INT32_TYPE).into(),
PrimitiveType(INT64_TYPE).into(),
PrimitiveType(FLOAT_TYPE).into(),
],
});
let args = vec![TypeVariable(i).into()];
ctx.add_fn(
"int32",
FnDef {
args: args.clone(),
result: Some(PrimitiveType(INT32_TYPE).into()),
},
);
ctx.add_fn(
"int64",
FnDef {
args: args.clone(),
result: Some(PrimitiveType(INT64_TYPE).into()),
},
);
ctx.add_fn(
"float",
FnDef {
args,
result: Some(PrimitiveType(FLOAT_TYPE).into()),
},
);
ctx
}

View File

@ -0,0 +1,174 @@
use std::cell::RefCell;
use std::collections::HashMap;
use crate::top_level::{DefinitionId, TopLevelContext, TopLevelDef};
use crate::typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, Unifier},
};
use crate::{location::Location, typecheck::typedef::TypeEnum};
use itertools::{chain, izip};
use rustpython_parser::ast::Expr;
#[derive(Clone, PartialEq)]
pub enum SymbolValue {
I32(i32),
I64(i64),
Double(f64),
Bool(bool),
Tuple(Vec<SymbolValue>),
// we should think about how to implement bytes later...
// Bytes(&'a [u8]),
}
pub trait SymbolResolver {
// get type of type variable identifier or top-level function type
fn get_symbol_type(
&self,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
str: &str,
) -> Option<Type>;
// get the top-level definition of identifiers
fn get_identifier_def(&self, str: &str) -> Option<DefinitionId>;
fn get_symbol_value(&self, str: &str) -> Option<SymbolValue>;
fn get_symbol_location(&self, str: &str) -> Option<Location>;
// handle function call etc.
}
// convert type annotation into type
pub fn parse_type_annotation<T>(
resolver: &dyn SymbolResolver,
top_level: &TopLevelContext,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, String> {
use rustpython_parser::ast::ExprKind::*;
match &expr.node {
Name { id, .. } => match id.as_str() {
"int32" => Ok(primitives.int32),
"int64" => Ok(primitives.int64),
"float" => Ok(primitives.float),
"bool" => Ok(primitives.bool),
"None" => Ok(primitives.none),
x => {
let obj_id = resolver.get_identifier_def(x);
if let Some(obj_id) = obj_id {
let defs = top_level.definitions.read();
let def = defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() {
return Err(format!(
"Unexpected number of type parameters: expected {} but got 0",
type_vars.len()
));
}
let fields = RefCell::new(
chain(
fields.iter().map(|(k, v)| (k.clone(), *v)),
methods.iter().map(|(k, v, _)| (k.clone(), *v)),
)
.collect(),
);
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields,
params: Default::default(),
}))
} else {
Err("Cannot use function name as type".into())
}
} else {
// it could be a type variable
let ty = resolver
.get_symbol_type(unifier, primitives, x)
.ok_or_else(|| "Cannot use function name as type".to_owned())?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty)
} else {
Err(format!("Unknown type annotation {}", x))
}
}
}
},
Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node {
if id == "virtual" {
let ty =
parse_type_annotation(resolver, top_level, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|v| {
parse_type_annotation(resolver, top_level, unifier, primitives, v)
})
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_type_annotation(
resolver, top_level, unifier, primitives, slice,
)?]
};
let obj_id = resolver
.get_identifier_def(id)
.ok_or_else(|| format!("Unknown type annotation {}", id))?;
let defs = top_level.definitions.read();
let def = defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() {
return Err(format!(
"Unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
types.len()
));
}
let mut subst = HashMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
};
subst.insert(id, *ty);
}
let mut fields = fields
.iter()
.map(|(attr, ty)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(attr.clone(), ty)
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(attr.clone(), ty)
}));
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields: fields.into(),
params: subst.into(),
}))
} else {
Err("Cannot use function name as type".into())
}
}
} else {
Err("unsupported type expression".into())
}
}
_ => Err("unsupported type expression".into()),
}
}
impl dyn SymbolResolver + Send + Sync {
pub fn parse_type_annotation<T>(
&self,
top_level: &TopLevelContext,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, String> {
parse_type_annotation(self, top_level, unifier, primitives, expr)
}
}

778
nac3core/src/top_level.rs Normal file
View File

@ -0,0 +1,778 @@
use std::borrow::BorrowMut;
use std::ops::{Deref, DerefMut};
use std::{collections::HashMap, collections::HashSet, sync::Arc};
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier};
use crate::typecheck::typedef::{FunSignature, FuncArg};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Mapping};
use itertools::Itertools;
use parking_lot::{Mutex, RwLock};
use rustpython_parser::ast::{self, Stmt};
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct DefinitionId(pub usize);
pub enum TopLevelDef {
Class {
// object ID used for TypeEnum
object_id: DefinitionId,
// type variables bounded to the class.
type_vars: Vec<Type>,
// class fields
fields: Vec<(String, Type)>,
// class methods, pointing to the corresponding function definition.
methods: Vec<(String, Type, DefinitionId)>,
// ancestor classes, including itself.
ancestors: Vec<DefinitionId>,
// symbol resolver of the module defined the class, none if it is built-in type
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
},
Function {
// prefix for symbol, should be unique globally, and not ending with numbers
name: String,
// function signature.
signature: Type,
/// Function instance to symbol mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class.
/// Value: function symbol name.
instance_to_symbol: HashMap<String, String>,
/// Function instances to annotated AST mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type
/// variables.
/// Value: AST annotated with types together with a unification table index. Could contain
/// rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, (Stmt<Option<Type>>, usize)>,
// symbol resolver of the module defined the class
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
},
Initializer {
class_id: DefinitionId,
},
}
impl TopLevelDef {
fn get_function_type(&self) -> Result<Type, String> {
if let Self::Function { signature, .. } = self {
Ok(*signature)
} else {
Err("only expect function def here".into())
}
}
}
pub struct TopLevelContext {
pub definitions: Arc<RwLock<Vec<Arc<RwLock<TopLevelDef>>>>>,
pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>,
}
pub struct TopLevelComposer {
// list of top level definitions, same as top level context
pub definition_ast_list: Arc<RwLock<Vec<(Arc<RwLock<TopLevelDef>>, Option<ast::Stmt<()>>)>>>,
// start as a primitive unifier, will add more top_level defs inside
pub unifier: Unifier,
// primitive store
pub primitives: PrimitiveStore,
// mangled class method name to def_id
pub class_method_to_def_id: HashMap<String, DefinitionId>,
// record the def id of the classes whoses fields and methods are to be analyzed
pub to_be_analyzed_class: Vec<DefinitionId>,
}
impl TopLevelComposer {
pub fn to_top_level_context(&self) -> TopLevelContext {
let def_list =
self.definition_ast_list.read().iter().map(|(x, _)| x.clone()).collect::<Vec<_>>();
TopLevelContext {
definitions: RwLock::new(def_list).into(),
// FIXME: all the big unifier or?
unifiers: Default::default(),
}
}
fn name_mangling(mut class_name: String, method_name: &str) -> String {
class_name.push_str(method_name);
class_name
}
pub fn make_primitives() -> (PrimitiveStore, Unifier) {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let primitives = PrimitiveStore { int32, int64, float, bool, none };
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
(primitives, unifier)
}
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name
pub fn new() -> (Vec<(String, DefinitionId, Type)>, Self) {
let primitives = Self::make_primitives();
let top_level_def_list = vec![
Arc::new(RwLock::new(Self::make_top_level_class_def(0, None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(1, None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(2, None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(3, None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(4, None))),
];
let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None];
let composer = TopLevelComposer {
definition_ast_list: RwLock::new(
top_level_def_list.into_iter().zip(ast_list).collect_vec(),
)
.into(),
primitives: primitives.0,
unifier: primitives.1,
class_method_to_def_id: Default::default(),
to_be_analyzed_class: Default::default(),
};
(
vec![
("int32".into(), DefinitionId(0), composer.primitives.int32),
("int64".into(), DefinitionId(1), composer.primitives.int64),
("float".into(), DefinitionId(2), composer.primitives.float),
("bool".into(), DefinitionId(3), composer.primitives.bool),
("none".into(), DefinitionId(4), composer.primitives.none),
],
composer,
)
}
/// already include the definition_id of itself inside the ancestors vector
/// when first regitering, the type_vars, fields, methods, ancestors are invalid
pub fn make_top_level_class_def(
index: usize,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
) -> TopLevelDef {
TopLevelDef::Class {
object_id: DefinitionId(index),
type_vars: Default::default(),
fields: Default::default(),
methods: Default::default(),
ancestors: vec![DefinitionId(index)],
resolver,
}
}
/// when first registering, the type is a invalid value
pub fn make_top_level_function_def(
name: String,
ty: Type,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
) -> TopLevelDef {
TopLevelDef::Function {
name,
signature: ty,
instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(),
resolver,
}
}
/// step 0, register, just remeber the names of top level classes/function
pub fn register_top_level(
&mut self,
ast: ast::Stmt<()>,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
) -> Result<(String, DefinitionId), String> {
let mut def_list = self.definition_ast_list.write();
match &ast.node {
ast::StmtKind::ClassDef { name, body, .. } => {
let class_name = name.to_string();
let class_def_id = def_list.len();
// add the class to the definition lists
// since later when registering class method, ast will still be used,
// here push None temporarly, later will move the ast inside
let mut class_def_ast = (
Arc::new(RwLock::new(Self::make_top_level_class_def(
class_def_id,
resolver.clone(),
))),
None,
);
// parse class def body and register class methods into the def list.
// module's symbol resolver would not know the name of the class methods,
// thus cannot return their definition_id
let mut class_method_name_def_ids: Vec<(
String,
Arc<RwLock<TopLevelDef>>,
DefinitionId,
)> = Vec::new();
let mut class_method_index_offset = 0;
for b in body {
if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node {
let method_name = Self::name_mangling(class_name.clone(), method_name);
let method_def_id = def_list.len() + {
class_method_index_offset += 1;
class_method_index_offset
};
// dummy method define here
// the ast of class method is in the class, push None in to the list here
class_method_name_def_ids.push((
method_name.clone(),
RwLock::new(Self::make_top_level_function_def(
method_name.clone(),
self.primitives.none,
resolver.clone(),
))
.into(),
DefinitionId(method_def_id),
));
}
}
// move the ast to the entry of the class in the ast_list
class_def_ast.1 = Some(ast);
// now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order
def_list.push(class_def_ast);
for (name, def, id) in class_method_name_def_ids {
def_list.push((def, None));
self.class_method_to_def_id.insert(name, id);
}
// put the constructor into the def_list
def_list.push((
RwLock::new(TopLevelDef::Initializer { class_id: DefinitionId(class_def_id) })
.into(),
None,
));
// class, put its def_id into the to be analyzed set
self.to_be_analyzed_class.push(DefinitionId(class_def_id));
Ok((class_name, DefinitionId(class_def_id)))
}
ast::StmtKind::FunctionDef { name, .. } => {
let fun_name = name.to_string();
// add to the definition list
def_list.push((
RwLock::new(Self::make_top_level_function_def(
name.into(),
self.primitives.none,
resolver,
))
.into(),
Some(ast),
));
// return
Ok((fun_name, DefinitionId(def_list.len() - 1)))
}
_ => Err("only registrations of top level classes/functions are supprted".into()),
}
}
/// step 1, analyze the type vars associated with top level class
fn analyze_top_level_class_type_var(&mut self) -> Result<(), String> {
let mut def_list = self.definition_ast_list.write();
let converted_top_level = &self.to_top_level_context();
let primitives = &self.primitives;
let unifier = &mut self.unifier;
for (class_def, class_ast) in def_list.iter_mut() {
// only deal with class def here
let mut class_def = class_def.write();
let (class_bases_ast, class_def_type_vars, class_resolver) = {
if let TopLevelDef::Class { type_vars, resolver, .. } = class_def.deref_mut() {
if let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, ..
}) = class_ast
{
(bases, type_vars, resolver)
} else {
unreachable!("must be both class")
}
} else {
continue;
}
};
let class_resolver = class_resolver.as_ref().unwrap().lock();
let mut is_generic = false;
for b in class_bases_ast {
match &b.node {
// analyze typevars bounded to the class,
// only support things like `class A(Generic[T, V])`,
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
// i.e. only simple names are allowed in the subscript
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
ast::ExprKind::Subscript { value, slice, .. } if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Generic") =>
{
if !is_generic {
is_generic = true;
} else {
return Err("Only single Generic[...] can be in bases".into());
}
// if `class A(Generic[T, V, G])`
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
// parse the type vars
let type_vars = elts
.iter()
.map(|e| {
class_resolver.parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
e,
)
})
.collect::<Result<Vec<_>, _>>()?;
// check if all are unique type vars
let mut occured_type_var_id: HashSet<u32> = HashSet::new();
let all_unique_type_var = type_vars.iter().all(|x| {
let ty = unifier.get_ty(*x);
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
occured_type_var_id.insert(*id)
} else {
false
}
});
if !all_unique_type_var {
return Err("expect unique type variables".into());
}
// add to TopLevelDef
class_def_type_vars.extend(type_vars);
// `class A(Generic[T])`
} else {
let ty = class_resolver.parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
&slice,
)?;
// check if it is type var
let is_type_var =
matches!(unifier.get_ty(ty).as_ref(), &TypeEnum::TVar { .. });
if !is_type_var {
return Err("expect type variable here".into());
}
// add to TopLevelDef
class_def_type_vars.push(ty);
}
}
// if others, do nothing in this function
_ => continue,
}
}
}
Ok(())
}
/// step 2, base classes. Need to separate step1 and step2 for this reason:
/// `class B(Generic[T, V]);
/// class A(B[int, bool])`
/// if the type var associated with class `B` has not been handled properly,
/// the parse of type annotation of `B[int, bool]` will fail
fn analyze_top_level_class_bases(&mut self) -> Result<(), String> {
let mut def_list = self.definition_ast_list.write();
let converted_top_level = &self.to_top_level_context();
let primitives = &self.primitives;
let unifier = &mut self.unifier;
for (class_def, class_ast) in def_list.iter_mut() {
let mut class_def = class_def.write();
let (class_bases, class_ancestors, class_resolver) = {
if let TopLevelDef::Class { ancestors, resolver, .. } = class_def.deref_mut() {
if let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, ..
}) = class_ast
{
(bases, ancestors, resolver)
} else {
unreachable!("must be both class")
}
} else {
continue;
}
};
let class_resolver = class_resolver.as_ref().unwrap().lock();
for b in class_bases {
// type vars have already been handled, so skip on `Generic[...]`
if let ast::ExprKind::Subscript { value, .. } = &b.node {
if let ast::ExprKind::Name { id, .. } = &value.node {
if id == "Generic" {
continue;
}
}
}
// get the def id of the base class
let base_ty = class_resolver.parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
b,
)?;
let base_id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(base_ty).as_ref() {
*obj_id
} else {
return Err("expect concrete class/type to be base class".into());
};
// write to the class ancestors, make sure the uniqueness
if !class_ancestors.contains(&base_id) {
class_ancestors.push(base_id);
} else {
return Err("cannot specify the same base class twice".into());
}
}
}
Ok(())
}
/// step 3, class fields and methods
// FIXME: analyze base classes here
// FIXME: deal with self type
// NOTE: prevent cycles only roughly done
fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> {
let mut def_ast_list = self.definition_ast_list.write();
let converted_top_level = &self.to_top_level_context();
let primitives = &self.primitives;
let to_be_analyzed_class = &mut self.to_be_analyzed_class;
let unifier = &mut self.unifier;
// NOTE: roughly prevent infinite loop
let mut max_iter = to_be_analyzed_class.len() * 4;
'class: loop {
if to_be_analyzed_class.is_empty() && {
max_iter -= 1;
max_iter > 0
} {
break;
}
let class_ind = to_be_analyzed_class.remove(0).0;
let (class_name, class_body_ast, class_bases_ast, class_resolver, class_ancestors) = {
let (class_def, class_ast) = &mut def_ast_list[class_ind];
if let Some(ast::Located {
node: ast::StmtKind::ClassDef { name, body, bases, .. },
..
}) = class_ast.as_ref()
{
if let TopLevelDef::Class { resolver, ancestors, .. } =
class_def.write().deref()
{
(name, body, bases, resolver.as_ref().unwrap().clone(), ancestors.clone())
} else {
unreachable!()
}
} else {
unreachable!("should be class def ast")
}
};
let all_base_class_analyzed = {
let not_yet_analyzed =
to_be_analyzed_class.clone().into_iter().collect::<HashSet<_>>();
let base = class_ancestors.clone().into_iter().collect::<HashSet<_>>();
let intersection = not_yet_analyzed.intersection(&base).collect_vec();
intersection.is_empty()
};
if !all_base_class_analyzed {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
}
// get the bases type, can directly do this since it
// already pass the check in the previous stages
let class_bases_ty = class_bases_ast
.iter()
.filter_map(|x| {
class_resolver
.as_ref()
.lock()
.parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
x,
)
.ok()
})
.collect_vec();
// need these vectors to check re-defining methods, class fields
// and store the parsed result in case some method cannot be typed for now
let mut class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![];
let mut class_fields_parsing_result: Vec<(String, Type)> = vec![];
for b in class_body_ast {
if let ast::StmtKind::FunctionDef {
args: method_args_ast,
body: method_body_ast,
name: method_name,
returns: method_returns_ast,
..
} = &b.node
{
let arg_name_tys: Vec<(String, Type)> = {
let mut result = vec![];
for a in &method_args_ast.args {
if a.node.arg != "self" {
let annotation = a
.node
.annotation
.as_ref()
.ok_or_else(|| {
"type annotation for function parameter is needed"
.to_string()
})?
.as_ref();
let ty = class_resolver.as_ref().lock().parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
annotation,
)?;
if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
}
result.push((a.node.arg.to_string(), ty));
} else {
// TODO: handle self, how
unimplemented!()
}
}
result
};
let method_type_var = arg_name_tys
.iter()
.filter_map(|(_, ty)| {
let ty_enum = unifier.get_ty(*ty);
if let TypeEnum::TVar { id, .. } = ty_enum.as_ref() {
Some((*id, *ty))
} else {
None
}
})
.collect::<Mapping<u32>>();
let ret_ty = {
if method_name != "__init__" {
let ty = method_returns_ast
.as_ref()
.map(|x| {
class_resolver.as_ref().lock().parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
x.as_ref(),
)
})
.ok_or_else(|| "return type annotation error".to_string())??;
if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
} else {
ty
}
} else {
// TODO: __init__ function, self type, how
unimplemented!()
}
};
// handle fields
let class_field_name_tys: Option<Vec<(String, Type)>> = if method_name
== "__init__"
{
let mut result: Vec<(String, Type)> = vec![];
for body in method_body_ast {
match &body.node {
ast::StmtKind::AnnAssign { target, annotation, .. }
if {
if let ast::ExprKind::Attribute { value, .. } = &target.node
{
matches!(
&value.node,
ast::ExprKind::Name { id, .. } if id == "self")
} else {
false
}
} =>
{
let field_ty =
class_resolver.as_ref().lock().parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
annotation.as_ref(),
)?;
if !Self::check_ty_analyzed(
field_ty,
unifier,
to_be_analyzed_class,
) {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
} else {
result.push((
if let ast::ExprKind::Attribute { attr, .. } =
&target.node
{
attr.to_string()
} else {
unreachable!()
},
field_ty,
))
}
}
// exclude those without type annotation
ast::StmtKind::Assign { targets, .. }
if {
if let ast::ExprKind::Attribute { value, .. } =
&targets[0].node
{
matches!(
&value.node,
ast::ExprKind::Name {id, ..} if id == "self")
} else {
false
}
} =>
{
return Err("class fields type annotation needed".into())
}
// do nothing
_ => {}
}
}
Some(result)
} else {
None
};
// current method all type ok, put the current method into the list
if class_methods_parsing_result.iter().any(|(name, _, _)| name == method_name) {
return Err("duplicate method definition".into());
} else {
class_methods_parsing_result.push((
method_name.clone(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: ret_ty,
args: arg_name_tys
.into_iter()
.map(|(name, ty)| FuncArg { name, ty, default_value: None })
.collect_vec(),
vars: method_type_var,
}
.into(),
)),
*self
.class_method_to_def_id
.get(&Self::name_mangling(class_name.clone(), method_name))
.unwrap(),
))
}
// put the fiedlds inside
if let Some(class_field_name_tys) = class_field_name_tys {
assert!(class_fields_parsing_result.is_empty());
class_fields_parsing_result.extend(class_field_name_tys);
}
} else {
// what should we do with `class A: a = 3`?
// do nothing, continue the for loop to iterate class ast
continue;
}
}
// now it should be confirmed that every
// methods and fields of the class can be correctly typed, put the results
// into the actual class def method and fields field
let (class_def, _) = &def_ast_list[class_ind];
let mut class_def = class_def.write();
if let TopLevelDef::Class { fields, methods, .. } = class_def.deref_mut() {
for (ref n, ref t) in class_fields_parsing_result {
fields.push((n.clone(), *t));
}
for (n, t, id) in &class_methods_parsing_result {
methods.push((n.clone(), *t, *id));
}
} else {
unreachable!()
}
// change the signature field of the class methods
for (_, ty, id) in &class_methods_parsing_result {
let (method_def, _) = &def_ast_list[id.0];
let mut method_def = method_def.write();
if let TopLevelDef::Function { signature, .. } = method_def.deref_mut() {
*signature = *ty;
}
}
}
Ok(())
}
fn analyze_top_level_function(&mut self) -> Result<(), String> {
unimplemented!()
}
fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> {
unimplemented!()
}
fn check_ty_analyzed(ty: Type, unifier: &mut Unifier, to_be_analyzed: &[DefinitionId]) -> bool {
let type_enum = unifier.get_ty(ty);
match type_enum.as_ref() {
TypeEnum::TObj { obj_id, .. } => !to_be_analyzed.contains(obj_id),
TypeEnum::TVirtual { ty } => {
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(*ty).as_ref() {
!to_be_analyzed.contains(obj_id)
} else {
unreachable!()
}
}
TypeEnum::TVar { .. } => true,
_ => unreachable!(),
}
}
}

View File

@ -0,0 +1,216 @@
use super::type_inferencer::Inferencer;
use super::typedef::Type;
use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind};
use std::iter::once;
impl<'a> Inferencer<'a> {
fn check_pattern(
&mut self,
pattern: &Expr<Option<Type>>,
defined_identifiers: &mut Vec<String>,
) -> Result<(), String> {
match &pattern.node {
ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) {
defined_identifiers.push(id.clone());
}
Ok(())
}
ExprKind::Tuple { elts, .. } => {
for elt in elts.iter() {
self.check_pattern(elt, defined_identifiers)?;
}
Ok(())
}
_ => self.check_expr(pattern, defined_identifiers),
}
}
fn check_expr(
&mut self,
expr: &Expr<Option<Type>>,
defined_identifiers: &[String],
) -> Result<(), String> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom {
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
return Err(format!(
"expected concrete type at {} but got {}",
expr.location,
self.unifier.get_ty(*ty).get_type_name()
));
}
}
match &expr.node {
ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) {
return Err(format!(
"unknown identifier {} (use before def?) at {}",
id, expr.location
));
}
}
ExprKind::List { elts, .. }
| ExprKind::Tuple { elts, .. }
| ExprKind::BoolOp { values: elts, .. } => {
for elt in elts.iter() {
self.check_expr(elt, defined_identifiers)?;
}
}
ExprKind::Attribute { value, .. } => {
self.check_expr(value, defined_identifiers)?;
}
ExprKind::BinOp { left, right, .. } => {
self.check_expr(left, defined_identifiers)?;
self.check_expr(right, defined_identifiers)?;
}
ExprKind::UnaryOp { operand, .. } => {
self.check_expr(operand, defined_identifiers)?;
}
ExprKind::Compare { left, comparators, .. } => {
for elt in once(left.as_ref()).chain(comparators.iter()) {
self.check_expr(elt, defined_identifiers)?;
}
}
ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value, defined_identifiers)?;
self.check_expr(slice, defined_identifiers)?;
}
ExprKind::IfExp { test, body, orelse } => {
self.check_expr(test, defined_identifiers)?;
self.check_expr(body, defined_identifiers)?;
self.check_expr(orelse, defined_identifiers)?;
}
ExprKind::Slice { lower, upper, step } => {
for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
self.check_expr(elt, defined_identifiers)?;
}
}
ExprKind::Lambda { args, body } => {
let mut defined_identifiers = defined_identifiers.to_vec();
for arg in args.args.iter() {
if !defined_identifiers.contains(&arg.node.arg) {
defined_identifiers.push(arg.node.arg.clone());
}
}
self.check_expr(body, &defined_identifiers)?;
}
ExprKind::ListComp { elt, generators, .. } => {
// in our type inference stage, we already make sure that there is only 1 generator
let ast::Comprehension { target, iter, ifs, .. } = &generators[0];
self.check_expr(iter, defined_identifiers)?;
let mut defined_identifiers = defined_identifiers.to_vec();
self.check_pattern(target, &mut defined_identifiers)?;
for term in once(elt.as_ref()).chain(ifs.iter()) {
self.check_expr(term, &defined_identifiers)?;
}
}
ExprKind::Call { func, args, keywords } => {
for expr in once(func.as_ref())
.chain(args.iter())
.chain(keywords.iter().map(|v| v.node.value.as_ref()))
{
self.check_expr(expr, defined_identifiers)?;
}
}
ExprKind::Constant { .. } => {}
_ => {
println!("{:?}", expr.node);
unimplemented!()
}
}
Ok(())
}
// check statements for proper identifier def-use and return on all paths
fn check_stmt(
&mut self,
stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut Vec<String>,
) -> Result<bool, String> {
match &stmt.node {
StmtKind::For { target, iter, body, orelse, .. } => {
self.check_expr(iter, defined_identifiers)?;
for stmt in orelse.iter() {
self.check_stmt(stmt, defined_identifiers)?;
}
let mut defined_identifiers = defined_identifiers.clone();
self.check_pattern(target, &mut defined_identifiers)?;
for stmt in body.iter() {
self.check_stmt(stmt, &mut defined_identifiers)?;
}
Ok(false)
}
StmtKind::If { test, body, orelse } => {
self.check_expr(test, defined_identifiers)?;
let mut body_identifiers = defined_identifiers.clone();
let mut orelse_identifiers = defined_identifiers.clone();
let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in body_identifiers.iter() {
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
defined_identifiers.push(ident.clone())
}
}
Ok(body_returned && orelse_returned)
}
StmtKind::While { test, body, orelse } => {
self.check_expr(test, defined_identifiers)?;
let mut defined_identifiers = defined_identifiers.clone();
self.check_block(body, &mut defined_identifiers)?;
self.check_block(orelse, &mut defined_identifiers)?;
Ok(false)
}
StmtKind::Expr { value } => {
self.check_expr(value, defined_identifiers)?;
Ok(false)
}
StmtKind::Assign { targets, value, .. } => {
self.check_expr(value, defined_identifiers)?;
for target in targets {
self.check_pattern(target, defined_identifiers)?;
}
Ok(false)
}
StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value {
self.check_expr(value, defined_identifiers)?;
self.check_pattern(target, defined_identifiers)?;
}
Ok(false)
}
StmtKind::Return { value } => {
if let Some(value) = value {
self.check_expr(value, defined_identifiers)?;
}
Ok(true)
}
StmtKind::Raise { exc, .. } => {
if let Some(value) = exc {
self.check_expr(value, defined_identifiers)?;
}
Ok(true)
}
// break, raise, etc.
_ => Ok(false),
}
}
pub fn check_block(
&mut self,
block: &[Stmt<Option<Type>>],
defined_identifiers: &mut Vec<String>,
) -> Result<bool, String> {
let mut ret = false;
for stmt in block {
if ret {
return Err(format!("dead code at {:?}", stmt.location));
}
if self.check_stmt(stmt, defined_identifiers)? {
ret = true;
}
}
Ok(ret)
}
}

View File

@ -0,0 +1,322 @@
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
};
use rustpython_parser::ast;
use rustpython_parser::ast::{Cmpop, Operator, Unaryop};
use std::borrow::Borrow;
use std::collections::HashMap;
pub fn binop_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__add__",
Operator::Sub => "__sub__",
Operator::Div => "__truediv__",
Operator::Mod => "__mod__",
Operator::Mult => "__mul__",
Operator::Pow => "__pow__",
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
}
}
pub fn binop_assign_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__iadd__",
Operator::Sub => "__isub__",
Operator::Div => "__itruediv__",
Operator::Mod => "__imod__",
Operator::Mult => "__imul__",
Operator::Pow => "__ipow__",
Operator::BitOr => "__ior__",
Operator::BitXor => "__ixor__",
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
}
pub fn unaryop_name(op: &Unaryop) -> &'static str {
match op {
Unaryop::UAdd => "__pos__",
Unaryop::USub => "__neg__",
Unaryop::Not => "__not__",
Unaryop::Invert => "__inv__",
}
}
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
match op {
Cmpop::Lt => Some("__lt__"),
Cmpop::LtE => Some("__le__"),
Cmpop::Gt => Some("__gt__"),
Cmpop::GtE => Some("__ge__"),
Cmpop::Eq => Some("__eq__"),
Cmpop::NotEq => Some("__ne__"),
_ => None,
}
}
pub fn impl_binop(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
ops: &[ast::Operator],
) {
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty);
(ty, Some(var_id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<HashMap<_, _>>()
} else {
HashMap::new()
};
for op in ops {
fields.borrow_mut().insert(binop_name(op).into(), {
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
}
.into(),
))
});
fields.borrow_mut().insert(binop_assign_name(op).into(), {
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: store.none,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
}
.into(),
))
});
}
} else {
unreachable!("")
}
}
pub fn impl_unaryop(
unifier: &mut Unifier,
_store: &PrimitiveStore,
ty: Type,
ret_ty: Type,
ops: &[ast::Unaryop],
) {
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
for op in ops {
fields.borrow_mut().insert(
unaryop_name(op).into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(),
)),
);
}
} else {
unreachable!()
}
}
pub fn impl_cmpop(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: Type,
ops: &[ast::Cmpop],
) {
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
for op in ops {
fields.borrow_mut().insert(
comparison_name(op).unwrap().into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: store.bool,
vars: HashMap::new(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
}
.into(),
)),
);
}
} else {
unreachable!()
}
}
/// Add, Sub, Mult, Pow
pub fn impl_basic_arithmetic(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
) {
impl_binop(
unifier,
store,
ty,
other_ty,
ret_ty,
&[ast::Operator::Add, ast::Operator::Sub, ast::Operator::Mult],
)
}
pub fn impl_pow(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Pow])
}
/// BitOr, BitXor, BitAnd
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(
unifier,
store,
ty,
&[ty],
ty,
&[ast::Operator::BitAnd, ast::Operator::BitOr, ast::Operator::BitXor],
)
}
/// LShift, RShift
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift])
}
/// Div
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
impl_binop(unifier, store, ty, other_ty, store.float, &[ast::Operator::Div])
}
/// FloorDiv
pub fn impl_floordiv(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::FloorDiv])
}
/// Mod
pub fn impl_mod(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Mod])
}
/// UAdd, USub
pub fn impl_sign(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::UAdd, ast::Unaryop::USub])
}
/// Invert
pub fn impl_invert(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::Invert])
}
/// Not
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, store, ty, store.bool, &[ast::Unaryop::Not])
}
/// Lt, LtE, Gt, GtE
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
impl_cmpop(
unifier,
store,
ty,
other_ty,
&[ast::Cmpop::Lt, ast::Cmpop::Gt, ast::Cmpop::LtE, ast::Cmpop::GtE],
)
}
/// Eq, NotEq
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_cmpop(unifier, store, ty, ty, &[ast::Cmpop::Eq, ast::Cmpop::NotEq])
}
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
let PrimitiveStore { int32: int32_t, int64: int64_t, float: float_t, bool: bool_t, .. } =
*store;
/* int32 ======== */
impl_basic_arithmetic(unifier, store, int32_t, &[int32_t], int32_t);
impl_pow(unifier, store, int32_t, &[int32_t], int32_t);
impl_bitwise_arithmetic(unifier, store, int32_t);
impl_bitwise_shift(unifier, store, int32_t);
impl_div(unifier, store, int32_t, &[int32_t]);
impl_floordiv(unifier, store, int32_t, &[int32_t], int32_t);
impl_mod(unifier, store, int32_t, &[int32_t], int32_t);
impl_sign(unifier, store, int32_t);
impl_invert(unifier, store, int32_t);
impl_not(unifier, store, int32_t);
impl_comparison(unifier, store, int32_t, int32_t);
impl_eq(unifier, store, int32_t);
/* int64 ======== */
impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t);
impl_pow(unifier, store, int64_t, &[int64_t], int64_t);
impl_bitwise_arithmetic(unifier, store, int64_t);
impl_bitwise_shift(unifier, store, int64_t);
impl_div(unifier, store, int64_t, &[int64_t]);
impl_floordiv(unifier, store, int64_t, &[int64_t], int64_t);
impl_mod(unifier, store, int64_t, &[int64_t], int64_t);
impl_sign(unifier, store, int64_t);
impl_invert(unifier, store, int64_t);
impl_not(unifier, store, int64_t);
impl_comparison(unifier, store, int64_t, int64_t);
impl_eq(unifier, store, int64_t);
/* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
impl_div(unifier, store, float_t, &[float_t]);
impl_floordiv(unifier, store, float_t, &[float_t], float_t);
impl_mod(unifier, store, float_t, &[float_t], float_t);
impl_sign(unifier, store, float_t);
impl_not(unifier, store, float_t);
impl_comparison(unifier, store, float_t, float_t);
impl_eq(unifier, store, float_t);
/* bool ======== */
impl_not(unifier, store, bool_t);
impl_eq(unifier, store, bool_t);
}

View File

@ -0,0 +1,5 @@
mod function_check;
pub mod magic_methods;
pub mod type_inferencer;
pub mod typedef;
mod unification_table;

View File

@ -0,0 +1,582 @@
use std::collections::HashMap;
use std::convert::{From, TryInto};
use std::iter::once;
use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
use super::{magic_methods::*, typedef::CallId};
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
use itertools::izip;
use rustpython_parser::ast::{
self,
fold::{self, Fold},
Arguments, Comprehension, ExprKind, Located, Location,
};
#[cfg(test)]
mod test;
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
pub struct CodeLocation {
row: usize,
col: usize,
}
impl From<Location> for CodeLocation {
fn from(loc: Location) -> CodeLocation {
CodeLocation { row: loc.row(), col: loc.column() }
}
}
#[derive(Clone, Copy)]
pub struct PrimitiveStore {
pub int32: Type,
pub int64: Type,
pub float: Type,
pub bool: Type,
pub none: Type,
}
pub struct FunctionData {
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub return_type: Option<Type>,
pub bound_variables: Vec<Type>,
}
pub struct Inferencer<'a> {
pub top_level: &'a TopLevelContext,
pub function_data: &'a mut FunctionData,
pub unifier: &'a mut Unifier,
pub primitives: &'a PrimitiveStore,
pub virtual_checks: &'a mut Vec<(Type, Type)>,
pub variable_mapping: HashMap<String, Type>,
pub calls: &'a mut HashMap<CodeLocation, CallId>,
}
struct NaiveFolder();
impl fold::Fold<()> for NaiveFolder {
type TargetU = Option<Type>;
type Error = String;
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> {
Ok(None)
}
}
impl<'a> fold::Fold<()> for Inferencer<'a> {
type TargetU = Option<Type>;
type Error = String;
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> {
Ok(None)
}
fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
let stmt = match node.node {
// we don't want fold over type annotation
ast::StmtKind::AnnAssign { target, annotation, value, simple } => {
let target = Box::new(self.fold_expr(*target)?);
let value = if let Some(v) = value {
let ty = Box::new(self.fold_expr(*v)?);
self.unifier.unify(target.custom.unwrap(), ty.custom.unwrap())?;
Some(ty)
} else {
None
};
let annotation_type = self.function_data.resolver.parse_type_annotation(
self.top_level,
self.unifier,
&self.primitives,
annotation.as_ref(),
)?;
self.unifier.unify(annotation_type, target.custom.unwrap())?;
let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?);
Located {
location: node.location,
custom: None,
node: ast::StmtKind::AnnAssign { target, annotation, value, simple },
}
}
_ => fold::fold_stmt(self, node)?,
};
match &stmt.node {
ast::StmtKind::For { target, iter, .. } => {
let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
self.unifier.unify(list, iter.custom.unwrap())?;
}
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
self.unifier.unify(test.custom.unwrap(), self.primitives.bool)?;
}
ast::StmtKind::Assign { targets, value, .. } => {
for target in targets.iter() {
self.unifier.unify(target.custom.unwrap(), value.custom.unwrap())?;
}
}
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
ast::StmtKind::Break | ast::StmtKind::Continue => {}
ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
(Some(v), Some(v1)) => {
self.unifier.unify(v.custom.unwrap(), v1)?;
}
(Some(_), None) => {
return Err("Unexpected return value".to_string());
}
(None, Some(_)) => {
return Err("Expected return value".to_string());
}
(None, None) => {}
},
_ => return Err("Unsupported statement type".to_string()),
};
Ok(stmt)
}
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
let expr = match node.node {
ast::ExprKind::Call { func, args, keywords } => {
return self.fold_call(node.location, *func, args, keywords);
}
ast::ExprKind::Lambda { args, body } => {
return self.fold_lambda(node.location, *args, *body);
}
ast::ExprKind::ListComp { elt, generators } => {
return self.fold_listcomp(node.location, *elt, generators);
}
_ => fold::fold_expr(self, node)?,
};
let custom = match &expr.node {
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?),
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
Some(self.infer_attribute(value, attr)?)
}
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
ast::ExprKind::BinOp { left, op, right } => Some(self.infer_bin_ops(left, op, right)?),
ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?),
ast::ExprKind::Compare { left, ops, comparators } => {
Some(self.infer_compare(left, ops, comparators)?)
}
ast::ExprKind::Subscript { value, slice, .. } => {
Some(self.infer_subscript(value.as_ref(), slice.as_ref())?)
}
ast::ExprKind::IfExp { test, body, orelse } => {
Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?)
}
ast::ExprKind::ListComp { .. }
| ast::ExprKind::Lambda { .. }
| ast::ExprKind::Call { .. } => expr.custom, // already computed
ast::ExprKind::Slice { .. } => None, // we don't need it for slice
_ => return Err("not supported yet".into()),
};
Ok(ast::Expr { custom, location: expr.location, node: expr.node })
}
}
type InferenceResult = Result<Type, String>;
impl<'a> Inferencer<'a> {
/// Constrain a <: b
/// Currently implemented as unification
fn constrain(&mut self, a: Type, b: Type) -> Result<(), String> {
self.unifier.unify(a, b)
}
fn build_method_call(
&mut self,
location: Location,
method: String,
obj: Type,
params: Vec<Type>,
ret: Type,
) -> InferenceResult {
let call = self.unifier.add_call(Call {
posargs: params,
kwargs: HashMap::new(),
ret,
fun: RefCell::new(None),
});
self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
let fields = once((method, call)).collect();
let record = self.unifier.add_record(fields);
self.constrain(obj, record)?;
Ok(ret)
}
fn fold_lambda(
&mut self,
location: Location,
args: Arguments,
body: ast::Expr<()>,
) -> Result<ast::Expr<Option<Type>>, String> {
if !args.posonlyargs.is_empty()
|| args.vararg.is_some()
|| !args.kwonlyargs.is_empty()
|| args.kwarg.is_some()
|| !args.defaults.is_empty()
{
// actually I'm not sure whether programs violating this is a valid python program.
return Err(
"We only support positional or keyword arguments without defaults for lambdas."
.to_string(),
);
}
let fn_args: Vec<_> = args
.args
.iter()
.map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0))
.collect();
let mut variable_mapping = self.variable_mapping.clone();
variable_mapping.extend(fn_args.iter().cloned());
let ret = self.unifier.get_fresh_var().0;
let mut new_context = Inferencer {
function_data: self.function_data,
unifier: self.unifier,
primitives: self.primitives,
virtual_checks: self.virtual_checks,
calls: self.calls,
top_level: self.top_level,
variable_mapping,
};
let fun = FunSignature {
args: fn_args
.iter()
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None })
.collect(),
ret,
vars: Default::default(),
};
let body = new_context.fold_expr(body)?;
new_context.unifier.unify(fun.ret, body.custom.unwrap())?;
let mut args = new_context.fold_arguments(args)?;
for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) {
assert_eq!(&arg.node.arg, name);
arg.custom = Some(*ty);
}
Ok(Located {
location,
node: ExprKind::Lambda { args: args.into(), body: body.into() },
custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun.into()))),
})
}
fn fold_listcomp(
&mut self,
location: Location,
elt: ast::Expr<()>,
mut generators: Vec<Comprehension>,
) -> Result<ast::Expr<Option<Type>>, String> {
if generators.len() != 1 {
return Err(
"Only 1 generator statement for list comprehension is supported.".to_string()
);
}
let variable_mapping = self.variable_mapping.clone();
let mut new_context = Inferencer {
function_data: self.function_data,
unifier: self.unifier,
virtual_checks: self.virtual_checks,
top_level: self.top_level,
variable_mapping,
primitives: self.primitives,
calls: self.calls,
};
let elt = new_context.fold_expr(elt)?;
let generator = generators.pop().unwrap();
if generator.is_async {
return Err("Async iterator not supported.".to_string());
}
let target = new_context.fold_expr(*generator.target)?;
let iter = new_context.fold_expr(*generator.iter)?;
let ifs: Vec<_> = generator
.ifs
.into_iter()
.map(|v| new_context.fold_expr(v))
.collect::<Result<_, _>>()?;
// iter should be a list of targets...
// actually it should be an iterator of targets, but we don't have iter type for now
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
new_context.unifier.unify(iter.custom.unwrap(), list)?;
// if conditions should be bool
for v in ifs.iter() {
new_context.unifier.unify(v.custom.unwrap(), new_context.primitives.bool)?;
}
Ok(Located {
location,
custom: Some(new_context.unifier.add_ty(TypeEnum::TList { ty: elt.custom.unwrap() })),
node: ExprKind::ListComp {
elt: Box::new(elt),
generators: vec![ast::Comprehension {
target: Box::new(target),
iter: Box::new(iter),
ifs,
is_async: false,
}],
},
})
}
fn fold_call(
&mut self,
location: Location,
func: ast::Expr<()>,
mut args: Vec<ast::Expr<()>>,
keywords: Vec<Located<ast::KeywordData>>,
) -> Result<ast::Expr<Option<Type>>, String> {
let func =
if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } =
func
{
// handle special functions that cannot be typed in the usual way...
if id == "virtual" {
if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
return Err(
"`virtual` can only accept 1/2 positional arguments.".to_string()
);
}
let arg0 = self.fold_expr(args.remove(0))?;
let ty = if let Some(arg) = args.pop() {
self.function_data.resolver.parse_type_annotation(
self.top_level,
self.unifier,
self.primitives,
&arg,
)?
} else {
self.unifier.get_fresh_var().0
};
self.virtual_checks.push((arg0.custom.unwrap(), ty));
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
return Ok(Located {
location,
custom,
node: ExprKind::Call {
func: Box::new(Located {
custom: None,
location: func.location,
node: ExprKind::Name { id, ctx },
}),
args: vec![arg0],
keywords: vec![],
},
});
}
// int64 is special because its argument can be a constant larger than int32
if id == "int64" && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node
{
let int64: Result<i64, _> = val.try_into();
let custom;
if int64.is_ok() {
custom = Some(self.primitives.int64);
} else {
return Err("Integer out of bound".into());
}
return Ok(Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(val.clone()),
kind: kind.clone(),
},
});
}
}
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
} else {
func
};
let func = Box::new(self.fold_expr(func)?);
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
let keywords = keywords
.into_iter()
.map(|v| fold::fold_keyword(self, v))
.collect::<Result<Vec<_>, _>>()?;
let ret = self.unifier.get_fresh_var().0;
let call = self.unifier.add_call(Call {
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
kwargs: keywords
.iter()
.map(|v| (v.node.arg.as_ref().unwrap().clone(), v.custom.unwrap()))
.collect(),
fun: RefCell::new(None),
ret,
});
self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
self.unifier.unify(func.custom.unwrap(), call)?;
Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } })
}
fn infer_identifier(&mut self, id: &str) -> InferenceResult {
if let Some(ty) = self.variable_mapping.get(id) {
Ok(*ty)
} else {
Ok(self
.function_data
.resolver
.get_symbol_type(self.unifier, self.primitives, id)
.unwrap_or_else(|| {
let ty = self.unifier.get_fresh_var().0;
self.variable_mapping.insert(id.to_string(), ty);
ty
}))
}
}
fn infer_constant(&mut self, constant: &ast::Constant) -> InferenceResult {
match constant {
ast::Constant::Bool(_) => Ok(self.primitives.bool),
ast::Constant::Int(val) => {
let int32: Result<i32, _> = val.try_into();
// int64 would be handled separately in functions
if int32.is_ok() {
Ok(self.primitives.int32)
} else {
Err("Integer out of bound".into())
}
}
ast::Constant::Float(_) => Ok(self.primitives.float),
ast::Constant::Tuple(vals) => {
let ty: Result<Vec<_>, _> = vals.iter().map(|x| self.infer_constant(x)).collect();
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? }))
}
_ => Err("not supported".into()),
}
}
fn infer_list(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let (ty, _) = self.unifier.get_fresh_var();
for t in elts.iter() {
self.unifier.unify(ty, t.custom.unwrap())?;
}
Ok(self.unifier.add_ty(TypeEnum::TList { ty }))
}
fn infer_tuple(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let ty = elts.iter().map(|x| x.custom.unwrap()).collect();
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
}
fn infer_attribute(&mut self, value: &ast::Expr<Option<Type>>, attr: &str) -> InferenceResult {
let (attr_ty, _) = self.unifier.get_fresh_var();
let fields = once((attr.to_string(), attr_ty)).collect();
let record = self.unifier.add_record(fields);
self.constrain(value.custom.unwrap(), record)?;
Ok(attr_ty)
}
fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let b = self.primitives.bool;
for v in values {
self.constrain(v.custom.unwrap(), b)?;
}
Ok(b)
}
fn infer_bin_ops(
&mut self,
left: &ast::Expr<Option<Type>>,
op: &ast::Operator,
right: &ast::Expr<Option<Type>>,
) -> InferenceResult {
let method = binop_name(op);
let ret = self.unifier.get_fresh_var().0;
self.build_method_call(
left.location,
method.to_string(),
left.custom.unwrap(),
vec![right.custom.unwrap()],
ret,
)
}
fn infer_unary_ops(
&mut self,
op: &ast::Unaryop,
operand: &ast::Expr<Option<Type>>,
) -> InferenceResult {
let method = unaryop_name(op);
let ret = self.unifier.get_fresh_var().0;
self.build_method_call(
operand.location,
method.to_string(),
operand.custom.unwrap(),
vec![],
ret,
)
}
fn infer_compare(
&mut self,
left: &ast::Expr<Option<Type>>,
ops: &[ast::Cmpop],
comparators: &[ast::Expr<Option<Type>>],
) -> InferenceResult {
let boolean = self.primitives.bool;
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method =
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string();
self.build_method_call(
a.location,
method,
a.custom.unwrap(),
vec![b.custom.unwrap()],
boolean,
)?;
}
Ok(boolean)
}
fn infer_subscript(
&mut self,
value: &ast::Expr<Option<Type>>,
slice: &ast::Expr<Option<Type>>,
) -> InferenceResult {
let ty = self.unifier.get_fresh_var().0;
match &slice.node {
ast::ExprKind::Slice { lower, upper, step } => {
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
self.constrain(v.custom.unwrap(), self.primitives.int32)?;
}
let list = self.unifier.add_ty(TypeEnum::TList { ty });
self.constrain(value.custom.unwrap(), list)?;
Ok(list)
}
ast::ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
// the index is a constant, so value can be a sequence.
let ind: i32 = val.try_into().map_err(|_| "Index must be int32".to_string())?;
let map = once((ind, ty)).collect();
let seq = self.unifier.add_sequence(map);
self.constrain(value.custom.unwrap(), seq)?;
Ok(ty)
}
_ => {
// the index is not a constant, so value can only be a list
self.constrain(slice.custom.unwrap(), self.primitives.int32)?;
let list = self.unifier.add_ty(TypeEnum::TList { ty });
self.constrain(value.custom.unwrap(), list)?;
Ok(ty)
}
}
}
fn infer_if_expr(
&mut self,
test: &ast::Expr<Option<Type>>,
body: &ast::Expr<Option<Type>>,
orelse: &ast::Expr<Option<Type>>,
) -> InferenceResult {
self.constrain(test.custom.unwrap(), self.primitives.bool)?;
let ty = self.unifier.get_fresh_var().0;
self.constrain(body.custom.unwrap(), ty)?;
self.constrain(orelse.custom.unwrap(), ty)?;
Ok(ty)
}
}

View File

@ -0,0 +1,546 @@
use super::super::typedef::*;
use super::*;
use crate::symbol_resolver::*;
use crate::top_level::DefinitionId;
use crate::{location::Location, top_level::TopLevelDef};
use indoc::indoc;
use itertools::zip;
use parking_lot::RwLock;
use rustpython_parser::parser::parse_program;
use test_case::test_case;
struct Resolver {
id_to_type: HashMap<String, Type>,
id_to_def: HashMap<String, DefinitionId>,
class_names: HashMap<String, Type>,
}
impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
self.id_to_type.get(str).cloned()
}
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
unimplemented!()
}
fn get_symbol_location(&self, _: &str) -> Option<Location> {
unimplemented!()
}
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
self.id_to_def.get(id).cloned()
}
}
struct TestEnvironment {
pub unifier: Unifier,
pub function_data: FunctionData,
pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext,
}
impl TestEnvironment {
pub fn basic_test_env() -> TestEnvironment {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let primitives = PrimitiveStore { int32, int64, float, bool, none };
set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name = [
(0, "int32".to_string()),
(1, "int64".to_string()),
(2, "float".to_string()),
(3, "bool".to_string()),
(4, "none".to_string()),
]
.iter()
.cloned()
.collect();
let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
top_level: TopLevelContext {
definitions: Default::default(),
unifiers: Default::default(),
},
unifier,
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: None,
},
primitives,
id_to_name,
identifier_mapping,
virtual_checks: Vec::new(),
calls: HashMap::new(),
}
}
fn new() -> TestEnvironment {
let mut unifier = Unifier::new();
let mut identifier_mapping = HashMap::new();
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
identifier_mapping.insert("None".into(), none);
for i in 0..5 {
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(i),
type_vars: Default::default(),
fields: Default::default(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
})
.into(),
);
}
let primitives = PrimitiveStore { int32, int64, float, bool, none };
let (v0, id) = unifier.get_fresh_var();
let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5),
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(5),
type_vars: vec![v0],
fields: [("a".into(), v0)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
})
.into(),
);
identifier_mapping.insert(
"Foo".into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
args: vec![],
ret: foo_ty,
vars: [(id, v0)].iter().cloned().collect(),
}
.into(),
)),
);
let fun = unifier.add_ty(TypeEnum::TFunc(
FunSignature { args: vec![], ret: int32, vars: Default::default() }.into(),
));
let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6),
fields: [("a".into(), int32), ("b".into(), fun)]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
.into(),
params: Default::default(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(6),
type_vars: Default::default(),
fields: [("a".into(), int32), ("b".into(), fun)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
})
.into(),
);
identifier_mapping.insert(
"Bar".into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature { args: vec![], ret: bar, vars: Default::default() }.into(),
)),
);
let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7),
fields: [("a".into(), bool), ("b".into(), fun)]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
.into(),
params: Default::default(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(7),
type_vars: Default::default(),
fields: [("a".into(), bool), ("b".into(), fun)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
})
.into(),
);
identifier_mapping.insert(
"Bar2".into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature { args: vec![], ret: bar2, vars: Default::default() }.into(),
)),
);
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
let id_to_name = [
(0, "int32".to_string()),
(1, "int64".to_string()),
(2, "float".to_string()),
(3, "bool".to_string()),
(4, "none".to_string()),
(5, "Foo".to_string()),
(6, "Bar".to_string()),
(7, "Bar2".to_string()),
]
.iter()
.cloned()
.collect();
let top_level = TopLevelContext {
definitions: Arc::new(RwLock::new(top_level_defs)),
unifiers: Default::default(),
};
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: [
("Foo".into(), DefinitionId(5)),
("Bar".into(), DefinitionId(6)),
("Bar2".into(), DefinitionId(7)),
]
.iter()
.cloned()
.collect(),
class_names,
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
unifier,
top_level,
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: None,
},
primitives,
id_to_name,
identifier_mapping,
virtual_checks: Vec::new(),
calls: HashMap::new(),
}
}
fn get_inferencer(&mut self) -> Inferencer {
Inferencer {
top_level: &self.top_level,
function_data: &mut self.function_data,
unifier: &mut self.unifier,
variable_mapping: Default::default(),
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
}
}
}
#[test_case(indoc! {"
a = 1234
b = int64(2147483648)
c = 1.234
d = True
"},
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(),
&[]
; "primitives test")]
#[test_case(indoc! {"
a = lambda x, y: x
b = lambda x: a(x, x)
c = 1.234
d = b(c)
"},
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
&[]
; "lambda test")]
#[test_case(indoc! {"
a = lambda x: x
b = lambda x: x
foo1 = Foo()
foo2 = Foo()
c = a(foo1.a)
d = b(foo2.a)
a(True)
b(123)
"},
[("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"),
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(),
&[]
; "obj test")]
#[test_case(indoc! {"
f = lambda x: True
a = [1, 2, 3]
b = [f(x) for x in a if f(x)]
"},
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(),
&[]
; "listcomp test")]
#[test_case(indoc! {"
a = virtual(Bar(), Bar)
b = a.b()
a = virtual(Bar2())
"},
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual test")]
#[test_case(indoc! {"
a = [virtual(Bar(), Bar), virtual(Bar2())]
b = [x.b() for x in a]
"},
[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual list test")]
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
println!("source:\n{}", source);
let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer();
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.stringify(
*v,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
println!("{}: {}", k, name);
}
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(*k).unwrap();
let name = inferencer.unifier.stringify(
*ty,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
let a = inferencer.unifier.stringify(
*a,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
let b = inferencer.unifier.stringify(
*b,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(&a, x);
assert_eq!(&b, y);
}
}
#[test_case(indoc! {"
a = 2
b = 2
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
"},
[("a", "int32"),
("b", "int32"),
("c", "int32"),
("d", "int32"),
("e", "int32"),
("f", "float"),
("g", "int32"),
("h", "int32")].iter().cloned().collect()
; "int32")]
#[test_case(
indoc! {"
a = 2.4
b = 3.6
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
i = a ** b
ii = 3
j = a ** b
"},
[("a", "float"),
("b", "float"),
("c", "float"),
("d", "float"),
("e", "float"),
("f", "float"),
("g", "float"),
("h", "float"),
("i", "float"),
("ii", "int32"),
("j", "float")].iter().cloned().collect()
; "float"
)]
#[test_case(
indoc! {"
a = int64(12312312312)
b = int64(24242424424)
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
i = a == b
j = a > b
k = a < b
l = a != b
"},
[("a", "int64"),
("b", "int64"),
("c", "int64"),
("d", "int64"),
("e", "int64"),
("f", "float"),
("g", "int64"),
("h", "int64"),
("i", "bool"),
("j", "bool"),
("k", "bool"),
("l", "bool")].iter().cloned().collect()
; "int64"
)]
#[test_case(
indoc! {"
a = True
b = False
c = a == b
d = not a
e = a != b
"},
[("a", "bool"),
("b", "bool"),
("c", "bool"),
("d", "bool"),
("e", "bool")].iter().cloned().collect()
; "boolean"
)]
fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
println!("source:\n{}", source);
let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer();
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.stringify(
*v,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
println!("{}: {}", k, name);
}
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(*k).unwrap();
let name = inferencer.unifier.stringify(
*ty,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
}

View File

@ -0,0 +1,947 @@
use itertools::{chain, zip, Itertools};
use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::HashMap;
use std::iter::once;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue;
use crate::top_level::DefinitionId;
#[cfg(test)]
mod test;
/// Handle for a type, implementated as a key in the unification table.
pub type Type = UnificationKey;
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct CallId(usize);
pub type Mapping<K, V = Type> = HashMap<K, V>;
type VarMap = Mapping<u32>;
#[derive(Clone)]
pub struct Call {
pub posargs: Vec<Type>,
pub kwargs: HashMap<String, Type>,
pub ret: Type,
pub fun: RefCell<Option<Type>>,
}
#[derive(Clone)]
pub struct FuncArg {
pub name: String,
pub ty: Type,
pub default_value: Option<SymbolValue>,
}
#[derive(Clone)]
pub struct FunSignature {
pub args: Vec<FuncArg>,
pub ret: Type,
pub vars: VarMap,
}
#[derive(Clone)]
pub enum TypeVarMeta {
Generic,
Sequence(RefCell<Mapping<i32>>),
Record(RefCell<Mapping<String>>),
}
#[derive(Clone)]
pub enum TypeEnum {
TRigidVar {
id: u32,
},
TVar {
id: u32,
meta: TypeVarMeta,
// empty indicates no restriction
range: RefCell<Vec<Type>>,
},
TTuple {
ty: Vec<Type>,
},
TList {
ty: Type,
},
TObj {
obj_id: DefinitionId,
fields: RefCell<Mapping<String>>,
params: RefCell<VarMap>,
},
TVirtual {
ty: Type,
},
TCall(RefCell<Vec<CallId>>),
TFunc(RefCell<FunSignature>),
}
impl TypeEnum {
pub fn get_type_name(&self) -> &'static str {
match self {
TypeEnum::TRigidVar { .. } => "TRigidVar",
TypeEnum::TVar { .. } => "TVar",
TypeEnum::TTuple { .. } => "TTuple",
TypeEnum::TList { .. } => "TList",
TypeEnum::TObj { .. } => "TObj",
TypeEnum::TVirtual { .. } => "TVirtual",
TypeEnum::TCall { .. } => "TCall",
TypeEnum::TFunc { .. } => "TFunc",
}
}
}
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
pub struct Unifier {
unification_table: UnificationTable<Rc<TypeEnum>>,
calls: Vec<Rc<Call>>,
var_id: u32,
}
impl Unifier {
/// Get an empty unifier
pub fn new() -> Unifier {
Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() }
}
/// Determine if the two types are the same
pub fn unioned(&mut self, a: Type, b: Type) -> bool {
self.unification_table.unioned(a, b)
}
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
let lock = unifier.lock().unwrap();
Unifier {
unification_table: UnificationTable::from_send(&lock.0),
var_id: lock.1,
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
}
}
pub fn get_shared_unifier(&self) -> SharedUnifier {
Arc::new(Mutex::new((
self.unification_table.get_send(),
self.var_id,
self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(),
)))
}
/// Register a type to the unifier.
/// Returns a key in the unification_table.
pub fn add_ty(&mut self, a: TypeEnum) -> Type {
self.unification_table.new_key(Rc::new(a))
}
pub fn add_record(&mut self, fields: Mapping<String>) -> Type {
let id = self.var_id + 1;
self.var_id += 1;
self.add_ty(TypeEnum::TVar {
id,
range: vec![].into(),
meta: TypeVarMeta::Record(fields.into()),
})
}
pub fn add_call(&mut self, call: Call) -> CallId {
let id = CallId(self.calls.len());
self.calls.push(Rc::new(call));
id
}
pub fn get_representative(&mut self, ty: Type) -> Type {
self.unification_table.get_representative(ty)
}
pub fn add_sequence(&mut self, sequence: Mapping<i32>) -> Type {
let id = self.var_id + 1;
self.var_id += 1;
self.add_ty(TypeEnum::TVar {
id,
range: vec![].into(),
meta: TypeVarMeta::Sequence(sequence.into()),
})
}
/// Get the TypeEnum of a type.
pub fn get_ty(&mut self, a: Type) -> Rc<TypeEnum> {
self.unification_table.probe_value(a).clone()
}
pub fn get_fresh_rigid_var(&mut self) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
(self.add_ty(TypeEnum::TRigidVar { id }), id)
}
pub fn get_fresh_var(&mut self) -> (Type, u32) {
self.get_fresh_var_with_range(&[])
}
/// Get a fresh type variable.
pub fn get_fresh_var_with_range(&mut self, range: &[Type]) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
let range = range.to_vec().into();
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
}
/// Unification would not unify rigid variables with other types, but we want to do this for
/// function instantiations, so we make it explicit.
pub fn replace_rigid_var(&mut self, rigid: Type, b: Type) {
assert!(matches!(&*self.get_ty(rigid), TypeEnum::TRigidVar { .. }));
self.set_a_to_b(rigid, b);
}
pub fn get_instantiations(&mut self, ty: Type) -> Option<Vec<Type>> {
match &*self.get_ty(ty) {
TypeEnum::TVar { range, .. } => {
let range = range.borrow();
if range.is_empty() {
None
} else {
Some(
range
.iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.flatten()
.collect_vec(),
)
}
}
TypeEnum::TList { ty } => self
.get_instantiations(*ty)
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
}),
TypeEnum::TTuple { ty } => {
let tuples = ty
.iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.multi_cartesian_product()
.collect_vec();
if tuples.len() == 1 {
None
} else {
Some(
tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
)
}
}
TypeEnum::TObj { params, .. } => {
let params = params.borrow();
let (keys, params): (Vec<&u32>, Vec<&Type>) = params.iter().unzip();
let params = params
.into_iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.multi_cartesian_product()
.collect_vec();
if params.len() <= 1 {
None
} else {
Some(
params
.into_iter()
.map(|params| {
self.subst(
ty,
&zip(keys.iter().cloned().cloned(), params.iter().cloned())
.collect(),
)
.unwrap_or(ty)
})
.collect(),
)
}
}
_ => None,
}
}
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*;
match &*self.get_ty(a) {
TRigidVar { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false,
TList { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => {
vars.borrow().values().all(|ty| self.is_concrete(*ty, allowed_typevars))
}
// functions are instantiated for each call sites, so the function type can contain
// type variables.
TFunc { .. } => true,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
}
}
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
if self.unification_table.unioned(a, b) {
Ok(())
} else {
self.unify_impl(a, b, false)
}
}
fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> {
use TypeEnum::*;
use TypeVarMeta::*;
let (ty_a, ty_b) = {
(
self.unification_table.probe_value(a).clone(),
self.unification_table.probe_value(b).clone(),
)
};
match (&*ty_a, &*ty_b) {
(TVar { meta: meta1, range: range1, .. }, TVar { meta: meta2, range: range2, .. }) => {
self.occur_check(a, b)?;
self.occur_check(b, a)?;
match (meta1, meta2) {
(Generic, _) => {}
(_, Generic) => {
return self.unify_impl(b, a, true);
}
(Record(fields1), Record(fields2)) => {
let mut fields2 = fields2.borrow_mut();
for (key, value) in fields1.borrow().iter() {
if let Some(ty) = fields2.get(key) {
self.unify(*ty, *value)?;
} else {
fields2.insert(key.clone(), *value);
}
}
}
(Sequence(map1), Sequence(map2)) => {
let mut map2 = map2.borrow_mut();
for (key, value) in map1.borrow().iter() {
if let Some(ty) = map2.get(key) {
self.unify(*ty, *value)?;
} else {
map2.insert(*key, *value);
}
}
}
_ => {
return Err("Incompatible".to_string());
}
}
let range1 = range1.borrow();
// new range is the intersection of them
// empty range indicates no constraint
if !range1.is_empty() {
let old_range2 = range2.take();
let mut range2 = range2.borrow_mut();
if old_range2.is_empty() {
range2.extend_from_slice(&range1);
}
for v1 in old_range2.iter() {
for v2 in range1.iter() {
if let Ok(result) = self.get_intersection(*v1, *v2) {
range2.push(result.unwrap_or(*v2));
}
}
}
if range2.is_empty() {
return Err(
"cannot unify type variables with incompatible value range".to_string()
);
}
}
self.set_a_to_b(a, b);
}
(TVar { meta: Generic, id, range, .. }, _) => {
self.occur_check(a, b)?;
// We check for the range of the type variable to see if unification is allowed.
// Note that although b may be compatible with a, we may have to constrain type
// variables in b to make sure that instantiations of b would always be compatible
// with a.
// The return value x of check_var_compatibility would be a new type that is
// guaranteed to be compatible with a under all possible instantiations. So we
// unify x with b to recursively apply the constrains, and then set a to x.
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => {
self.occur_check(a, b)?;
let len = ty.len() as i32;
for (k, v) in map.borrow().iter() {
// handle negative index
let ind = if *k < 0 { len + *k } else { *k };
if ind >= len || ind < 0 {
return Err(format!(
"Tuple index out of range. (Length: {}, Index: {})",
len, k
));
}
self.unify(*v, ty[ind as usize])?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => {
self.occur_check(a, b)?;
for v in map.borrow().values() {
self.unify(*v, *ty)?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() {
return Err(format!(
"Cannot unify tuples with length {} and {}",
ty1.len(),
ty2.len()
));
}
for (x, y) in ty1.iter().zip(ty2.iter()) {
self.unify(*x, *y)?;
}
self.set_a_to_b(a, b);
}
(TList { ty: ty1 }, TList { ty: ty2 }) => {
self.unify(*ty1, *ty2)?;
self.set_a_to_b(a, b);
}
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
self.occur_check(a, b)?;
for (k, v) in map.borrow().iter() {
let ty = fields
.borrow()
.get(k)
.copied()
.ok_or_else(|| format!("No such attribute {}", k))?;
self.unify(ty, *v)?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
self.occur_check(a, b)?;
let ty = self.get_ty(*ty);
if let TObj { fields, .. } = ty.as_ref() {
for (k, v) in map.borrow().iter() {
let ty = fields
.borrow()
.get(k)
.copied()
.ok_or_else(|| format!("No such attribute {}", k))?;
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
return Err(format!("Cannot access field {} for virtual type", k));
}
self.unify(*v, ty)?;
}
} else {
// require annotation...
return Err("Requires type annotation for virtual".to_string());
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(
TObj { obj_id: id1, params: params1, .. },
TObj { obj_id: id2, params: params2, .. },
) => {
if id1 != id2 {
return Err(format!("Cannot unify objects with ID {} and {}", id1.0, id2.0));
}
for (x, y) in zip(params1.borrow().values(), params2.borrow().values()) {
self.unify(*x, *y)?;
}
self.set_a_to_b(a, b);
}
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
self.unify(*ty1, *ty2)?;
self.set_a_to_b(a, b);
}
(TCall(calls1), TCall(calls2)) => {
// we do not unify individual calls, instead we defer until the unification wtih a
// function definition.
calls2.borrow_mut().extend_from_slice(&calls1.borrow());
}
(TCall(calls), TFunc(signature)) => {
self.occur_check(a, b)?;
let required: Vec<String> = signature
.borrow()
.args
.iter()
.filter(|v| v.default_value.is_none())
.map(|v| v.name.clone())
.rev()
.collect();
// we unify every calls to the function signature.
for c in calls.borrow().iter() {
let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone();
let instantiated = self.instantiate_fun(b, &*signature.borrow());
let r = self.get_ty(instantiated);
let r = r.as_ref();
let signature;
if let TypeEnum::TFunc(s) = &*r {
signature = s;
} else {
unreachable!();
}
// we check to make sure that all required arguments (those without default
// arguments) are provided, and do not provide the same argument twice.
let mut required = required.clone();
let mut all_names: Vec<_> = signature
.borrow()
.args
.iter()
.map(|v| (v.name.clone(), v.ty))
.rev()
.collect();
for (i, t) in posargs.iter().enumerate() {
if signature.borrow().args.len() <= i {
return Err("Too many arguments.".to_string());
}
if !required.is_empty() {
required.pop();
}
self.unify(all_names.pop().unwrap().1, *t)?;
}
for (k, t) in kwargs.iter() {
if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i);
}
let i = all_names
.iter()
.position(|v| &v.0 == k)
.ok_or_else(|| format!("Unknown keyword argument {}", k))?;
self.unify(all_names.remove(i).1, *t)?;
}
if !required.is_empty() {
return Err("Expected more arguments".to_string());
}
self.unify(*ret, signature.borrow().ret)?;
*fun.borrow_mut() = Some(instantiated);
}
self.set_a_to_b(a, b);
}
(TFunc(sign1), TFunc(sign2)) => {
let (sign1, sign2) = (&*sign1.borrow(), &*sign2.borrow());
if !sign1.vars.is_empty() || !sign2.vars.is_empty() {
return Err("Polymorphic function pointer is prohibited.".to_string());
}
if sign1.args.len() != sign2.args.len() {
return Err("Functions differ in number of parameters.".to_string());
}
for (x, y) in sign1.args.iter().zip(sign2.args.iter()) {
if x.name != y.name {
return Err("Functions differ in parameter names.".to_string());
}
if x.default_value != y.default_value {
return Err("Functions differ in optional parameters.".to_string());
}
self.unify(x.ty, y.ty)?;
}
self.unify(sign1.ret, sign2.ret)?;
self.set_a_to_b(a, b);
}
_ => {
if swapped {
return self.incompatible_types(&*ty_a, &*ty_b);
} else {
self.unify_impl(b, a, true)?;
}
}
}
Ok(())
}
/// Get string representation of the type
pub fn stringify<F, G>(&mut self, ty: Type, obj_to_name: &mut F, var_to_name: &mut G) -> String
where
F: FnMut(usize) -> String,
G: FnMut(u32) -> String,
{
use TypeVarMeta::*;
let ty = self.unification_table.probe_value(ty).clone();
match ty.as_ref() {
TypeEnum::TRigidVar { id } => var_to_name(*id),
TypeEnum::TVar { id, meta: Generic, .. } => var_to_name(*id),
TypeEnum::TVar { meta: Sequence(map), .. } => {
let fields = map
.borrow()
.iter()
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
.join(", ");
format!("seq[{}]", fields)
}
TypeEnum::TVar { meta: Record(fields), .. } => {
let fields = fields
.borrow()
.iter()
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
.join(", ");
format!("record[{}]", fields)
}
TypeEnum::TTuple { ty } => {
let mut fields = ty.iter().map(|v| self.stringify(*v, obj_to_name, var_to_name));
format!("tuple[{}]", fields.join(", "))
}
TypeEnum::TList { ty } => {
format!("list[{}]", self.stringify(*ty, obj_to_name, var_to_name))
}
TypeEnum::TVirtual { ty } => {
format!("virtual[{}]", self.stringify(*ty, obj_to_name, var_to_name))
}
TypeEnum::TObj { obj_id, params, .. } => {
let name = obj_to_name(obj_id.0);
let params = params.borrow();
if !params.is_empty() {
let mut params =
params.values().map(|v| self.stringify(*v, obj_to_name, var_to_name));
format!("{}[{}]", name, params.join(", "))
} else {
name
}
}
TypeEnum::TCall { .. } => "call".to_owned(),
TypeEnum::TFunc(signature) => {
let params = signature
.borrow()
.args
.iter()
.map(|arg| {
format!("{}={}", arg.name, self.stringify(arg.ty, obj_to_name, var_to_name))
})
.join(", ");
let ret = self.stringify(signature.borrow().ret, obj_to_name, var_to_name);
format!("fn[[{}], {}]", params, ret)
}
}
}
fn set_a_to_b(&mut self, a: Type, b: Type) {
// unify a and b together, and set the value to b's value.
let table = &mut self.unification_table;
let ty_b = table.probe_value(b).clone();
table.unify(a, b);
table.set_value(a, ty_b)
}
fn incompatible_types(&self, a: &TypeEnum, b: &TypeEnum) -> Result<(), String> {
Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name()))
}
/// Instantiate a function if it hasn't been instantiated.
/// Returns Some(T) where T is the instantiated type.
/// Returns None if the function is already instantiated.
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
let mut instantiated = false;
let mut vars = Vec::new();
for (k, v) in fun.vars.iter() {
if let TypeEnum::TVar { id, range, .. } =
self.unification_table.probe_value(*v).as_ref()
{
if k != id {
instantiated = true;
break;
}
// actually, if the first check succeeded, the function should be uninstatiated.
// The cloned values must be used and would not be wasted.
vars.push((*k, range.clone()));
} else {
instantiated = true;
break;
}
}
if instantiated {
ty
} else {
let mapping = vars
.into_iter()
.map(|(k, range)| (k, self.get_fresh_var_with_range(range.borrow().as_ref()).0))
.collect();
self.subst(ty, &mapping).unwrap_or(ty)
}
}
/// Substitute type variables within a type into other types.
/// If this returns Some(T), T would be the substituted type.
/// If this returns None, the result type would be the original type
/// (no substitution has to be done).
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
use TypeVarMeta::*;
let ty = self.unification_table.probe_value(a).clone();
// this function would only be called when we instantiate functions.
// function type signature should ONLY contain concrete types and type
// variables, i.e. things like TRecord, TCall should not occur, and we
// should be safe to not implement the substitution for those variants.
match &*ty {
TypeEnum::TRigidVar { .. } => None,
TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(),
TypeEnum::TTuple { ty } => {
let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() {
if let Some(t1) = self.subst(*t, mapping) {
new_ty.to_mut()[i] = t1;
}
}
if matches!(new_ty, Cow::Owned(_)) {
Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() }))
} else {
None
}
}
TypeEnum::TList { ty } => {
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
}
TypeEnum::TVirtual { ty } => {
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
}
TypeEnum::TObj { obj_id, fields, params } => {
// Type variables in field types must be present in the type parameter.
// If the mapping does not contain any type variables in the
// parameter list, we don't need to substitute the fields.
// This is also used to prevent infinite substitution...
let params = params.borrow();
let need_subst = params.values().any(|v| {
let ty = self.unification_table.probe_value(*v);
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
mapping.contains_key(&id)
} else {
false
}
});
if need_subst {
let obj_id = *obj_id;
let params = self.subst_map(&params, mapping).unwrap_or_else(|| params.clone());
let fields = self
.subst_map(&fields.borrow(), mapping)
.unwrap_or_else(|| fields.borrow().clone());
Some(self.add_ty(TypeEnum::TObj {
obj_id,
params: params.into(),
fields: fields.into(),
}))
} else {
None
}
}
TypeEnum::TFunc(sig) => {
let FunSignature { args, ret, vars: params } = &*sig.borrow();
let new_params = self.subst_map(params, mapping);
let new_ret = self.subst(*ret, mapping);
let mut new_args = Cow::from(args);
for (i, t) in args.iter().enumerate() {
if let Some(t1) = self.subst(t.ty, mapping) {
let mut t = t.clone();
t.ty = t1;
new_args.to_mut()[i] = t;
}
}
if new_params.is_some() || new_ret.is_some() || matches!(new_args, Cow::Owned(..)) {
let params = new_params.unwrap_or_else(|| params.clone());
let ret = new_ret.unwrap_or_else(|| *ret);
let args = new_args.into_owned();
Some(
self.add_ty(TypeEnum::TFunc(
FunSignature { args, ret, vars: params }.into(),
)),
)
} else {
None
}
}
_ => unimplemented!(),
}
}
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{
let mut map2 = None;
for (k, v) in map.iter() {
if let Some(v1) = self.subst(*v, mapping) {
if map2.is_none() {
map2 = Some(map.clone());
}
*map2.as_mut().unwrap().get_mut(k).unwrap() = v1;
}
}
map2
}
fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> {
use TypeVarMeta::*;
if self.unification_table.unioned(a, b) {
return Err("Recursive type is prohibited.".to_owned());
}
let ty = self.unification_table.probe_value(b).clone();
match ty.as_ref() {
TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {}
TypeEnum::TVar { meta: Sequence(map), .. } => {
for t in map.borrow().values() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TVar { meta: Record(map), .. } => {
for t in map.borrow().values() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TCall(calls) => {
let call_store = self.calls.clone();
for t in calls
.borrow()
.iter()
.map(|call| {
let call = call_store[call.0].as_ref();
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
})
.flatten()
{
self.occur_check(a, *t)?;
}
}
TypeEnum::TTuple { ty } => {
for t in ty.iter() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => {
self.occur_check(a, *ty)?;
}
TypeEnum::TObj { params: map, .. } => {
for t in map.borrow().values() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TFunc(sig) => {
let FunSignature { args, ret, vars: params } = &*sig.borrow();
for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) {
self.occur_check(a, *t)?;
}
}
}
Ok(())
}
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
use TypeEnum::*;
let x = self.get_ty(a);
let y = self.get_ty(b);
match (x.as_ref(), y.as_ref()) {
(TVar { range: range1, .. }, TVar { meta, range: range2, .. }) => {
// we should restrict range2
let range1 = range1.borrow();
// new range is the intersection of them
// empty range indicates no constraint
if !range1.is_empty() {
let range2 = range2.borrow();
let mut range = Vec::new();
if range2.is_empty() {
range.extend_from_slice(&range1);
}
for v1 in range2.iter() {
for v2 in range1.iter() {
let result = self.get_intersection(*v1, *v2);
if let Ok(result) = result {
range.push(result.unwrap_or(*v2));
}
}
}
if range.is_empty() {
Err(())
} else {
let id = self.var_id + 1;
self.var_id += 1;
let ty = TVar { id, meta: meta.clone(), range: range.into() };
Ok(Some(self.unification_table.new_key(ty.into())))
}
} else {
Ok(Some(b))
}
}
(_, TVar { range, .. }) => {
// range should be restricted to the left hand side
let range = range.borrow();
if range.is_empty() {
Ok(Some(a))
} else {
for v in range.iter() {
let result = self.get_intersection(a, *v);
if let Ok(result) = result {
return Ok(result.or(Some(a)));
}
}
Err(())
}
}
(TVar { id, range, .. }, _) => {
self.check_var_compatibility(*id, b, &range.borrow()).or(Err(()))
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() {
return Err(());
}
let mut need_new = false;
let mut ty = ty1.clone();
for (a, b) in zip(ty1.iter(), ty2.iter()) {
let result = self.get_intersection(*a, *b)?;
ty.push(result.unwrap_or(*a));
if result.is_some() {
need_new = true;
}
}
if need_new {
Ok(Some(self.add_ty(TTuple { ty })))
} else {
Ok(None)
}
}
(TList { ty: ty1 }, TList { ty: ty2 }) => {
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
}
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
}
(TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => {
if id1 == id2 {
Ok(None)
} else {
Err(())
}
}
// don't deal with function shape for now
_ => Err(()),
}
}
fn check_var_compatibility(
&mut self,
id: u32,
b: Type,
range: &[Type],
) -> Result<Option<Type>, String> {
if range.is_empty() {
return Ok(None);
}
for t in range.iter() {
let result = self.get_intersection(*t, b);
if let Ok(result) = result {
return Ok(result);
}
}
return Err(format!(
"Cannot unify type variable {} with {} due to incompatible value range",
id,
self.get_ty(b).get_type_name()
));
}
}

View File

@ -0,0 +1,534 @@
use super::*;
use indoc::indoc;
use itertools::Itertools;
use std::collections::HashMap;
use test_case::test_case;
impl Unifier {
/// Check whether two types are equal.
fn eq(&mut self, a: Type, b: Type) -> bool {
use TypeVarMeta::*;
if a == b {
return true;
}
let (ty_a, ty_b) = {
let table = &mut self.unification_table;
if table.unioned(a, b) {
return true;
}
(table.probe_value(a).clone(), table.probe_value(b).clone())
};
match (&*ty_a, &*ty_b) {
(
TypeEnum::TVar { meta: Generic, id: id1, .. },
TypeEnum::TVar { meta: Generic, id: id2, .. },
) => id1 == id2,
(
TypeEnum::TVar { meta: Sequence(map1), .. },
TypeEnum::TVar { meta: Sequence(map2), .. },
) => self.map_eq(&map1.borrow(), &map2.borrow()),
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
}
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
self.eq(*ty1, *ty2)
}
(
TypeEnum::TVar { meta: Record(fields1), .. },
TypeEnum::TVar { meta: Record(fields2), .. },
) => self.map_eq(&fields1.borrow(), &fields2.borrow()),
(
TypeEnum::TObj { obj_id: id1, params: params1, .. },
TypeEnum::TObj { obj_id: id2, params: params2, .. },
) => id1 == id2 && self.map_eq(&params1.borrow(), &params2.borrow()),
// TCall and TFunc are not yet implemented
_ => false,
}
}
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{
if map1.len() != map2.len() {
return false;
}
for (k, v) in map1.iter() {
if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) {
return false;
}
}
true
}
}
struct TestEnvironment {
pub unifier: Unifier,
type_mapping: HashMap<String, Type>,
}
impl TestEnvironment {
fn new() -> TestEnvironment {
let mut unifier = Unifier::new();
let mut type_mapping = HashMap::new();
type_mapping.insert(
"int".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
}),
);
type_mapping.insert(
"float".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
}),
);
type_mapping.insert(
"bool".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
}),
);
let (v0, id) = unifier.get_fresh_var();
type_mapping.insert(
"Foo".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
}),
);
TestEnvironment { unifier, type_mapping }
}
fn parse(&mut self, typ: &str, mapping: &Mapping<String>) -> Type {
let result = self.internal_parse(typ, mapping);
assert!(result.1.is_empty());
result.0
}
fn internal_parse<'a, 'b>(
&'a mut self,
typ: &'b str,
mapping: &Mapping<String>,
) -> (Type, &'b str) {
// for testing only, so we can just panic when the input is malformed
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or_else(|| typ.len());
match &typ[..end] {
"Tuple" => {
let mut s = &typ[end..];
assert!(&s[0..1] == "[");
let mut ty = Vec::new();
while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
ty.push(result.0);
s = result.1;
}
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
}
"List" => {
assert!(&typ[end..end + 1] == "[");
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
assert!(&s[0..1] == "]");
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
}
"Record" => {
let mut s = &typ[end..];
assert!(&s[0..1] == "[");
let mut fields = HashMap::new();
while &s[0..1] != "]" {
let eq = s.find('=').unwrap();
let key = s[1..eq].to_string();
let result = self.internal_parse(&s[eq + 1..], mapping);
fields.insert(key, result.0);
s = result.1;
}
(self.unifier.add_record(fields), &s[1..])
}
x => {
let mut s = &typ[end..];
let ty = mapping.get(x).cloned().unwrap_or_else(|| {
// mapping should be type variables, type_mapping should be concrete types
// we should not resolve the type of type variables.
let mut ty = *self.type_mapping.get(x).unwrap();
let te = self.unifier.get_ty(ty);
if let TypeEnum::TObj { params, .. } = &*te.as_ref() {
let params = params.borrow();
if !params.is_empty() {
assert!(&s[0..1] == "[");
let mut p = Vec::new();
while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
p.push(result.0);
s = result.1;
}
s = &s[1..];
ty = self
.unifier
.subst(ty, &params.keys().cloned().zip(p.into_iter()).collect())
.unwrap_or(ty);
}
}
ty
});
(ty, s)
}
}
}
}
#[test_case(2,
&[("v1", "v2"), ("v2", "float")],
&[("v1", "float"), ("v2", "float")]
; "simple variable"
)]
#[test_case(2,
&[("v1", "List[v2]"), ("v1", "List[float]")],
&[("v1", "List[float]"), ("v2", "float")]
; "list element"
)]
#[test_case(3,
&[
("v1", "Record[a=v3,b=v3]"),
("v2", "Record[b=float,c=v3]"),
("v1", "v2")
],
&[
("v1", "Record[a=float,b=float,c=float]"),
("v2", "Record[a=float,b=float,c=float]"),
("v3", "float")
]
; "record merge"
)]
#[test_case(3,
&[
("v1", "Record[a=float]"),
("v2", "Foo[v3]"),
("v1", "v2")
],
&[
("v1", "Foo[float]"),
("v3", "float")
]
; "record obj merge"
)]
/// Test cases for valid unifications.
fn test_unify(
variable_count: u32,
unify_pairs: &[(&'static str, &'static str)],
verify_pairs: &[(&'static str, &'static str)],
) {
let unify_count = unify_pairs.len();
// test all permutations...
for perm in unify_pairs.iter().permutations(unify_count) {
let mut env = TestEnvironment::new();
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.unifier.get_fresh_var();
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.
let mut pairs = Vec::new();
for (a, b) in perm.iter() {
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
pairs.push((t1, t2));
}
for (t1, t2) in pairs {
env.unifier.unify(t1, t2).unwrap();
}
for (a, b) in verify_pairs.iter() {
println!("{} = {}", a, b);
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
assert!(env.unifier.eq(t1, t2));
}
}
}
#[test_case(2,
&[
("v1", "Tuple[int]"),
("v2", "List[int]"),
],
(("v1", "v2"), "Cannot unify TList with TTuple")
; "type mismatch"
)]
#[test_case(2,
&[
("v1", "Tuple[int]"),
("v2", "Tuple[float]"),
],
(("v1", "v2"), "Cannot unify objects with ID 0 and 1")
; "tuple parameter mismatch"
)]
#[test_case(2,
&[
("v1", "Tuple[int,int]"),
("v2", "Tuple[int]"),
],
(("v1", "v2"), "Cannot unify tuples with length 2 and 1")
; "tuple length mismatch"
)]
#[test_case(3,
&[
("v1", "Record[a=float,b=int]"),
("v2", "Foo[v3]"),
],
(("v1", "v2"), "No such attribute b")
; "record obj merge"
)]
#[test_case(2,
&[
("v1", "List[v2]"),
],
(("v1", "v2"), "Recursive type is prohibited.")
; "recursive type for lists"
)]
/// Test cases for invalid unifications.
fn test_invalid_unification(
variable_count: u32,
unify_pairs: &[(&'static str, &'static str)],
errornous_pair: ((&'static str, &'static str), &'static str),
) {
let mut env = TestEnvironment::new();
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.unifier.get_fresh_var();
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.
let mut pairs = Vec::new();
for (a, b) in unify_pairs.iter() {
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
pairs.push((t1, t2));
}
let (t1, t2) =
(env.parse(errornous_pair.0 .0, &mapping), env.parse(errornous_pair.0 .1, &mapping));
for (a, b) in pairs {
env.unifier.unify(a, b).unwrap();
}
assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string()));
}
#[test]
fn test_virtual() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let fun = env.unifier.add_ty(TypeEnum::TFunc(
FunSignature { args: vec![], ret: int, vars: HashMap::new() }.into(),
));
let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5),
fields: [("f".to_string(), fun), ("a".to_string(), int)]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
.into(),
params: HashMap::new().into(),
});
let v0 = env.unifier.get_fresh_var().0;
let v1 = env.unifier.get_fresh_var().0;
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
let c = env.unifier.add_record([("f".to_string(), v1)].iter().cloned().collect());
env.unifier.unify(a, b).unwrap();
env.unifier.unify(b, c).unwrap();
assert!(env.unifier.eq(v1, fun));
let d = env.unifier.add_record([("a".to_string(), v1)].iter().cloned().collect());
assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string()));
let d = env.unifier.add_record([("b".to_string(), v1)].iter().cloned().collect());
assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string()));
}
#[test]
fn test_typevar_range() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let boolean = env.parse("bool", &HashMap::new());
let float = env.parse("float", &HashMap::new());
let int_list = env.parse("List[int]", &HashMap::new());
let float_list = env.parse("List[float]", &HashMap::new());
// unification between v and int
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
env.unifier.unify(int, v).unwrap();
// unification between v and List[int]
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
assert_eq!(
env.unifier.unify(int_list, v),
Err("Cannot unify type variable 3 with TList due to incompatible value range".to_string())
);
// unification between v and float
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
assert_eq!(
env.unifier.unify(float, v),
Err("Cannot unify type variable 4 with TObj due to incompatible value range".to_string())
);
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 });
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
// unification between v and int
// where v in (int, List[v1]), v1 in (int, bool)
env.unifier.unify(int, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
// unification between v and List[int]
// where v in (int, List[v1]), v1 in (int, bool)
env.unifier.unify(int_list, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
// unification between v and List[float]
// where v in (int, List[v1]), v1 in (int, bool)
assert_eq!(
env.unifier.unify(float_list, v),
Err("Cannot unify type variable 8 with TList due to incompatible value range".to_string())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
env.unifier.unify(a, b).unwrap();
env.unifier.unify(a, float).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
env.unifier.unify(a, b).unwrap();
assert_eq!(
env.unifier.unify(a, int),
Err("Cannot unify type variable 12 with TObj due to incompatible value range".into())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0;
env.unifier.unify(a_list, b_list).unwrap();
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
env.unifier.unify(a_list, float_list).unwrap();
// previous unifications should not affect a and b
env.unifier.unify(a, int).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int });
assert_eq!(
env.unifier.unify(a_list, int_list),
Err("Cannot unify type variable 19 with TObj due to incompatible value range".into())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var().0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
assert_eq!(
env.unifier.unify(b, boolean),
Err("Cannot unify type variable 21 with TObj due to incompatible value range".into())
);
}
#[test]
fn test_rigid_var() {
let mut env = TestEnvironment::new();
let a = env.unifier.get_fresh_rigid_var().0;
let b = env.unifier.get_fresh_rigid_var().0;
let x = env.unifier.get_fresh_var().0;
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
let int = env.parse("int", &HashMap::new());
let list_int = env.parse("List[int]", &HashMap::new());
assert_eq!(env.unifier.unify(a, b), Err("Cannot unify TRigidVar with TRigidVar".to_string()));
env.unifier.unify(list_a, list_x).unwrap();
assert_eq!(
env.unifier.unify(list_x, list_int),
Err("Cannot unify TObj with TRigidVar".to_string())
);
env.unifier.replace_rigid_var(a, int);
env.unifier.unify(list_x, list_int).unwrap();
}
#[test]
fn test_instantiation() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let boolean = env.parse("bool", &HashMap::new());
let float = env.parse("float", &HashMap::new());
let list_int = env.parse("List[int]", &HashMap::new());
let obj_map: HashMap<_, _> =
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int]).0;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float]).0;
let t = env.unifier.get_fresh_rigid_var().0;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t]).0;
// t = TypeVar('t')
// v = TypeVar('v', int, bool)
// v1 = TypeVar('v1', 'list[v]', int)
// v2 = TypeVar('v2', 'list[int]', float)
// v3 = TypeVar('v3', tuple[v, v1, v2], t)
// what values can v3 take?
let types = env.unifier.get_instantiations(v3).unwrap();
let expected_types = indoc! {"
tuple[bool, int, float]
tuple[bool, int, list[int]]
tuple[bool, list[bool], float]
tuple[bool, list[bool], list[int]]
tuple[bool, list[int], float]
tuple[bool, list[int], list[int]]
tuple[int, int, float]
tuple[int, int, list[int]]
tuple[int, list[bool], float]
tuple[int, list[bool], list[int]]
tuple[int, list[int], float]
tuple[int, list[int], list[int]]
v5"
}
.split('\n')
.collect_vec();
let types = types
.iter()
.map(|ty| {
env.unifier.stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| {
format!("v{}", i)
})
})
.sorted()
.collect_vec();
assert_eq!(expected_types, types);
}

View File

@ -0,0 +1,87 @@
use std::rc::Rc;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct UnificationKey(usize);
pub struct UnificationTable<V> {
parents: Vec<usize>,
ranks: Vec<u32>,
values: Vec<V>,
}
impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> {
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new() }
}
pub fn new_key(&mut self, v: V) -> UnificationKey {
let index = self.parents.len();
self.parents.push(index);
self.ranks.push(0);
self.values.push(v);
UnificationKey(index)
}
pub fn unify(&mut self, a: UnificationKey, b: UnificationKey) {
let mut a = self.find(a);
let mut b = self.find(b);
if a == b {
return;
}
if self.ranks[a] < self.ranks[b] {
std::mem::swap(&mut a, &mut b);
}
self.parents[b] = a;
if self.ranks[a] == self.ranks[b] {
self.ranks[a] += 1;
}
}
pub fn probe_value(&mut self, a: UnificationKey) -> &V {
let index = self.find(a);
&self.values[index]
}
pub fn set_value(&mut self, a: UnificationKey, v: V) {
let index = self.find(a);
self.values[index] = v;
}
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
self.find(a) == self.find(b)
}
pub fn get_representative(&mut self, key: UnificationKey) -> UnificationKey {
UnificationKey(self.find(key))
}
fn find(&mut self, key: UnificationKey) -> usize {
let mut root = key.0;
let mut parent = self.parents[root];
while root != parent {
// a = parent.parent
let a = self.parents[parent];
// root.parent = parent.parent
self.parents[root] = a;
root = parent;
// parent = root.parent
parent = a;
}
parent
}
}
impl<V> UnificationTable<Rc<V>>
where
V: Clone,
{
pub fn get_send(&self) -> UnificationTable<V> {
let values = self.values.iter().map(|v| v.as_ref().clone()).collect();
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values }
}
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.iter().cloned().map(Rc::new).collect();
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values }
}
}

View File

@ -1,60 +0,0 @@
use std::collections::HashMap;
use std::rc::Rc;
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct PrimitiveId(pub(crate) usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct ClassId(pub(crate) usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct ParamId(pub(crate) usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct VariableId(pub(crate) usize);
#[derive(PartialEq, Eq, Clone, Hash, Debug)]
pub enum TypeEnum {
BotType,
SelfType,
PrimitiveType(PrimitiveId),
ClassType(ClassId),
VirtualClassType(ClassId),
ParametricType(ParamId, Vec<Rc<TypeEnum>>),
TypeVariable(VariableId),
}
pub type Type = Rc<TypeEnum>;
#[derive(Clone)]
pub struct FnDef {
// we assume methods first argument to be SelfType,
// so the first argument is not contained here
pub args: Vec<Type>,
pub result: Option<Type>,
}
#[derive(Clone)]
pub struct TypeDef<'a> {
pub name: &'a str,
pub fields: HashMap<&'a str, Type>,
pub methods: HashMap<&'a str, FnDef>,
}
#[derive(Clone)]
pub struct ClassDef<'a> {
pub base: TypeDef<'a>,
pub parents: Vec<ClassId>,
}
#[derive(Clone)]
pub struct ParametricDef<'a> {
pub base: TypeDef<'a>,
pub params: Vec<VariableId>,
}
#[derive(Clone)]
pub struct VarDef<'a> {
pub name: &'a str,
pub bound: Vec<Type>,
}

View File

@ -4,6 +4,6 @@ in
pkgs.stdenv.mkDerivation {
name = "nac3-env";
buildInputs = with pkgs; [
llvm_10 clang_10 cargo rustc libffi libxml2 clippy
llvm_11 clang_11 cargo rustc libffi libxml2 clippy
];
}