hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
29 changed files with 5928 additions and 2948 deletions

668
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -15,20 +15,20 @@ caller to specify which methods should be compiled). After type checking, the
compiler would analyse the set of functions/classes that are used and perform compiler would analyse the set of functions/classes that are used and perform
code generation. code generation.
Symbol resolver:
- Str -> Nac3Type
- Str -> Value
value could be integer values, boolean values, bytes (for memcpy), function ID value could be integer values, boolean values, bytes (for memcpy), function ID
(full name + concrete type) (full name + concrete type)
## Current Plan ## Current Plan
1. Write out the syntax-directed type checking/inferencing rules. Fix the rule Type checking:
for type variable instantiation.
2. Update the library dependencies and rewrite some of the type checking code. - [x] Basic interface for symbol resolver.
3. Design the symbol resolver API. - [x] Track location information in context object (for diagnostics).
4. Move tests from code to external files to cleanup the code. - [ ] Refactor old expression and statement type inference code. (anto)
- [ ] Error diagnostics utilities. (pca)
- [ ] Move tests to external files, write scripts for testing. (pca)
- [ ] Implement function type checking (instantiate bounded type parameters),
loop unrolling, type inference for lists with virtual objects. (pca)

View File

@ -7,5 +7,13 @@ edition = "2018"
[dependencies] [dependencies]
num-bigint = "0.3" num-bigint = "0.3"
num-traits = "0.2" num-traits = "0.2"
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] } inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" } rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" }
itertools = "0.10.1"
crossbeam = "0.8.1"
parking_lot = "0.11.1"
rayon = "1.5.1"
[dev-dependencies]
test-case = "1.2.0"
indoc = "1.0"

1
nac3core/rustfmt.toml Normal file
View File

@ -0,0 +1 @@
use_small_heuristics = "Max"

View File

@ -0,0 +1,527 @@
use std::{collections::HashMap, convert::TryInto, iter::once};
use super::{get_llvm_type, CodeGenContext};
use crate::{
symbol_resolver::SymbolValue,
top_level::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
use inkwell::{
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
};
use itertools::{chain, izip, zip, Itertools};
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
let mut vars = obj
.map(|ty| {
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
params.borrow().clone()
} else {
unreachable!()
}
})
.unwrap_or_default();
vars.extend(fun.vars.iter());
let sorted = vars.keys().sorted();
sorted
.map(|id| {
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
})
.join(", ")
}
pub fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize {
let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id,
// we cannot have other types, virtual type should be handled by function calls
_ => unreachable!(),
};
let def = &self.top_level.definitions.read()[obj_id.0];
let index = if let TopLevelDef::Class { fields, .. } = &*def.read() {
fields.iter().find_position(|x| x.0 == attr).unwrap().0
} else {
unreachable!()
};
index
}
fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> {
match val {
SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(),
SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(),
SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(),
SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(),
SymbolValue::Tuple(ls) => {
let vals = ls.iter().map(|v| self.gen_symbol_val(v)).collect_vec();
let fields = vals.iter().map(|v| v.get_type()).collect_vec();
let ty = self.ctx.struct_type(&fields, false);
let ptr = self.builder.build_alloca(ty, "tuple");
let zero = self.ctx.i32_type().const_zero();
unsafe {
for (i, val) in vals.into_iter().enumerate() {
let p = ptr.const_in_bounds_gep(&[
zero,
self.ctx.i32_type().const_int(i as u64, false),
]);
self.builder.build_store(p, val);
}
}
ptr.into()
}
}
}
pub fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> {
get_llvm_type(self.ctx, &mut self.unifier, self.top_level, &mut self.type_cache, ty)
}
fn gen_call(
&mut self,
obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
ret: Type,
) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
let defs = self.top_level.definitions.read();
let definition = defs.get(fun.1 .0).unwrap();
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
// TODO: codegen for function that are not yet generated
unimplemented!()
});
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) {
self.ctx.void_type().fn_type(&params, false)
} else {
self.get_llvm_type(ret).fn_type(&params, false)
};
self.module.add_function(symbol, fun_ty, None)
});
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in params.into_iter() {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
for k in keys.into_iter() {
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
}
// reorder the parameters
let params =
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left()
} else {
unreachable!()
};
val
}
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
match value {
Constant::Bool(v) => {
assert!(self.unifier.unioned(ty, self.primitives.bool));
let ty = self.ctx.bool_type();
ty.const_int(if *v { 1 } else { 0 }, false).into()
}
Constant::Int(v) => {
let ty = if self.unifier.unioned(ty, self.primitives.int32) {
self.ctx.i32_type()
} else if self.unifier.unioned(ty, self.primitives.int64) {
self.ctx.i64_type()
} else {
unreachable!();
};
ty.const_int(v.try_into().unwrap(), false).into()
}
Constant::Float(v) => {
assert!(self.unifier.unioned(ty, self.primitives.float));
let ty = self.ctx.f64_type();
ty.const_float(*v).into()
}
Constant::Tuple(v) => {
let ty = self.unifier.get_ty(ty);
let types =
if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() };
let values = zip(types.into_iter(), v.iter())
.map(|(ty, v)| self.gen_const(v, ty))
.collect_vec();
let types = values.iter().map(BasicValueEnum::get_type).collect_vec();
let ty = self.ctx.struct_type(&types, false);
ty.const_named_struct(&values).into()
}
_ => unreachable!(),
}
}
fn gen_int_ops(
&mut self,
op: &Operator,
lhs: BasicValueEnum<'ctx>,
rhs: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
let (lhs, rhs) =
if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) {
(lhs, rhs)
} else {
unreachable!()
};
match op {
Operator::Add => self.builder.build_int_add(lhs, rhs, "add").into(),
Operator::Sub => self.builder.build_int_sub(lhs, rhs, "sub").into(),
Operator::Mult => self.builder.build_int_mul(lhs, rhs, "mul").into(),
Operator::Div => {
let float = self.ctx.f64_type();
let left = self.builder.build_signed_int_to_float(lhs, float, "i2f");
let right = self.builder.build_signed_int_to_float(rhs, float, "i2f");
self.builder.build_float_div(left, right, "fdiv").into()
}
Operator::Mod => self.builder.build_int_signed_rem(lhs, rhs, "mod").into(),
Operator::BitOr => self.builder.build_or(lhs, rhs, "or").into(),
Operator::BitXor => self.builder.build_xor(lhs, rhs, "xor").into(),
Operator::BitAnd => self.builder.build_and(lhs, rhs, "and").into(),
Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(),
Operator::RShift => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(),
Operator::FloorDiv => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(),
// special implementation?
Operator::Pow => unimplemented!(),
Operator::MatMult => unreachable!(),
}
}
fn gen_float_ops(
&mut self,
op: &Operator,
lhs: BasicValueEnum<'ctx>,
rhs: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
let (lhs, rhs) = if let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) =
(lhs, rhs)
{
(lhs, rhs)
} else {
unreachable!()
};
match op {
Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").into(),
Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").into(),
Operator::Mult => self.builder.build_float_mul(lhs, rhs, "fmul").into(),
Operator::Div => self.builder.build_float_div(lhs, rhs, "fdiv").into(),
Operator::Mod => self.builder.build_float_rem(lhs, rhs, "fmod").into(),
Operator::FloorDiv => {
let div = self.builder.build_float_div(lhs, rhs, "fdiv");
let floor_intrinsic =
self.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let float = self.ctx.f64_type();
let fn_type = float.fn_type(&[float.into()], false);
self.module.add_function("llvm.floor.f64", fn_type, None)
});
self.builder
.build_call(floor_intrinsic, &[div.into()], "floor")
.try_as_basic_value()
.left()
.unwrap()
}
// special implementation?
_ => unimplemented!(),
}
}
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> {
let zero = self.ctx.i32_type().const_int(0, false);
match &expr.node {
ExprKind::Constant { value, .. } => {
let ty = expr.custom.unwrap();
self.gen_const(value, ty)
}
ExprKind::Name { id, .. } => {
let ptr = self.var_assignment.get(id).unwrap();
let primitives = &self.primitives;
// we should only dereference primitive types
if [primitives.int32, primitives.int64, primitives.float, primitives.bool]
.contains(&self.unifier.get_representative(expr.custom.unwrap()))
{
self.builder.build_load(*ptr, "load")
} else {
(*ptr).into()
}
}
ExprKind::List { elts, .. } => {
// this shall be optimized later for constant primitive lists...
// we should use memcpy for that instead of generating thousands of stores
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
let ty = if elements.is_empty() {
self.ctx.i32_type().into()
} else {
elements[0].get_type()
};
let arr_ptr = self.builder.build_array_alloca(
ty,
self.ctx.i32_type().const_int(elements.len() as u64, false),
"tmparr",
);
let arr_ty = self.ctx.struct_type(
&[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()],
false,
);
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
unsafe {
self.builder.build_store(
arr_str_ptr.const_in_bounds_gep(&[zero, zero]),
self.ctx.i32_type().const_int(elements.len() as u64, false),
);
self.builder.build_store(
arr_str_ptr
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(1, false)]),
arr_ptr,
);
let arr_offset = self.ctx.i32_type().const_int(1, false);
for (i, v) in elements.iter().enumerate() {
let ptr = self.builder.build_in_bounds_gep(
arr_ptr,
&[zero, arr_offset, self.ctx.i32_type().const_int(i as u64, false)],
"arr_element",
);
self.builder.build_store(ptr, *v);
}
}
arr_str_ptr.into()
}
ExprKind::Tuple { elts, .. } => {
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec();
let tuple_ty = self.ctx.struct_type(&element_ty, false);
let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple");
for (i, v) in element_val.into_iter().enumerate() {
unsafe {
let ptr = tuple_ptr.const_in_bounds_gep(&[
zero,
self.ctx.i32_type().const_int(i as u64, false),
]);
self.builder.build_store(ptr, v);
}
}
tuple_ptr.into()
}
ExprKind::Attribute { value, attr, .. } => {
// note that we would handle class methods directly in calls
let index = self.get_attr_index(value.custom.unwrap(), attr);
let val = self.gen_expr(value);
let ptr = if let BasicValueEnum::PointerValue(v) = val {
v
} else {
unreachable!();
};
unsafe {
let ptr = ptr.const_in_bounds_gep(&[
zero,
self.ctx.i32_type().const_int(index as u64, false),
]);
self.builder.build_load(ptr, "field")
}
}
ExprKind::BoolOp { op, values } => {
// requires conditional branches for short-circuiting...
let left = if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]) {
left
} else {
unreachable!()
};
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let a_bb = self.ctx.append_basic_block(current, "a");
let b_bb = self.ctx.append_basic_block(current, "b");
let cont_bb = self.ctx.append_basic_block(current, "cont");
self.builder.build_conditional_branch(left, a_bb, b_bb);
let (a, b) = match op {
Boolop::Or => {
self.builder.position_at_end(a_bb);
let a = self.ctx.bool_type().const_int(1, false);
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(b_bb);
let b = if let BasicValueEnum::IntValue(b) = self.gen_expr(&values[1]) {
b
} else {
unreachable!()
};
self.builder.build_unconditional_branch(cont_bb);
(a, b)
}
Boolop::And => {
self.builder.position_at_end(a_bb);
let a = if let BasicValueEnum::IntValue(a) = self.gen_expr(&values[1]) {
a
} else {
unreachable!()
};
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(b_bb);
let b = self.ctx.bool_type().const_int(0, false);
self.builder.build_unconditional_branch(cont_bb);
(a, b)
}
};
self.builder.position_at_end(cont_bb);
let phi = self.builder.build_phi(self.ctx.bool_type(), "phi");
phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]);
phi.as_basic_value()
}
ExprKind::BinOp { op, left, right } => {
let ty1 = self.unifier.get_representative(left.custom.unwrap());
let ty2 = self.unifier.get_representative(right.custom.unwrap());
let left = self.gen_expr(left);
let right = self.gen_expr(right);
// we can directly compare the types, because we've got their representatives
// which would be unchanged until further unification, which we would never do
// when doing code generation for function instances
if ty1 == ty2 && [self.primitives.int32, self.primitives.int64].contains(&ty1) {
self.gen_int_ops(op, left, right)
} else if ty1 == ty2 && self.primitives.float == ty1 {
self.gen_float_ops(op, left, right)
} else {
unimplemented!()
}
}
ExprKind::UnaryOp { op, operand } => {
let ty = self.unifier.get_representative(operand.custom.unwrap());
let val = self.gen_expr(operand);
if ty == self.primitives.bool {
let val =
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
match op {
ast::Unaryop::Invert | ast::Unaryop::Not => {
self.builder.build_not(val, "not").into()
}
_ => val.into(),
}
} else if [self.primitives.int32, self.primitives.int64].contains(&ty) {
let val =
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
match op {
ast::Unaryop::USub => self.builder.build_int_neg(val, "neg").into(),
ast::Unaryop::Invert => self.builder.build_not(val, "not").into(),
ast::Unaryop::Not => self
.builder
.build_int_compare(
inkwell::IntPredicate::EQ,
val,
val.get_type().const_zero(),
"not",
)
.into(),
_ => val.into(),
}
} else if ty == self.primitives.float {
let val = if let BasicValueEnum::FloatValue(val) = val {
val
} else {
unreachable!()
};
match op {
ast::Unaryop::USub => self.builder.build_float_neg(val, "neg").into(),
ast::Unaryop::Not => self
.builder
.build_float_compare(
inkwell::FloatPredicate::OEQ,
val,
val.get_type().const_zero(),
"not",
)
.into(),
_ => val.into(),
}
} else {
unimplemented!()
}
}
ExprKind::Compare { left, ops, comparators } => {
izip!(
chain(once(left.as_ref()), comparators.iter()),
comparators.iter(),
ops.iter(),
)
.fold(None, |prev, (lhs, rhs, op)| {
let ty = self.unifier.get_representative(lhs.custom.unwrap());
let current =
if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
.contains(&ty)
{
let (lhs, rhs) = if let (
BasicValueEnum::IntValue(lhs),
BasicValueEnum::IntValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
{
(lhs, rhs)
} else {
unreachable!()
};
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
_ => unreachable!(),
};
self.builder.build_int_compare(op, lhs, rhs, "cmp")
} else if ty == self.primitives.float {
let (lhs, rhs) = if let (
BasicValueEnum::FloatValue(lhs),
BasicValueEnum::FloatValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
{
(lhs, rhs)
} else {
unreachable!()
};
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(),
};
self.builder.build_float_compare(op, lhs, rhs, "cmp")
} else {
unimplemented!()
};
prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current))
})
.unwrap()
.into() // as there should be at least 1 element, it should never be none
}
ExprKind::IfExp { test, body, orelse } => {
let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test) {
test
} else {
unreachable!()
};
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_bb = self.ctx.append_basic_block(current, "then");
let else_bb = self.ctx.append_basic_block(current, "else");
let cont_bb = self.ctx.append_basic_block(current, "cont");
self.builder.build_conditional_branch(test, then_bb, else_bb);
self.builder.position_at_end(then_bb);
let a = self.gen_expr(body);
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(else_bb);
let b = self.gen_expr(orelse);
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
let phi = self.builder.build_phi(a.get_type(), "ifexpr");
phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]);
phi.as_basic_value()
}
_ => unimplemented!(),
}
}
}

343
nac3core/src/codegen/mod.rs Normal file
View File

@ -0,0 +1,343 @@
use crate::{
symbol_resolver::SymbolResolver,
top_level::{TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{FunSignature, Type, TypeEnum, Unifier},
},
};
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{
basic_block::BasicBlock,
builder::Builder,
context::Context,
module::Module,
types::{BasicType, BasicTypeEnum},
values::PointerValue,
AddressSpace,
};
use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use rustpython_parser::ast::Stmt;
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::thread;
mod expr;
mod stmt;
#[cfg(test)]
mod test;
pub struct CodeGenContext<'ctx, 'a> {
pub ctx: &'ctx Context,
pub builder: Builder<'ctx>,
pub module: Module<'ctx>,
pub top_level: &'a TopLevelContext,
pub unifier: Unifier,
pub resolver: Arc<dyn SymbolResolver>,
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore,
// stores the alloca for variables
pub init_bb: BasicBlock<'ctx>,
// where continue and break should go to respectively
// the first one is the test_bb, and the second one is bb after the loop
pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
}
type Fp = Box<dyn Fn(&Module) + Send + Sync>;
pub struct WithCall {
fp: Fp,
}
impl WithCall {
pub fn new(fp: Fp) -> WithCall {
WithCall { fp }
}
pub fn run<'ctx>(&self, m: &Module<'ctx>) {
(self.fp)(m)
}
}
pub struct WorkerRegistry {
sender: Arc<Sender<Option<CodeGenTask>>>,
receiver: Arc<Receiver<Option<CodeGenTask>>>,
panicked: AtomicBool,
task_count: Mutex<usize>,
thread_count: usize,
wait_condvar: Condvar,
}
impl WorkerRegistry {
pub fn create_workers(
names: &[&str],
top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>,
) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
let (sender, receiver) = unbounded();
let task_count = Mutex::new(0);
let wait_condvar = Condvar::new();
let registry = Arc::new(WorkerRegistry {
sender: Arc::new(sender),
receiver: Arc::new(receiver),
thread_count: names.len(),
panicked: AtomicBool::new(false),
task_count,
wait_condvar,
});
let mut handles = Vec::new();
for name in names.iter() {
let top_level_ctx = top_level_ctx.clone();
let registry = registry.clone();
let registry2 = registry.clone();
let name = name.to_string();
let f = f.clone();
let handle = thread::spawn(move || {
registry.worker_thread(name, top_level_ctx, f);
});
let handle = thread::spawn(move || {
if let Err(e) = handle.join() {
if let Some(e) = e.downcast_ref::<&'static str>() {
eprintln!("Got an error: {}", e);
} else {
eprintln!("Got an unknown error: {:?}", e);
}
registry2.panicked.store(true, Ordering::SeqCst);
registry2.wait_condvar.notify_all();
}
});
handles.push(handle);
}
(registry, handles)
}
pub fn wait_tasks_complete(&self, handles: Vec<thread::JoinHandle<()>>) {
{
let mut count = self.task_count.lock();
while *count != 0 {
if self.panicked.load(Ordering::SeqCst) {
break;
}
self.wait_condvar.wait(&mut count);
}
}
for _ in 0..self.thread_count {
self.sender.send(None).unwrap();
}
{
let mut count = self.task_count.lock();
while *count != self.thread_count {
if self.panicked.load(Ordering::SeqCst) {
break;
}
self.wait_condvar.wait(&mut count);
}
}
for handle in handles {
handle.join().unwrap();
}
if self.panicked.load(Ordering::SeqCst) {
panic!("tasks panicked");
}
}
pub fn add_task(&self, task: CodeGenTask) {
*self.task_count.lock() += 1;
self.sender.send(Some(task)).unwrap();
}
fn worker_thread(
&self,
module_name: String,
top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>,
) {
let context = Context::create();
let mut builder = context.create_builder();
let mut module = context.create_module(&module_name);
while let Some(task) = self.receiver.recv().unwrap() {
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
builder = result.0;
module = result.1;
*self.task_count.lock() -= 1;
self.wait_condvar.notify_all();
}
// do whatever...
let mut lock = self.task_count.lock();
module.verify().unwrap();
f.run(&module);
*lock += 1;
self.wait_condvar.notify_all();
}
}
pub struct CodeGenTask {
pub subst: Vec<(Type, Type)>,
pub symbol_name: String,
pub signature: FunSignature,
pub body: Vec<Stmt<Option<Type>>>,
pub unifier_index: usize,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
}
fn get_llvm_type<'ctx>(
ctx: &'ctx Context,
unifier: &mut Unifier,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
ty: Type,
) -> BasicTypeEnum<'ctx> {
use TypeEnum::*;
// we assume the type cache should already contain primitive types,
// and they should be passed by value instead of passing as pointer.
type_cache.get(&unifier.get_representative(ty)).cloned().unwrap_or_else(|| {
match &*unifier.get_ty(ty) {
TObj { obj_id, fields, .. } => {
// a struct with fields in the order of declaration
let defs = top_level.definitions.read();
let definition = defs.get(obj_id.0).unwrap();
let ty = if let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read()
{
let fields = fields.borrow();
let fields = fields_list
.iter()
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0]))
.collect_vec();
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
} else {
unreachable!()
};
ty
}
TTuple { ty } => {
// a struct with fields in the order present in the tuple
let fields = ty
.iter()
.map(|ty| get_llvm_type(ctx, unifier, top_level, type_cache, *ty))
.collect_vec();
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
}
TList { ty } => {
// a struct with an integer and a pointer to an array
let element_type = get_llvm_type(ctx, unifier, top_level, type_cache, *ty);
let fields =
[ctx.i32_type().into(), element_type.ptr_type(AddressSpace::Generic).into()];
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
}
TVirtual { .. } => unimplemented!(),
_ => unreachable!(),
}
})
}
pub fn gen_func<'ctx>(
context: &'ctx Context,
builder: Builder<'ctx>,
module: Module<'ctx>,
task: CodeGenTask,
top_level_ctx: Arc<TopLevelContext>,
) -> (Builder<'ctx>, Module<'ctx>) {
// unwrap_or(0) is for unit tests without using rayon
let (mut unifier, primitives) = {
let unifiers = top_level_ctx.unifiers.read();
let (unifier, primitives) = &unifiers[task.unifier_index];
(Unifier::from_shared_unifier(unifier), *primitives)
};
for (a, b) in task.subst.iter() {
// this should be unification between variables and concrete types
// and should not cause any problem...
unifier.unify(*a, *b).unwrap();
}
// rebuild primitive store with unique representatives
let primitives = PrimitiveStore {
int32: unifier.get_representative(primitives.int32),
int64: unifier.get_representative(primitives.int64),
float: unifier.get_representative(primitives.float),
bool: unifier.get_representative(primitives.bool),
none: unifier.get_representative(primitives.none),
};
let mut type_cache: HashMap<_, _> = [
(unifier.get_representative(primitives.int32), context.i32_type().into()),
(unifier.get_representative(primitives.int64), context.i64_type().into()),
(unifier.get_representative(primitives.float), context.f64_type().into()),
(unifier.get_representative(primitives.bool), context.bool_type().into()),
]
.iter()
.cloned()
.collect();
let params = task
.signature
.args
.iter()
.map(|arg| {
get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty)
})
.collect_vec();
let fn_type = if unifier.unioned(task.signature.ret, primitives.none) {
context.void_type().fn_type(&params, false)
} else {
get_llvm_type(
&context,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
task.signature.ret,
)
.fn_type(&params, false)
};
let fn_val = module.add_function(&task.symbol_name, fn_type, None);
let init_bb = context.append_basic_block(fn_val, "init");
builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body");
let mut var_assignment = HashMap::new();
for (n, arg) in task.signature.args.iter().enumerate() {
let param = fn_val.get_nth_param(n as u32).unwrap();
let alloca = builder.build_alloca(
get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty),
&arg.name,
);
builder.build_store(alloca, param);
var_assignment.insert(arg.name.clone(), alloca);
}
builder.build_unconditional_branch(body_bb);
builder.position_at_end(body_bb);
let mut code_gen_context = CodeGenContext {
ctx: &context,
resolver: task.resolver,
top_level: top_level_ctx.as_ref(),
loop_bb: None,
var_assignment,
type_cache,
primitives,
init_bb,
builder,
module,
unifier,
};
for stmt in task.body.iter() {
code_gen_context.gen_stmt(stmt);
}
let CodeGenContext { builder, module, .. } = code_gen_context;
(builder, module)
}

View File

@ -0,0 +1,138 @@
use super::CodeGenContext;
use crate::typecheck::typedef::Type;
use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind};
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> {
// put the alloca in init block
let current = self.builder.get_insert_block().unwrap();
// position before the last branching instruction...
self.builder.position_before(&self.init_bb.get_last_instruction().unwrap());
let ty = self.get_llvm_type(ty);
let ptr = self.builder.build_alloca(ty, "tmp");
self.builder.position_at_end(current);
ptr
}
fn parse_pattern(&mut self, pattern: &Expr<Option<Type>>) -> PointerValue<'ctx> {
// very similar to gen_expr, but we don't do an extra load at the end
// and we flatten nested tuples
match &pattern.node {
ExprKind::Name { id, .. } => {
self.var_assignment.get(id).cloned().unwrap_or_else(|| {
let ptr = self.gen_var(pattern.custom.unwrap());
self.var_assignment.insert(id.clone(), ptr);
ptr
})
}
ExprKind::Attribute { value, attr, .. } => {
let index = self.get_attr_index(value.custom.unwrap(), attr);
let val = self.gen_expr(value);
let ptr = if let BasicValueEnum::PointerValue(v) = val {
v
} else {
unreachable!();
};
unsafe {
ptr.const_in_bounds_gep(&[
self.ctx.i32_type().const_zero(),
self.ctx.i32_type().const_int(index as u64, false),
])
}
}
ExprKind::Subscript { .. } => unimplemented!(),
_ => unreachable!(),
}
}
fn gen_assignment(&mut self, target: &Expr<Option<Type>>, value: BasicValueEnum<'ctx>) {
if let ExprKind::Tuple { elts, .. } = &target.node {
if let BasicValueEnum::PointerValue(ptr) = value {
for (i, elt) in elts.iter().enumerate() {
unsafe {
let t = ptr.const_in_bounds_gep(&[
self.ctx.i32_type().const_zero(),
self.ctx.i32_type().const_int(i as u64, false),
]);
let v = self.builder.build_load(t, "tmpload");
self.gen_assignment(elt, v);
}
}
} else {
unreachable!()
}
} else {
let ptr = self.parse_pattern(target);
self.builder.build_store(ptr, value);
}
}
pub fn gen_stmt(&mut self, stmt: &Stmt<Option<Type>>) {
match &stmt.node {
StmtKind::Expr { value } => {
self.gen_expr(&value);
}
StmtKind::Return { value } => {
let value = value.as_ref().map(|v| self.gen_expr(&v));
let value = value.as_ref().map(|v| v as &dyn BasicValue);
self.builder.build_return(value);
}
StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value {
let value = self.gen_expr(&value);
self.gen_assignment(target, value);
}
}
StmtKind::Assign { targets, value, .. } => {
let value = self.gen_expr(&value);
for target in targets.iter() {
self.gen_assignment(target, value);
}
}
StmtKind::Continue => {
self.builder.build_unconditional_branch(self.loop_bb.unwrap().0);
}
StmtKind::Break => {
self.builder.build_unconditional_branch(self.loop_bb.unwrap().1);
}
StmtKind::While { test, body, orelse } => {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let test_bb = self.ctx.append_basic_block(current, "test");
let body_bb = self.ctx.append_basic_block(current, "body");
let cont_bb = self.ctx.append_basic_block(current, "cont");
// if there is no orelse, we just go to cont_bb
let orelse_bb = if orelse.is_empty() {
cont_bb
} else {
self.ctx.append_basic_block(current, "orelse")
};
// store loop bb information and restore it later
let loop_bb = self.loop_bb.replace((test_bb, cont_bb));
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb);
let test = self.gen_expr(test);
if let BasicValueEnum::IntValue(test) = test {
self.builder.build_conditional_branch(test, body_bb, orelse_bb);
} else {
unreachable!()
};
self.builder.position_at_end(body_bb);
for stmt in body.iter() {
self.gen_stmt(stmt);
}
self.builder.build_unconditional_branch(test_bb);
if !orelse.is_empty() {
self.builder.position_at_end(orelse_bb);
for stmt in orelse.iter() {
self.gen_stmt(stmt);
}
self.builder.build_unconditional_branch(cont_bb);
}
self.builder.position_at_end(cont_bb);
self.loop_bb = loop_bb;
}
_ => unimplemented!(),
}
}
}

View File

@ -0,0 +1,247 @@
use super::{CodeGenTask, WorkerRegistry};
use crate::{
codegen::WithCall,
location::Location,
symbol_resolver::{SymbolResolver, SymbolValue},
top_level::{DefinitionId, TopLevelContext},
typecheck::{
magic_methods::set_primitives_magic_methods,
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier},
},
};
use indoc::indoc;
use parking_lot::RwLock;
use rustpython_parser::{ast::fold::Fold, parser::parse_program};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
struct Resolver {
id_to_type: HashMap<String, Type>,
id_to_def: HashMap<String, DefinitionId>,
class_names: HashMap<String, Type>,
}
impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
self.id_to_type.get(str).cloned()
}
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
unimplemented!()
}
fn get_symbol_location(&self, _: &str) -> Option<Location> {
unimplemented!()
}
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
self.id_to_def.get(id).cloned()
}
}
struct TestEnvironment {
pub unifier: Unifier,
pub function_data: FunctionData,
pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext,
}
impl TestEnvironment {
pub fn basic_test_env() -> TestEnvironment {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let primitives = PrimitiveStore { int32, int64, float, bool, none };
set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name = [
(0, "int32".to_string()),
(1, "int64".to_string()),
(2, "float".to_string()),
(3, "bool".to_string()),
(4, "none".to_string()),
]
.iter()
.cloned()
.collect();
let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
unifier,
top_level: TopLevelContext {
definitions: Default::default(),
unifiers: Default::default(),
// conetexts: Default::default(),
},
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
},
primitives,
id_to_name,
identifier_mapping,
virtual_checks: Vec::new(),
calls: HashMap::new(),
}
}
fn get_inferencer(&mut self) -> Inferencer {
Inferencer {
top_level: &self.top_level,
function_data: &mut self.function_data,
unifier: &mut self.unifier,
variable_mapping: Default::default(),
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
}
}
}
#[test]
fn test_primitives() {
let mut env = TestEnvironment::basic_test_env();
let threads = ["test"];
let signature = FunSignature {
args: vec![
FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None },
FuncArg { name: "b".to_string(), ty: env.primitives.int32, default_value: None },
],
ret: env.primitives.int32,
vars: HashMap::new(),
};
let mut inferencer = env.get_inferencer();
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
let source = indoc! { "
c = a + b
d = a if c == 1 else 0
return d
"};
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let mut identifiers = vec!["a".to_string(), "b".to_string()];
inferencer.check_block(&statements, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Default::default(),
unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])),
// conetexts: Default::default(),
});
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "testing".to_string(),
body: statements,
unifier_index: 0,
resolver: env.function_data.resolver.clone(),
signature,
};
let f = Arc::new(WithCall::new(Box::new(|module| {
// the following IR is equivalent to
// ```
// ; ModuleID = 'test.ll'
// source_filename = "test"
//
// ; Function Attrs: norecurse nounwind readnone
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
// init:
// %add = add i32 %1, %0
// %cmp = icmp eq i32 %add, 1
// %ifexpr = select i1 %cmp, i32 %0, i32 0
// ret i32 %ifexpr
// }
//
// attributes #0 = { norecurse nounwind readnone }
// ```
// after O2 optimization
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
define i32 @testing(i32 %0, i32 %1) {
init:
%a = alloca i32, align 4
store i32 %0, i32* %a, align 4
%b = alloca i32, align 4
store i32 %1, i32* %b, align 4
%tmp = alloca i32, align 4
%tmp4 = alloca i32, align 4
br label %body
body: ; preds = %init
%load = load i32, i32* %a, align 4
%load1 = load i32, i32* %b, align 4
%add = add i32 %load, %load1
store i32 %add, i32* %tmp, align 4
%load2 = load i32, i32* %tmp, align 4
%cmp = icmp eq i32 %load2, 1
br i1 %cmp, label %then, label %else
then: ; preds = %body
%load3 = load i32, i32* %a, align 4
br label %cont
else: ; preds = %body
br label %cont
cont: ; preds = %else, %then
%ifexpr = phi i32 [ %load3, %then ], [ 0, %else ]
store i32 %ifexpr, i32* %tmp4, align 4
%load5 = load i32, i32* %tmp4, align 4
ret i32 %load5
}
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}

View File

@ -1,212 +0,0 @@
use super::TopLevelContext;
use crate::typedef::*;
use std::boxed::Box;
use std::collections::HashMap;
struct ContextStack<'a> {
/// stack level, starts from 0
level: u32,
/// stack of variable definitions containing (id, def, level) where `def` is the original
/// definition in `level-1`.
var_defs: Vec<(usize, VarDef<'a>, u32)>,
/// stack of symbol definitions containing (name, level) where `level` is the smallest level
/// where the name is assigned a value
sym_def: Vec<(&'a str, u32)>,
}
pub struct InferenceContext<'a> {
/// top level context
top_level: TopLevelContext<'a>,
/// list of primitive instances
primitives: Vec<Type>,
/// list of variable instances
variables: Vec<Type>,
/// identifier to (type, readable) mapping.
/// an identifier might be defined earlier but has no value (for some code path), thus not
/// readable.
sym_table: HashMap<&'a str, (Type, bool)>,
/// resolution function reference, that may resolve unbounded identifiers to some type
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
/// stack
stack: ContextStack<'a>,
}
// non-trivial implementations here
impl<'a> InferenceContext<'a> {
/// return a new `InferenceContext` from `TopLevelContext` and resolution function.
pub fn new(
top_level: TopLevelContext,
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
) -> InferenceContext {
let primitives = (0..top_level.primitive_defs.len())
.map(|v| TypeEnum::PrimitiveType(PrimitiveId(v)).into())
.collect();
let variables = (0..top_level.var_defs.len())
.map(|v| TypeEnum::TypeVariable(VariableId(v)).into())
.collect();
InferenceContext {
top_level,
primitives,
variables,
sym_table: HashMap::new(),
resolution_fn,
stack: ContextStack {
level: 0,
var_defs: Vec::new(),
sym_def: Vec::new(),
},
}
}
/// execute the function with new scope.
/// variable assignment would be limited within the scope (not readable outside), and type
/// variable type guard would be limited within the scope.
/// returns the list of variables assigned within the scope, and the result of the function
pub fn with_scope<F, R>(&mut self, f: F) -> (Vec<&'a str>, R)
where
F: FnOnce(&mut Self) -> R,
{
self.stack.level += 1;
let result = f(self);
self.stack.level -= 1;
while !self.stack.var_defs.is_empty() {
let (_, _, level) = self.stack.var_defs.last().unwrap();
if *level > self.stack.level {
let (id, def, _) = self.stack.var_defs.pop().unwrap();
self.top_level.var_defs[id] = def;
} else {
break;
}
}
let mut poped_names = Vec::new();
while !self.stack.sym_def.is_empty() {
let (_, level) = self.stack.sym_def.last().unwrap();
if *level > self.stack.level {
let (name, _) = self.stack.sym_def.pop().unwrap();
self.sym_table.remove(name).unwrap();
poped_names.push(name);
} else {
break;
}
}
(poped_names, result)
}
/// assign a type to an identifier.
/// may return error if the identifier was defined but with different type
pub fn assign(&mut self, name: &'a str, ty: Type) -> Result<Type, String> {
if let Some((t, x)) = self.sym_table.get_mut(name) {
if t == &ty {
if !*x {
self.stack.sym_def.push((name, self.stack.level));
}
*x = true;
Ok(ty)
} else {
Err("different types".into())
}
} else {
self.stack.sym_def.push((name, self.stack.level));
self.sym_table.insert(name, (ty.clone(), true));
Ok(ty)
}
}
/// check if an identifier is already defined
pub fn defined(&self, name: &str) -> bool {
self.sym_table.get(name).is_some()
}
/// get the type of an identifier
/// may return error if the identifier is not defined, and cannot be resolved with the
/// resolution function.
pub fn resolve(&mut self, name: &str) -> Result<Type, String> {
if let Some((t, x)) = self.sym_table.get(name) {
if *x {
Ok(t.clone())
} else {
Err("may not have value".into())
}
} else {
self.resolution_fn.as_mut()(name)
}
}
/// restrict the bound of a type variable by replacing its definition.
/// used for implementing type guard
pub fn restrict(&mut self, id: VariableId, mut def: VarDef<'a>) {
std::mem::swap(self.top_level.var_defs.get_mut(id.0).unwrap(), &mut def);
self.stack.var_defs.push((id.0, def, self.stack.level));
}
}
// trivial getters:
impl<'a> InferenceContext<'a> {
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
self.primitives.get(id.0).unwrap().clone()
}
pub fn get_variable(&self, id: VariableId) -> Type {
self.variables.get(id.0).unwrap().clone()
}
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
self.top_level.fn_table.get(name)
}
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
self.top_level.primitive_defs.get(id.0).unwrap()
}
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
self.top_level.class_defs.get(id.0).unwrap()
}
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
self.top_level.parametric_defs.get(id.0).unwrap()
}
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
self.top_level.var_defs.get(id.0).unwrap()
}
pub fn get_type(&self, name: &str) -> Option<Type> {
self.top_level.get_type(name)
}
}
impl TypeEnum {
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> TypeEnum {
match self {
TypeEnum::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(),
TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType(
*id,
params
.iter()
.map(|v| v.as_ref().subst(map).into())
.collect(),
),
_ => self.clone(),
}
}
pub fn get_subst(&self, ctx: &InferenceContext) -> HashMap<VariableId, Type> {
match self {
TypeEnum::ParametricType(id, params) => {
let vars = &ctx.get_parametric_def(*id).params;
vars.iter()
.zip(params)
.map(|(v, p)| (*v, p.as_ref().clone().into()))
.collect()
}
// if this proves to be slow, we can use option type
_ => HashMap::new(),
}
}
pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b InferenceContext) -> Option<&'b TypeDef> {
match self {
TypeEnum::PrimitiveType(id) => Some(ctx.get_primitive_def(*id)),
TypeEnum::ClassType(id) | TypeEnum::VirtualClassType(id) => {
Some(&ctx.get_class_def(*id).base)
}
TypeEnum::ParametricType(id, _) => Some(&ctx.get_parametric_def(*id).base),
_ => None,
}
}
}

View File

@ -1,4 +0,0 @@
mod inference_context;
mod top_level_context;
pub use inference_context::InferenceContext;
pub use top_level_context::TopLevelContext;

View File

@ -1,136 +0,0 @@
use crate::typedef::*;
use std::collections::HashMap;
use std::rc::Rc;
/// Structure for storing top-level type definitions.
/// Used for collecting type signature from source code.
/// Can be converted to `InferenceContext` for type inference in functions.
pub struct TopLevelContext<'a> {
/// List of primitive definitions.
pub(super) primitive_defs: Vec<TypeDef<'a>>,
/// List of class definitions.
pub(super) class_defs: Vec<ClassDef<'a>>,
/// List of parametric type definitions.
pub(super) parametric_defs: Vec<ParametricDef<'a>>,
/// List of type variable definitions.
pub(super) var_defs: Vec<VarDef<'a>>,
/// Function name to signature mapping.
pub(super) fn_table: HashMap<&'a str, FnDef>,
/// Type name to type mapping.
pub(super) sym_table: HashMap<&'a str, Type>,
primitives: Vec<Type>,
variables: Vec<Type>,
}
impl<'a> TopLevelContext<'a> {
pub fn new(primitive_defs: Vec<TypeDef<'a>>) -> TopLevelContext {
let mut sym_table = HashMap::new();
let mut primitives = Vec::new();
for (i, t) in primitive_defs.iter().enumerate() {
primitives.push(TypeEnum::PrimitiveType(PrimitiveId(i)).into());
sym_table.insert(t.name, TypeEnum::PrimitiveType(PrimitiveId(i)).into());
}
TopLevelContext {
primitive_defs,
class_defs: Vec::new(),
parametric_defs: Vec::new(),
var_defs: Vec::new(),
fn_table: HashMap::new(),
sym_table,
primitives,
variables: Vec::new(),
}
}
pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId {
self.sym_table.insert(
def.base.name,
TypeEnum::ClassType(ClassId(self.class_defs.len())).into(),
);
self.class_defs.push(def);
ClassId(self.class_defs.len() - 1)
}
pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId {
let params = def
.params
.iter()
.map(|&v| Rc::new(TypeEnum::TypeVariable(v)))
.collect();
self.sym_table.insert(
def.base.name,
TypeEnum::ParametricType(ParamId(self.parametric_defs.len()), params).into(),
);
self.parametric_defs.push(def);
ParamId(self.parametric_defs.len() - 1)
}
pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId {
self.sym_table.insert(
def.name,
TypeEnum::TypeVariable(VariableId(self.var_defs.len())).into(),
);
self.add_variable_private(def)
}
pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId {
self.var_defs.push(def);
self.variables
.push(TypeEnum::TypeVariable(VariableId(self.var_defs.len() - 1)).into());
VariableId(self.var_defs.len() - 1)
}
pub fn add_fn(&mut self, name: &'a str, def: FnDef) {
self.fn_table.insert(name, def);
}
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
self.fn_table.get(name)
}
pub fn get_primitive_def_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
self.primitive_defs.get_mut(id.0).unwrap()
}
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
self.primitive_defs.get(id.0).unwrap()
}
pub fn get_class_def_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> {
self.class_defs.get_mut(id.0).unwrap()
}
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
self.class_defs.get(id.0).unwrap()
}
pub fn get_parametric_def_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> {
self.parametric_defs.get_mut(id.0).unwrap()
}
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
self.parametric_defs.get(id.0).unwrap()
}
pub fn get_variable_def_mut(&mut self, id: VariableId) -> &mut VarDef<'a> {
self.var_defs.get_mut(id.0).unwrap()
}
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
self.var_defs.get(id.0).unwrap()
}
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
self.primitives.get(id.0).unwrap().clone()
}
pub fn get_variable(&self, id: VariableId) -> Type {
self.variables.get(id.0).unwrap().clone()
}
pub fn get_type(&self, name: &str) -> Option<Type> {
// TODO: handle parametric types
self.sym_table.get(name).cloned()
}
}

View File

@ -1,922 +0,0 @@
use crate::context::InferenceContext;
use crate::inference_core::resolve_call;
use crate::magic_methods::*;
use crate::primitives::*;
use crate::typedef::{Type, TypeEnum::*};
use rustpython_parser::ast::{
Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator,
UnaryOperator,
};
use std::convert::TryInto;
type ParserResult = Result<Option<Type>, String>;
pub fn infer_expr<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
expr: &'b Expression,
) -> ParserResult {
match &expr.node {
ExpressionType::Number { value } => infer_constant(ctx, value),
ExpressionType::Identifier { name } => infer_identifier(ctx, name),
ExpressionType::List { elements } => infer_list(ctx, elements),
ExpressionType::Tuple { elements } => infer_tuple(ctx, elements),
ExpressionType::Attribute { value, name } => infer_attribute(ctx, value, name),
ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, values),
ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, op, a, b),
ExpressionType::Unop { op, a } => infer_unary_ops(ctx, op, a),
ExpressionType::Compare { vals, ops } => infer_compare(ctx, vals, ops),
ExpressionType::Call {
args,
function,
keywords,
} => {
if !keywords.is_empty() {
Err("keyword is not supported".into())
} else {
infer_call(ctx, &args, &function)
}
}
ExpressionType::Subscript { a, b } => infer_subscript(ctx, a, b),
ExpressionType::IfExpression { test, body, orelse } => {
infer_if_expr(ctx, &test, &body, orelse)
}
ExpressionType::Comprehension { kind, generators } => match kind.as_ref() {
ComprehensionKind::List { element } => {
if generators.len() == 1 {
infer_list_comprehension(ctx, element, &generators[0])
} else {
Err("only 1 generator statement is supported".into())
}
}
_ => Err("only list comprehension is supported".into()),
},
ExpressionType::True | ExpressionType::False => Ok(Some(ctx.get_primitive(BOOL_TYPE))),
_ => Err("not supported".into()),
}
}
fn infer_constant(
ctx: &mut InferenceContext,
value: &rustpython_parser::ast::Number,
) -> ParserResult {
use rustpython_parser::ast::Number;
match value {
Number::Integer { value } => {
let int32: Result<i32, _> = value.try_into();
if int32.is_ok() {
Ok(Some(ctx.get_primitive(INT32_TYPE)))
} else {
Err("integer out of range".into())
}
}
Number::Float { .. } => Ok(Some(ctx.get_primitive(FLOAT_TYPE))),
_ => Err("not supported".into()),
}
}
fn infer_identifier(ctx: &mut InferenceContext, name: &str) -> ParserResult {
Ok(Some(ctx.resolve(name)?))
}
fn infer_list<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
elements: &'b [Expression],
) -> ParserResult {
if elements.is_empty() {
return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into()));
}
let mut types = elements.iter().map(|v| infer_expr(ctx, v));
let head = types.next().unwrap()?;
if head.is_none() {
return Err("list elements must have some type".into());
}
for v in types {
// TODO: try virtual type...
if v? != head {
return Err("inhomogeneous list is not allowed".into());
}
}
Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into()))
}
fn infer_tuple<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
elements: &'b [Expression],
) -> ParserResult {
let types: Result<Option<Vec<_>>, String> =
elements.iter().map(|v| infer_expr(ctx, v)).collect();
if let Some(t) = types? {
Ok(Some(ParametricType(TUPLE_TYPE, t).into()))
} else {
Err("tuple elements must have some type".into())
}
}
fn infer_attribute<'a>(
ctx: &mut InferenceContext<'a>,
value: &'a Expression,
name: &str,
) -> ParserResult {
let value = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?;
if let TypeVariable(_) = value.as_ref() {
return Err("no fields for type variable".into());
}
value
.get_base(ctx)
.and_then(|b| b.fields.get(name).cloned())
.map_or_else(|| Err("no such field".to_string()), |v| Ok(Some(v)))
}
fn infer_bool_ops<'a>(ctx: &mut InferenceContext<'a>, values: &'a [Expression]) -> ParserResult {
assert_eq!(values.len(), 2);
let left = infer_expr(ctx, &values[0])?.ok_or_else(|| "no value".to_string())?;
let right = infer_expr(ctx, &values[1])?.ok_or_else(|| "no value".to_string())?;
let b = ctx.get_primitive(BOOL_TYPE);
if left == b && right == b {
Ok(Some(b))
} else {
Err("bool operands must be bool".into())
}
}
fn infer_bin_ops<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
op: &Operator,
left: &'b Expression,
right: &'b Expression,
) -> ParserResult {
let left = infer_expr(ctx, left)?.ok_or_else(|| "no value".to_string())?;
let right = infer_expr(ctx, right)?.ok_or_else(|| "no value".to_string())?;
let fun = binop_name(op);
resolve_call(ctx, Some(left), fun, &[right])
}
fn infer_unary_ops<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
op: &UnaryOperator,
obj: &'b Expression,
) -> ParserResult {
let ty = infer_expr(ctx, obj)?.ok_or_else(|| "no value".to_string())?;
if let UnaryOperator::Not = op {
if ty == ctx.get_primitive(BOOL_TYPE) {
Ok(Some(ty))
} else {
Err("logical not must be applied to bool".into())
}
} else {
resolve_call(ctx, Some(ty), unaryop_name(op), &[])
}
}
fn infer_compare<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
vals: &'b [Expression],
ops: &'b [Comparison],
) -> ParserResult {
let types: Result<Option<Vec<_>>, _> = vals.iter().map(|v| infer_expr(ctx, v)).collect();
let types = types?;
if types.is_none() {
return Err("comparison operands must have type".into());
}
let types = types.unwrap();
let boolean = ctx.get_primitive(BOOL_TYPE);
let left = &types[..types.len() - 1];
let right = &types[1..];
for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) {
let fun = comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?;
let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?;
if ty.is_none() || ty.unwrap() != boolean {
return Err("comparison result must be boolean".into());
}
}
Ok(Some(boolean))
}
fn infer_call<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
args: &'b [Expression],
function: &'b Expression,
) -> ParserResult {
// TODO: special handling for int64 constant
let types: Result<Option<Vec<_>>, _> = args.iter().map(|v| infer_expr(ctx, v)).collect();
let types = types?;
if types.is_none() {
return Err("function params must have type".into());
}
let (obj, fun) = match &function.node {
ExpressionType::Identifier { name } => (None, name),
ExpressionType::Attribute { value, name } => (
Some(infer_expr(ctx, &value)?.ok_or_else(|| "no value".to_string())?),
name,
),
_ => return Err("not supported".into()),
};
resolve_call(ctx, obj, fun.as_str(), &types.unwrap())
}
fn infer_subscript<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
a: &'b Expression,
b: &'b Expression,
) -> ParserResult {
let a = infer_expr(ctx, a)?.ok_or_else(|| "no value".to_string())?;
let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() {
ls[0].clone()
} else {
return Err("subscript is not supported for types other than list".into());
};
match &b.node {
ExpressionType::Slice { elements } => {
let int32 = ctx.get_primitive(INT32_TYPE);
let types: Result<Option<Vec<_>>, _> = elements
.iter()
.map(|v| {
if let ExpressionType::None = v.node {
Ok(Some(int32.clone()))
} else {
infer_expr(ctx, v)
}
})
.collect();
let types = types?.ok_or_else(|| "slice must have type".to_string())?;
if types.iter().all(|v| v == &int32) {
Ok(Some(a))
} else {
Err("slice must be int32 type".into())
}
}
_ => {
let b = infer_expr(ctx, b)?.ok_or_else(|| "no value".to_string())?;
if b == ctx.get_primitive(INT32_TYPE) {
Ok(Some(t))
} else {
Err("index must be either slice or int32".into())
}
}
}
}
fn infer_if_expr<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
test: &'b Expression,
body: &'b Expression,
orelse: &'b Expression,
) -> ParserResult {
let test = infer_expr(ctx, test)?.ok_or_else(|| "no value".to_string())?;
if test != ctx.get_primitive(BOOL_TYPE) {
return Err("test should be bool".into());
}
let body = infer_expr(ctx, body)?;
let orelse = infer_expr(ctx, orelse)?;
if body.as_ref() == orelse.as_ref() {
Ok(body)
} else {
Err("divergent type".into())
}
}
fn infer_simple_binding<'a: 'b, 'b>(
ctx: &mut InferenceContext<'b>,
name: &'a Expression,
ty: Type,
) -> Result<(), String> {
match &name.node {
ExpressionType::Identifier { name } => {
if name == "_" {
Ok(())
} else if ctx.defined(name.as_str()) {
Err("duplicated naming".into())
} else {
ctx.assign(name.as_str(), ty)?;
Ok(())
}
}
ExpressionType::Tuple { elements } => {
if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() {
if elements.len() == ls.len() {
for (a, b) in elements.iter().zip(ls.iter()) {
infer_simple_binding(ctx, a, b.clone())?;
}
Ok(())
} else {
Err("different length".into())
}
} else {
Err("not supported".into())
}
}
_ => Err("not supported".into()),
}
}
fn infer_list_comprehension<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
element: &'b Expression,
comprehension: &'b Comprehension,
) -> ParserResult {
if comprehension.is_async {
return Err("async is not supported".into());
}
let iter = infer_expr(ctx, &comprehension.iter)?.ok_or_else(|| "no value".to_string())?;
if let ParametricType(LIST_TYPE, ls) = iter.as_ref() {
ctx.with_scope(|ctx| {
infer_simple_binding(ctx, &comprehension.target, ls[0].clone())?;
let boolean = ctx.get_primitive(BOOL_TYPE);
for test in comprehension.ifs.iter() {
let result =
infer_expr(ctx, test)?.ok_or_else(|| "no value in test".to_string())?;
if result != boolean {
return Err("test must be bool".into());
}
}
let result = infer_expr(ctx, element)?.ok_or_else(|| "no value")?;
Ok(Some(ParametricType(LIST_TYPE, vec![result]).into()))
})
.1
} else {
Err("iteration is supported for list only".into())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::context::*;
use crate::typedef::*;
use rustpython_parser::parser::parse_expression;
use std::collections::HashMap;
use std::rc::Rc;
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
}
#[test]
fn test_constants() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("123").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("2147483647").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("2147483648").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("integer out of range".into()));
//
// let ast = parse_expression("2147483648").unwrap();
// let result = infer_expr(&mut ctx, &ast);
// assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
// let ast = parse_expression("9223372036854775807").unwrap();
// let result = infer_expr(&mut ctx, &ast);
// assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
// let ast = parse_expression("9223372036854775808").unwrap();
// let result = infer_expr(&mut ctx, &ast);
// assert_eq!(result, Err("integer out of range".into()));
let ast = parse_expression("123.456").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(FLOAT_TYPE));
let ast = parse_expression("True").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
}
#[test]
fn test_identifier() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
let ast = parse_expression("abc").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("ab").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded identifier".into()));
}
#[test]
fn test_list() {
let mut ctx = basic_ctx();
ctx.add_fn(
"foo",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
// def is reserved...
ctx.assign("efg", ctx.get_primitive(INT32_TYPE)).unwrap();
ctx.assign("xyz", ctx.get_primitive(FLOAT_TYPE)).unwrap();
let ast = parse_expression("[]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![BotType.into()]).into()
);
let ast = parse_expression("[abc]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[abc, efg]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[abc, efg, xyz]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("inhomogeneous list is not allowed".into()));
let ast = parse_expression("[foo()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("list elements must have some type".into()));
}
#[test]
fn test_tuple() {
let mut ctx = basic_ctx();
ctx.add_fn(
"foo",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
ctx.assign("efg", ctx.get_primitive(FLOAT_TYPE)).unwrap();
let ast = parse_expression("(abc, efg)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(
TUPLE_TYPE,
vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)]
)
.into()
);
let ast = parse_expression("(abc, efg, foo())").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("tuple elements must have some type".into()));
}
#[test]
fn test_attribute() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let int32 = ctx.get_primitive(INT32_TYPE);
let float = ctx.get_primitive(FLOAT_TYPE);
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let foo_def = ctx.get_class_def_mut(foo);
foo_def.base.fields.insert("a", int32.clone());
foo_def.base.fields.insert("b", ClassType(foo).into());
foo_def.base.fields.insert("c", int32.clone());
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "Bar",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let bar_def = ctx.get_class_def_mut(bar);
bar_def.base.fields.insert("a", int32);
bar_def.base.fields.insert("b", ClassType(bar).into());
bar_def.base.fields.insert("c", float);
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![],
});
let v1 = ctx.add_variable(VarDef {
name: "v1",
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
.unwrap();
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
ctx.assign("bot", Rc::new(BotType)).unwrap();
let ast = parse_expression("foo.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("foo.d").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such field".into()));
let ast = parse_expression("foobar.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("v0.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no fields for type variable".into()));
let ast = parse_expression("v1.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no fields for type variable".into()));
let ast = parse_expression("none().a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("bot.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such field".into()));
}
#[test]
fn test_bool_ops() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("True and False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("True and none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("True and 123").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("bool operands must be bool".into()));
}
#[test]
fn test_bin_ops() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("1 + 2 + 3").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("a + a + a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
}
#[test]
fn test_unary_ops() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("-(123)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("-a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("not True").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("not (1)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("logical not must be applied to bool".into()));
}
#[test]
fn test_compare() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("a == a == a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("a == a == 1").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("True > False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such function".into()));
let ast = parse_expression("True in False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unsupported comparison".into()));
}
#[test]
fn test_call() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let foo_def = ctx.get_class_def_mut(foo);
foo_def.base.methods.insert(
"a",
FnDef {
args: vec![],
result: Some(Rc::new(ClassType(foo))),
},
);
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "Bar",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let bar_def = ctx.get_class_def_mut(bar);
bar_def.base.methods.insert(
"a",
FnDef {
args: vec![],
result: Some(Rc::new(ClassType(bar))),
},
);
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![],
});
let v1 = ctx.add_variable(VarDef {
name: "v1",
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
});
let v2 = ctx.add_variable(VarDef {
name: "v2",
bound: vec![
ClassType(foo).into(),
ClassType(bar).into(),
ctx.get_primitive(INT32_TYPE),
],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
.unwrap();
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
ctx.assign("v2", ctx.get_variable(v2)).unwrap();
ctx.assign("bot", Rc::new(BotType)).unwrap();
let ast = parse_expression("foo.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
let ast = parse_expression("v1.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("foobar.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
let ast = parse_expression("none().a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("bot.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("[][0].a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
}
#[test]
fn infer_subscript() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("[1, 2, 3][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("[[1]][0][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("[1, 2, 3][1:2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[1, 2, 3][1:2:2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[1, 2, 3][1:1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("slice must be int32 type".into()));
let ast = parse_expression("[1, 2, 3][1:none()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("slice must have type".into()));
let ast = parse_expression("[1, 2, 3][1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("index must be either slice or int32".into()));
let ast = parse_expression("[1, 2, 3][none()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("none()[1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("123[1]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("subscript is not supported for types other than list".into())
);
}
#[test]
fn test_if_expr() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("1 if True else 0").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("none() if True else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap(), None);
let ast = parse_expression("none() if 1 else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("test should be bool".into()));
let ast = parse_expression("1 if True else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("divergent type".into()));
}
#[test]
fn test_list_comp() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let int32 = ctx.get_primitive(INT32_TYPE);
let mut ctx = get_inference_context(ctx);
ctx.assign("z", int32.clone()).unwrap();
let ast = parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into()
);
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), int32);
let ast =
parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), int32);
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("test must be bool".into()));
let ast = parse_expression("[y for x in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded identifier".into()));
let ast = parse_expression("[none() for x in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("[z for z in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("duplicated naming".into()));
let ast = parse_expression("[x for x in [] for y in []]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("only 1 generator statement is supported".into())
);
}
}

View File

@ -1,525 +0,0 @@
use crate::context::InferenceContext;
use crate::typedef::{TypeEnum::*, *};
use std::collections::HashMap;
fn find_subst(
ctx: &InferenceContext,
valuation: &Option<(VariableId, Type)>,
sub: &mut HashMap<VariableId, Type>,
mut a: Type,
mut b: Type,
) -> Result<(), String> {
// TODO: fix error messages later
if let TypeVariable(id) = a.as_ref() {
if let Some((assumption_id, t)) = valuation {
if assumption_id == id {
a = t.clone();
}
}
}
let mut substituted = false;
if let TypeVariable(id) = b.as_ref() {
if let Some(c) = sub.get(&id) {
b = c.clone();
substituted = true;
}
}
match (a.as_ref(), b.as_ref()) {
(BotType, _) => Ok(()),
(TypeVariable(id_a), TypeVariable(id_b)) => {
if substituted {
return if id_a == id_b {
Ok(())
} else {
Err("different variables".to_string())
};
}
let v_a = ctx.get_variable_def(*id_a);
let v_b = ctx.get_variable_def(*id_b);
if !v_b.bound.is_empty() {
if v_a.bound.is_empty() {
return Err("unbounded a".to_string());
} else {
let diff: Vec<_> = v_a
.bound
.iter()
.filter(|x| !v_b.bound.contains(x))
.collect();
if !diff.is_empty() {
return Err("different domain".to_string());
}
}
}
sub.insert(*id_b, a.clone());
Ok(())
}
(TypeVariable(id_a), _) => {
let v_a = ctx.get_variable_def(*id_a);
if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() {
Ok(())
} else {
Err("different domain".to_string())
}
}
(_, TypeVariable(id_b)) => {
let v_b = ctx.get_variable_def(*id_b);
if v_b.bound.is_empty() || v_b.bound.contains(&a) {
sub.insert(*id_b, a.clone());
Ok(())
} else {
Err("different domain".to_string())
}
}
(_, VirtualClassType(id_b)) => {
let mut parents;
match a.as_ref() {
ClassType(id_a) => {
parents = [*id_a].to_vec();
}
VirtualClassType(id_a) => {
parents = [*id_a].to_vec();
}
_ => {
return Err("cannot substitute non-class type into virtual class".to_string());
}
};
while !parents.is_empty() {
if *id_b == parents[0] {
return Ok(());
}
let c = ctx.get_class_def(parents.remove(0));
parents.extend_from_slice(&c.parents);
}
Err("not subtype".to_string())
}
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
if id_a != id_b || param_a.len() != param_b.len() {
Err("different parametric types".to_string())
} else {
for (x, y) in param_a.iter().zip(param_b.iter()) {
find_subst(ctx, valuation, sub, x.clone(), y.clone())?;
}
Ok(())
}
}
(_, _) => {
if a == b {
Ok(())
} else {
Err("not equal".to_string())
}
}
}
}
fn resolve_call_rec(
ctx: &InferenceContext,
valuation: &Option<(VariableId, Type)>,
obj: Option<Type>,
func: &str,
args: &[Type],
) -> Result<Option<Type>, String> {
let mut subst = obj
.as_ref()
.map(|v| v.get_subst(ctx))
.unwrap_or_else(HashMap::new);
let fun = match &obj {
Some(obj) => {
let base = match obj.as_ref() {
PrimitiveType(id) => &ctx.get_primitive_def(*id),
ClassType(id) | VirtualClassType(id) => &ctx.get_class_def(*id).base,
ParametricType(id, _) => &ctx.get_parametric_def(*id).base,
_ => return Err("not supported".to_string()),
};
base.methods.get(func)
}
None => ctx.get_fn_def(func),
}
.ok_or_else(|| "no such function".to_string())?;
if args.len() != fun.args.len() {
return Err("incorrect parameter number".to_string());
}
for (a, b) in args.iter().zip(fun.args.iter()) {
find_subst(ctx, valuation, &mut subst, a.clone(), b.clone())?;
}
let result = fun.result.as_ref().map(|v| v.subst(&subst));
Ok(result.map(|result| {
if let SelfType = result {
obj.unwrap()
} else {
result.into()
}
}))
}
pub fn resolve_call(
ctx: &InferenceContext,
obj: Option<Type>,
func: &str,
args: &[Type],
) -> Result<Option<Type>, String> {
resolve_call_rec(ctx, &None, obj, func, args)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::TopLevelContext;
use crate::primitives::*;
use std::rc::Rc;
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
}
#[test]
fn test_simple_generic() {
let mut ctx = basic_ctx();
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)],
});
let v1 = ctx.get_variable(v1);
let v2 = ctx.add_variable(VarDef {
name: "V2",
bound: vec![
ctx.get_primitive(BOOL_TYPE),
ctx.get_primitive(INT32_TYPE),
ctx.get_primitive(FLOAT_TYPE),
],
});
let v2 = ctx.get_variable(v2);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(FLOAT_TYPE)]),
Ok(Some(ctx.get_primitive(INT32_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(INT32_TYPE)],),
Ok(Some(ctx.get_primitive(INT32_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "float", &[ctx.get_primitive(INT32_TYPE)]),
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "float", &[ctx.get_primitive(BOOL_TYPE)]),
Err("different domain".to_string())
);
assert_eq!(
resolve_call(&ctx, None, "float", &[]),
Err("incorrect parameter number".to_string())
);
assert_eq!(
resolve_call(&ctx, None, "float", &[v1]),
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "float", &[v2]),
Err("different domain".to_string())
);
}
#[test]
fn test_methods() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "V0",
bound: vec![],
});
let v0 = ctx.get_variable(v0);
let int32 = ctx.get_primitive(INT32_TYPE);
let int64 = ctx.get_primitive(INT64_TYPE);
let ctx = get_inference_context(ctx);
// simple cases
assert_eq!(
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
Ok(Some(int32.clone()))
);
assert_ne!(
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
Ok(Some(int64.clone()))
);
assert_eq!(
resolve_call(&ctx, Some(int32), "__add__", &[int64]),
Err("not equal".to_string())
);
// with type variables
assert_eq!(
resolve_call(&ctx, Some(v0.clone()), "__add__", &[v0.clone()]),
Err("not supported".into())
);
}
#[test]
fn test_multi_generic() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "V0",
bound: vec![],
});
let v0 = ctx.get_variable(v0);
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![],
});
let v1 = ctx.get_variable(v1);
let v2 = ctx.add_variable(VarDef {
name: "V2",
bound: vec![],
});
let v2 = ctx.get_variable(v2);
let v3 = ctx.add_variable(VarDef {
name: "V3",
bound: vec![],
});
let v3 = ctx.get_variable(v3);
ctx.add_fn(
"foo",
FnDef {
args: vec![v0.clone(), v0.clone(), v1.clone()],
result: Some(v0.clone()),
},
);
ctx.add_fn(
"foo1",
FnDef {
args: vec![ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1]).into()],
result: Some(v0),
},
);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]),
Err("different variables".to_string())
);
assert_eq!(
resolve_call(
&ctx,
None,
"foo1",
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()]
),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(
&ctx,
None,
"foo1",
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()]
),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(
&ctx,
None,
"foo1",
&[ParametricType(TUPLE_TYPE, vec![v2, v3.clone(), v3]).into()]
),
Err("different variables".to_string())
);
}
#[test]
fn test_class_generics() {
let mut ctx = basic_ctx();
let list = ctx.get_parametric_def_mut(LIST_TYPE);
let t = Rc::new(TypeVariable(list.params[0]));
list.base.methods.insert(
"head",
FnDef {
args: vec![],
result: Some(t.clone()),
},
);
list.base.methods.insert(
"append",
FnDef {
args: vec![t],
result: None,
},
);
let v0 = ctx.add_variable(VarDef {
name: "V0",
bound: vec![],
});
let v0 = ctx.get_variable(v0);
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![],
});
let v1 = ctx.get_variable(v1);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(
&ctx,
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
"head",
&[]
),
Ok(Some(v0.clone()))
);
assert_eq!(
resolve_call(
&ctx,
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
"append",
&[v0.clone()]
),
Ok(None)
);
assert_eq!(
resolve_call(
&ctx,
Some(ParametricType(LIST_TYPE, vec![v0]).into()),
"append",
&[v1]
),
Err("different variables".to_string())
);
}
#[test]
fn test_virtual_class() {
let mut ctx = basic_ctx();
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![],
});
let foo1 = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo1",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![foo],
});
let foo2 = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo2",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![foo1],
});
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "bar",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![],
});
ctx.add_fn(
"foo",
FnDef {
args: vec![VirtualClassType(foo).into()],
result: None,
},
);
ctx.add_fn(
"foo1",
FnDef {
args: vec![VirtualClassType(foo1).into()],
result: None,
},
);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]),
Err("not subtype".to_string())
);
assert_eq!(
resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]),
Err("not subtype".to_string())
);
// virtual class substitution
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]),
Err("not subtype".to_string())
);
}
}

View File

@ -1,593 +1,8 @@
#![warn(clippy::all)] #![warn(clippy::all)]
#![allow(clippy::clone_double_ref)] #![allow(dead_code)]
extern crate num_bigint; mod codegen;
extern crate inkwell; mod location;
extern crate rustpython_parser; mod symbol_resolver;
mod top_level;
pub mod expression_inference; mod typecheck;
pub mod inference_core;
mod magic_methods;
pub mod primitives;
pub mod typedef;
pub mod context;
use std::error::Error;
use std::fmt;
use std::path::Path;
use std::collections::HashMap;
use num_traits::cast::ToPrimitive;
use rustpython_parser::ast;
use inkwell::OptimizationLevel;
use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::module::Module;
use inkwell::targets::*;
use inkwell::types;
use inkwell::types::BasicType;
use inkwell::values;
use inkwell::{IntPredicate, FloatPredicate};
use inkwell::basic_block;
use inkwell::passes;
#[derive(Debug)]
enum CompileErrorKind {
Unsupported(&'static str),
MissingTypeAnnotation,
UnknownTypeAnnotation,
IncompatibleTypes,
UnboundIdentifier,
BreakOutsideLoop,
Internal(&'static str)
}
impl fmt::Display for CompileErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CompileErrorKind::Unsupported(feature)
=> write!(f, "The following Python feature is not supported by NAC3: {}", feature),
CompileErrorKind::MissingTypeAnnotation
=> write!(f, "Missing type annotation"),
CompileErrorKind::UnknownTypeAnnotation
=> write!(f, "Unknown type annotation"),
CompileErrorKind::IncompatibleTypes
=> write!(f, "Incompatible types"),
CompileErrorKind::UnboundIdentifier
=> write!(f, "Unbound identifier"),
CompileErrorKind::BreakOutsideLoop
=> write!(f, "Break outside loop"),
CompileErrorKind::Internal(details)
=> write!(f, "Internal compiler error: {}", details),
}
}
}
#[derive(Debug)]
pub struct CompileError {
location: ast::Location,
kind: CompileErrorKind,
}
impl fmt::Display for CompileError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}, at {}", self.kind, self.location)
}
}
impl Error for CompileError {}
type CompileResult<T> = Result<T, CompileError>;
pub struct CodeGen<'ctx> {
context: &'ctx Context,
module: Module<'ctx>,
pass_manager: passes::PassManager<values::FunctionValue<'ctx>>,
builder: Builder<'ctx>,
current_source_location: ast::Location,
namespace: HashMap<String, values::PointerValue<'ctx>>,
break_bb: Option<basic_block::BasicBlock<'ctx>>,
}
impl<'ctx> CodeGen<'ctx> {
pub fn new(context: &'ctx Context) -> CodeGen<'ctx> {
let module = context.create_module("kernel");
let pass_manager = passes::PassManager::create(&module);
pass_manager.add_instruction_combining_pass();
pass_manager.add_reassociate_pass();
pass_manager.add_gvn_pass();
pass_manager.add_cfg_simplification_pass();
pass_manager.add_basic_alias_analysis_pass();
pass_manager.add_promote_memory_to_register_pass();
pass_manager.add_instruction_combining_pass();
pass_manager.add_reassociate_pass();
pass_manager.initialize();
let i32_type = context.i32_type();
let fn_type = i32_type.fn_type(&[i32_type.into()], false);
module.add_function("output", fn_type, None);
CodeGen {
context, module, pass_manager,
builder: context.create_builder(),
current_source_location: ast::Location::default(),
namespace: HashMap::new(),
break_bb: None,
}
}
fn set_source_location(&mut self, location: ast::Location) {
self.current_source_location = location;
}
fn compile_error(&self, kind: CompileErrorKind) -> CompileError {
CompileError {
location: self.current_source_location,
kind
}
}
fn get_basic_type(&self, name: &str) -> CompileResult<types::BasicTypeEnum<'ctx>> {
match name {
"bool" => Ok(self.context.bool_type().into()),
"int32" => Ok(self.context.i32_type().into()),
"int64" => Ok(self.context.i64_type().into()),
"float32" => Ok(self.context.f32_type().into()),
"float64" => Ok(self.context.f64_type().into()),
_ => Err(self.compile_error(CompileErrorKind::UnknownTypeAnnotation))
}
}
fn compile_function_def(
&mut self,
name: &str,
args: &ast::Parameters,
body: &ast::Suite,
decorator_list: &[ast::Expression],
returns: &Option<ast::Expression>,
is_async: bool,
) -> CompileResult<values::FunctionValue<'ctx>> {
if is_async {
return Err(self.compile_error(CompileErrorKind::Unsupported("async functions")))
}
for decorator in decorator_list.iter() {
self.set_source_location(decorator.location);
if let ast::ExpressionType::Identifier { name } = &decorator.node {
if name != "kernel" && name != "portable" {
return Err(self.compile_error(CompileErrorKind::Unsupported("custom decorators")))
}
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("decorator must be an identifier")))
}
}
let args_type = args.args.iter().map(|val| {
self.set_source_location(val.location);
if let Some(annotation) = &val.annotation {
if let ast::ExpressionType::Identifier { name } = &annotation.node {
Ok(self.get_basic_type(&name)?)
} else {
Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
}
} else {
Err(self.compile_error(CompileErrorKind::MissingTypeAnnotation))
}
}).collect::<CompileResult<Vec<types::BasicTypeEnum>>>()?;
let return_type = if let Some(returns) = returns {
self.set_source_location(returns.location);
if let ast::ExpressionType::Identifier { name } = &returns.node {
if name == "None" { None } else { Some(self.get_basic_type(name)?) }
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
}
} else {
None
};
let fn_type = match return_type {
Some(ty) => ty.fn_type(&args_type, false),
None => self.context.void_type().fn_type(&args_type, false)
};
let function = self.module.add_function(name, fn_type, None);
let basic_block = self.context.append_basic_block(function, "entry");
self.builder.position_at_end(basic_block);
for (n, arg) in args.args.iter().enumerate() {
let param = function.get_nth_param(n as u32).unwrap();
let alloca = self.builder.build_alloca(param.get_type(), &arg.arg);
self.builder.build_store(alloca, param);
self.namespace.insert(arg.arg.clone(), alloca);
}
self.compile_suite(body, return_type)?;
Ok(function)
}
fn compile_expression(
&mut self,
expression: &ast::Expression
) -> CompileResult<values::BasicValueEnum<'ctx>> {
self.set_source_location(expression.location);
match &expression.node {
ast::ExpressionType::True => Ok(self.context.bool_type().const_int(1, false).into()),
ast::ExpressionType::False => Ok(self.context.bool_type().const_int(0, false).into()),
ast::ExpressionType::Number { value: ast::Number::Integer { value } } => {
let mut bits = value.bits();
if value.sign() == num_bigint::Sign::Minus {
bits += 1;
}
match bits {
0..=32 => Ok(self.context.i32_type().const_int(value.to_i32().unwrap() as _, true).into()),
33..=64 => Ok(self.context.i64_type().const_int(value.to_i64().unwrap() as _, true).into()),
_ => Err(self.compile_error(CompileErrorKind::Unsupported("integers larger than 64 bits")))
}
},
ast::ExpressionType::Number { value: ast::Number::Float { value } } => {
Ok(self.context.f64_type().const_float(*value).into())
},
ast::ExpressionType::Identifier { name } => {
match self.namespace.get(name) {
Some(value) => Ok(self.builder.build_load(*value, name).into()),
None => Err(self.compile_error(CompileErrorKind::UnboundIdentifier))
}
},
ast::ExpressionType::Unop { op, a } => {
let a = self.compile_expression(&a)?;
match (op, a) {
(ast::UnaryOperator::Pos, values::BasicValueEnum::IntValue(a))
=> Ok(a.into()),
(ast::UnaryOperator::Pos, values::BasicValueEnum::FloatValue(a))
=> Ok(a.into()),
(ast::UnaryOperator::Neg, values::BasicValueEnum::IntValue(a))
=> Ok(self.builder.build_int_neg(a, "tmpneg").into()),
(ast::UnaryOperator::Neg, values::BasicValueEnum::FloatValue(a))
=> Ok(self.builder.build_float_neg(a, "tmpneg").into()),
(ast::UnaryOperator::Inv, values::BasicValueEnum::IntValue(a))
=> Ok(self.builder.build_not(a, "tmpnot").into()),
(ast::UnaryOperator::Not, values::BasicValueEnum::IntValue(a)) => {
// boolean "not"
if a.get_type().get_bit_width() != 1 {
Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation")))
} else {
Ok(self.builder.build_not(a, "tmpnot").into())
}
},
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation"))),
}
},
ast::ExpressionType::Binop { a, op, b } => {
let a = self.compile_expression(&a)?;
let b = self.compile_expression(&b)?;
if a.get_type() != b.get_type() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
use ast::Operator::*;
match (op, a, b) {
(Add, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_add(a, b, "tmpadd").into()),
(Sub, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_sub(a, b, "tmpsub").into()),
(Mult, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_mul(a, b, "tmpmul").into()),
(Add, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_add(a, b, "tmpadd").into()),
(Sub, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_sub(a, b, "tmpsub").into()),
(Mult, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_mul(a, b, "tmpmul").into()),
(Div, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_div(a, b, "tmpdiv").into()),
(FloorDiv, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_signed_div(a, b, "tmpdiv").into()),
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented binary operation"))),
}
},
ast::ExpressionType::Compare { vals, ops } => {
let mut vals = vals.iter();
let mut ops = ops.iter();
let mut result = None;
let mut a = self.compile_expression(vals.next().unwrap())?;
loop {
if let Some(op) = ops.next() {
let b = self.compile_expression(vals.next().unwrap())?;
if a.get_type() != b.get_type() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let this_result = match (a, b) {
(values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) => {
match op {
ast::Comparison::Equal
=> self.builder.build_int_compare(IntPredicate::EQ, a, b, "tmpeq"),
ast::Comparison::NotEqual
=> self.builder.build_int_compare(IntPredicate::NE, a, b, "tmpne"),
ast::Comparison::Less
=> self.builder.build_int_compare(IntPredicate::SLT, a, b, "tmpslt"),
ast::Comparison::LessOrEqual
=> self.builder.build_int_compare(IntPredicate::SLE, a, b, "tmpsle"),
ast::Comparison::Greater
=> self.builder.build_int_compare(IntPredicate::SGT, a, b, "tmpsgt"),
ast::Comparison::GreaterOrEqual
=> self.builder.build_int_compare(IntPredicate::SGE, a, b, "tmpsge"),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
}
},
(values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) => {
match op {
ast::Comparison::Equal
=> self.builder.build_float_compare(FloatPredicate::OEQ, a, b, "tmpoeq"),
ast::Comparison::NotEqual
=> self.builder.build_float_compare(FloatPredicate::UNE, a, b, "tmpune"),
ast::Comparison::Less
=> self.builder.build_float_compare(FloatPredicate::OLT, a, b, "tmpolt"),
ast::Comparison::LessOrEqual
=> self.builder.build_float_compare(FloatPredicate::OLE, a, b, "tmpole"),
ast::Comparison::Greater
=> self.builder.build_float_compare(FloatPredicate::OGT, a, b, "tmpogt"),
ast::Comparison::GreaterOrEqual
=> self.builder.build_float_compare(FloatPredicate::OGE, a, b, "tmpoge"),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
}
},
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("comparison of non-numerical types"))),
};
match result {
Some(last) => {
result = Some(self.builder.build_and(last, this_result, "tmpand"));
}
None => {
result = Some(this_result);
}
}
a = b;
} else {
return Ok(result.unwrap().into())
}
}
},
ast::ExpressionType::Call { function, args, keywords } => {
if !keywords.is_empty() {
return Err(self.compile_error(CompileErrorKind::Unsupported("keyword arguments")))
}
let args = args.iter().map(|val| self.compile_expression(val))
.collect::<CompileResult<Vec<values::BasicValueEnum>>>()?;
self.set_source_location(expression.location);
if let ast::ExpressionType::Identifier { name } = &function.node {
match (name.as_str(), args[0]) {
("int32", values::BasicValueEnum::IntValue(a)) => {
let nbits = a.get_type().get_bit_width();
if nbits < 32 {
Ok(self.builder.build_int_s_extend(a, self.context.i32_type(), "tmpsext").into())
} else if nbits > 32 {
Ok(self.builder.build_int_truncate(a, self.context.i32_type(), "tmptrunc").into())
} else {
Ok(a.into())
}
},
("int64", values::BasicValueEnum::IntValue(a)) => {
let nbits = a.get_type().get_bit_width();
if nbits < 64 {
Ok(self.builder.build_int_s_extend(a, self.context.i64_type(), "tmpsext").into())
} else {
Ok(a.into())
}
},
("int32", values::BasicValueEnum::FloatValue(a)) => {
Ok(self.builder.build_float_to_signed_int(a, self.context.i32_type(), "tmpfptosi").into())
},
("int64", values::BasicValueEnum::FloatValue(a)) => {
Ok(self.builder.build_float_to_signed_int(a, self.context.i64_type(), "tmpfptosi").into())
},
("float32", values::BasicValueEnum::IntValue(a)) => {
Ok(self.builder.build_signed_int_to_float(a, self.context.f32_type(), "tmpsitofp").into())
},
("float64", values::BasicValueEnum::IntValue(a)) => {
Ok(self.builder.build_signed_int_to_float(a, self.context.f64_type(), "tmpsitofp").into())
},
("float32", values::BasicValueEnum::FloatValue(a)) => {
if a.get_type() == self.context.f64_type() {
Ok(self.builder.build_float_trunc(a, self.context.f32_type(), "tmptrunc").into())
} else {
Ok(a.into())
}
},
("float64", values::BasicValueEnum::FloatValue(a)) => {
if a.get_type() == self.context.f32_type() {
Ok(self.builder.build_float_ext(a, self.context.f64_type(), "tmpext").into())
} else {
Ok(a.into())
}
},
("output", values::BasicValueEnum::IntValue(a)) => {
let fn_value = self.module.get_function("output").unwrap();
Ok(self.builder.build_call(fn_value, &[a.into()], "call")
.try_as_basic_value().left().unwrap())
},
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unrecognized call")))
}
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("function must be an identifier")))
}
},
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented expression"))),
}
}
fn compile_statement(
&mut self,
statement: &ast::Statement,
return_type: Option<types::BasicTypeEnum>
) -> CompileResult<()> {
self.set_source_location(statement.location);
use ast::StatementType::*;
match &statement.node {
Assign { targets, value } => {
let value = self.compile_expression(value)?;
for target in targets.iter() {
self.set_source_location(target.location);
if let ast::ExpressionType::Identifier { name } = &target.node {
let builder = &self.builder;
let target = self.namespace.entry(name.clone()).or_insert_with(
|| builder.build_alloca(value.get_type(), name));
if target.get_type() != value.get_type().ptr_type(inkwell::AddressSpace::Generic) {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
builder.build_store(*target, value);
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("assignment target must be an identifier")))
}
}
},
Expression { expression } => { self.compile_expression(expression)?; },
If { test, body, orelse } => {
let test = self.compile_expression(test)?;
if test.get_type() != self.context.bool_type().into() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_bb = self.context.append_basic_block(parent, "then");
let else_bb = self.context.append_basic_block(parent, "else");
let cont_bb = self.context.append_basic_block(parent, "ifcont");
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
self.builder.position_at_end(then_bb);
self.compile_suite(body, return_type)?;
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(else_bb);
if let Some(orelse) = orelse {
self.compile_suite(orelse, return_type)?;
}
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
},
While { test, body, orelse } => {
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let test_bb = self.context.append_basic_block(parent, "test");
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb);
let test = self.compile_expression(test)?;
if test.get_type() != self.context.bool_type().into() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let then_bb = self.context.append_basic_block(parent, "then");
let else_bb = self.context.append_basic_block(parent, "else");
let cont_bb = self.context.append_basic_block(parent, "ifcont");
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
self.break_bb = Some(cont_bb);
self.builder.position_at_end(then_bb);
self.compile_suite(body, return_type)?;
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(else_bb);
if let Some(orelse) = orelse {
self.compile_suite(orelse, return_type)?;
}
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
self.break_bb = None;
},
Break => {
if let Some(bb) = self.break_bb {
self.builder.build_unconditional_branch(bb);
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let unreachable_bb = self.context.append_basic_block(parent, "unreachable");
self.builder.position_at_end(unreachable_bb);
} else {
return Err(self.compile_error(CompileErrorKind::BreakOutsideLoop));
}
}
Return { value: Some(value) } => {
if let Some(return_type) = return_type {
let value = self.compile_expression(value)?;
if value.get_type() != return_type {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
self.builder.build_return(Some(&value));
} else {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
},
Return { value: None } => {
if !return_type.is_none() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
self.builder.build_return(None);
},
Pass => (),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special statement"))),
}
Ok(())
}
fn compile_suite(
&mut self,
suite: &ast::Suite,
return_type: Option<types::BasicTypeEnum>
) -> CompileResult<()> {
for statement in suite.iter() {
self.compile_statement(statement, return_type)?;
}
Ok(())
}
pub fn compile_toplevel(&mut self, statement: &ast::Statement) -> CompileResult<()> {
self.set_source_location(statement.location);
if let ast::StatementType::FunctionDef {
is_async,
name,
args,
body,
decorator_list,
returns,
} = &statement.node {
let function = self.compile_function_def(name, args, body, decorator_list, returns, *is_async)?;
self.pass_manager.run_on(&function);
Ok(())
} else {
Err(self.compile_error(CompileErrorKind::Internal("top-level is not a function definition")))
}
}
pub fn print_ir(&self) {
self.module.print_to_stderr();
}
pub fn output(&self, filename: &str) {
//let triple = TargetTriple::create("riscv32-none-linux-gnu");
let triple = TargetMachine::get_default_triple();
let target = Target::from_triple(&triple)
.expect("couldn't create target from target triple");
let target_machine = target
.create_target_machine(
&triple,
"",
"",
OptimizationLevel::Default,
RelocMode::Default,
CodeModel::Default,
)
.expect("couldn't create target machine");
target_machine
.write_to_file(&self.module, FileType::Object, Path::new(filename))
.expect("couldn't write module to file");
}
}

31
nac3core/src/location.rs Normal file
View File

@ -0,0 +1,31 @@
use rustpython_parser::ast;
use std::vec::Vec;
#[derive(Clone, Copy, PartialEq)]
pub struct FileID(u32);
#[derive(Clone, Copy, PartialEq)]
pub enum Location {
CodeRange(FileID, ast::Location),
Builtin,
}
pub struct FileRegistry {
files: Vec<String>,
}
impl FileRegistry {
pub fn new() -> FileRegistry {
FileRegistry { files: Vec::new() }
}
pub fn add_file(&mut self, path: &str) -> FileID {
let index = self.files.len() as u32;
self.files.push(path.to_owned());
FileID(index)
}
pub fn query_file(&self, id: FileID) -> &str {
&self.files[id.0 as usize]
}
}

View File

@ -1,58 +0,0 @@
use rustpython_parser::ast::{Comparison, Operator, UnaryOperator};
pub fn binop_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__add__",
Operator::Sub => "__sub__",
Operator::Div => "__truediv__",
Operator::Mod => "__mod__",
Operator::Mult => "__mul__",
Operator::Pow => "__pow__",
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
}
}
pub fn binop_assign_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__iadd__",
Operator::Sub => "__isub__",
Operator::Div => "__itruediv__",
Operator::Mod => "__imod__",
Operator::Mult => "__imul__",
Operator::Pow => "__ipow__",
Operator::BitOr => "__ior__",
Operator::BitXor => "__ixor__",
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
}
pub fn unaryop_name(op: &UnaryOperator) -> &'static str {
match op {
UnaryOperator::Pos => "__pos__",
UnaryOperator::Neg => "__neg__",
UnaryOperator::Not => "__not__",
UnaryOperator::Inv => "__inv__",
}
}
pub fn comparison_name(op: &Comparison) -> Option<&'static str> {
match op {
Comparison::Less => Some("__lt__"),
Comparison::LessOrEqual => Some("__le__"),
Comparison::Greater => Some("__gt__"),
Comparison::GreaterOrEqual => Some("__ge__"),
Comparison::Equal => Some("__eq__"),
Comparison::NotEqual => Some("__ne__"),
_ => None,
}
}

View File

@ -1,184 +0,0 @@
use super::typedef::{TypeEnum::*, *};
use crate::context::*;
use std::collections::HashMap;
pub const TUPLE_TYPE: ParamId = ParamId(0);
pub const LIST_TYPE: ParamId = ParamId(1);
pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0);
pub const INT32_TYPE: PrimitiveId = PrimitiveId(1);
pub const INT64_TYPE: PrimitiveId = PrimitiveId(2);
pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(3);
fn impl_math(def: &mut TypeDef, ty: &Type) {
let result = Some(ty.clone());
let fun = FnDef {
args: vec![ty.clone()],
result: result.clone(),
};
def.methods.insert("__add__", fun.clone());
def.methods.insert("__sub__", fun.clone());
def.methods.insert("__mul__", fun.clone());
def.methods.insert(
"__neg__",
FnDef {
args: vec![],
result,
},
);
def.methods.insert(
"__truediv__",
FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(FLOAT_TYPE).into()),
},
);
def.methods.insert("__floordiv__", fun.clone());
def.methods.insert("__mod__", fun.clone());
def.methods.insert("__pow__", fun);
}
fn impl_bits(def: &mut TypeDef, ty: &Type) {
let result = Some(ty.clone());
let fun = FnDef {
args: vec![PrimitiveType(INT32_TYPE).into()],
result,
};
def.methods.insert("__lshift__", fun.clone());
def.methods.insert("__rshift__", fun);
def.methods.insert(
"__xor__",
FnDef {
args: vec![ty.clone()],
result: Some(ty.clone()),
},
);
}
fn impl_eq(def: &mut TypeDef, ty: &Type) {
let fun = FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(BOOL_TYPE).into()),
};
def.methods.insert("__eq__", fun.clone());
def.methods.insert("__ne__", fun);
}
fn impl_order(def: &mut TypeDef, ty: &Type) {
let fun = FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(BOOL_TYPE).into()),
};
def.methods.insert("__lt__", fun.clone());
def.methods.insert("__gt__", fun.clone());
def.methods.insert("__le__", fun.clone());
def.methods.insert("__ge__", fun);
}
pub fn basic_ctx() -> TopLevelContext<'static> {
let primitives = [
TypeDef {
name: "bool",
fields: HashMap::new(),
methods: HashMap::new(),
},
TypeDef {
name: "int32",
fields: HashMap::new(),
methods: HashMap::new(),
},
TypeDef {
name: "int64",
fields: HashMap::new(),
methods: HashMap::new(),
},
TypeDef {
name: "float",
fields: HashMap::new(),
methods: HashMap::new(),
},
]
.to_vec();
let mut ctx = TopLevelContext::new(primitives);
let b = ctx.get_primitive(BOOL_TYPE);
let b_def = ctx.get_primitive_def_mut(BOOL_TYPE);
impl_eq(b_def, &b);
let int32 = ctx.get_primitive(INT32_TYPE);
let int32_def = ctx.get_primitive_def_mut(INT32_TYPE);
impl_math(int32_def, &int32);
impl_bits(int32_def, &int32);
impl_order(int32_def, &int32);
impl_eq(int32_def, &int32);
let int64 = ctx.get_primitive(INT64_TYPE);
let int64_def = ctx.get_primitive_def_mut(INT64_TYPE);
impl_math(int64_def, &int64);
impl_bits(int64_def, &int64);
impl_order(int64_def, &int64);
impl_eq(int64_def, &int64);
let float = ctx.get_primitive(FLOAT_TYPE);
let float_def = ctx.get_primitive_def_mut(FLOAT_TYPE);
impl_math(float_def, &float);
impl_order(float_def, &float);
impl_eq(float_def, &float);
let t = ctx.add_variable_private(VarDef {
name: "T",
bound: vec![],
});
ctx.add_parametric(ParametricDef {
base: TypeDef {
name: "tuple",
fields: HashMap::new(),
methods: HashMap::new(),
},
// we have nothing for tuple, so no param def
params: vec![],
});
ctx.add_parametric(ParametricDef {
base: TypeDef {
name: "list",
fields: HashMap::new(),
methods: HashMap::new(),
},
params: vec![t],
});
let i = ctx.add_variable_private(VarDef {
name: "I",
bound: vec![
PrimitiveType(INT32_TYPE).into(),
PrimitiveType(INT64_TYPE).into(),
PrimitiveType(FLOAT_TYPE).into(),
],
});
let args = vec![TypeVariable(i).into()];
ctx.add_fn(
"int32",
FnDef {
args: args.clone(),
result: Some(PrimitiveType(INT32_TYPE).into()),
},
);
ctx.add_fn(
"int64",
FnDef {
args: args.clone(),
result: Some(PrimitiveType(INT64_TYPE).into()),
},
);
ctx.add_fn(
"float",
FnDef {
args,
result: Some(PrimitiveType(FLOAT_TYPE).into()),
},
);
ctx
}

View File

@ -0,0 +1,174 @@
use std::cell::RefCell;
use std::collections::HashMap;
use crate::top_level::{DefinitionId, TopLevelContext, TopLevelDef};
use crate::typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, Unifier},
};
use crate::{location::Location, typecheck::typedef::TypeEnum};
use itertools::{chain, izip};
use rustpython_parser::ast::Expr;
#[derive(Clone, PartialEq)]
pub enum SymbolValue {
I32(i32),
I64(i64),
Double(f64),
Bool(bool),
Tuple(Vec<SymbolValue>),
// we should think about how to implement bytes later...
// Bytes(&'a [u8]),
}
pub trait SymbolResolver {
// get type of type variable identifier or top-level function type
fn get_symbol_type(
&self,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
str: &str,
) -> Option<Type>;
// get the top-level definition of identifiers
fn get_identifier_def(&self, str: &str) -> Option<DefinitionId>;
fn get_symbol_value(&self, str: &str) -> Option<SymbolValue>;
fn get_symbol_location(&self, str: &str) -> Option<Location>;
// handle function call etc.
}
// convert type annotation into type
pub fn parse_type_annotation<T>(
resolver: &dyn SymbolResolver,
top_level: &TopLevelContext,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, String> {
use rustpython_parser::ast::ExprKind::*;
match &expr.node {
Name { id, .. } => match id.as_str() {
"int32" => Ok(primitives.int32),
"int64" => Ok(primitives.int64),
"float" => Ok(primitives.float),
"bool" => Ok(primitives.bool),
"None" => Ok(primitives.none),
x => {
let obj_id = resolver.get_identifier_def(x);
if let Some(obj_id) = obj_id {
let defs = top_level.definitions.read();
let def = defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() {
return Err(format!(
"Unexpected number of type parameters: expected {} but got 0",
type_vars.len()
));
}
let fields = RefCell::new(
chain(
fields.iter().map(|(k, v)| (k.clone(), *v)),
methods.iter().map(|(k, v, _)| (k.clone(), *v)),
)
.collect(),
);
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields,
params: Default::default(),
}))
} else {
Err("Cannot use function name as type".into())
}
} else {
// it could be a type variable
let ty = resolver
.get_symbol_type(unifier, primitives, x)
.ok_or_else(|| "Cannot use function name as type".to_owned())?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty)
} else {
Err(format!("Unknown type annotation {}", x))
}
}
}
},
Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node {
if id == "virtual" {
let ty =
parse_type_annotation(resolver, top_level, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|v| {
parse_type_annotation(resolver, top_level, unifier, primitives, v)
})
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_type_annotation(
resolver, top_level, unifier, primitives, slice,
)?]
};
let obj_id = resolver
.get_identifier_def(id)
.ok_or_else(|| format!("Unknown type annotation {}", id))?;
let defs = top_level.definitions.read();
let def = defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() {
return Err(format!(
"Unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
types.len()
));
}
let mut subst = HashMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
};
subst.insert(id, *ty);
}
let mut fields = fields
.iter()
.map(|(attr, ty)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(attr.clone(), ty)
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(attr.clone(), ty)
}));
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields: fields.into(),
params: subst.into(),
}))
} else {
Err("Cannot use function name as type".into())
}
}
} else {
Err("unsupported type expression".into())
}
}
_ => Err("unsupported type expression".into()),
}
}
impl dyn SymbolResolver + Send + Sync {
pub fn parse_type_annotation<T>(
&self,
top_level: &TopLevelContext,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, String> {
parse_type_annotation(self, top_level, unifier, primitives, expr)
}
}

778
nac3core/src/top_level.rs Normal file
View File

@ -0,0 +1,778 @@
use std::borrow::BorrowMut;
use std::ops::{Deref, DerefMut};
use std::{collections::HashMap, collections::HashSet, sync::Arc};
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier};
use crate::typecheck::typedef::{FunSignature, FuncArg};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Mapping};
use itertools::Itertools;
use parking_lot::{Mutex, RwLock};
use rustpython_parser::ast::{self, Stmt};
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct DefinitionId(pub usize);
pub enum TopLevelDef {
Class {
// object ID used for TypeEnum
object_id: DefinitionId,
// type variables bounded to the class.
type_vars: Vec<Type>,
// class fields
fields: Vec<(String, Type)>,
// class methods, pointing to the corresponding function definition.
methods: Vec<(String, Type, DefinitionId)>,
// ancestor classes, including itself.
ancestors: Vec<DefinitionId>,
// symbol resolver of the module defined the class, none if it is built-in type
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
},
Function {
// prefix for symbol, should be unique globally, and not ending with numbers
name: String,
// function signature.
signature: Type,
/// Function instance to symbol mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class.
/// Value: function symbol name.
instance_to_symbol: HashMap<String, String>,
/// Function instances to annotated AST mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type
/// variables.
/// Value: AST annotated with types together with a unification table index. Could contain
/// rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, (Stmt<Option<Type>>, usize)>,
// symbol resolver of the module defined the class
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
},
Initializer {
class_id: DefinitionId,
},
}
impl TopLevelDef {
fn get_function_type(&self) -> Result<Type, String> {
if let Self::Function { signature, .. } = self {
Ok(*signature)
} else {
Err("only expect function def here".into())
}
}
}
pub struct TopLevelContext {
pub definitions: Arc<RwLock<Vec<Arc<RwLock<TopLevelDef>>>>>,
pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>,
}
pub struct TopLevelComposer {
// list of top level definitions, same as top level context
pub definition_ast_list: Arc<RwLock<Vec<(Arc<RwLock<TopLevelDef>>, Option<ast::Stmt<()>>)>>>,
// start as a primitive unifier, will add more top_level defs inside
pub unifier: Unifier,
// primitive store
pub primitives: PrimitiveStore,
// mangled class method name to def_id
pub class_method_to_def_id: HashMap<String, DefinitionId>,
// record the def id of the classes whoses fields and methods are to be analyzed
pub to_be_analyzed_class: Vec<DefinitionId>,
}
impl TopLevelComposer {
pub fn to_top_level_context(&self) -> TopLevelContext {
let def_list =
self.definition_ast_list.read().iter().map(|(x, _)| x.clone()).collect::<Vec<_>>();
TopLevelContext {
definitions: RwLock::new(def_list).into(),
// FIXME: all the big unifier or?
unifiers: Default::default(),
}
}
fn name_mangling(mut class_name: String, method_name: &str) -> String {
class_name.push_str(method_name);
class_name
}
pub fn make_primitives() -> (PrimitiveStore, Unifier) {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let primitives = PrimitiveStore { int32, int64, float, bool, none };
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
(primitives, unifier)
}
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name
pub fn new() -> (Vec<(String, DefinitionId, Type)>, Self) {
let primitives = Self::make_primitives();
let top_level_def_list = vec![
Arc::new(RwLock::new(Self::make_top_level_class_def(0, None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(1, None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(2, None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(3, None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(4, None))),
];
let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None];
let composer = TopLevelComposer {
definition_ast_list: RwLock::new(
top_level_def_list.into_iter().zip(ast_list).collect_vec(),
)
.into(),
primitives: primitives.0,
unifier: primitives.1,
class_method_to_def_id: Default::default(),
to_be_analyzed_class: Default::default(),
};
(
vec![
("int32".into(), DefinitionId(0), composer.primitives.int32),
("int64".into(), DefinitionId(1), composer.primitives.int64),
("float".into(), DefinitionId(2), composer.primitives.float),
("bool".into(), DefinitionId(3), composer.primitives.bool),
("none".into(), DefinitionId(4), composer.primitives.none),
],
composer,
)
}
/// already include the definition_id of itself inside the ancestors vector
/// when first regitering, the type_vars, fields, methods, ancestors are invalid
pub fn make_top_level_class_def(
index: usize,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
) -> TopLevelDef {
TopLevelDef::Class {
object_id: DefinitionId(index),
type_vars: Default::default(),
fields: Default::default(),
methods: Default::default(),
ancestors: vec![DefinitionId(index)],
resolver,
}
}
/// when first registering, the type is a invalid value
pub fn make_top_level_function_def(
name: String,
ty: Type,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
) -> TopLevelDef {
TopLevelDef::Function {
name,
signature: ty,
instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(),
resolver,
}
}
/// step 0, register, just remeber the names of top level classes/function
pub fn register_top_level(
&mut self,
ast: ast::Stmt<()>,
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
) -> Result<(String, DefinitionId), String> {
let mut def_list = self.definition_ast_list.write();
match &ast.node {
ast::StmtKind::ClassDef { name, body, .. } => {
let class_name = name.to_string();
let class_def_id = def_list.len();
// add the class to the definition lists
// since later when registering class method, ast will still be used,
// here push None temporarly, later will move the ast inside
let mut class_def_ast = (
Arc::new(RwLock::new(Self::make_top_level_class_def(
class_def_id,
resolver.clone(),
))),
None,
);
// parse class def body and register class methods into the def list.
// module's symbol resolver would not know the name of the class methods,
// thus cannot return their definition_id
let mut class_method_name_def_ids: Vec<(
String,
Arc<RwLock<TopLevelDef>>,
DefinitionId,
)> = Vec::new();
let mut class_method_index_offset = 0;
for b in body {
if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node {
let method_name = Self::name_mangling(class_name.clone(), method_name);
let method_def_id = def_list.len() + {
class_method_index_offset += 1;
class_method_index_offset
};
// dummy method define here
// the ast of class method is in the class, push None in to the list here
class_method_name_def_ids.push((
method_name.clone(),
RwLock::new(Self::make_top_level_function_def(
method_name.clone(),
self.primitives.none,
resolver.clone(),
))
.into(),
DefinitionId(method_def_id),
));
}
}
// move the ast to the entry of the class in the ast_list
class_def_ast.1 = Some(ast);
// now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order
def_list.push(class_def_ast);
for (name, def, id) in class_method_name_def_ids {
def_list.push((def, None));
self.class_method_to_def_id.insert(name, id);
}
// put the constructor into the def_list
def_list.push((
RwLock::new(TopLevelDef::Initializer { class_id: DefinitionId(class_def_id) })
.into(),
None,
));
// class, put its def_id into the to be analyzed set
self.to_be_analyzed_class.push(DefinitionId(class_def_id));
Ok((class_name, DefinitionId(class_def_id)))
}
ast::StmtKind::FunctionDef { name, .. } => {
let fun_name = name.to_string();
// add to the definition list
def_list.push((
RwLock::new(Self::make_top_level_function_def(
name.into(),
self.primitives.none,
resolver,
))
.into(),
Some(ast),
));
// return
Ok((fun_name, DefinitionId(def_list.len() - 1)))
}
_ => Err("only registrations of top level classes/functions are supprted".into()),
}
}
/// step 1, analyze the type vars associated with top level class
fn analyze_top_level_class_type_var(&mut self) -> Result<(), String> {
let mut def_list = self.definition_ast_list.write();
let converted_top_level = &self.to_top_level_context();
let primitives = &self.primitives;
let unifier = &mut self.unifier;
for (class_def, class_ast) in def_list.iter_mut() {
// only deal with class def here
let mut class_def = class_def.write();
let (class_bases_ast, class_def_type_vars, class_resolver) = {
if let TopLevelDef::Class { type_vars, resolver, .. } = class_def.deref_mut() {
if let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, ..
}) = class_ast
{
(bases, type_vars, resolver)
} else {
unreachable!("must be both class")
}
} else {
continue;
}
};
let class_resolver = class_resolver.as_ref().unwrap().lock();
let mut is_generic = false;
for b in class_bases_ast {
match &b.node {
// analyze typevars bounded to the class,
// only support things like `class A(Generic[T, V])`,
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
// i.e. only simple names are allowed in the subscript
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
ast::ExprKind::Subscript { value, slice, .. } if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Generic") =>
{
if !is_generic {
is_generic = true;
} else {
return Err("Only single Generic[...] can be in bases".into());
}
// if `class A(Generic[T, V, G])`
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
// parse the type vars
let type_vars = elts
.iter()
.map(|e| {
class_resolver.parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
e,
)
})
.collect::<Result<Vec<_>, _>>()?;
// check if all are unique type vars
let mut occured_type_var_id: HashSet<u32> = HashSet::new();
let all_unique_type_var = type_vars.iter().all(|x| {
let ty = unifier.get_ty(*x);
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
occured_type_var_id.insert(*id)
} else {
false
}
});
if !all_unique_type_var {
return Err("expect unique type variables".into());
}
// add to TopLevelDef
class_def_type_vars.extend(type_vars);
// `class A(Generic[T])`
} else {
let ty = class_resolver.parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
&slice,
)?;
// check if it is type var
let is_type_var =
matches!(unifier.get_ty(ty).as_ref(), &TypeEnum::TVar { .. });
if !is_type_var {
return Err("expect type variable here".into());
}
// add to TopLevelDef
class_def_type_vars.push(ty);
}
}
// if others, do nothing in this function
_ => continue,
}
}
}
Ok(())
}
/// step 2, base classes. Need to separate step1 and step2 for this reason:
/// `class B(Generic[T, V]);
/// class A(B[int, bool])`
/// if the type var associated with class `B` has not been handled properly,
/// the parse of type annotation of `B[int, bool]` will fail
fn analyze_top_level_class_bases(&mut self) -> Result<(), String> {
let mut def_list = self.definition_ast_list.write();
let converted_top_level = &self.to_top_level_context();
let primitives = &self.primitives;
let unifier = &mut self.unifier;
for (class_def, class_ast) in def_list.iter_mut() {
let mut class_def = class_def.write();
let (class_bases, class_ancestors, class_resolver) = {
if let TopLevelDef::Class { ancestors, resolver, .. } = class_def.deref_mut() {
if let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, ..
}) = class_ast
{
(bases, ancestors, resolver)
} else {
unreachable!("must be both class")
}
} else {
continue;
}
};
let class_resolver = class_resolver.as_ref().unwrap().lock();
for b in class_bases {
// type vars have already been handled, so skip on `Generic[...]`
if let ast::ExprKind::Subscript { value, .. } = &b.node {
if let ast::ExprKind::Name { id, .. } = &value.node {
if id == "Generic" {
continue;
}
}
}
// get the def id of the base class
let base_ty = class_resolver.parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
b,
)?;
let base_id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(base_ty).as_ref() {
*obj_id
} else {
return Err("expect concrete class/type to be base class".into());
};
// write to the class ancestors, make sure the uniqueness
if !class_ancestors.contains(&base_id) {
class_ancestors.push(base_id);
} else {
return Err("cannot specify the same base class twice".into());
}
}
}
Ok(())
}
/// step 3, class fields and methods
// FIXME: analyze base classes here
// FIXME: deal with self type
// NOTE: prevent cycles only roughly done
fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> {
let mut def_ast_list = self.definition_ast_list.write();
let converted_top_level = &self.to_top_level_context();
let primitives = &self.primitives;
let to_be_analyzed_class = &mut self.to_be_analyzed_class;
let unifier = &mut self.unifier;
// NOTE: roughly prevent infinite loop
let mut max_iter = to_be_analyzed_class.len() * 4;
'class: loop {
if to_be_analyzed_class.is_empty() && {
max_iter -= 1;
max_iter > 0
} {
break;
}
let class_ind = to_be_analyzed_class.remove(0).0;
let (class_name, class_body_ast, class_bases_ast, class_resolver, class_ancestors) = {
let (class_def, class_ast) = &mut def_ast_list[class_ind];
if let Some(ast::Located {
node: ast::StmtKind::ClassDef { name, body, bases, .. },
..
}) = class_ast.as_ref()
{
if let TopLevelDef::Class { resolver, ancestors, .. } =
class_def.write().deref()
{
(name, body, bases, resolver.as_ref().unwrap().clone(), ancestors.clone())
} else {
unreachable!()
}
} else {
unreachable!("should be class def ast")
}
};
let all_base_class_analyzed = {
let not_yet_analyzed =
to_be_analyzed_class.clone().into_iter().collect::<HashSet<_>>();
let base = class_ancestors.clone().into_iter().collect::<HashSet<_>>();
let intersection = not_yet_analyzed.intersection(&base).collect_vec();
intersection.is_empty()
};
if !all_base_class_analyzed {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
}
// get the bases type, can directly do this since it
// already pass the check in the previous stages
let class_bases_ty = class_bases_ast
.iter()
.filter_map(|x| {
class_resolver
.as_ref()
.lock()
.parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
x,
)
.ok()
})
.collect_vec();
// need these vectors to check re-defining methods, class fields
// and store the parsed result in case some method cannot be typed for now
let mut class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![];
let mut class_fields_parsing_result: Vec<(String, Type)> = vec![];
for b in class_body_ast {
if let ast::StmtKind::FunctionDef {
args: method_args_ast,
body: method_body_ast,
name: method_name,
returns: method_returns_ast,
..
} = &b.node
{
let arg_name_tys: Vec<(String, Type)> = {
let mut result = vec![];
for a in &method_args_ast.args {
if a.node.arg != "self" {
let annotation = a
.node
.annotation
.as_ref()
.ok_or_else(|| {
"type annotation for function parameter is needed"
.to_string()
})?
.as_ref();
let ty = class_resolver.as_ref().lock().parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
annotation,
)?;
if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
}
result.push((a.node.arg.to_string(), ty));
} else {
// TODO: handle self, how
unimplemented!()
}
}
result
};
let method_type_var = arg_name_tys
.iter()
.filter_map(|(_, ty)| {
let ty_enum = unifier.get_ty(*ty);
if let TypeEnum::TVar { id, .. } = ty_enum.as_ref() {
Some((*id, *ty))
} else {
None
}
})
.collect::<Mapping<u32>>();
let ret_ty = {
if method_name != "__init__" {
let ty = method_returns_ast
.as_ref()
.map(|x| {
class_resolver.as_ref().lock().parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
x.as_ref(),
)
})
.ok_or_else(|| "return type annotation error".to_string())??;
if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
} else {
ty
}
} else {
// TODO: __init__ function, self type, how
unimplemented!()
}
};
// handle fields
let class_field_name_tys: Option<Vec<(String, Type)>> = if method_name
== "__init__"
{
let mut result: Vec<(String, Type)> = vec![];
for body in method_body_ast {
match &body.node {
ast::StmtKind::AnnAssign { target, annotation, .. }
if {
if let ast::ExprKind::Attribute { value, .. } = &target.node
{
matches!(
&value.node,
ast::ExprKind::Name { id, .. } if id == "self")
} else {
false
}
} =>
{
let field_ty =
class_resolver.as_ref().lock().parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
annotation.as_ref(),
)?;
if !Self::check_ty_analyzed(
field_ty,
unifier,
to_be_analyzed_class,
) {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
} else {
result.push((
if let ast::ExprKind::Attribute { attr, .. } =
&target.node
{
attr.to_string()
} else {
unreachable!()
},
field_ty,
))
}
}
// exclude those without type annotation
ast::StmtKind::Assign { targets, .. }
if {
if let ast::ExprKind::Attribute { value, .. } =
&targets[0].node
{
matches!(
&value.node,
ast::ExprKind::Name {id, ..} if id == "self")
} else {
false
}
} =>
{
return Err("class fields type annotation needed".into())
}
// do nothing
_ => {}
}
}
Some(result)
} else {
None
};
// current method all type ok, put the current method into the list
if class_methods_parsing_result.iter().any(|(name, _, _)| name == method_name) {
return Err("duplicate method definition".into());
} else {
class_methods_parsing_result.push((
method_name.clone(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: ret_ty,
args: arg_name_tys
.into_iter()
.map(|(name, ty)| FuncArg { name, ty, default_value: None })
.collect_vec(),
vars: method_type_var,
}
.into(),
)),
*self
.class_method_to_def_id
.get(&Self::name_mangling(class_name.clone(), method_name))
.unwrap(),
))
}
// put the fiedlds inside
if let Some(class_field_name_tys) = class_field_name_tys {
assert!(class_fields_parsing_result.is_empty());
class_fields_parsing_result.extend(class_field_name_tys);
}
} else {
// what should we do with `class A: a = 3`?
// do nothing, continue the for loop to iterate class ast
continue;
}
}
// now it should be confirmed that every
// methods and fields of the class can be correctly typed, put the results
// into the actual class def method and fields field
let (class_def, _) = &def_ast_list[class_ind];
let mut class_def = class_def.write();
if let TopLevelDef::Class { fields, methods, .. } = class_def.deref_mut() {
for (ref n, ref t) in class_fields_parsing_result {
fields.push((n.clone(), *t));
}
for (n, t, id) in &class_methods_parsing_result {
methods.push((n.clone(), *t, *id));
}
} else {
unreachable!()
}
// change the signature field of the class methods
for (_, ty, id) in &class_methods_parsing_result {
let (method_def, _) = &def_ast_list[id.0];
let mut method_def = method_def.write();
if let TopLevelDef::Function { signature, .. } = method_def.deref_mut() {
*signature = *ty;
}
}
}
Ok(())
}
fn analyze_top_level_function(&mut self) -> Result<(), String> {
unimplemented!()
}
fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> {
unimplemented!()
}
fn check_ty_analyzed(ty: Type, unifier: &mut Unifier, to_be_analyzed: &[DefinitionId]) -> bool {
let type_enum = unifier.get_ty(ty);
match type_enum.as_ref() {
TypeEnum::TObj { obj_id, .. } => !to_be_analyzed.contains(obj_id),
TypeEnum::TVirtual { ty } => {
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(*ty).as_ref() {
!to_be_analyzed.contains(obj_id)
} else {
unreachable!()
}
}
TypeEnum::TVar { .. } => true,
_ => unreachable!(),
}
}
}

View File

@ -0,0 +1,216 @@
use super::type_inferencer::Inferencer;
use super::typedef::Type;
use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind};
use std::iter::once;
impl<'a> Inferencer<'a> {
fn check_pattern(
&mut self,
pattern: &Expr<Option<Type>>,
defined_identifiers: &mut Vec<String>,
) -> Result<(), String> {
match &pattern.node {
ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) {
defined_identifiers.push(id.clone());
}
Ok(())
}
ExprKind::Tuple { elts, .. } => {
for elt in elts.iter() {
self.check_pattern(elt, defined_identifiers)?;
}
Ok(())
}
_ => self.check_expr(pattern, defined_identifiers),
}
}
fn check_expr(
&mut self,
expr: &Expr<Option<Type>>,
defined_identifiers: &[String],
) -> Result<(), String> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom {
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
return Err(format!(
"expected concrete type at {} but got {}",
expr.location,
self.unifier.get_ty(*ty).get_type_name()
));
}
}
match &expr.node {
ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) {
return Err(format!(
"unknown identifier {} (use before def?) at {}",
id, expr.location
));
}
}
ExprKind::List { elts, .. }
| ExprKind::Tuple { elts, .. }
| ExprKind::BoolOp { values: elts, .. } => {
for elt in elts.iter() {
self.check_expr(elt, defined_identifiers)?;
}
}
ExprKind::Attribute { value, .. } => {
self.check_expr(value, defined_identifiers)?;
}
ExprKind::BinOp { left, right, .. } => {
self.check_expr(left, defined_identifiers)?;
self.check_expr(right, defined_identifiers)?;
}
ExprKind::UnaryOp { operand, .. } => {
self.check_expr(operand, defined_identifiers)?;
}
ExprKind::Compare { left, comparators, .. } => {
for elt in once(left.as_ref()).chain(comparators.iter()) {
self.check_expr(elt, defined_identifiers)?;
}
}
ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value, defined_identifiers)?;
self.check_expr(slice, defined_identifiers)?;
}
ExprKind::IfExp { test, body, orelse } => {
self.check_expr(test, defined_identifiers)?;
self.check_expr(body, defined_identifiers)?;
self.check_expr(orelse, defined_identifiers)?;
}
ExprKind::Slice { lower, upper, step } => {
for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
self.check_expr(elt, defined_identifiers)?;
}
}
ExprKind::Lambda { args, body } => {
let mut defined_identifiers = defined_identifiers.to_vec();
for arg in args.args.iter() {
if !defined_identifiers.contains(&arg.node.arg) {
defined_identifiers.push(arg.node.arg.clone());
}
}
self.check_expr(body, &defined_identifiers)?;
}
ExprKind::ListComp { elt, generators, .. } => {
// in our type inference stage, we already make sure that there is only 1 generator
let ast::Comprehension { target, iter, ifs, .. } = &generators[0];
self.check_expr(iter, defined_identifiers)?;
let mut defined_identifiers = defined_identifiers.to_vec();
self.check_pattern(target, &mut defined_identifiers)?;
for term in once(elt.as_ref()).chain(ifs.iter()) {
self.check_expr(term, &defined_identifiers)?;
}
}
ExprKind::Call { func, args, keywords } => {
for expr in once(func.as_ref())
.chain(args.iter())
.chain(keywords.iter().map(|v| v.node.value.as_ref()))
{
self.check_expr(expr, defined_identifiers)?;
}
}
ExprKind::Constant { .. } => {}
_ => {
println!("{:?}", expr.node);
unimplemented!()
}
}
Ok(())
}
// check statements for proper identifier def-use and return on all paths
fn check_stmt(
&mut self,
stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut Vec<String>,
) -> Result<bool, String> {
match &stmt.node {
StmtKind::For { target, iter, body, orelse, .. } => {
self.check_expr(iter, defined_identifiers)?;
for stmt in orelse.iter() {
self.check_stmt(stmt, defined_identifiers)?;
}
let mut defined_identifiers = defined_identifiers.clone();
self.check_pattern(target, &mut defined_identifiers)?;
for stmt in body.iter() {
self.check_stmt(stmt, &mut defined_identifiers)?;
}
Ok(false)
}
StmtKind::If { test, body, orelse } => {
self.check_expr(test, defined_identifiers)?;
let mut body_identifiers = defined_identifiers.clone();
let mut orelse_identifiers = defined_identifiers.clone();
let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in body_identifiers.iter() {
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
defined_identifiers.push(ident.clone())
}
}
Ok(body_returned && orelse_returned)
}
StmtKind::While { test, body, orelse } => {
self.check_expr(test, defined_identifiers)?;
let mut defined_identifiers = defined_identifiers.clone();
self.check_block(body, &mut defined_identifiers)?;
self.check_block(orelse, &mut defined_identifiers)?;
Ok(false)
}
StmtKind::Expr { value } => {
self.check_expr(value, defined_identifiers)?;
Ok(false)
}
StmtKind::Assign { targets, value, .. } => {
self.check_expr(value, defined_identifiers)?;
for target in targets {
self.check_pattern(target, defined_identifiers)?;
}
Ok(false)
}
StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value {
self.check_expr(value, defined_identifiers)?;
self.check_pattern(target, defined_identifiers)?;
}
Ok(false)
}
StmtKind::Return { value } => {
if let Some(value) = value {
self.check_expr(value, defined_identifiers)?;
}
Ok(true)
}
StmtKind::Raise { exc, .. } => {
if let Some(value) = exc {
self.check_expr(value, defined_identifiers)?;
}
Ok(true)
}
// break, raise, etc.
_ => Ok(false),
}
}
pub fn check_block(
&mut self,
block: &[Stmt<Option<Type>>],
defined_identifiers: &mut Vec<String>,
) -> Result<bool, String> {
let mut ret = false;
for stmt in block {
if ret {
return Err(format!("dead code at {:?}", stmt.location));
}
if self.check_stmt(stmt, defined_identifiers)? {
ret = true;
}
}
Ok(ret)
}
}

View File

@ -0,0 +1,322 @@
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
};
use rustpython_parser::ast;
use rustpython_parser::ast::{Cmpop, Operator, Unaryop};
use std::borrow::Borrow;
use std::collections::HashMap;
pub fn binop_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__add__",
Operator::Sub => "__sub__",
Operator::Div => "__truediv__",
Operator::Mod => "__mod__",
Operator::Mult => "__mul__",
Operator::Pow => "__pow__",
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
}
}
pub fn binop_assign_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__iadd__",
Operator::Sub => "__isub__",
Operator::Div => "__itruediv__",
Operator::Mod => "__imod__",
Operator::Mult => "__imul__",
Operator::Pow => "__ipow__",
Operator::BitOr => "__ior__",
Operator::BitXor => "__ixor__",
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
}
pub fn unaryop_name(op: &Unaryop) -> &'static str {
match op {
Unaryop::UAdd => "__pos__",
Unaryop::USub => "__neg__",
Unaryop::Not => "__not__",
Unaryop::Invert => "__inv__",
}
}
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
match op {
Cmpop::Lt => Some("__lt__"),
Cmpop::LtE => Some("__le__"),
Cmpop::Gt => Some("__gt__"),
Cmpop::GtE => Some("__ge__"),
Cmpop::Eq => Some("__eq__"),
Cmpop::NotEq => Some("__ne__"),
_ => None,
}
}
pub fn impl_binop(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
ops: &[ast::Operator],
) {
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty);
(ty, Some(var_id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<HashMap<_, _>>()
} else {
HashMap::new()
};
for op in ops {
fields.borrow_mut().insert(binop_name(op).into(), {
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
}
.into(),
))
});
fields.borrow_mut().insert(binop_assign_name(op).into(), {
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: store.none,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
}
.into(),
))
});
}
} else {
unreachable!("")
}
}
pub fn impl_unaryop(
unifier: &mut Unifier,
_store: &PrimitiveStore,
ty: Type,
ret_ty: Type,
ops: &[ast::Unaryop],
) {
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
for op in ops {
fields.borrow_mut().insert(
unaryop_name(op).into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(),
)),
);
}
} else {
unreachable!()
}
}
pub fn impl_cmpop(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: Type,
ops: &[ast::Cmpop],
) {
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
for op in ops {
fields.borrow_mut().insert(
comparison_name(op).unwrap().into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: store.bool,
vars: HashMap::new(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
}
.into(),
)),
);
}
} else {
unreachable!()
}
}
/// Add, Sub, Mult, Pow
pub fn impl_basic_arithmetic(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
) {
impl_binop(
unifier,
store,
ty,
other_ty,
ret_ty,
&[ast::Operator::Add, ast::Operator::Sub, ast::Operator::Mult],
)
}
pub fn impl_pow(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Pow])
}
/// BitOr, BitXor, BitAnd
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(
unifier,
store,
ty,
&[ty],
ty,
&[ast::Operator::BitAnd, ast::Operator::BitOr, ast::Operator::BitXor],
)
}
/// LShift, RShift
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift])
}
/// Div
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
impl_binop(unifier, store, ty, other_ty, store.float, &[ast::Operator::Div])
}
/// FloorDiv
pub fn impl_floordiv(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::FloorDiv])
}
/// Mod
pub fn impl_mod(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Mod])
}
/// UAdd, USub
pub fn impl_sign(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::UAdd, ast::Unaryop::USub])
}
/// Invert
pub fn impl_invert(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::Invert])
}
/// Not
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, store, ty, store.bool, &[ast::Unaryop::Not])
}
/// Lt, LtE, Gt, GtE
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
impl_cmpop(
unifier,
store,
ty,
other_ty,
&[ast::Cmpop::Lt, ast::Cmpop::Gt, ast::Cmpop::LtE, ast::Cmpop::GtE],
)
}
/// Eq, NotEq
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_cmpop(unifier, store, ty, ty, &[ast::Cmpop::Eq, ast::Cmpop::NotEq])
}
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
let PrimitiveStore { int32: int32_t, int64: int64_t, float: float_t, bool: bool_t, .. } =
*store;
/* int32 ======== */
impl_basic_arithmetic(unifier, store, int32_t, &[int32_t], int32_t);
impl_pow(unifier, store, int32_t, &[int32_t], int32_t);
impl_bitwise_arithmetic(unifier, store, int32_t);
impl_bitwise_shift(unifier, store, int32_t);
impl_div(unifier, store, int32_t, &[int32_t]);
impl_floordiv(unifier, store, int32_t, &[int32_t], int32_t);
impl_mod(unifier, store, int32_t, &[int32_t], int32_t);
impl_sign(unifier, store, int32_t);
impl_invert(unifier, store, int32_t);
impl_not(unifier, store, int32_t);
impl_comparison(unifier, store, int32_t, int32_t);
impl_eq(unifier, store, int32_t);
/* int64 ======== */
impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t);
impl_pow(unifier, store, int64_t, &[int64_t], int64_t);
impl_bitwise_arithmetic(unifier, store, int64_t);
impl_bitwise_shift(unifier, store, int64_t);
impl_div(unifier, store, int64_t, &[int64_t]);
impl_floordiv(unifier, store, int64_t, &[int64_t], int64_t);
impl_mod(unifier, store, int64_t, &[int64_t], int64_t);
impl_sign(unifier, store, int64_t);
impl_invert(unifier, store, int64_t);
impl_not(unifier, store, int64_t);
impl_comparison(unifier, store, int64_t, int64_t);
impl_eq(unifier, store, int64_t);
/* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
impl_div(unifier, store, float_t, &[float_t]);
impl_floordiv(unifier, store, float_t, &[float_t], float_t);
impl_mod(unifier, store, float_t, &[float_t], float_t);
impl_sign(unifier, store, float_t);
impl_not(unifier, store, float_t);
impl_comparison(unifier, store, float_t, float_t);
impl_eq(unifier, store, float_t);
/* bool ======== */
impl_not(unifier, store, bool_t);
impl_eq(unifier, store, bool_t);
}

View File

@ -0,0 +1,5 @@
mod function_check;
pub mod magic_methods;
pub mod type_inferencer;
pub mod typedef;
mod unification_table;

View File

@ -0,0 +1,582 @@
use std::collections::HashMap;
use std::convert::{From, TryInto};
use std::iter::once;
use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
use super::{magic_methods::*, typedef::CallId};
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
use itertools::izip;
use rustpython_parser::ast::{
self,
fold::{self, Fold},
Arguments, Comprehension, ExprKind, Located, Location,
};
#[cfg(test)]
mod test;
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
pub struct CodeLocation {
row: usize,
col: usize,
}
impl From<Location> for CodeLocation {
fn from(loc: Location) -> CodeLocation {
CodeLocation { row: loc.row(), col: loc.column() }
}
}
#[derive(Clone, Copy)]
pub struct PrimitiveStore {
pub int32: Type,
pub int64: Type,
pub float: Type,
pub bool: Type,
pub none: Type,
}
pub struct FunctionData {
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub return_type: Option<Type>,
pub bound_variables: Vec<Type>,
}
pub struct Inferencer<'a> {
pub top_level: &'a TopLevelContext,
pub function_data: &'a mut FunctionData,
pub unifier: &'a mut Unifier,
pub primitives: &'a PrimitiveStore,
pub virtual_checks: &'a mut Vec<(Type, Type)>,
pub variable_mapping: HashMap<String, Type>,
pub calls: &'a mut HashMap<CodeLocation, CallId>,
}
struct NaiveFolder();
impl fold::Fold<()> for NaiveFolder {
type TargetU = Option<Type>;
type Error = String;
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> {
Ok(None)
}
}
impl<'a> fold::Fold<()> for Inferencer<'a> {
type TargetU = Option<Type>;
type Error = String;
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> {
Ok(None)
}
fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
let stmt = match node.node {
// we don't want fold over type annotation
ast::StmtKind::AnnAssign { target, annotation, value, simple } => {
let target = Box::new(self.fold_expr(*target)?);
let value = if let Some(v) = value {
let ty = Box::new(self.fold_expr(*v)?);
self.unifier.unify(target.custom.unwrap(), ty.custom.unwrap())?;
Some(ty)
} else {
None
};
let annotation_type = self.function_data.resolver.parse_type_annotation(
self.top_level,
self.unifier,
&self.primitives,
annotation.as_ref(),
)?;
self.unifier.unify(annotation_type, target.custom.unwrap())?;
let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?);
Located {
location: node.location,
custom: None,
node: ast::StmtKind::AnnAssign { target, annotation, value, simple },
}
}
_ => fold::fold_stmt(self, node)?,
};
match &stmt.node {
ast::StmtKind::For { target, iter, .. } => {
let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
self.unifier.unify(list, iter.custom.unwrap())?;
}
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
self.unifier.unify(test.custom.unwrap(), self.primitives.bool)?;
}
ast::StmtKind::Assign { targets, value, .. } => {
for target in targets.iter() {
self.unifier.unify(target.custom.unwrap(), value.custom.unwrap())?;
}
}
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
ast::StmtKind::Break | ast::StmtKind::Continue => {}
ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
(Some(v), Some(v1)) => {
self.unifier.unify(v.custom.unwrap(), v1)?;
}
(Some(_), None) => {
return Err("Unexpected return value".to_string());
}
(None, Some(_)) => {
return Err("Expected return value".to_string());
}
(None, None) => {}
},
_ => return Err("Unsupported statement type".to_string()),
};
Ok(stmt)
}
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
let expr = match node.node {
ast::ExprKind::Call { func, args, keywords } => {
return self.fold_call(node.location, *func, args, keywords);
}
ast::ExprKind::Lambda { args, body } => {
return self.fold_lambda(node.location, *args, *body);
}
ast::ExprKind::ListComp { elt, generators } => {
return self.fold_listcomp(node.location, *elt, generators);
}
_ => fold::fold_expr(self, node)?,
};
let custom = match &expr.node {
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?),
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
Some(self.infer_attribute(value, attr)?)
}
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
ast::ExprKind::BinOp { left, op, right } => Some(self.infer_bin_ops(left, op, right)?),
ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?),
ast::ExprKind::Compare { left, ops, comparators } => {
Some(self.infer_compare(left, ops, comparators)?)
}
ast::ExprKind::Subscript { value, slice, .. } => {
Some(self.infer_subscript(value.as_ref(), slice.as_ref())?)
}
ast::ExprKind::IfExp { test, body, orelse } => {
Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?)
}
ast::ExprKind::ListComp { .. }
| ast::ExprKind::Lambda { .. }
| ast::ExprKind::Call { .. } => expr.custom, // already computed
ast::ExprKind::Slice { .. } => None, // we don't need it for slice
_ => return Err("not supported yet".into()),
};
Ok(ast::Expr { custom, location: expr.location, node: expr.node })
}
}
type InferenceResult = Result<Type, String>;
impl<'a> Inferencer<'a> {
/// Constrain a <: b
/// Currently implemented as unification
fn constrain(&mut self, a: Type, b: Type) -> Result<(), String> {
self.unifier.unify(a, b)
}
fn build_method_call(
&mut self,
location: Location,
method: String,
obj: Type,
params: Vec<Type>,
ret: Type,
) -> InferenceResult {
let call = self.unifier.add_call(Call {
posargs: params,
kwargs: HashMap::new(),
ret,
fun: RefCell::new(None),
});
self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
let fields = once((method, call)).collect();
let record = self.unifier.add_record(fields);
self.constrain(obj, record)?;
Ok(ret)
}
fn fold_lambda(
&mut self,
location: Location,
args: Arguments,
body: ast::Expr<()>,
) -> Result<ast::Expr<Option<Type>>, String> {
if !args.posonlyargs.is_empty()
|| args.vararg.is_some()
|| !args.kwonlyargs.is_empty()
|| args.kwarg.is_some()
|| !args.defaults.is_empty()
{
// actually I'm not sure whether programs violating this is a valid python program.
return Err(
"We only support positional or keyword arguments without defaults for lambdas."
.to_string(),
);
}
let fn_args: Vec<_> = args
.args
.iter()
.map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0))
.collect();
let mut variable_mapping = self.variable_mapping.clone();
variable_mapping.extend(fn_args.iter().cloned());
let ret = self.unifier.get_fresh_var().0;
let mut new_context = Inferencer {
function_data: self.function_data,
unifier: self.unifier,
primitives: self.primitives,
virtual_checks: self.virtual_checks,
calls: self.calls,
top_level: self.top_level,
variable_mapping,
};
let fun = FunSignature {
args: fn_args
.iter()
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None })
.collect(),
ret,
vars: Default::default(),
};
let body = new_context.fold_expr(body)?;
new_context.unifier.unify(fun.ret, body.custom.unwrap())?;
let mut args = new_context.fold_arguments(args)?;
for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) {
assert_eq!(&arg.node.arg, name);
arg.custom = Some(*ty);
}
Ok(Located {
location,
node: ExprKind::Lambda { args: args.into(), body: body.into() },
custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun.into()))),
})
}
fn fold_listcomp(
&mut self,
location: Location,
elt: ast::Expr<()>,
mut generators: Vec<Comprehension>,
) -> Result<ast::Expr<Option<Type>>, String> {
if generators.len() != 1 {
return Err(
"Only 1 generator statement for list comprehension is supported.".to_string()
);
}
let variable_mapping = self.variable_mapping.clone();
let mut new_context = Inferencer {
function_data: self.function_data,
unifier: self.unifier,
virtual_checks: self.virtual_checks,
top_level: self.top_level,
variable_mapping,
primitives: self.primitives,
calls: self.calls,
};
let elt = new_context.fold_expr(elt)?;
let generator = generators.pop().unwrap();
if generator.is_async {
return Err("Async iterator not supported.".to_string());
}
let target = new_context.fold_expr(*generator.target)?;
let iter = new_context.fold_expr(*generator.iter)?;
let ifs: Vec<_> = generator
.ifs
.into_iter()
.map(|v| new_context.fold_expr(v))
.collect::<Result<_, _>>()?;
// iter should be a list of targets...
// actually it should be an iterator of targets, but we don't have iter type for now
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
new_context.unifier.unify(iter.custom.unwrap(), list)?;
// if conditions should be bool
for v in ifs.iter() {
new_context.unifier.unify(v.custom.unwrap(), new_context.primitives.bool)?;
}
Ok(Located {
location,
custom: Some(new_context.unifier.add_ty(TypeEnum::TList { ty: elt.custom.unwrap() })),
node: ExprKind::ListComp {
elt: Box::new(elt),
generators: vec![ast::Comprehension {
target: Box::new(target),
iter: Box::new(iter),
ifs,
is_async: false,
}],
},
})
}
fn fold_call(
&mut self,
location: Location,
func: ast::Expr<()>,
mut args: Vec<ast::Expr<()>>,
keywords: Vec<Located<ast::KeywordData>>,
) -> Result<ast::Expr<Option<Type>>, String> {
let func =
if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } =
func
{
// handle special functions that cannot be typed in the usual way...
if id == "virtual" {
if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
return Err(
"`virtual` can only accept 1/2 positional arguments.".to_string()
);
}
let arg0 = self.fold_expr(args.remove(0))?;
let ty = if let Some(arg) = args.pop() {
self.function_data.resolver.parse_type_annotation(
self.top_level,
self.unifier,
self.primitives,
&arg,
)?
} else {
self.unifier.get_fresh_var().0
};
self.virtual_checks.push((arg0.custom.unwrap(), ty));
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
return Ok(Located {
location,
custom,
node: ExprKind::Call {
func: Box::new(Located {
custom: None,
location: func.location,
node: ExprKind::Name { id, ctx },
}),
args: vec![arg0],
keywords: vec![],
},
});
}
// int64 is special because its argument can be a constant larger than int32
if id == "int64" && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node
{
let int64: Result<i64, _> = val.try_into();
let custom;
if int64.is_ok() {
custom = Some(self.primitives.int64);
} else {
return Err("Integer out of bound".into());
}
return Ok(Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(val.clone()),
kind: kind.clone(),
},
});
}
}
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
} else {
func
};
let func = Box::new(self.fold_expr(func)?);
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
let keywords = keywords
.into_iter()
.map(|v| fold::fold_keyword(self, v))
.collect::<Result<Vec<_>, _>>()?;
let ret = self.unifier.get_fresh_var().0;
let call = self.unifier.add_call(Call {
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
kwargs: keywords
.iter()
.map(|v| (v.node.arg.as_ref().unwrap().clone(), v.custom.unwrap()))
.collect(),
fun: RefCell::new(None),
ret,
});
self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
self.unifier.unify(func.custom.unwrap(), call)?;
Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } })
}
fn infer_identifier(&mut self, id: &str) -> InferenceResult {
if let Some(ty) = self.variable_mapping.get(id) {
Ok(*ty)
} else {
Ok(self
.function_data
.resolver
.get_symbol_type(self.unifier, self.primitives, id)
.unwrap_or_else(|| {
let ty = self.unifier.get_fresh_var().0;
self.variable_mapping.insert(id.to_string(), ty);
ty
}))
}
}
fn infer_constant(&mut self, constant: &ast::Constant) -> InferenceResult {
match constant {
ast::Constant::Bool(_) => Ok(self.primitives.bool),
ast::Constant::Int(val) => {
let int32: Result<i32, _> = val.try_into();
// int64 would be handled separately in functions
if int32.is_ok() {
Ok(self.primitives.int32)
} else {
Err("Integer out of bound".into())
}
}
ast::Constant::Float(_) => Ok(self.primitives.float),
ast::Constant::Tuple(vals) => {
let ty: Result<Vec<_>, _> = vals.iter().map(|x| self.infer_constant(x)).collect();
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? }))
}
_ => Err("not supported".into()),
}
}
fn infer_list(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let (ty, _) = self.unifier.get_fresh_var();
for t in elts.iter() {
self.unifier.unify(ty, t.custom.unwrap())?;
}
Ok(self.unifier.add_ty(TypeEnum::TList { ty }))
}
fn infer_tuple(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let ty = elts.iter().map(|x| x.custom.unwrap()).collect();
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
}
fn infer_attribute(&mut self, value: &ast::Expr<Option<Type>>, attr: &str) -> InferenceResult {
let (attr_ty, _) = self.unifier.get_fresh_var();
let fields = once((attr.to_string(), attr_ty)).collect();
let record = self.unifier.add_record(fields);
self.constrain(value.custom.unwrap(), record)?;
Ok(attr_ty)
}
fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let b = self.primitives.bool;
for v in values {
self.constrain(v.custom.unwrap(), b)?;
}
Ok(b)
}
fn infer_bin_ops(
&mut self,
left: &ast::Expr<Option<Type>>,
op: &ast::Operator,
right: &ast::Expr<Option<Type>>,
) -> InferenceResult {
let method = binop_name(op);
let ret = self.unifier.get_fresh_var().0;
self.build_method_call(
left.location,
method.to_string(),
left.custom.unwrap(),
vec![right.custom.unwrap()],
ret,
)
}
fn infer_unary_ops(
&mut self,
op: &ast::Unaryop,
operand: &ast::Expr<Option<Type>>,
) -> InferenceResult {
let method = unaryop_name(op);
let ret = self.unifier.get_fresh_var().0;
self.build_method_call(
operand.location,
method.to_string(),
operand.custom.unwrap(),
vec![],
ret,
)
}
fn infer_compare(
&mut self,
left: &ast::Expr<Option<Type>>,
ops: &[ast::Cmpop],
comparators: &[ast::Expr<Option<Type>>],
) -> InferenceResult {
let boolean = self.primitives.bool;
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method =
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string();
self.build_method_call(
a.location,
method,
a.custom.unwrap(),
vec![b.custom.unwrap()],
boolean,
)?;
}
Ok(boolean)
}
fn infer_subscript(
&mut self,
value: &ast::Expr<Option<Type>>,
slice: &ast::Expr<Option<Type>>,
) -> InferenceResult {
let ty = self.unifier.get_fresh_var().0;
match &slice.node {
ast::ExprKind::Slice { lower, upper, step } => {
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
self.constrain(v.custom.unwrap(), self.primitives.int32)?;
}
let list = self.unifier.add_ty(TypeEnum::TList { ty });
self.constrain(value.custom.unwrap(), list)?;
Ok(list)
}
ast::ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
// the index is a constant, so value can be a sequence.
let ind: i32 = val.try_into().map_err(|_| "Index must be int32".to_string())?;
let map = once((ind, ty)).collect();
let seq = self.unifier.add_sequence(map);
self.constrain(value.custom.unwrap(), seq)?;
Ok(ty)
}
_ => {
// the index is not a constant, so value can only be a list
self.constrain(slice.custom.unwrap(), self.primitives.int32)?;
let list = self.unifier.add_ty(TypeEnum::TList { ty });
self.constrain(value.custom.unwrap(), list)?;
Ok(ty)
}
}
}
fn infer_if_expr(
&mut self,
test: &ast::Expr<Option<Type>>,
body: &ast::Expr<Option<Type>>,
orelse: &ast::Expr<Option<Type>>,
) -> InferenceResult {
self.constrain(test.custom.unwrap(), self.primitives.bool)?;
let ty = self.unifier.get_fresh_var().0;
self.constrain(body.custom.unwrap(), ty)?;
self.constrain(orelse.custom.unwrap(), ty)?;
Ok(ty)
}
}

View File

@ -0,0 +1,546 @@
use super::super::typedef::*;
use super::*;
use crate::symbol_resolver::*;
use crate::top_level::DefinitionId;
use crate::{location::Location, top_level::TopLevelDef};
use indoc::indoc;
use itertools::zip;
use parking_lot::RwLock;
use rustpython_parser::parser::parse_program;
use test_case::test_case;
struct Resolver {
id_to_type: HashMap<String, Type>,
id_to_def: HashMap<String, DefinitionId>,
class_names: HashMap<String, Type>,
}
impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
self.id_to_type.get(str).cloned()
}
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
unimplemented!()
}
fn get_symbol_location(&self, _: &str) -> Option<Location> {
unimplemented!()
}
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
self.id_to_def.get(id).cloned()
}
}
struct TestEnvironment {
pub unifier: Unifier,
pub function_data: FunctionData,
pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext,
}
impl TestEnvironment {
pub fn basic_test_env() -> TestEnvironment {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let primitives = PrimitiveStore { int32, int64, float, bool, none };
set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name = [
(0, "int32".to_string()),
(1, "int64".to_string()),
(2, "float".to_string()),
(3, "bool".to_string()),
(4, "none".to_string()),
]
.iter()
.cloned()
.collect();
let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
top_level: TopLevelContext {
definitions: Default::default(),
unifiers: Default::default(),
},
unifier,
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: None,
},
primitives,
id_to_name,
identifier_mapping,
virtual_checks: Vec::new(),
calls: HashMap::new(),
}
}
fn new() -> TestEnvironment {
let mut unifier = Unifier::new();
let mut identifier_mapping = HashMap::new();
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4),
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
identifier_mapping.insert("None".into(), none);
for i in 0..5 {
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(i),
type_vars: Default::default(),
fields: Default::default(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
})
.into(),
);
}
let primitives = PrimitiveStore { int32, int64, float, bool, none };
let (v0, id) = unifier.get_fresh_var();
let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5),
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(5),
type_vars: vec![v0],
fields: [("a".into(), v0)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
})
.into(),
);
identifier_mapping.insert(
"Foo".into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
args: vec![],
ret: foo_ty,
vars: [(id, v0)].iter().cloned().collect(),
}
.into(),
)),
);
let fun = unifier.add_ty(TypeEnum::TFunc(
FunSignature { args: vec![], ret: int32, vars: Default::default() }.into(),
));
let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6),
fields: [("a".into(), int32), ("b".into(), fun)]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
.into(),
params: Default::default(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(6),
type_vars: Default::default(),
fields: [("a".into(), int32), ("b".into(), fun)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
})
.into(),
);
identifier_mapping.insert(
"Bar".into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature { args: vec![], ret: bar, vars: Default::default() }.into(),
)),
);
let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7),
fields: [("a".into(), bool), ("b".into(), fun)]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
.into(),
params: Default::default(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(7),
type_vars: Default::default(),
fields: [("a".into(), bool), ("b".into(), fun)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
})
.into(),
);
identifier_mapping.insert(
"Bar2".into(),
unifier.add_ty(TypeEnum::TFunc(
FunSignature { args: vec![], ret: bar2, vars: Default::default() }.into(),
)),
);
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
let id_to_name = [
(0, "int32".to_string()),
(1, "int64".to_string()),
(2, "float".to_string()),
(3, "bool".to_string()),
(4, "none".to_string()),
(5, "Foo".to_string()),
(6, "Bar".to_string()),
(7, "Bar2".to_string()),
]
.iter()
.cloned()
.collect();
let top_level = TopLevelContext {
definitions: Arc::new(RwLock::new(top_level_defs)),
unifiers: Default::default(),
};
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: [
("Foo".into(), DefinitionId(5)),
("Bar".into(), DefinitionId(6)),
("Bar2".into(), DefinitionId(7)),
]
.iter()
.cloned()
.collect(),
class_names,
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
unifier,
top_level,
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: None,
},
primitives,
id_to_name,
identifier_mapping,
virtual_checks: Vec::new(),
calls: HashMap::new(),
}
}
fn get_inferencer(&mut self) -> Inferencer {
Inferencer {
top_level: &self.top_level,
function_data: &mut self.function_data,
unifier: &mut self.unifier,
variable_mapping: Default::default(),
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
}
}
}
#[test_case(indoc! {"
a = 1234
b = int64(2147483648)
c = 1.234
d = True
"},
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(),
&[]
; "primitives test")]
#[test_case(indoc! {"
a = lambda x, y: x
b = lambda x: a(x, x)
c = 1.234
d = b(c)
"},
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
&[]
; "lambda test")]
#[test_case(indoc! {"
a = lambda x: x
b = lambda x: x
foo1 = Foo()
foo2 = Foo()
c = a(foo1.a)
d = b(foo2.a)
a(True)
b(123)
"},
[("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"),
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(),
&[]
; "obj test")]
#[test_case(indoc! {"
f = lambda x: True
a = [1, 2, 3]
b = [f(x) for x in a if f(x)]
"},
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(),
&[]
; "listcomp test")]
#[test_case(indoc! {"
a = virtual(Bar(), Bar)
b = a.b()
a = virtual(Bar2())
"},
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual test")]
#[test_case(indoc! {"
a = [virtual(Bar(), Bar), virtual(Bar2())]
b = [x.b() for x in a]
"},
[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual list test")]
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
println!("source:\n{}", source);
let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer();
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.stringify(
*v,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
println!("{}: {}", k, name);
}
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(*k).unwrap();
let name = inferencer.unifier.stringify(
*ty,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
let a = inferencer.unifier.stringify(
*a,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
let b = inferencer.unifier.stringify(
*b,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(&a, x);
assert_eq!(&b, y);
}
}
#[test_case(indoc! {"
a = 2
b = 2
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
"},
[("a", "int32"),
("b", "int32"),
("c", "int32"),
("d", "int32"),
("e", "int32"),
("f", "float"),
("g", "int32"),
("h", "int32")].iter().cloned().collect()
; "int32")]
#[test_case(
indoc! {"
a = 2.4
b = 3.6
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
i = a ** b
ii = 3
j = a ** b
"},
[("a", "float"),
("b", "float"),
("c", "float"),
("d", "float"),
("e", "float"),
("f", "float"),
("g", "float"),
("h", "float"),
("i", "float"),
("ii", "int32"),
("j", "float")].iter().cloned().collect()
; "float"
)]
#[test_case(
indoc! {"
a = int64(12312312312)
b = int64(24242424424)
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
i = a == b
j = a > b
k = a < b
l = a != b
"},
[("a", "int64"),
("b", "int64"),
("c", "int64"),
("d", "int64"),
("e", "int64"),
("f", "float"),
("g", "int64"),
("h", "int64"),
("i", "bool"),
("j", "bool"),
("k", "bool"),
("l", "bool")].iter().cloned().collect()
; "int64"
)]
#[test_case(
indoc! {"
a = True
b = False
c = a == b
d = not a
e = a != b
"},
[("a", "bool"),
("b", "bool"),
("c", "bool"),
("d", "bool"),
("e", "bool")].iter().cloned().collect()
; "boolean"
)]
fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
println!("source:\n{}", source);
let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer();
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.stringify(
*v,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
println!("{}: {}", k, name);
}
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(*k).unwrap();
let name = inferencer.unifier.stringify(
*ty,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
}

View File

@ -0,0 +1,947 @@
use itertools::{chain, zip, Itertools};
use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::HashMap;
use std::iter::once;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue;
use crate::top_level::DefinitionId;
#[cfg(test)]
mod test;
/// Handle for a type, implementated as a key in the unification table.
pub type Type = UnificationKey;
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct CallId(usize);
pub type Mapping<K, V = Type> = HashMap<K, V>;
type VarMap = Mapping<u32>;
#[derive(Clone)]
pub struct Call {
pub posargs: Vec<Type>,
pub kwargs: HashMap<String, Type>,
pub ret: Type,
pub fun: RefCell<Option<Type>>,
}
#[derive(Clone)]
pub struct FuncArg {
pub name: String,
pub ty: Type,
pub default_value: Option<SymbolValue>,
}
#[derive(Clone)]
pub struct FunSignature {
pub args: Vec<FuncArg>,
pub ret: Type,
pub vars: VarMap,
}
#[derive(Clone)]
pub enum TypeVarMeta {
Generic,
Sequence(RefCell<Mapping<i32>>),
Record(RefCell<Mapping<String>>),
}
#[derive(Clone)]
pub enum TypeEnum {
TRigidVar {
id: u32,
},
TVar {
id: u32,
meta: TypeVarMeta,
// empty indicates no restriction
range: RefCell<Vec<Type>>,
},
TTuple {
ty: Vec<Type>,
},
TList {
ty: Type,
},
TObj {
obj_id: DefinitionId,
fields: RefCell<Mapping<String>>,
params: RefCell<VarMap>,
},
TVirtual {
ty: Type,
},
TCall(RefCell<Vec<CallId>>),
TFunc(RefCell<FunSignature>),
}
impl TypeEnum {
pub fn get_type_name(&self) -> &'static str {
match self {
TypeEnum::TRigidVar { .. } => "TRigidVar",
TypeEnum::TVar { .. } => "TVar",
TypeEnum::TTuple { .. } => "TTuple",
TypeEnum::TList { .. } => "TList",
TypeEnum::TObj { .. } => "TObj",
TypeEnum::TVirtual { .. } => "TVirtual",
TypeEnum::TCall { .. } => "TCall",
TypeEnum::TFunc { .. } => "TFunc",
}
}
}
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
pub struct Unifier {
unification_table: UnificationTable<Rc<TypeEnum>>,
calls: Vec<Rc<Call>>,
var_id: u32,
}
impl Unifier {
/// Get an empty unifier
pub fn new() -> Unifier {
Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() }
}
/// Determine if the two types are the same
pub fn unioned(&mut self, a: Type, b: Type) -> bool {
self.unification_table.unioned(a, b)
}
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
let lock = unifier.lock().unwrap();
Unifier {
unification_table: UnificationTable::from_send(&lock.0),
var_id: lock.1,
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
}
}
pub fn get_shared_unifier(&self) -> SharedUnifier {
Arc::new(Mutex::new((
self.unification_table.get_send(),
self.var_id,
self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(),
)))
}
/// Register a type to the unifier.
/// Returns a key in the unification_table.
pub fn add_ty(&mut self, a: TypeEnum) -> Type {
self.unification_table.new_key(Rc::new(a))
}
pub fn add_record(&mut self, fields: Mapping<String>) -> Type {
let id = self.var_id + 1;
self.var_id += 1;
self.add_ty(TypeEnum::TVar {
id,
range: vec![].into(),
meta: TypeVarMeta::Record(fields.into()),
})
}
pub fn add_call(&mut self, call: Call) -> CallId {
let id = CallId(self.calls.len());
self.calls.push(Rc::new(call));
id
}
pub fn get_representative(&mut self, ty: Type) -> Type {
self.unification_table.get_representative(ty)
}
pub fn add_sequence(&mut self, sequence: Mapping<i32>) -> Type {
let id = self.var_id + 1;
self.var_id += 1;
self.add_ty(TypeEnum::TVar {
id,
range: vec![].into(),
meta: TypeVarMeta::Sequence(sequence.into()),
})
}
/// Get the TypeEnum of a type.
pub fn get_ty(&mut self, a: Type) -> Rc<TypeEnum> {
self.unification_table.probe_value(a).clone()
}
pub fn get_fresh_rigid_var(&mut self) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
(self.add_ty(TypeEnum::TRigidVar { id }), id)
}
pub fn get_fresh_var(&mut self) -> (Type, u32) {
self.get_fresh_var_with_range(&[])
}
/// Get a fresh type variable.
pub fn get_fresh_var_with_range(&mut self, range: &[Type]) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
let range = range.to_vec().into();
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
}
/// Unification would not unify rigid variables with other types, but we want to do this for
/// function instantiations, so we make it explicit.
pub fn replace_rigid_var(&mut self, rigid: Type, b: Type) {
assert!(matches!(&*self.get_ty(rigid), TypeEnum::TRigidVar { .. }));
self.set_a_to_b(rigid, b);
}
pub fn get_instantiations(&mut self, ty: Type) -> Option<Vec<Type>> {
match &*self.get_ty(ty) {
TypeEnum::TVar { range, .. } => {
let range = range.borrow();
if range.is_empty() {
None
} else {
Some(
range
.iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.flatten()
.collect_vec(),
)
}
}
TypeEnum::TList { ty } => self
.get_instantiations(*ty)
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
}),
TypeEnum::TTuple { ty } => {
let tuples = ty
.iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.multi_cartesian_product()
.collect_vec();
if tuples.len() == 1 {
None
} else {
Some(
tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
)
}
}
TypeEnum::TObj { params, .. } => {
let params = params.borrow();
let (keys, params): (Vec<&u32>, Vec<&Type>) = params.iter().unzip();
let params = params
.into_iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.multi_cartesian_product()
.collect_vec();
if params.len() <= 1 {
None
} else {
Some(
params
.into_iter()
.map(|params| {
self.subst(
ty,
&zip(keys.iter().cloned().cloned(), params.iter().cloned())
.collect(),
)
.unwrap_or(ty)
})
.collect(),
)
}
}
_ => None,
}
}
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*;
match &*self.get_ty(a) {
TRigidVar { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false,
TList { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => {
vars.borrow().values().all(|ty| self.is_concrete(*ty, allowed_typevars))
}
// functions are instantiated for each call sites, so the function type can contain
// type variables.
TFunc { .. } => true,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
}
}
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
if self.unification_table.unioned(a, b) {
Ok(())
} else {
self.unify_impl(a, b, false)
}
}
fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> {
use TypeEnum::*;
use TypeVarMeta::*;
let (ty_a, ty_b) = {
(
self.unification_table.probe_value(a).clone(),
self.unification_table.probe_value(b).clone(),
)
};
match (&*ty_a, &*ty_b) {
(TVar { meta: meta1, range: range1, .. }, TVar { meta: meta2, range: range2, .. }) => {
self.occur_check(a, b)?;
self.occur_check(b, a)?;
match (meta1, meta2) {
(Generic, _) => {}
(_, Generic) => {
return self.unify_impl(b, a, true);
}
(Record(fields1), Record(fields2)) => {
let mut fields2 = fields2.borrow_mut();
for (key, value) in fields1.borrow().iter() {
if let Some(ty) = fields2.get(key) {
self.unify(*ty, *value)?;
} else {
fields2.insert(key.clone(), *value);
}
}
}
(Sequence(map1), Sequence(map2)) => {
let mut map2 = map2.borrow_mut();
for (key, value) in map1.borrow().iter() {
if let Some(ty) = map2.get(key) {
self.unify(*ty, *value)?;
} else {
map2.insert(*key, *value);
}
}
}
_ => {
return Err("Incompatible".to_string());
}
}
let range1 = range1.borrow();
// new range is the intersection of them
// empty range indicates no constraint
if !range1.is_empty() {
let old_range2 = range2.take();
let mut range2 = range2.borrow_mut();
if old_range2.is_empty() {
range2.extend_from_slice(&range1);
}
for v1 in old_range2.iter() {
for v2 in range1.iter() {
if let Ok(result) = self.get_intersection(*v1, *v2) {
range2.push(result.unwrap_or(*v2));
}
}
}
if range2.is_empty() {
return Err(
"cannot unify type variables with incompatible value range".to_string()
);
}
}
self.set_a_to_b(a, b);
}
(TVar { meta: Generic, id, range, .. }, _) => {
self.occur_check(a, b)?;
// We check for the range of the type variable to see if unification is allowed.
// Note that although b may be compatible with a, we may have to constrain type
// variables in b to make sure that instantiations of b would always be compatible
// with a.
// The return value x of check_var_compatibility would be a new type that is
// guaranteed to be compatible with a under all possible instantiations. So we
// unify x with b to recursively apply the constrains, and then set a to x.
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => {
self.occur_check(a, b)?;
let len = ty.len() as i32;
for (k, v) in map.borrow().iter() {
// handle negative index
let ind = if *k < 0 { len + *k } else { *k };
if ind >= len || ind < 0 {
return Err(format!(
"Tuple index out of range. (Length: {}, Index: {})",
len, k
));
}
self.unify(*v, ty[ind as usize])?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => {
self.occur_check(a, b)?;
for v in map.borrow().values() {
self.unify(*v, *ty)?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() {
return Err(format!(
"Cannot unify tuples with length {} and {}",
ty1.len(),
ty2.len()
));
}
for (x, y) in ty1.iter().zip(ty2.iter()) {
self.unify(*x, *y)?;
}
self.set_a_to_b(a, b);
}
(TList { ty: ty1 }, TList { ty: ty2 }) => {
self.unify(*ty1, *ty2)?;
self.set_a_to_b(a, b);
}
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
self.occur_check(a, b)?;
for (k, v) in map.borrow().iter() {
let ty = fields
.borrow()
.get(k)
.copied()
.ok_or_else(|| format!("No such attribute {}", k))?;
self.unify(ty, *v)?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
self.occur_check(a, b)?;
let ty = self.get_ty(*ty);
if let TObj { fields, .. } = ty.as_ref() {
for (k, v) in map.borrow().iter() {
let ty = fields
.borrow()
.get(k)
.copied()
.ok_or_else(|| format!("No such attribute {}", k))?;
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
return Err(format!("Cannot access field {} for virtual type", k));
}
self.unify(*v, ty)?;
}
} else {
// require annotation...
return Err("Requires type annotation for virtual".to_string());
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.set_a_to_b(a, x);
}
(
TObj { obj_id: id1, params: params1, .. },
TObj { obj_id: id2, params: params2, .. },
) => {
if id1 != id2 {
return Err(format!("Cannot unify objects with ID {} and {}", id1.0, id2.0));
}
for (x, y) in zip(params1.borrow().values(), params2.borrow().values()) {
self.unify(*x, *y)?;
}
self.set_a_to_b(a, b);
}
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
self.unify(*ty1, *ty2)?;
self.set_a_to_b(a, b);
}
(TCall(calls1), TCall(calls2)) => {
// we do not unify individual calls, instead we defer until the unification wtih a
// function definition.
calls2.borrow_mut().extend_from_slice(&calls1.borrow());
}
(TCall(calls), TFunc(signature)) => {
self.occur_check(a, b)?;
let required: Vec<String> = signature
.borrow()
.args
.iter()
.filter(|v| v.default_value.is_none())
.map(|v| v.name.clone())
.rev()
.collect();
// we unify every calls to the function signature.
for c in calls.borrow().iter() {
let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone();
let instantiated = self.instantiate_fun(b, &*signature.borrow());
let r = self.get_ty(instantiated);
let r = r.as_ref();
let signature;
if let TypeEnum::TFunc(s) = &*r {
signature = s;
} else {
unreachable!();
}
// we check to make sure that all required arguments (those without default
// arguments) are provided, and do not provide the same argument twice.
let mut required = required.clone();
let mut all_names: Vec<_> = signature
.borrow()
.args
.iter()
.map(|v| (v.name.clone(), v.ty))
.rev()
.collect();
for (i, t) in posargs.iter().enumerate() {
if signature.borrow().args.len() <= i {
return Err("Too many arguments.".to_string());
}
if !required.is_empty() {
required.pop();
}
self.unify(all_names.pop().unwrap().1, *t)?;
}
for (k, t) in kwargs.iter() {
if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i);
}
let i = all_names
.iter()
.position(|v| &v.0 == k)
.ok_or_else(|| format!("Unknown keyword argument {}", k))?;
self.unify(all_names.remove(i).1, *t)?;
}
if !required.is_empty() {
return Err("Expected more arguments".to_string());
}
self.unify(*ret, signature.borrow().ret)?;
*fun.borrow_mut() = Some(instantiated);
}
self.set_a_to_b(a, b);
}
(TFunc(sign1), TFunc(sign2)) => {
let (sign1, sign2) = (&*sign1.borrow(), &*sign2.borrow());
if !sign1.vars.is_empty() || !sign2.vars.is_empty() {
return Err("Polymorphic function pointer is prohibited.".to_string());
}
if sign1.args.len() != sign2.args.len() {
return Err("Functions differ in number of parameters.".to_string());
}
for (x, y) in sign1.args.iter().zip(sign2.args.iter()) {
if x.name != y.name {
return Err("Functions differ in parameter names.".to_string());
}
if x.default_value != y.default_value {
return Err("Functions differ in optional parameters.".to_string());
}
self.unify(x.ty, y.ty)?;
}
self.unify(sign1.ret, sign2.ret)?;
self.set_a_to_b(a, b);
}
_ => {
if swapped {
return self.incompatible_types(&*ty_a, &*ty_b);
} else {
self.unify_impl(b, a, true)?;
}
}
}
Ok(())
}
/// Get string representation of the type
pub fn stringify<F, G>(&mut self, ty: Type, obj_to_name: &mut F, var_to_name: &mut G) -> String
where
F: FnMut(usize) -> String,
G: FnMut(u32) -> String,
{
use TypeVarMeta::*;
let ty = self.unification_table.probe_value(ty).clone();
match ty.as_ref() {
TypeEnum::TRigidVar { id } => var_to_name(*id),
TypeEnum::TVar { id, meta: Generic, .. } => var_to_name(*id),
TypeEnum::TVar { meta: Sequence(map), .. } => {
let fields = map
.borrow()
.iter()
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
.join(", ");
format!("seq[{}]", fields)
}
TypeEnum::TVar { meta: Record(fields), .. } => {
let fields = fields
.borrow()
.iter()
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
.join(", ");
format!("record[{}]", fields)
}
TypeEnum::TTuple { ty } => {
let mut fields = ty.iter().map(|v| self.stringify(*v, obj_to_name, var_to_name));
format!("tuple[{}]", fields.join(", "))
}
TypeEnum::TList { ty } => {
format!("list[{}]", self.stringify(*ty, obj_to_name, var_to_name))
}
TypeEnum::TVirtual { ty } => {
format!("virtual[{}]", self.stringify(*ty, obj_to_name, var_to_name))
}
TypeEnum::TObj { obj_id, params, .. } => {
let name = obj_to_name(obj_id.0);
let params = params.borrow();
if !params.is_empty() {
let mut params =
params.values().map(|v| self.stringify(*v, obj_to_name, var_to_name));
format!("{}[{}]", name, params.join(", "))
} else {
name
}
}
TypeEnum::TCall { .. } => "call".to_owned(),
TypeEnum::TFunc(signature) => {
let params = signature
.borrow()
.args
.iter()
.map(|arg| {
format!("{}={}", arg.name, self.stringify(arg.ty, obj_to_name, var_to_name))
})
.join(", ");
let ret = self.stringify(signature.borrow().ret, obj_to_name, var_to_name);
format!("fn[[{}], {}]", params, ret)
}
}
}
fn set_a_to_b(&mut self, a: Type, b: Type) {
// unify a and b together, and set the value to b's value.
let table = &mut self.unification_table;
let ty_b = table.probe_value(b).clone();
table.unify(a, b);
table.set_value(a, ty_b)
}
fn incompatible_types(&self, a: &TypeEnum, b: &TypeEnum) -> Result<(), String> {
Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name()))
}
/// Instantiate a function if it hasn't been instantiated.
/// Returns Some(T) where T is the instantiated type.
/// Returns None if the function is already instantiated.
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
let mut instantiated = false;
let mut vars = Vec::new();
for (k, v) in fun.vars.iter() {
if let TypeEnum::TVar { id, range, .. } =
self.unification_table.probe_value(*v).as_ref()
{
if k != id {
instantiated = true;
break;
}
// actually, if the first check succeeded, the function should be uninstatiated.
// The cloned values must be used and would not be wasted.
vars.push((*k, range.clone()));
} else {
instantiated = true;
break;
}
}
if instantiated {
ty
} else {
let mapping = vars
.into_iter()
.map(|(k, range)| (k, self.get_fresh_var_with_range(range.borrow().as_ref()).0))
.collect();
self.subst(ty, &mapping).unwrap_or(ty)
}
}
/// Substitute type variables within a type into other types.
/// If this returns Some(T), T would be the substituted type.
/// If this returns None, the result type would be the original type
/// (no substitution has to be done).
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
use TypeVarMeta::*;
let ty = self.unification_table.probe_value(a).clone();
// this function would only be called when we instantiate functions.
// function type signature should ONLY contain concrete types and type
// variables, i.e. things like TRecord, TCall should not occur, and we
// should be safe to not implement the substitution for those variants.
match &*ty {
TypeEnum::TRigidVar { .. } => None,
TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(),
TypeEnum::TTuple { ty } => {
let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() {
if let Some(t1) = self.subst(*t, mapping) {
new_ty.to_mut()[i] = t1;
}
}
if matches!(new_ty, Cow::Owned(_)) {
Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() }))
} else {
None
}
}
TypeEnum::TList { ty } => {
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
}
TypeEnum::TVirtual { ty } => {
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
}
TypeEnum::TObj { obj_id, fields, params } => {
// Type variables in field types must be present in the type parameter.
// If the mapping does not contain any type variables in the
// parameter list, we don't need to substitute the fields.
// This is also used to prevent infinite substitution...
let params = params.borrow();
let need_subst = params.values().any(|v| {
let ty = self.unification_table.probe_value(*v);
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
mapping.contains_key(&id)
} else {
false
}
});
if need_subst {
let obj_id = *obj_id;
let params = self.subst_map(&params, mapping).unwrap_or_else(|| params.clone());
let fields = self
.subst_map(&fields.borrow(), mapping)
.unwrap_or_else(|| fields.borrow().clone());
Some(self.add_ty(TypeEnum::TObj {
obj_id,
params: params.into(),
fields: fields.into(),
}))
} else {
None
}
}
TypeEnum::TFunc(sig) => {
let FunSignature { args, ret, vars: params } = &*sig.borrow();
let new_params = self.subst_map(params, mapping);
let new_ret = self.subst(*ret, mapping);
let mut new_args = Cow::from(args);
for (i, t) in args.iter().enumerate() {
if let Some(t1) = self.subst(t.ty, mapping) {
let mut t = t.clone();
t.ty = t1;
new_args.to_mut()[i] = t;
}
}
if new_params.is_some() || new_ret.is_some() || matches!(new_args, Cow::Owned(..)) {
let params = new_params.unwrap_or_else(|| params.clone());
let ret = new_ret.unwrap_or_else(|| *ret);
let args = new_args.into_owned();
Some(
self.add_ty(TypeEnum::TFunc(
FunSignature { args, ret, vars: params }.into(),
)),
)
} else {
None
}
}
_ => unimplemented!(),
}
}
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{
let mut map2 = None;
for (k, v) in map.iter() {
if let Some(v1) = self.subst(*v, mapping) {
if map2.is_none() {
map2 = Some(map.clone());
}
*map2.as_mut().unwrap().get_mut(k).unwrap() = v1;
}
}
map2
}
fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> {
use TypeVarMeta::*;
if self.unification_table.unioned(a, b) {
return Err("Recursive type is prohibited.".to_owned());
}
let ty = self.unification_table.probe_value(b).clone();
match ty.as_ref() {
TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {}
TypeEnum::TVar { meta: Sequence(map), .. } => {
for t in map.borrow().values() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TVar { meta: Record(map), .. } => {
for t in map.borrow().values() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TCall(calls) => {
let call_store = self.calls.clone();
for t in calls
.borrow()
.iter()
.map(|call| {
let call = call_store[call.0].as_ref();
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
})
.flatten()
{
self.occur_check(a, *t)?;
}
}
TypeEnum::TTuple { ty } => {
for t in ty.iter() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => {
self.occur_check(a, *ty)?;
}
TypeEnum::TObj { params: map, .. } => {
for t in map.borrow().values() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TFunc(sig) => {
let FunSignature { args, ret, vars: params } = &*sig.borrow();
for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) {
self.occur_check(a, *t)?;
}
}
}
Ok(())
}
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
use TypeEnum::*;
let x = self.get_ty(a);
let y = self.get_ty(b);
match (x.as_ref(), y.as_ref()) {
(TVar { range: range1, .. }, TVar { meta, range: range2, .. }) => {
// we should restrict range2
let range1 = range1.borrow();
// new range is the intersection of them
// empty range indicates no constraint
if !range1.is_empty() {
let range2 = range2.borrow();
let mut range = Vec::new();
if range2.is_empty() {
range.extend_from_slice(&range1);
}
for v1 in range2.iter() {
for v2 in range1.iter() {
let result = self.get_intersection(*v1, *v2);
if let Ok(result) = result {
range.push(result.unwrap_or(*v2));
}
}
}
if range.is_empty() {
Err(())
} else {
let id = self.var_id + 1;
self.var_id += 1;
let ty = TVar { id, meta: meta.clone(), range: range.into() };
Ok(Some(self.unification_table.new_key(ty.into())))
}
} else {
Ok(Some(b))
}
}
(_, TVar { range, .. }) => {
// range should be restricted to the left hand side
let range = range.borrow();
if range.is_empty() {
Ok(Some(a))
} else {
for v in range.iter() {
let result = self.get_intersection(a, *v);
if let Ok(result) = result {
return Ok(result.or(Some(a)));
}
}
Err(())
}
}
(TVar { id, range, .. }, _) => {
self.check_var_compatibility(*id, b, &range.borrow()).or(Err(()))
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() {
return Err(());
}
let mut need_new = false;
let mut ty = ty1.clone();
for (a, b) in zip(ty1.iter(), ty2.iter()) {
let result = self.get_intersection(*a, *b)?;
ty.push(result.unwrap_or(*a));
if result.is_some() {
need_new = true;
}
}
if need_new {
Ok(Some(self.add_ty(TTuple { ty })))
} else {
Ok(None)
}
}
(TList { ty: ty1 }, TList { ty: ty2 }) => {
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
}
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
}
(TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => {
if id1 == id2 {
Ok(None)
} else {
Err(())
}
}
// don't deal with function shape for now
_ => Err(()),
}
}
fn check_var_compatibility(
&mut self,
id: u32,
b: Type,
range: &[Type],
) -> Result<Option<Type>, String> {
if range.is_empty() {
return Ok(None);
}
for t in range.iter() {
let result = self.get_intersection(*t, b);
if let Ok(result) = result {
return Ok(result);
}
}
return Err(format!(
"Cannot unify type variable {} with {} due to incompatible value range",
id,
self.get_ty(b).get_type_name()
));
}
}

View File

@ -0,0 +1,534 @@
use super::*;
use indoc::indoc;
use itertools::Itertools;
use std::collections::HashMap;
use test_case::test_case;
impl Unifier {
/// Check whether two types are equal.
fn eq(&mut self, a: Type, b: Type) -> bool {
use TypeVarMeta::*;
if a == b {
return true;
}
let (ty_a, ty_b) = {
let table = &mut self.unification_table;
if table.unioned(a, b) {
return true;
}
(table.probe_value(a).clone(), table.probe_value(b).clone())
};
match (&*ty_a, &*ty_b) {
(
TypeEnum::TVar { meta: Generic, id: id1, .. },
TypeEnum::TVar { meta: Generic, id: id2, .. },
) => id1 == id2,
(
TypeEnum::TVar { meta: Sequence(map1), .. },
TypeEnum::TVar { meta: Sequence(map2), .. },
) => self.map_eq(&map1.borrow(), &map2.borrow()),
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
}
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
self.eq(*ty1, *ty2)
}
(
TypeEnum::TVar { meta: Record(fields1), .. },
TypeEnum::TVar { meta: Record(fields2), .. },
) => self.map_eq(&fields1.borrow(), &fields2.borrow()),
(
TypeEnum::TObj { obj_id: id1, params: params1, .. },
TypeEnum::TObj { obj_id: id2, params: params2, .. },
) => id1 == id2 && self.map_eq(&params1.borrow(), &params2.borrow()),
// TCall and TFunc are not yet implemented
_ => false,
}
}
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{
if map1.len() != map2.len() {
return false;
}
for (k, v) in map1.iter() {
if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) {
return false;
}
}
true
}
}
struct TestEnvironment {
pub unifier: Unifier,
type_mapping: HashMap<String, Type>,
}
impl TestEnvironment {
fn new() -> TestEnvironment {
let mut unifier = Unifier::new();
let mut type_mapping = HashMap::new();
type_mapping.insert(
"int".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new().into(),
}),
);
type_mapping.insert(
"float".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new().into(),
}),
);
type_mapping.insert(
"bool".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new().into(),
}),
);
let (v0, id) = unifier.get_fresh_var();
type_mapping.insert(
"Foo".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
}),
);
TestEnvironment { unifier, type_mapping }
}
fn parse(&mut self, typ: &str, mapping: &Mapping<String>) -> Type {
let result = self.internal_parse(typ, mapping);
assert!(result.1.is_empty());
result.0
}
fn internal_parse<'a, 'b>(
&'a mut self,
typ: &'b str,
mapping: &Mapping<String>,
) -> (Type, &'b str) {
// for testing only, so we can just panic when the input is malformed
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or_else(|| typ.len());
match &typ[..end] {
"Tuple" => {
let mut s = &typ[end..];
assert!(&s[0..1] == "[");
let mut ty = Vec::new();
while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
ty.push(result.0);
s = result.1;
}
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
}
"List" => {
assert!(&typ[end..end + 1] == "[");
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
assert!(&s[0..1] == "]");
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
}
"Record" => {
let mut s = &typ[end..];
assert!(&s[0..1] == "[");
let mut fields = HashMap::new();
while &s[0..1] != "]" {
let eq = s.find('=').unwrap();
let key = s[1..eq].to_string();
let result = self.internal_parse(&s[eq + 1..], mapping);
fields.insert(key, result.0);
s = result.1;
}
(self.unifier.add_record(fields), &s[1..])
}
x => {
let mut s = &typ[end..];
let ty = mapping.get(x).cloned().unwrap_or_else(|| {
// mapping should be type variables, type_mapping should be concrete types
// we should not resolve the type of type variables.
let mut ty = *self.type_mapping.get(x).unwrap();
let te = self.unifier.get_ty(ty);
if let TypeEnum::TObj { params, .. } = &*te.as_ref() {
let params = params.borrow();
if !params.is_empty() {
assert!(&s[0..1] == "[");
let mut p = Vec::new();
while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
p.push(result.0);
s = result.1;
}
s = &s[1..];
ty = self
.unifier
.subst(ty, &params.keys().cloned().zip(p.into_iter()).collect())
.unwrap_or(ty);
}
}
ty
});
(ty, s)
}
}
}
}
#[test_case(2,
&[("v1", "v2"), ("v2", "float")],
&[("v1", "float"), ("v2", "float")]
; "simple variable"
)]
#[test_case(2,
&[("v1", "List[v2]"), ("v1", "List[float]")],
&[("v1", "List[float]"), ("v2", "float")]
; "list element"
)]
#[test_case(3,
&[
("v1", "Record[a=v3,b=v3]"),
("v2", "Record[b=float,c=v3]"),
("v1", "v2")
],
&[
("v1", "Record[a=float,b=float,c=float]"),
("v2", "Record[a=float,b=float,c=float]"),
("v3", "float")
]
; "record merge"
)]
#[test_case(3,
&[
("v1", "Record[a=float]"),
("v2", "Foo[v3]"),
("v1", "v2")
],
&[
("v1", "Foo[float]"),
("v3", "float")
]
; "record obj merge"
)]
/// Test cases for valid unifications.
fn test_unify(
variable_count: u32,
unify_pairs: &[(&'static str, &'static str)],
verify_pairs: &[(&'static str, &'static str)],
) {
let unify_count = unify_pairs.len();
// test all permutations...
for perm in unify_pairs.iter().permutations(unify_count) {
let mut env = TestEnvironment::new();
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.unifier.get_fresh_var();
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.
let mut pairs = Vec::new();
for (a, b) in perm.iter() {
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
pairs.push((t1, t2));
}
for (t1, t2) in pairs {
env.unifier.unify(t1, t2).unwrap();
}
for (a, b) in verify_pairs.iter() {
println!("{} = {}", a, b);
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
assert!(env.unifier.eq(t1, t2));
}
}
}
#[test_case(2,
&[
("v1", "Tuple[int]"),
("v2", "List[int]"),
],
(("v1", "v2"), "Cannot unify TList with TTuple")
; "type mismatch"
)]
#[test_case(2,
&[
("v1", "Tuple[int]"),
("v2", "Tuple[float]"),
],
(("v1", "v2"), "Cannot unify objects with ID 0 and 1")
; "tuple parameter mismatch"
)]
#[test_case(2,
&[
("v1", "Tuple[int,int]"),
("v2", "Tuple[int]"),
],
(("v1", "v2"), "Cannot unify tuples with length 2 and 1")
; "tuple length mismatch"
)]
#[test_case(3,
&[
("v1", "Record[a=float,b=int]"),
("v2", "Foo[v3]"),
],
(("v1", "v2"), "No such attribute b")
; "record obj merge"
)]
#[test_case(2,
&[
("v1", "List[v2]"),
],
(("v1", "v2"), "Recursive type is prohibited.")
; "recursive type for lists"
)]
/// Test cases for invalid unifications.
fn test_invalid_unification(
variable_count: u32,
unify_pairs: &[(&'static str, &'static str)],
errornous_pair: ((&'static str, &'static str), &'static str),
) {
let mut env = TestEnvironment::new();
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.unifier.get_fresh_var();
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.
let mut pairs = Vec::new();
for (a, b) in unify_pairs.iter() {
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
pairs.push((t1, t2));
}
let (t1, t2) =
(env.parse(errornous_pair.0 .0, &mapping), env.parse(errornous_pair.0 .1, &mapping));
for (a, b) in pairs {
env.unifier.unify(a, b).unwrap();
}
assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string()));
}
#[test]
fn test_virtual() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let fun = env.unifier.add_ty(TypeEnum::TFunc(
FunSignature { args: vec![], ret: int, vars: HashMap::new() }.into(),
));
let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5),
fields: [("f".to_string(), fun), ("a".to_string(), int)]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
.into(),
params: HashMap::new().into(),
});
let v0 = env.unifier.get_fresh_var().0;
let v1 = env.unifier.get_fresh_var().0;
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
let c = env.unifier.add_record([("f".to_string(), v1)].iter().cloned().collect());
env.unifier.unify(a, b).unwrap();
env.unifier.unify(b, c).unwrap();
assert!(env.unifier.eq(v1, fun));
let d = env.unifier.add_record([("a".to_string(), v1)].iter().cloned().collect());
assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string()));
let d = env.unifier.add_record([("b".to_string(), v1)].iter().cloned().collect());
assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string()));
}
#[test]
fn test_typevar_range() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let boolean = env.parse("bool", &HashMap::new());
let float = env.parse("float", &HashMap::new());
let int_list = env.parse("List[int]", &HashMap::new());
let float_list = env.parse("List[float]", &HashMap::new());
// unification between v and int
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
env.unifier.unify(int, v).unwrap();
// unification between v and List[int]
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
assert_eq!(
env.unifier.unify(int_list, v),
Err("Cannot unify type variable 3 with TList due to incompatible value range".to_string())
);
// unification between v and float
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
assert_eq!(
env.unifier.unify(float, v),
Err("Cannot unify type variable 4 with TObj due to incompatible value range".to_string())
);
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 });
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
// unification between v and int
// where v in (int, List[v1]), v1 in (int, bool)
env.unifier.unify(int, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
// unification between v and List[int]
// where v in (int, List[v1]), v1 in (int, bool)
env.unifier.unify(int_list, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
// unification between v and List[float]
// where v in (int, List[v1]), v1 in (int, bool)
assert_eq!(
env.unifier.unify(float_list, v),
Err("Cannot unify type variable 8 with TList due to incompatible value range".to_string())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
env.unifier.unify(a, b).unwrap();
env.unifier.unify(a, float).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
env.unifier.unify(a, b).unwrap();
assert_eq!(
env.unifier.unify(a, int),
Err("Cannot unify type variable 12 with TObj due to incompatible value range".into())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0;
env.unifier.unify(a_list, b_list).unwrap();
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
env.unifier.unify(a_list, float_list).unwrap();
// previous unifications should not affect a and b
env.unifier.unify(a, int).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int });
assert_eq!(
env.unifier.unify(a_list, int_list),
Err("Cannot unify type variable 19 with TObj due to incompatible value range".into())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var().0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
assert_eq!(
env.unifier.unify(b, boolean),
Err("Cannot unify type variable 21 with TObj due to incompatible value range".into())
);
}
#[test]
fn test_rigid_var() {
let mut env = TestEnvironment::new();
let a = env.unifier.get_fresh_rigid_var().0;
let b = env.unifier.get_fresh_rigid_var().0;
let x = env.unifier.get_fresh_var().0;
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
let int = env.parse("int", &HashMap::new());
let list_int = env.parse("List[int]", &HashMap::new());
assert_eq!(env.unifier.unify(a, b), Err("Cannot unify TRigidVar with TRigidVar".to_string()));
env.unifier.unify(list_a, list_x).unwrap();
assert_eq!(
env.unifier.unify(list_x, list_int),
Err("Cannot unify TObj with TRigidVar".to_string())
);
env.unifier.replace_rigid_var(a, int);
env.unifier.unify(list_x, list_int).unwrap();
}
#[test]
fn test_instantiation() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let boolean = env.parse("bool", &HashMap::new());
let float = env.parse("float", &HashMap::new());
let list_int = env.parse("List[int]", &HashMap::new());
let obj_map: HashMap<_, _> =
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int]).0;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float]).0;
let t = env.unifier.get_fresh_rigid_var().0;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t]).0;
// t = TypeVar('t')
// v = TypeVar('v', int, bool)
// v1 = TypeVar('v1', 'list[v]', int)
// v2 = TypeVar('v2', 'list[int]', float)
// v3 = TypeVar('v3', tuple[v, v1, v2], t)
// what values can v3 take?
let types = env.unifier.get_instantiations(v3).unwrap();
let expected_types = indoc! {"
tuple[bool, int, float]
tuple[bool, int, list[int]]
tuple[bool, list[bool], float]
tuple[bool, list[bool], list[int]]
tuple[bool, list[int], float]
tuple[bool, list[int], list[int]]
tuple[int, int, float]
tuple[int, int, list[int]]
tuple[int, list[bool], float]
tuple[int, list[bool], list[int]]
tuple[int, list[int], float]
tuple[int, list[int], list[int]]
v5"
}
.split('\n')
.collect_vec();
let types = types
.iter()
.map(|ty| {
env.unifier.stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| {
format!("v{}", i)
})
})
.sorted()
.collect_vec();
assert_eq!(expected_types, types);
}

View File

@ -0,0 +1,87 @@
use std::rc::Rc;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct UnificationKey(usize);
pub struct UnificationTable<V> {
parents: Vec<usize>,
ranks: Vec<u32>,
values: Vec<V>,
}
impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> {
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new() }
}
pub fn new_key(&mut self, v: V) -> UnificationKey {
let index = self.parents.len();
self.parents.push(index);
self.ranks.push(0);
self.values.push(v);
UnificationKey(index)
}
pub fn unify(&mut self, a: UnificationKey, b: UnificationKey) {
let mut a = self.find(a);
let mut b = self.find(b);
if a == b {
return;
}
if self.ranks[a] < self.ranks[b] {
std::mem::swap(&mut a, &mut b);
}
self.parents[b] = a;
if self.ranks[a] == self.ranks[b] {
self.ranks[a] += 1;
}
}
pub fn probe_value(&mut self, a: UnificationKey) -> &V {
let index = self.find(a);
&self.values[index]
}
pub fn set_value(&mut self, a: UnificationKey, v: V) {
let index = self.find(a);
self.values[index] = v;
}
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
self.find(a) == self.find(b)
}
pub fn get_representative(&mut self, key: UnificationKey) -> UnificationKey {
UnificationKey(self.find(key))
}
fn find(&mut self, key: UnificationKey) -> usize {
let mut root = key.0;
let mut parent = self.parents[root];
while root != parent {
// a = parent.parent
let a = self.parents[parent];
// root.parent = parent.parent
self.parents[root] = a;
root = parent;
// parent = root.parent
parent = a;
}
parent
}
}
impl<V> UnificationTable<Rc<V>>
where
V: Clone,
{
pub fn get_send(&self) -> UnificationTable<V> {
let values = self.values.iter().map(|v| v.as_ref().clone()).collect();
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values }
}
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.iter().cloned().map(Rc::new).collect();
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values }
}
}

View File

@ -1,60 +0,0 @@
use std::collections::HashMap;
use std::rc::Rc;
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct PrimitiveId(pub(crate) usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct ClassId(pub(crate) usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct ParamId(pub(crate) usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct VariableId(pub(crate) usize);
#[derive(PartialEq, Eq, Clone, Hash, Debug)]
pub enum TypeEnum {
BotType,
SelfType,
PrimitiveType(PrimitiveId),
ClassType(ClassId),
VirtualClassType(ClassId),
ParametricType(ParamId, Vec<Rc<TypeEnum>>),
TypeVariable(VariableId),
}
pub type Type = Rc<TypeEnum>;
#[derive(Clone)]
pub struct FnDef {
// we assume methods first argument to be SelfType,
// so the first argument is not contained here
pub args: Vec<Type>,
pub result: Option<Type>,
}
#[derive(Clone)]
pub struct TypeDef<'a> {
pub name: &'a str,
pub fields: HashMap<&'a str, Type>,
pub methods: HashMap<&'a str, FnDef>,
}
#[derive(Clone)]
pub struct ClassDef<'a> {
pub base: TypeDef<'a>,
pub parents: Vec<ClassId>,
}
#[derive(Clone)]
pub struct ParametricDef<'a> {
pub base: TypeDef<'a>,
pub params: Vec<VariableId>,
}
#[derive(Clone)]
pub struct VarDef<'a> {
pub name: &'a str,
pub bound: Vec<Type>,
}

View File

@ -4,6 +4,6 @@ in
pkgs.stdenv.mkDerivation { pkgs.stdenv.mkDerivation {
name = "nac3-env"; name = "nac3-env";
buildInputs = with pkgs; [ buildInputs = with pkgs; [
llvm_10 clang_10 cargo rustc libffi libxml2 clippy llvm_11 clang_11 cargo rustc libffi libxml2 clippy
]; ];
} }