forked from M-Labs/nac3
Merge pull request 'hm-inference' (#6) from hm-inference into master
Reviewed-on: M-Labs/nac3#6
This commit is contained in:
commit
f205a8282a
File diff suppressed because it is too large
Load Diff
20
README.md
20
README.md
|
@ -15,20 +15,20 @@ caller to specify which methods should be compiled). After type checking, the
|
|||
compiler would analyse the set of functions/classes that are used and perform
|
||||
code generation.
|
||||
|
||||
|
||||
Symbol resolver:
|
||||
- Str -> Nac3Type
|
||||
- Str -> Value
|
||||
|
||||
value could be integer values, boolean values, bytes (for memcpy), function ID
|
||||
(full name + concrete type)
|
||||
|
||||
## Current Plan
|
||||
|
||||
1. Write out the syntax-directed type checking/inferencing rules. Fix the rule
|
||||
for type variable instantiation.
|
||||
2. Update the library dependencies and rewrite some of the type checking code.
|
||||
3. Design the symbol resolver API.
|
||||
4. Move tests from code to external files to cleanup the code.
|
||||
Type checking:
|
||||
|
||||
- [x] Basic interface for symbol resolver.
|
||||
- [x] Track location information in context object (for diagnostics).
|
||||
- [ ] Refactor old expression and statement type inference code. (anto)
|
||||
- [ ] Error diagnostics utilities. (pca)
|
||||
- [ ] Move tests to external files, write scripts for testing. (pca)
|
||||
- [ ] Implement function type checking (instantiate bounded type parameters),
|
||||
loop unrolling, type inference for lists with virtual objects. (pca)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -7,5 +7,13 @@ edition = "2018"
|
|||
[dependencies]
|
||||
num-bigint = "0.3"
|
||||
num-traits = "0.2"
|
||||
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] }
|
||||
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
|
||||
rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" }
|
||||
itertools = "0.10.1"
|
||||
crossbeam = "0.8.1"
|
||||
parking_lot = "0.11.1"
|
||||
rayon = "1.5.1"
|
||||
|
||||
[dev-dependencies]
|
||||
test-case = "1.2.0"
|
||||
indoc = "1.0"
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
use_small_heuristics = "Max"
|
|
@ -0,0 +1,527 @@
|
|||
use std::{collections::HashMap, convert::TryInto, iter::once};
|
||||
|
||||
use super::{get_llvm_type, CodeGenContext};
|
||||
use crate::{
|
||||
symbol_resolver::SymbolValue,
|
||||
top_level::{DefinitionId, TopLevelDef},
|
||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||
};
|
||||
use inkwell::{
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::BasicValueEnum,
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::{chain, izip, zip, Itertools};
|
||||
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
|
||||
|
||||
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
|
||||
let mut vars = obj
|
||||
.map(|ty| {
|
||||
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
|
||||
params.borrow().clone()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
vars.extend(fun.vars.iter());
|
||||
let sorted = vars.keys().sorted();
|
||||
sorted
|
||||
.map(|id| {
|
||||
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
|
||||
})
|
||||
.join(", ")
|
||||
}
|
||||
|
||||
pub fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize {
|
||||
let obj_id = match &*self.unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { obj_id, .. } => *obj_id,
|
||||
// we cannot have other types, virtual type should be handled by function calls
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let def = &self.top_level.definitions.read()[obj_id.0];
|
||||
let index = if let TopLevelDef::Class { fields, .. } = &*def.read() {
|
||||
fields.iter().find_position(|x| x.0 == attr).unwrap().0
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
index
|
||||
}
|
||||
|
||||
fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> {
|
||||
match val {
|
||||
SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(),
|
||||
SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(),
|
||||
SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(),
|
||||
SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(),
|
||||
SymbolValue::Tuple(ls) => {
|
||||
let vals = ls.iter().map(|v| self.gen_symbol_val(v)).collect_vec();
|
||||
let fields = vals.iter().map(|v| v.get_type()).collect_vec();
|
||||
let ty = self.ctx.struct_type(&fields, false);
|
||||
let ptr = self.builder.build_alloca(ty, "tuple");
|
||||
let zero = self.ctx.i32_type().const_zero();
|
||||
unsafe {
|
||||
for (i, val) in vals.into_iter().enumerate() {
|
||||
let p = ptr.const_in_bounds_gep(&[
|
||||
zero,
|
||||
self.ctx.i32_type().const_int(i as u64, false),
|
||||
]);
|
||||
self.builder.build_store(p, val);
|
||||
}
|
||||
}
|
||||
ptr.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> {
|
||||
get_llvm_type(self.ctx, &mut self.unifier, self.top_level, &mut self.type_cache, ty)
|
||||
}
|
||||
|
||||
fn gen_call(
|
||||
&mut self,
|
||||
obj: Option<(Type, BasicValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
|
||||
ret: Type,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
|
||||
let defs = self.top_level.definitions.read();
|
||||
let definition = defs.get(fun.1 .0).unwrap();
|
||||
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
|
||||
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
|
||||
// TODO: codegen for function that are not yet generated
|
||||
unimplemented!()
|
||||
});
|
||||
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
|
||||
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
|
||||
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) {
|
||||
self.ctx.void_type().fn_type(¶ms, false)
|
||||
} else {
|
||||
self.get_llvm_type(ret).fn_type(¶ms, false)
|
||||
};
|
||||
self.module.add_function(symbol, fun_ty, None)
|
||||
});
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
for (key, value) in params.into_iter() {
|
||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||
}
|
||||
// default value handling
|
||||
for k in keys.into_iter() {
|
||||
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
|
||||
}
|
||||
// reorder the parameters
|
||||
let params =
|
||||
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
|
||||
self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left()
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
val
|
||||
}
|
||||
|
||||
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
|
||||
match value {
|
||||
Constant::Bool(v) => {
|
||||
assert!(self.unifier.unioned(ty, self.primitives.bool));
|
||||
let ty = self.ctx.bool_type();
|
||||
ty.const_int(if *v { 1 } else { 0 }, false).into()
|
||||
}
|
||||
Constant::Int(v) => {
|
||||
let ty = if self.unifier.unioned(ty, self.primitives.int32) {
|
||||
self.ctx.i32_type()
|
||||
} else if self.unifier.unioned(ty, self.primitives.int64) {
|
||||
self.ctx.i64_type()
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
ty.const_int(v.try_into().unwrap(), false).into()
|
||||
}
|
||||
Constant::Float(v) => {
|
||||
assert!(self.unifier.unioned(ty, self.primitives.float));
|
||||
let ty = self.ctx.f64_type();
|
||||
ty.const_float(*v).into()
|
||||
}
|
||||
Constant::Tuple(v) => {
|
||||
let ty = self.unifier.get_ty(ty);
|
||||
let types =
|
||||
if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() };
|
||||
let values = zip(types.into_iter(), v.iter())
|
||||
.map(|(ty, v)| self.gen_const(v, ty))
|
||||
.collect_vec();
|
||||
let types = values.iter().map(BasicValueEnum::get_type).collect_vec();
|
||||
let ty = self.ctx.struct_type(&types, false);
|
||||
ty.const_named_struct(&values).into()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_int_ops(
|
||||
&mut self,
|
||||
op: &Operator,
|
||||
lhs: BasicValueEnum<'ctx>,
|
||||
rhs: BasicValueEnum<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let (lhs, rhs) =
|
||||
if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) {
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
match op {
|
||||
Operator::Add => self.builder.build_int_add(lhs, rhs, "add").into(),
|
||||
Operator::Sub => self.builder.build_int_sub(lhs, rhs, "sub").into(),
|
||||
Operator::Mult => self.builder.build_int_mul(lhs, rhs, "mul").into(),
|
||||
Operator::Div => {
|
||||
let float = self.ctx.f64_type();
|
||||
let left = self.builder.build_signed_int_to_float(lhs, float, "i2f");
|
||||
let right = self.builder.build_signed_int_to_float(rhs, float, "i2f");
|
||||
self.builder.build_float_div(left, right, "fdiv").into()
|
||||
}
|
||||
Operator::Mod => self.builder.build_int_signed_rem(lhs, rhs, "mod").into(),
|
||||
Operator::BitOr => self.builder.build_or(lhs, rhs, "or").into(),
|
||||
Operator::BitXor => self.builder.build_xor(lhs, rhs, "xor").into(),
|
||||
Operator::BitAnd => self.builder.build_and(lhs, rhs, "and").into(),
|
||||
Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(),
|
||||
Operator::RShift => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(),
|
||||
Operator::FloorDiv => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(),
|
||||
// special implementation?
|
||||
Operator::Pow => unimplemented!(),
|
||||
Operator::MatMult => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_float_ops(
|
||||
&mut self,
|
||||
op: &Operator,
|
||||
lhs: BasicValueEnum<'ctx>,
|
||||
rhs: BasicValueEnum<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let (lhs, rhs) = if let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) =
|
||||
(lhs, rhs)
|
||||
{
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
match op {
|
||||
Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").into(),
|
||||
Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").into(),
|
||||
Operator::Mult => self.builder.build_float_mul(lhs, rhs, "fmul").into(),
|
||||
Operator::Div => self.builder.build_float_div(lhs, rhs, "fdiv").into(),
|
||||
Operator::Mod => self.builder.build_float_rem(lhs, rhs, "fmod").into(),
|
||||
Operator::FloorDiv => {
|
||||
let div = self.builder.build_float_div(lhs, rhs, "fdiv");
|
||||
let floor_intrinsic =
|
||||
self.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
|
||||
let float = self.ctx.f64_type();
|
||||
let fn_type = float.fn_type(&[float.into()], false);
|
||||
self.module.add_function("llvm.floor.f64", fn_type, None)
|
||||
});
|
||||
self.builder
|
||||
.build_call(floor_intrinsic, &[div.into()], "floor")
|
||||
.try_as_basic_value()
|
||||
.left()
|
||||
.unwrap()
|
||||
}
|
||||
// special implementation?
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> {
|
||||
let zero = self.ctx.i32_type().const_int(0, false);
|
||||
match &expr.node {
|
||||
ExprKind::Constant { value, .. } => {
|
||||
let ty = expr.custom.unwrap();
|
||||
self.gen_const(value, ty)
|
||||
}
|
||||
ExprKind::Name { id, .. } => {
|
||||
let ptr = self.var_assignment.get(id).unwrap();
|
||||
let primitives = &self.primitives;
|
||||
// we should only dereference primitive types
|
||||
if [primitives.int32, primitives.int64, primitives.float, primitives.bool]
|
||||
.contains(&self.unifier.get_representative(expr.custom.unwrap()))
|
||||
{
|
||||
self.builder.build_load(*ptr, "load")
|
||||
} else {
|
||||
(*ptr).into()
|
||||
}
|
||||
}
|
||||
ExprKind::List { elts, .. } => {
|
||||
// this shall be optimized later for constant primitive lists...
|
||||
// we should use memcpy for that instead of generating thousands of stores
|
||||
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
||||
let ty = if elements.is_empty() {
|
||||
self.ctx.i32_type().into()
|
||||
} else {
|
||||
elements[0].get_type()
|
||||
};
|
||||
let arr_ptr = self.builder.build_array_alloca(
|
||||
ty,
|
||||
self.ctx.i32_type().const_int(elements.len() as u64, false),
|
||||
"tmparr",
|
||||
);
|
||||
let arr_ty = self.ctx.struct_type(
|
||||
&[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()],
|
||||
false,
|
||||
);
|
||||
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
|
||||
unsafe {
|
||||
self.builder.build_store(
|
||||
arr_str_ptr.const_in_bounds_gep(&[zero, zero]),
|
||||
self.ctx.i32_type().const_int(elements.len() as u64, false),
|
||||
);
|
||||
self.builder.build_store(
|
||||
arr_str_ptr
|
||||
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(1, false)]),
|
||||
arr_ptr,
|
||||
);
|
||||
let arr_offset = self.ctx.i32_type().const_int(1, false);
|
||||
for (i, v) in elements.iter().enumerate() {
|
||||
let ptr = self.builder.build_in_bounds_gep(
|
||||
arr_ptr,
|
||||
&[zero, arr_offset, self.ctx.i32_type().const_int(i as u64, false)],
|
||||
"arr_element",
|
||||
);
|
||||
self.builder.build_store(ptr, *v);
|
||||
}
|
||||
}
|
||||
arr_str_ptr.into()
|
||||
}
|
||||
ExprKind::Tuple { elts, .. } => {
|
||||
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
||||
let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec();
|
||||
let tuple_ty = self.ctx.struct_type(&element_ty, false);
|
||||
let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple");
|
||||
for (i, v) in element_val.into_iter().enumerate() {
|
||||
unsafe {
|
||||
let ptr = tuple_ptr.const_in_bounds_gep(&[
|
||||
zero,
|
||||
self.ctx.i32_type().const_int(i as u64, false),
|
||||
]);
|
||||
self.builder.build_store(ptr, v);
|
||||
}
|
||||
}
|
||||
tuple_ptr.into()
|
||||
}
|
||||
ExprKind::Attribute { value, attr, .. } => {
|
||||
// note that we would handle class methods directly in calls
|
||||
let index = self.get_attr_index(value.custom.unwrap(), attr);
|
||||
let val = self.gen_expr(value);
|
||||
let ptr = if let BasicValueEnum::PointerValue(v) = val {
|
||||
v
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
unsafe {
|
||||
let ptr = ptr.const_in_bounds_gep(&[
|
||||
zero,
|
||||
self.ctx.i32_type().const_int(index as u64, false),
|
||||
]);
|
||||
self.builder.build_load(ptr, "field")
|
||||
}
|
||||
}
|
||||
ExprKind::BoolOp { op, values } => {
|
||||
// requires conditional branches for short-circuiting...
|
||||
let left = if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]) {
|
||||
left
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let a_bb = self.ctx.append_basic_block(current, "a");
|
||||
let b_bb = self.ctx.append_basic_block(current, "b");
|
||||
let cont_bb = self.ctx.append_basic_block(current, "cont");
|
||||
self.builder.build_conditional_branch(left, a_bb, b_bb);
|
||||
let (a, b) = match op {
|
||||
Boolop::Or => {
|
||||
self.builder.position_at_end(a_bb);
|
||||
let a = self.ctx.bool_type().const_int(1, false);
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
self.builder.position_at_end(b_bb);
|
||||
let b = if let BasicValueEnum::IntValue(b) = self.gen_expr(&values[1]) {
|
||||
b
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
(a, b)
|
||||
}
|
||||
Boolop::And => {
|
||||
self.builder.position_at_end(a_bb);
|
||||
let a = if let BasicValueEnum::IntValue(a) = self.gen_expr(&values[1]) {
|
||||
a
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
self.builder.position_at_end(b_bb);
|
||||
let b = self.ctx.bool_type().const_int(0, false);
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
(a, b)
|
||||
}
|
||||
};
|
||||
self.builder.position_at_end(cont_bb);
|
||||
let phi = self.builder.build_phi(self.ctx.bool_type(), "phi");
|
||||
phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]);
|
||||
phi.as_basic_value()
|
||||
}
|
||||
ExprKind::BinOp { op, left, right } => {
|
||||
let ty1 = self.unifier.get_representative(left.custom.unwrap());
|
||||
let ty2 = self.unifier.get_representative(right.custom.unwrap());
|
||||
let left = self.gen_expr(left);
|
||||
let right = self.gen_expr(right);
|
||||
|
||||
// we can directly compare the types, because we've got their representatives
|
||||
// which would be unchanged until further unification, which we would never do
|
||||
// when doing code generation for function instances
|
||||
if ty1 == ty2 && [self.primitives.int32, self.primitives.int64].contains(&ty1) {
|
||||
self.gen_int_ops(op, left, right)
|
||||
} else if ty1 == ty2 && self.primitives.float == ty1 {
|
||||
self.gen_float_ops(op, left, right)
|
||||
} else {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
ExprKind::UnaryOp { op, operand } => {
|
||||
let ty = self.unifier.get_representative(operand.custom.unwrap());
|
||||
let val = self.gen_expr(operand);
|
||||
if ty == self.primitives.bool {
|
||||
let val =
|
||||
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
||||
match op {
|
||||
ast::Unaryop::Invert | ast::Unaryop::Not => {
|
||||
self.builder.build_not(val, "not").into()
|
||||
}
|
||||
_ => val.into(),
|
||||
}
|
||||
} else if [self.primitives.int32, self.primitives.int64].contains(&ty) {
|
||||
let val =
|
||||
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
||||
match op {
|
||||
ast::Unaryop::USub => self.builder.build_int_neg(val, "neg").into(),
|
||||
ast::Unaryop::Invert => self.builder.build_not(val, "not").into(),
|
||||
ast::Unaryop::Not => self
|
||||
.builder
|
||||
.build_int_compare(
|
||||
inkwell::IntPredicate::EQ,
|
||||
val,
|
||||
val.get_type().const_zero(),
|
||||
"not",
|
||||
)
|
||||
.into(),
|
||||
_ => val.into(),
|
||||
}
|
||||
} else if ty == self.primitives.float {
|
||||
let val = if let BasicValueEnum::FloatValue(val) = val {
|
||||
val
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
match op {
|
||||
ast::Unaryop::USub => self.builder.build_float_neg(val, "neg").into(),
|
||||
ast::Unaryop::Not => self
|
||||
.builder
|
||||
.build_float_compare(
|
||||
inkwell::FloatPredicate::OEQ,
|
||||
val,
|
||||
val.get_type().const_zero(),
|
||||
"not",
|
||||
)
|
||||
.into(),
|
||||
_ => val.into(),
|
||||
}
|
||||
} else {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
ExprKind::Compare { left, ops, comparators } => {
|
||||
izip!(
|
||||
chain(once(left.as_ref()), comparators.iter()),
|
||||
comparators.iter(),
|
||||
ops.iter(),
|
||||
)
|
||||
.fold(None, |prev, (lhs, rhs, op)| {
|
||||
let ty = self.unifier.get_representative(lhs.custom.unwrap());
|
||||
let current =
|
||||
if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
|
||||
.contains(&ty)
|
||||
{
|
||||
let (lhs, rhs) = if let (
|
||||
BasicValueEnum::IntValue(lhs),
|
||||
BasicValueEnum::IntValue(rhs),
|
||||
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
||||
{
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
|
||||
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
|
||||
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
|
||||
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
|
||||
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
|
||||
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.builder.build_int_compare(op, lhs, rhs, "cmp")
|
||||
} else if ty == self.primitives.float {
|
||||
let (lhs, rhs) = if let (
|
||||
BasicValueEnum::FloatValue(lhs),
|
||||
BasicValueEnum::FloatValue(rhs),
|
||||
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
||||
{
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
|
||||
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
|
||||
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
|
||||
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
|
||||
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
|
||||
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.builder.build_float_compare(op, lhs, rhs, "cmp")
|
||||
} else {
|
||||
unimplemented!()
|
||||
};
|
||||
prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current))
|
||||
})
|
||||
.unwrap()
|
||||
.into() // as there should be at least 1 element, it should never be none
|
||||
}
|
||||
ExprKind::IfExp { test, body, orelse } => {
|
||||
let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test) {
|
||||
test
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let then_bb = self.ctx.append_basic_block(current, "then");
|
||||
let else_bb = self.ctx.append_basic_block(current, "else");
|
||||
let cont_bb = self.ctx.append_basic_block(current, "cont");
|
||||
self.builder.build_conditional_branch(test, then_bb, else_bb);
|
||||
self.builder.position_at_end(then_bb);
|
||||
let a = self.gen_expr(body);
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
self.builder.position_at_end(else_bb);
|
||||
let b = self.gen_expr(orelse);
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
self.builder.position_at_end(cont_bb);
|
||||
let phi = self.builder.build_phi(a.get_type(), "ifexpr");
|
||||
phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]);
|
||||
phi.as_basic_value()
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,343 @@
|
|||
use crate::{
|
||||
symbol_resolver::SymbolResolver,
|
||||
top_level::{TopLevelContext, TopLevelDef},
|
||||
typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{FunSignature, Type, TypeEnum, Unifier},
|
||||
},
|
||||
};
|
||||
use crossbeam::channel::{unbounded, Receiver, Sender};
|
||||
use inkwell::{
|
||||
basic_block::BasicBlock,
|
||||
builder::Builder,
|
||||
context::Context,
|
||||
module::Module,
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::PointerValue,
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use rustpython_parser::ast::Stmt;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use std::thread;
|
||||
|
||||
mod expr;
|
||||
mod stmt;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub struct CodeGenContext<'ctx, 'a> {
|
||||
pub ctx: &'ctx Context,
|
||||
pub builder: Builder<'ctx>,
|
||||
pub module: Module<'ctx>,
|
||||
pub top_level: &'a TopLevelContext,
|
||||
pub unifier: Unifier,
|
||||
pub resolver: Arc<dyn SymbolResolver>,
|
||||
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
|
||||
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||
pub primitives: PrimitiveStore,
|
||||
// stores the alloca for variables
|
||||
pub init_bb: BasicBlock<'ctx>,
|
||||
// where continue and break should go to respectively
|
||||
// the first one is the test_bb, and the second one is bb after the loop
|
||||
pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
|
||||
}
|
||||
|
||||
type Fp = Box<dyn Fn(&Module) + Send + Sync>;
|
||||
|
||||
pub struct WithCall {
|
||||
fp: Fp,
|
||||
}
|
||||
|
||||
impl WithCall {
|
||||
pub fn new(fp: Fp) -> WithCall {
|
||||
WithCall { fp }
|
||||
}
|
||||
|
||||
pub fn run<'ctx>(&self, m: &Module<'ctx>) {
|
||||
(self.fp)(m)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WorkerRegistry {
|
||||
sender: Arc<Sender<Option<CodeGenTask>>>,
|
||||
receiver: Arc<Receiver<Option<CodeGenTask>>>,
|
||||
panicked: AtomicBool,
|
||||
task_count: Mutex<usize>,
|
||||
thread_count: usize,
|
||||
wait_condvar: Condvar,
|
||||
}
|
||||
|
||||
impl WorkerRegistry {
|
||||
pub fn create_workers(
|
||||
names: &[&str],
|
||||
top_level_ctx: Arc<TopLevelContext>,
|
||||
f: Arc<WithCall>,
|
||||
) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
|
||||
let (sender, receiver) = unbounded();
|
||||
let task_count = Mutex::new(0);
|
||||
let wait_condvar = Condvar::new();
|
||||
|
||||
let registry = Arc::new(WorkerRegistry {
|
||||
sender: Arc::new(sender),
|
||||
receiver: Arc::new(receiver),
|
||||
thread_count: names.len(),
|
||||
panicked: AtomicBool::new(false),
|
||||
task_count,
|
||||
wait_condvar,
|
||||
});
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for name in names.iter() {
|
||||
let top_level_ctx = top_level_ctx.clone();
|
||||
let registry = registry.clone();
|
||||
let registry2 = registry.clone();
|
||||
let name = name.to_string();
|
||||
let f = f.clone();
|
||||
let handle = thread::spawn(move || {
|
||||
registry.worker_thread(name, top_level_ctx, f);
|
||||
});
|
||||
let handle = thread::spawn(move || {
|
||||
if let Err(e) = handle.join() {
|
||||
if let Some(e) = e.downcast_ref::<&'static str>() {
|
||||
eprintln!("Got an error: {}", e);
|
||||
} else {
|
||||
eprintln!("Got an unknown error: {:?}", e);
|
||||
}
|
||||
registry2.panicked.store(true, Ordering::SeqCst);
|
||||
registry2.wait_condvar.notify_all();
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
(registry, handles)
|
||||
}
|
||||
|
||||
pub fn wait_tasks_complete(&self, handles: Vec<thread::JoinHandle<()>>) {
|
||||
{
|
||||
let mut count = self.task_count.lock();
|
||||
while *count != 0 {
|
||||
if self.panicked.load(Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
self.wait_condvar.wait(&mut count);
|
||||
}
|
||||
}
|
||||
for _ in 0..self.thread_count {
|
||||
self.sender.send(None).unwrap();
|
||||
}
|
||||
{
|
||||
let mut count = self.task_count.lock();
|
||||
while *count != self.thread_count {
|
||||
if self.panicked.load(Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
self.wait_condvar.wait(&mut count);
|
||||
}
|
||||
}
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
if self.panicked.load(Ordering::SeqCst) {
|
||||
panic!("tasks panicked");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_task(&self, task: CodeGenTask) {
|
||||
*self.task_count.lock() += 1;
|
||||
self.sender.send(Some(task)).unwrap();
|
||||
}
|
||||
|
||||
fn worker_thread(
|
||||
&self,
|
||||
module_name: String,
|
||||
top_level_ctx: Arc<TopLevelContext>,
|
||||
f: Arc<WithCall>,
|
||||
) {
|
||||
let context = Context::create();
|
||||
let mut builder = context.create_builder();
|
||||
let mut module = context.create_module(&module_name);
|
||||
|
||||
while let Some(task) = self.receiver.recv().unwrap() {
|
||||
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
|
||||
builder = result.0;
|
||||
module = result.1;
|
||||
*self.task_count.lock() -= 1;
|
||||
self.wait_condvar.notify_all();
|
||||
}
|
||||
|
||||
// do whatever...
|
||||
let mut lock = self.task_count.lock();
|
||||
module.verify().unwrap();
|
||||
f.run(&module);
|
||||
*lock += 1;
|
||||
self.wait_condvar.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CodeGenTask {
|
||||
pub subst: Vec<(Type, Type)>,
|
||||
pub symbol_name: String,
|
||||
pub signature: FunSignature,
|
||||
pub body: Vec<Stmt<Option<Type>>>,
|
||||
pub unifier_index: usize,
|
||||
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||
}
|
||||
|
||||
fn get_llvm_type<'ctx>(
|
||||
ctx: &'ctx Context,
|
||||
unifier: &mut Unifier,
|
||||
top_level: &TopLevelContext,
|
||||
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||
ty: Type,
|
||||
) -> BasicTypeEnum<'ctx> {
|
||||
use TypeEnum::*;
|
||||
// we assume the type cache should already contain primitive types,
|
||||
// and they should be passed by value instead of passing as pointer.
|
||||
type_cache.get(&unifier.get_representative(ty)).cloned().unwrap_or_else(|| {
|
||||
match &*unifier.get_ty(ty) {
|
||||
TObj { obj_id, fields, .. } => {
|
||||
// a struct with fields in the order of declaration
|
||||
let defs = top_level.definitions.read();
|
||||
let definition = defs.get(obj_id.0).unwrap();
|
||||
let ty = if let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read()
|
||||
{
|
||||
let fields = fields.borrow();
|
||||
let fields = fields_list
|
||||
.iter()
|
||||
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0]))
|
||||
.collect_vec();
|
||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
ty
|
||||
}
|
||||
TTuple { ty } => {
|
||||
// a struct with fields in the order present in the tuple
|
||||
let fields = ty
|
||||
.iter()
|
||||
.map(|ty| get_llvm_type(ctx, unifier, top_level, type_cache, *ty))
|
||||
.collect_vec();
|
||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||
}
|
||||
TList { ty } => {
|
||||
// a struct with an integer and a pointer to an array
|
||||
let element_type = get_llvm_type(ctx, unifier, top_level, type_cache, *ty);
|
||||
let fields =
|
||||
[ctx.i32_type().into(), element_type.ptr_type(AddressSpace::Generic).into()];
|
||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||
}
|
||||
TVirtual { .. } => unimplemented!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn gen_func<'ctx>(
|
||||
context: &'ctx Context,
|
||||
builder: Builder<'ctx>,
|
||||
module: Module<'ctx>,
|
||||
task: CodeGenTask,
|
||||
top_level_ctx: Arc<TopLevelContext>,
|
||||
) -> (Builder<'ctx>, Module<'ctx>) {
|
||||
// unwrap_or(0) is for unit tests without using rayon
|
||||
let (mut unifier, primitives) = {
|
||||
let unifiers = top_level_ctx.unifiers.read();
|
||||
let (unifier, primitives) = &unifiers[task.unifier_index];
|
||||
(Unifier::from_shared_unifier(unifier), *primitives)
|
||||
};
|
||||
|
||||
for (a, b) in task.subst.iter() {
|
||||
// this should be unification between variables and concrete types
|
||||
// and should not cause any problem...
|
||||
unifier.unify(*a, *b).unwrap();
|
||||
}
|
||||
|
||||
// rebuild primitive store with unique representatives
|
||||
let primitives = PrimitiveStore {
|
||||
int32: unifier.get_representative(primitives.int32),
|
||||
int64: unifier.get_representative(primitives.int64),
|
||||
float: unifier.get_representative(primitives.float),
|
||||
bool: unifier.get_representative(primitives.bool),
|
||||
none: unifier.get_representative(primitives.none),
|
||||
};
|
||||
|
||||
let mut type_cache: HashMap<_, _> = [
|
||||
(unifier.get_representative(primitives.int32), context.i32_type().into()),
|
||||
(unifier.get_representative(primitives.int64), context.i64_type().into()),
|
||||
(unifier.get_representative(primitives.float), context.f64_type().into()),
|
||||
(unifier.get_representative(primitives.bool), context.bool_type().into()),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let params = task
|
||||
.signature
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let fn_type = if unifier.unioned(task.signature.ret, primitives.none) {
|
||||
context.void_type().fn_type(¶ms, false)
|
||||
} else {
|
||||
get_llvm_type(
|
||||
&context,
|
||||
&mut unifier,
|
||||
top_level_ctx.as_ref(),
|
||||
&mut type_cache,
|
||||
task.signature.ret,
|
||||
)
|
||||
.fn_type(¶ms, false)
|
||||
};
|
||||
|
||||
let fn_val = module.add_function(&task.symbol_name, fn_type, None);
|
||||
let init_bb = context.append_basic_block(fn_val, "init");
|
||||
builder.position_at_end(init_bb);
|
||||
let body_bb = context.append_basic_block(fn_val, "body");
|
||||
|
||||
let mut var_assignment = HashMap::new();
|
||||
for (n, arg) in task.signature.args.iter().enumerate() {
|
||||
let param = fn_val.get_nth_param(n as u32).unwrap();
|
||||
let alloca = builder.build_alloca(
|
||||
get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty),
|
||||
&arg.name,
|
||||
);
|
||||
builder.build_store(alloca, param);
|
||||
var_assignment.insert(arg.name.clone(), alloca);
|
||||
}
|
||||
builder.build_unconditional_branch(body_bb);
|
||||
builder.position_at_end(body_bb);
|
||||
|
||||
let mut code_gen_context = CodeGenContext {
|
||||
ctx: &context,
|
||||
resolver: task.resolver,
|
||||
top_level: top_level_ctx.as_ref(),
|
||||
loop_bb: None,
|
||||
var_assignment,
|
||||
type_cache,
|
||||
primitives,
|
||||
init_bb,
|
||||
builder,
|
||||
module,
|
||||
unifier,
|
||||
};
|
||||
|
||||
for stmt in task.body.iter() {
|
||||
code_gen_context.gen_stmt(stmt);
|
||||
}
|
||||
|
||||
let CodeGenContext { builder, module, .. } = code_gen_context;
|
||||
|
||||
(builder, module)
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
use super::CodeGenContext;
|
||||
use crate::typecheck::typedef::Type;
|
||||
use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
|
||||
use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind};
|
||||
|
||||
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> {
|
||||
// put the alloca in init block
|
||||
let current = self.builder.get_insert_block().unwrap();
|
||||
// position before the last branching instruction...
|
||||
self.builder.position_before(&self.init_bb.get_last_instruction().unwrap());
|
||||
let ty = self.get_llvm_type(ty);
|
||||
let ptr = self.builder.build_alloca(ty, "tmp");
|
||||
self.builder.position_at_end(current);
|
||||
ptr
|
||||
}
|
||||
|
||||
fn parse_pattern(&mut self, pattern: &Expr<Option<Type>>) -> PointerValue<'ctx> {
|
||||
// very similar to gen_expr, but we don't do an extra load at the end
|
||||
// and we flatten nested tuples
|
||||
match &pattern.node {
|
||||
ExprKind::Name { id, .. } => {
|
||||
self.var_assignment.get(id).cloned().unwrap_or_else(|| {
|
||||
let ptr = self.gen_var(pattern.custom.unwrap());
|
||||
self.var_assignment.insert(id.clone(), ptr);
|
||||
ptr
|
||||
})
|
||||
}
|
||||
ExprKind::Attribute { value, attr, .. } => {
|
||||
let index = self.get_attr_index(value.custom.unwrap(), attr);
|
||||
let val = self.gen_expr(value);
|
||||
let ptr = if let BasicValueEnum::PointerValue(v) = val {
|
||||
v
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
unsafe {
|
||||
ptr.const_in_bounds_gep(&[
|
||||
self.ctx.i32_type().const_zero(),
|
||||
self.ctx.i32_type().const_int(index as u64, false),
|
||||
])
|
||||
}
|
||||
}
|
||||
ExprKind::Subscript { .. } => unimplemented!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_assignment(&mut self, target: &Expr<Option<Type>>, value: BasicValueEnum<'ctx>) {
|
||||
if let ExprKind::Tuple { elts, .. } = &target.node {
|
||||
if let BasicValueEnum::PointerValue(ptr) = value {
|
||||
for (i, elt) in elts.iter().enumerate() {
|
||||
unsafe {
|
||||
let t = ptr.const_in_bounds_gep(&[
|
||||
self.ctx.i32_type().const_zero(),
|
||||
self.ctx.i32_type().const_int(i as u64, false),
|
||||
]);
|
||||
let v = self.builder.build_load(t, "tmpload");
|
||||
self.gen_assignment(elt, v);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
} else {
|
||||
let ptr = self.parse_pattern(target);
|
||||
self.builder.build_store(ptr, value);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_stmt(&mut self, stmt: &Stmt<Option<Type>>) {
|
||||
match &stmt.node {
|
||||
StmtKind::Expr { value } => {
|
||||
self.gen_expr(&value);
|
||||
}
|
||||
StmtKind::Return { value } => {
|
||||
let value = value.as_ref().map(|v| self.gen_expr(&v));
|
||||
let value = value.as_ref().map(|v| v as &dyn BasicValue);
|
||||
self.builder.build_return(value);
|
||||
}
|
||||
StmtKind::AnnAssign { target, value, .. } => {
|
||||
if let Some(value) = value {
|
||||
let value = self.gen_expr(&value);
|
||||
self.gen_assignment(target, value);
|
||||
}
|
||||
}
|
||||
StmtKind::Assign { targets, value, .. } => {
|
||||
let value = self.gen_expr(&value);
|
||||
for target in targets.iter() {
|
||||
self.gen_assignment(target, value);
|
||||
}
|
||||
}
|
||||
StmtKind::Continue => {
|
||||
self.builder.build_unconditional_branch(self.loop_bb.unwrap().0);
|
||||
}
|
||||
StmtKind::Break => {
|
||||
self.builder.build_unconditional_branch(self.loop_bb.unwrap().1);
|
||||
}
|
||||
StmtKind::While { test, body, orelse } => {
|
||||
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let test_bb = self.ctx.append_basic_block(current, "test");
|
||||
let body_bb = self.ctx.append_basic_block(current, "body");
|
||||
let cont_bb = self.ctx.append_basic_block(current, "cont");
|
||||
// if there is no orelse, we just go to cont_bb
|
||||
let orelse_bb = if orelse.is_empty() {
|
||||
cont_bb
|
||||
} else {
|
||||
self.ctx.append_basic_block(current, "orelse")
|
||||
};
|
||||
// store loop bb information and restore it later
|
||||
let loop_bb = self.loop_bb.replace((test_bb, cont_bb));
|
||||
self.builder.build_unconditional_branch(test_bb);
|
||||
self.builder.position_at_end(test_bb);
|
||||
let test = self.gen_expr(test);
|
||||
if let BasicValueEnum::IntValue(test) = test {
|
||||
self.builder.build_conditional_branch(test, body_bb, orelse_bb);
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
self.builder.position_at_end(body_bb);
|
||||
for stmt in body.iter() {
|
||||
self.gen_stmt(stmt);
|
||||
}
|
||||
self.builder.build_unconditional_branch(test_bb);
|
||||
if !orelse.is_empty() {
|
||||
self.builder.position_at_end(orelse_bb);
|
||||
for stmt in orelse.iter() {
|
||||
self.gen_stmt(stmt);
|
||||
}
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
}
|
||||
self.builder.position_at_end(cont_bb);
|
||||
self.loop_bb = loop_bb;
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,247 @@
|
|||
use super::{CodeGenTask, WorkerRegistry};
|
||||
use crate::{
|
||||
codegen::WithCall,
|
||||
location::Location,
|
||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||
top_level::{DefinitionId, TopLevelContext},
|
||||
typecheck::{
|
||||
magic_methods::set_primitives_magic_methods,
|
||||
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
|
||||
typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
},
|
||||
};
|
||||
use indoc::indoc;
|
||||
use parking_lot::RwLock;
|
||||
use rustpython_parser::{ast::fold::Fold, parser::parse_program};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Resolver {
|
||||
id_to_type: HashMap<String, Type>,
|
||||
id_to_def: HashMap<String, DefinitionId>,
|
||||
class_names: HashMap<String, Type>,
|
||||
}
|
||||
|
||||
impl SymbolResolver for Resolver {
|
||||
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
||||
self.id_to_type.get(str).cloned()
|
||||
}
|
||||
|
||||
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_symbol_location(&self, _: &str) -> Option<Location> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
|
||||
self.id_to_def.get(id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
struct TestEnvironment {
|
||||
pub unifier: Unifier,
|
||||
pub function_data: FunctionData,
|
||||
pub primitives: PrimitiveStore,
|
||||
pub id_to_name: HashMap<usize, String>,
|
||||
pub identifier_mapping: HashMap<String, Type>,
|
||||
pub virtual_checks: Vec<(Type, Type)>,
|
||||
pub calls: HashMap<CodeLocation, CallId>,
|
||||
pub top_level: TopLevelContext,
|
||||
}
|
||||
|
||||
impl TestEnvironment {
|
||||
pub fn basic_test_env() -> TestEnvironment {
|
||||
let mut unifier = Unifier::new();
|
||||
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let float = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let none = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(4),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||
set_primitives_magic_methods(&primitives, &mut unifier);
|
||||
|
||||
let id_to_name = [
|
||||
(0, "int32".to_string()),
|
||||
(1, "int64".to_string()),
|
||||
(2, "float".to_string()),
|
||||
(3, "bool".to_string()),
|
||||
(4, "none".to_string()),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut identifier_mapping = HashMap::new();
|
||||
identifier_mapping.insert("None".into(), none);
|
||||
|
||||
let resolver = Arc::new(Resolver {
|
||||
id_to_type: identifier_mapping.clone(),
|
||||
id_to_def: Default::default(),
|
||||
class_names: Default::default(),
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
top_level: TopLevelContext {
|
||||
definitions: Default::default(),
|
||||
unifiers: Default::default(),
|
||||
// conetexts: Default::default(),
|
||||
},
|
||||
function_data: FunctionData {
|
||||
resolver,
|
||||
bound_variables: Vec::new(),
|
||||
return_type: Some(primitives.int32),
|
||||
},
|
||||
primitives,
|
||||
id_to_name,
|
||||
identifier_mapping,
|
||||
virtual_checks: Vec::new(),
|
||||
calls: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_inferencer(&mut self) -> Inferencer {
|
||||
Inferencer {
|
||||
top_level: &self.top_level,
|
||||
function_data: &mut self.function_data,
|
||||
unifier: &mut self.unifier,
|
||||
variable_mapping: Default::default(),
|
||||
primitives: &mut self.primitives,
|
||||
virtual_checks: &mut self.virtual_checks,
|
||||
calls: &mut self.calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_primitives() {
|
||||
let mut env = TestEnvironment::basic_test_env();
|
||||
let threads = ["test"];
|
||||
let signature = FunSignature {
|
||||
args: vec![
|
||||
FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None },
|
||||
FuncArg { name: "b".to_string(), ty: env.primitives.int32, default_value: None },
|
||||
],
|
||||
ret: env.primitives.int32,
|
||||
vars: HashMap::new(),
|
||||
};
|
||||
|
||||
let mut inferencer = env.get_inferencer();
|
||||
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
|
||||
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
|
||||
let source = indoc! { "
|
||||
c = a + b
|
||||
d = a if c == 1 else 0
|
||||
return d
|
||||
"};
|
||||
let statements = parse_program(source).unwrap();
|
||||
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|v| inferencer.fold_stmt(v))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap();
|
||||
let mut identifiers = vec!["a".to_string(), "b".to_string()];
|
||||
inferencer.check_block(&statements, &mut identifiers).unwrap();
|
||||
|
||||
let top_level = Arc::new(TopLevelContext {
|
||||
definitions: Default::default(),
|
||||
unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])),
|
||||
// conetexts: Default::default(),
|
||||
});
|
||||
let task = CodeGenTask {
|
||||
subst: Default::default(),
|
||||
symbol_name: "testing".to_string(),
|
||||
body: statements,
|
||||
unifier_index: 0,
|
||||
resolver: env.function_data.resolver.clone(),
|
||||
signature,
|
||||
};
|
||||
|
||||
let f = Arc::new(WithCall::new(Box::new(|module| {
|
||||
// the following IR is equivalent to
|
||||
// ```
|
||||
// ; ModuleID = 'test.ll'
|
||||
// source_filename = "test"
|
||||
//
|
||||
// ; Function Attrs: norecurse nounwind readnone
|
||||
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
|
||||
// init:
|
||||
// %add = add i32 %1, %0
|
||||
// %cmp = icmp eq i32 %add, 1
|
||||
// %ifexpr = select i1 %cmp, i32 %0, i32 0
|
||||
// ret i32 %ifexpr
|
||||
// }
|
||||
//
|
||||
// attributes #0 = { norecurse nounwind readnone }
|
||||
// ```
|
||||
// after O2 optimization
|
||||
|
||||
let expected = indoc! {"
|
||||
; ModuleID = 'test'
|
||||
source_filename = \"test\"
|
||||
|
||||
define i32 @testing(i32 %0, i32 %1) {
|
||||
init:
|
||||
%a = alloca i32, align 4
|
||||
store i32 %0, i32* %a, align 4
|
||||
%b = alloca i32, align 4
|
||||
store i32 %1, i32* %b, align 4
|
||||
%tmp = alloca i32, align 4
|
||||
%tmp4 = alloca i32, align 4
|
||||
br label %body
|
||||
|
||||
body: ; preds = %init
|
||||
%load = load i32, i32* %a, align 4
|
||||
%load1 = load i32, i32* %b, align 4
|
||||
%add = add i32 %load, %load1
|
||||
store i32 %add, i32* %tmp, align 4
|
||||
%load2 = load i32, i32* %tmp, align 4
|
||||
%cmp = icmp eq i32 %load2, 1
|
||||
br i1 %cmp, label %then, label %else
|
||||
|
||||
then: ; preds = %body
|
||||
%load3 = load i32, i32* %a, align 4
|
||||
br label %cont
|
||||
|
||||
else: ; preds = %body
|
||||
br label %cont
|
||||
|
||||
cont: ; preds = %else, %then
|
||||
%ifexpr = phi i32 [ %load3, %then ], [ 0, %else ]
|
||||
store i32 %ifexpr, i32* %tmp4, align 4
|
||||
%load5 = load i32, i32* %tmp4, align 4
|
||||
ret i32 %load5
|
||||
}
|
||||
"}
|
||||
.trim();
|
||||
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
|
||||
})));
|
||||
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
|
||||
registry.add_task(task);
|
||||
registry.wait_tasks_complete(handles);
|
||||
}
|
|
@ -1,212 +0,0 @@
|
|||
use super::TopLevelContext;
|
||||
use crate::typedef::*;
|
||||
use std::boxed::Box;
|
||||
use std::collections::HashMap;
|
||||
|
||||
struct ContextStack<'a> {
|
||||
/// stack level, starts from 0
|
||||
level: u32,
|
||||
/// stack of variable definitions containing (id, def, level) where `def` is the original
|
||||
/// definition in `level-1`.
|
||||
var_defs: Vec<(usize, VarDef<'a>, u32)>,
|
||||
/// stack of symbol definitions containing (name, level) where `level` is the smallest level
|
||||
/// where the name is assigned a value
|
||||
sym_def: Vec<(&'a str, u32)>,
|
||||
}
|
||||
|
||||
pub struct InferenceContext<'a> {
|
||||
/// top level context
|
||||
top_level: TopLevelContext<'a>,
|
||||
|
||||
/// list of primitive instances
|
||||
primitives: Vec<Type>,
|
||||
/// list of variable instances
|
||||
variables: Vec<Type>,
|
||||
/// identifier to (type, readable) mapping.
|
||||
/// an identifier might be defined earlier but has no value (for some code path), thus not
|
||||
/// readable.
|
||||
sym_table: HashMap<&'a str, (Type, bool)>,
|
||||
/// resolution function reference, that may resolve unbounded identifiers to some type
|
||||
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
|
||||
/// stack
|
||||
stack: ContextStack<'a>,
|
||||
}
|
||||
|
||||
// non-trivial implementations here
|
||||
impl<'a> InferenceContext<'a> {
|
||||
/// return a new `InferenceContext` from `TopLevelContext` and resolution function.
|
||||
pub fn new(
|
||||
top_level: TopLevelContext,
|
||||
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
|
||||
) -> InferenceContext {
|
||||
let primitives = (0..top_level.primitive_defs.len())
|
||||
.map(|v| TypeEnum::PrimitiveType(PrimitiveId(v)).into())
|
||||
.collect();
|
||||
let variables = (0..top_level.var_defs.len())
|
||||
.map(|v| TypeEnum::TypeVariable(VariableId(v)).into())
|
||||
.collect();
|
||||
InferenceContext {
|
||||
top_level,
|
||||
primitives,
|
||||
variables,
|
||||
sym_table: HashMap::new(),
|
||||
resolution_fn,
|
||||
stack: ContextStack {
|
||||
level: 0,
|
||||
var_defs: Vec::new(),
|
||||
sym_def: Vec::new(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// execute the function with new scope.
|
||||
/// variable assignment would be limited within the scope (not readable outside), and type
|
||||
/// variable type guard would be limited within the scope.
|
||||
/// returns the list of variables assigned within the scope, and the result of the function
|
||||
pub fn with_scope<F, R>(&mut self, f: F) -> (Vec<&'a str>, R)
|
||||
where
|
||||
F: FnOnce(&mut Self) -> R,
|
||||
{
|
||||
self.stack.level += 1;
|
||||
let result = f(self);
|
||||
self.stack.level -= 1;
|
||||
while !self.stack.var_defs.is_empty() {
|
||||
let (_, _, level) = self.stack.var_defs.last().unwrap();
|
||||
if *level > self.stack.level {
|
||||
let (id, def, _) = self.stack.var_defs.pop().unwrap();
|
||||
self.top_level.var_defs[id] = def;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let mut poped_names = Vec::new();
|
||||
while !self.stack.sym_def.is_empty() {
|
||||
let (_, level) = self.stack.sym_def.last().unwrap();
|
||||
if *level > self.stack.level {
|
||||
let (name, _) = self.stack.sym_def.pop().unwrap();
|
||||
self.sym_table.remove(name).unwrap();
|
||||
poped_names.push(name);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(poped_names, result)
|
||||
}
|
||||
|
||||
/// assign a type to an identifier.
|
||||
/// may return error if the identifier was defined but with different type
|
||||
pub fn assign(&mut self, name: &'a str, ty: Type) -> Result<Type, String> {
|
||||
if let Some((t, x)) = self.sym_table.get_mut(name) {
|
||||
if t == &ty {
|
||||
if !*x {
|
||||
self.stack.sym_def.push((name, self.stack.level));
|
||||
}
|
||||
*x = true;
|
||||
Ok(ty)
|
||||
} else {
|
||||
Err("different types".into())
|
||||
}
|
||||
} else {
|
||||
self.stack.sym_def.push((name, self.stack.level));
|
||||
self.sym_table.insert(name, (ty.clone(), true));
|
||||
Ok(ty)
|
||||
}
|
||||
}
|
||||
|
||||
/// check if an identifier is already defined
|
||||
pub fn defined(&self, name: &str) -> bool {
|
||||
self.sym_table.get(name).is_some()
|
||||
}
|
||||
|
||||
/// get the type of an identifier
|
||||
/// may return error if the identifier is not defined, and cannot be resolved with the
|
||||
/// resolution function.
|
||||
pub fn resolve(&mut self, name: &str) -> Result<Type, String> {
|
||||
if let Some((t, x)) = self.sym_table.get(name) {
|
||||
if *x {
|
||||
Ok(t.clone())
|
||||
} else {
|
||||
Err("may not have value".into())
|
||||
}
|
||||
} else {
|
||||
self.resolution_fn.as_mut()(name)
|
||||
}
|
||||
}
|
||||
|
||||
/// restrict the bound of a type variable by replacing its definition.
|
||||
/// used for implementing type guard
|
||||
pub fn restrict(&mut self, id: VariableId, mut def: VarDef<'a>) {
|
||||
std::mem::swap(self.top_level.var_defs.get_mut(id.0).unwrap(), &mut def);
|
||||
self.stack.var_defs.push((id.0, def, self.stack.level));
|
||||
}
|
||||
}
|
||||
|
||||
// trivial getters:
|
||||
impl<'a> InferenceContext<'a> {
|
||||
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
|
||||
self.primitives.get(id.0).unwrap().clone()
|
||||
}
|
||||
pub fn get_variable(&self, id: VariableId) -> Type {
|
||||
self.variables.get(id.0).unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
|
||||
self.top_level.fn_table.get(name)
|
||||
}
|
||||
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
|
||||
self.top_level.primitive_defs.get(id.0).unwrap()
|
||||
}
|
||||
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
|
||||
self.top_level.class_defs.get(id.0).unwrap()
|
||||
}
|
||||
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
|
||||
self.top_level.parametric_defs.get(id.0).unwrap()
|
||||
}
|
||||
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
|
||||
self.top_level.var_defs.get(id.0).unwrap()
|
||||
}
|
||||
pub fn get_type(&self, name: &str) -> Option<Type> {
|
||||
self.top_level.get_type(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeEnum {
|
||||
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> TypeEnum {
|
||||
match self {
|
||||
TypeEnum::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(),
|
||||
TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType(
|
||||
*id,
|
||||
params
|
||||
.iter()
|
||||
.map(|v| v.as_ref().subst(map).into())
|
||||
.collect(),
|
||||
),
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_subst(&self, ctx: &InferenceContext) -> HashMap<VariableId, Type> {
|
||||
match self {
|
||||
TypeEnum::ParametricType(id, params) => {
|
||||
let vars = &ctx.get_parametric_def(*id).params;
|
||||
vars.iter()
|
||||
.zip(params)
|
||||
.map(|(v, p)| (*v, p.as_ref().clone().into()))
|
||||
.collect()
|
||||
}
|
||||
// if this proves to be slow, we can use option type
|
||||
_ => HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b InferenceContext) -> Option<&'b TypeDef> {
|
||||
match self {
|
||||
TypeEnum::PrimitiveType(id) => Some(ctx.get_primitive_def(*id)),
|
||||
TypeEnum::ClassType(id) | TypeEnum::VirtualClassType(id) => {
|
||||
Some(&ctx.get_class_def(*id).base)
|
||||
}
|
||||
TypeEnum::ParametricType(id, _) => Some(&ctx.get_parametric_def(*id).base),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +0,0 @@
|
|||
mod inference_context;
|
||||
mod top_level_context;
|
||||
pub use inference_context::InferenceContext;
|
||||
pub use top_level_context::TopLevelContext;
|
|
@ -1,136 +0,0 @@
|
|||
use crate::typedef::*;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
/// Structure for storing top-level type definitions.
|
||||
/// Used for collecting type signature from source code.
|
||||
/// Can be converted to `InferenceContext` for type inference in functions.
|
||||
pub struct TopLevelContext<'a> {
|
||||
/// List of primitive definitions.
|
||||
pub(super) primitive_defs: Vec<TypeDef<'a>>,
|
||||
/// List of class definitions.
|
||||
pub(super) class_defs: Vec<ClassDef<'a>>,
|
||||
/// List of parametric type definitions.
|
||||
pub(super) parametric_defs: Vec<ParametricDef<'a>>,
|
||||
/// List of type variable definitions.
|
||||
pub(super) var_defs: Vec<VarDef<'a>>,
|
||||
/// Function name to signature mapping.
|
||||
pub(super) fn_table: HashMap<&'a str, FnDef>,
|
||||
/// Type name to type mapping.
|
||||
pub(super) sym_table: HashMap<&'a str, Type>,
|
||||
|
||||
primitives: Vec<Type>,
|
||||
variables: Vec<Type>,
|
||||
}
|
||||
|
||||
impl<'a> TopLevelContext<'a> {
|
||||
pub fn new(primitive_defs: Vec<TypeDef<'a>>) -> TopLevelContext {
|
||||
let mut sym_table = HashMap::new();
|
||||
let mut primitives = Vec::new();
|
||||
for (i, t) in primitive_defs.iter().enumerate() {
|
||||
primitives.push(TypeEnum::PrimitiveType(PrimitiveId(i)).into());
|
||||
sym_table.insert(t.name, TypeEnum::PrimitiveType(PrimitiveId(i)).into());
|
||||
}
|
||||
TopLevelContext {
|
||||
primitive_defs,
|
||||
class_defs: Vec::new(),
|
||||
parametric_defs: Vec::new(),
|
||||
var_defs: Vec::new(),
|
||||
fn_table: HashMap::new(),
|
||||
sym_table,
|
||||
primitives,
|
||||
variables: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId {
|
||||
self.sym_table.insert(
|
||||
def.base.name,
|
||||
TypeEnum::ClassType(ClassId(self.class_defs.len())).into(),
|
||||
);
|
||||
self.class_defs.push(def);
|
||||
ClassId(self.class_defs.len() - 1)
|
||||
}
|
||||
|
||||
pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId {
|
||||
let params = def
|
||||
.params
|
||||
.iter()
|
||||
.map(|&v| Rc::new(TypeEnum::TypeVariable(v)))
|
||||
.collect();
|
||||
self.sym_table.insert(
|
||||
def.base.name,
|
||||
TypeEnum::ParametricType(ParamId(self.parametric_defs.len()), params).into(),
|
||||
);
|
||||
self.parametric_defs.push(def);
|
||||
ParamId(self.parametric_defs.len() - 1)
|
||||
}
|
||||
|
||||
pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId {
|
||||
self.sym_table.insert(
|
||||
def.name,
|
||||
TypeEnum::TypeVariable(VariableId(self.var_defs.len())).into(),
|
||||
);
|
||||
self.add_variable_private(def)
|
||||
}
|
||||
|
||||
pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId {
|
||||
self.var_defs.push(def);
|
||||
self.variables
|
||||
.push(TypeEnum::TypeVariable(VariableId(self.var_defs.len() - 1)).into());
|
||||
VariableId(self.var_defs.len() - 1)
|
||||
}
|
||||
|
||||
pub fn add_fn(&mut self, name: &'a str, def: FnDef) {
|
||||
self.fn_table.insert(name, def);
|
||||
}
|
||||
|
||||
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
|
||||
self.fn_table.get(name)
|
||||
}
|
||||
|
||||
pub fn get_primitive_def_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
|
||||
self.primitive_defs.get_mut(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
|
||||
self.primitive_defs.get(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_class_def_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> {
|
||||
self.class_defs.get_mut(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
|
||||
self.class_defs.get(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_parametric_def_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> {
|
||||
self.parametric_defs.get_mut(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
|
||||
self.parametric_defs.get(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_variable_def_mut(&mut self, id: VariableId) -> &mut VarDef<'a> {
|
||||
self.var_defs.get_mut(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
|
||||
self.var_defs.get(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
|
||||
self.primitives.get(id.0).unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn get_variable(&self, id: VariableId) -> Type {
|
||||
self.variables.get(id.0).unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn get_type(&self, name: &str) -> Option<Type> {
|
||||
// TODO: handle parametric types
|
||||
self.sym_table.get(name).cloned()
|
||||
}
|
||||
}
|
|
@ -1,922 +0,0 @@
|
|||
use crate::context::InferenceContext;
|
||||
use crate::inference_core::resolve_call;
|
||||
use crate::magic_methods::*;
|
||||
use crate::primitives::*;
|
||||
use crate::typedef::{Type, TypeEnum::*};
|
||||
use rustpython_parser::ast::{
|
||||
Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator,
|
||||
UnaryOperator,
|
||||
};
|
||||
use std::convert::TryInto;
|
||||
|
||||
type ParserResult = Result<Option<Type>, String>;
|
||||
|
||||
pub fn infer_expr<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
expr: &'b Expression,
|
||||
) -> ParserResult {
|
||||
match &expr.node {
|
||||
ExpressionType::Number { value } => infer_constant(ctx, value),
|
||||
ExpressionType::Identifier { name } => infer_identifier(ctx, name),
|
||||
ExpressionType::List { elements } => infer_list(ctx, elements),
|
||||
ExpressionType::Tuple { elements } => infer_tuple(ctx, elements),
|
||||
ExpressionType::Attribute { value, name } => infer_attribute(ctx, value, name),
|
||||
ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, values),
|
||||
ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, op, a, b),
|
||||
ExpressionType::Unop { op, a } => infer_unary_ops(ctx, op, a),
|
||||
ExpressionType::Compare { vals, ops } => infer_compare(ctx, vals, ops),
|
||||
ExpressionType::Call {
|
||||
args,
|
||||
function,
|
||||
keywords,
|
||||
} => {
|
||||
if !keywords.is_empty() {
|
||||
Err("keyword is not supported".into())
|
||||
} else {
|
||||
infer_call(ctx, &args, &function)
|
||||
}
|
||||
}
|
||||
ExpressionType::Subscript { a, b } => infer_subscript(ctx, a, b),
|
||||
ExpressionType::IfExpression { test, body, orelse } => {
|
||||
infer_if_expr(ctx, &test, &body, orelse)
|
||||
}
|
||||
ExpressionType::Comprehension { kind, generators } => match kind.as_ref() {
|
||||
ComprehensionKind::List { element } => {
|
||||
if generators.len() == 1 {
|
||||
infer_list_comprehension(ctx, element, &generators[0])
|
||||
} else {
|
||||
Err("only 1 generator statement is supported".into())
|
||||
}
|
||||
}
|
||||
_ => Err("only list comprehension is supported".into()),
|
||||
},
|
||||
ExpressionType::True | ExpressionType::False => Ok(Some(ctx.get_primitive(BOOL_TYPE))),
|
||||
_ => Err("not supported".into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_constant(
|
||||
ctx: &mut InferenceContext,
|
||||
value: &rustpython_parser::ast::Number,
|
||||
) -> ParserResult {
|
||||
use rustpython_parser::ast::Number;
|
||||
match value {
|
||||
Number::Integer { value } => {
|
||||
let int32: Result<i32, _> = value.try_into();
|
||||
if int32.is_ok() {
|
||||
Ok(Some(ctx.get_primitive(INT32_TYPE)))
|
||||
} else {
|
||||
Err("integer out of range".into())
|
||||
}
|
||||
}
|
||||
Number::Float { .. } => Ok(Some(ctx.get_primitive(FLOAT_TYPE))),
|
||||
_ => Err("not supported".into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_identifier(ctx: &mut InferenceContext, name: &str) -> ParserResult {
|
||||
Ok(Some(ctx.resolve(name)?))
|
||||
}
|
||||
|
||||
fn infer_list<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
elements: &'b [Expression],
|
||||
) -> ParserResult {
|
||||
if elements.is_empty() {
|
||||
return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into()));
|
||||
}
|
||||
|
||||
let mut types = elements.iter().map(|v| infer_expr(ctx, v));
|
||||
|
||||
let head = types.next().unwrap()?;
|
||||
if head.is_none() {
|
||||
return Err("list elements must have some type".into());
|
||||
}
|
||||
for v in types {
|
||||
// TODO: try virtual type...
|
||||
if v? != head {
|
||||
return Err("inhomogeneous list is not allowed".into());
|
||||
}
|
||||
}
|
||||
Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into()))
|
||||
}
|
||||
|
||||
fn infer_tuple<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
elements: &'b [Expression],
|
||||
) -> ParserResult {
|
||||
let types: Result<Option<Vec<_>>, String> =
|
||||
elements.iter().map(|v| infer_expr(ctx, v)).collect();
|
||||
if let Some(t) = types? {
|
||||
Ok(Some(ParametricType(TUPLE_TYPE, t).into()))
|
||||
} else {
|
||||
Err("tuple elements must have some type".into())
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_attribute<'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
value: &'a Expression,
|
||||
name: &str,
|
||||
) -> ParserResult {
|
||||
let value = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?;
|
||||
if let TypeVariable(_) = value.as_ref() {
|
||||
return Err("no fields for type variable".into());
|
||||
}
|
||||
|
||||
value
|
||||
.get_base(ctx)
|
||||
.and_then(|b| b.fields.get(name).cloned())
|
||||
.map_or_else(|| Err("no such field".to_string()), |v| Ok(Some(v)))
|
||||
}
|
||||
|
||||
fn infer_bool_ops<'a>(ctx: &mut InferenceContext<'a>, values: &'a [Expression]) -> ParserResult {
|
||||
assert_eq!(values.len(), 2);
|
||||
let left = infer_expr(ctx, &values[0])?.ok_or_else(|| "no value".to_string())?;
|
||||
let right = infer_expr(ctx, &values[1])?.ok_or_else(|| "no value".to_string())?;
|
||||
|
||||
let b = ctx.get_primitive(BOOL_TYPE);
|
||||
if left == b && right == b {
|
||||
Ok(Some(b))
|
||||
} else {
|
||||
Err("bool operands must be bool".into())
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_bin_ops<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
op: &Operator,
|
||||
left: &'b Expression,
|
||||
right: &'b Expression,
|
||||
) -> ParserResult {
|
||||
let left = infer_expr(ctx, left)?.ok_or_else(|| "no value".to_string())?;
|
||||
let right = infer_expr(ctx, right)?.ok_or_else(|| "no value".to_string())?;
|
||||
let fun = binop_name(op);
|
||||
resolve_call(ctx, Some(left), fun, &[right])
|
||||
}
|
||||
|
||||
fn infer_unary_ops<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
op: &UnaryOperator,
|
||||
obj: &'b Expression,
|
||||
) -> ParserResult {
|
||||
let ty = infer_expr(ctx, obj)?.ok_or_else(|| "no value".to_string())?;
|
||||
if let UnaryOperator::Not = op {
|
||||
if ty == ctx.get_primitive(BOOL_TYPE) {
|
||||
Ok(Some(ty))
|
||||
} else {
|
||||
Err("logical not must be applied to bool".into())
|
||||
}
|
||||
} else {
|
||||
resolve_call(ctx, Some(ty), unaryop_name(op), &[])
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_compare<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
vals: &'b [Expression],
|
||||
ops: &'b [Comparison],
|
||||
) -> ParserResult {
|
||||
let types: Result<Option<Vec<_>>, _> = vals.iter().map(|v| infer_expr(ctx, v)).collect();
|
||||
let types = types?;
|
||||
if types.is_none() {
|
||||
return Err("comparison operands must have type".into());
|
||||
}
|
||||
let types = types.unwrap();
|
||||
let boolean = ctx.get_primitive(BOOL_TYPE);
|
||||
let left = &types[..types.len() - 1];
|
||||
let right = &types[1..];
|
||||
|
||||
for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) {
|
||||
let fun = comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?;
|
||||
let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?;
|
||||
if ty.is_none() || ty.unwrap() != boolean {
|
||||
return Err("comparison result must be boolean".into());
|
||||
}
|
||||
}
|
||||
Ok(Some(boolean))
|
||||
}
|
||||
|
||||
fn infer_call<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
args: &'b [Expression],
|
||||
function: &'b Expression,
|
||||
) -> ParserResult {
|
||||
// TODO: special handling for int64 constant
|
||||
let types: Result<Option<Vec<_>>, _> = args.iter().map(|v| infer_expr(ctx, v)).collect();
|
||||
let types = types?;
|
||||
if types.is_none() {
|
||||
return Err("function params must have type".into());
|
||||
}
|
||||
|
||||
let (obj, fun) = match &function.node {
|
||||
ExpressionType::Identifier { name } => (None, name),
|
||||
ExpressionType::Attribute { value, name } => (
|
||||
Some(infer_expr(ctx, &value)?.ok_or_else(|| "no value".to_string())?),
|
||||
name,
|
||||
),
|
||||
_ => return Err("not supported".into()),
|
||||
};
|
||||
resolve_call(ctx, obj, fun.as_str(), &types.unwrap())
|
||||
}
|
||||
|
||||
fn infer_subscript<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
a: &'b Expression,
|
||||
b: &'b Expression,
|
||||
) -> ParserResult {
|
||||
let a = infer_expr(ctx, a)?.ok_or_else(|| "no value".to_string())?;
|
||||
let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() {
|
||||
ls[0].clone()
|
||||
} else {
|
||||
return Err("subscript is not supported for types other than list".into());
|
||||
};
|
||||
|
||||
match &b.node {
|
||||
ExpressionType::Slice { elements } => {
|
||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
||||
let types: Result<Option<Vec<_>>, _> = elements
|
||||
.iter()
|
||||
.map(|v| {
|
||||
if let ExpressionType::None = v.node {
|
||||
Ok(Some(int32.clone()))
|
||||
} else {
|
||||
infer_expr(ctx, v)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let types = types?.ok_or_else(|| "slice must have type".to_string())?;
|
||||
if types.iter().all(|v| v == &int32) {
|
||||
Ok(Some(a))
|
||||
} else {
|
||||
Err("slice must be int32 type".into())
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let b = infer_expr(ctx, b)?.ok_or_else(|| "no value".to_string())?;
|
||||
if b == ctx.get_primitive(INT32_TYPE) {
|
||||
Ok(Some(t))
|
||||
} else {
|
||||
Err("index must be either slice or int32".into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_if_expr<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
test: &'b Expression,
|
||||
body: &'b Expression,
|
||||
orelse: &'b Expression,
|
||||
) -> ParserResult {
|
||||
let test = infer_expr(ctx, test)?.ok_or_else(|| "no value".to_string())?;
|
||||
if test != ctx.get_primitive(BOOL_TYPE) {
|
||||
return Err("test should be bool".into());
|
||||
}
|
||||
|
||||
let body = infer_expr(ctx, body)?;
|
||||
let orelse = infer_expr(ctx, orelse)?;
|
||||
if body.as_ref() == orelse.as_ref() {
|
||||
Ok(body)
|
||||
} else {
|
||||
Err("divergent type".into())
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_simple_binding<'a: 'b, 'b>(
|
||||
ctx: &mut InferenceContext<'b>,
|
||||
name: &'a Expression,
|
||||
ty: Type,
|
||||
) -> Result<(), String> {
|
||||
match &name.node {
|
||||
ExpressionType::Identifier { name } => {
|
||||
if name == "_" {
|
||||
Ok(())
|
||||
} else if ctx.defined(name.as_str()) {
|
||||
Err("duplicated naming".into())
|
||||
} else {
|
||||
ctx.assign(name.as_str(), ty)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
ExpressionType::Tuple { elements } => {
|
||||
if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() {
|
||||
if elements.len() == ls.len() {
|
||||
for (a, b) in elements.iter().zip(ls.iter()) {
|
||||
infer_simple_binding(ctx, a, b.clone())?;
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
Err("different length".into())
|
||||
}
|
||||
} else {
|
||||
Err("not supported".into())
|
||||
}
|
||||
}
|
||||
_ => Err("not supported".into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_list_comprehension<'b: 'a, 'a>(
|
||||
ctx: &mut InferenceContext<'a>,
|
||||
element: &'b Expression,
|
||||
comprehension: &'b Comprehension,
|
||||
) -> ParserResult {
|
||||
if comprehension.is_async {
|
||||
return Err("async is not supported".into());
|
||||
}
|
||||
|
||||
let iter = infer_expr(ctx, &comprehension.iter)?.ok_or_else(|| "no value".to_string())?;
|
||||
if let ParametricType(LIST_TYPE, ls) = iter.as_ref() {
|
||||
ctx.with_scope(|ctx| {
|
||||
infer_simple_binding(ctx, &comprehension.target, ls[0].clone())?;
|
||||
|
||||
let boolean = ctx.get_primitive(BOOL_TYPE);
|
||||
for test in comprehension.ifs.iter() {
|
||||
let result =
|
||||
infer_expr(ctx, test)?.ok_or_else(|| "no value in test".to_string())?;
|
||||
if result != boolean {
|
||||
return Err("test must be bool".into());
|
||||
}
|
||||
}
|
||||
let result = infer_expr(ctx, element)?.ok_or_else(|| "no value")?;
|
||||
Ok(Some(ParametricType(LIST_TYPE, vec![result]).into()))
|
||||
})
|
||||
.1
|
||||
} else {
|
||||
Err("iteration is supported for list only".into())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::context::*;
|
||||
use crate::typedef::*;
|
||||
use rustpython_parser::parser::parse_expression;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
|
||||
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constants() {
|
||||
let ctx = basic_ctx();
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
|
||||
let ast = parse_expression("123").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("2147483647").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("2147483648").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("integer out of range".into()));
|
||||
//
|
||||
// let ast = parse_expression("2147483648").unwrap();
|
||||
// let result = infer_expr(&mut ctx, &ast);
|
||||
// assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
|
||||
|
||||
// let ast = parse_expression("9223372036854775807").unwrap();
|
||||
// let result = infer_expr(&mut ctx, &ast);
|
||||
// assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
|
||||
|
||||
// let ast = parse_expression("9223372036854775808").unwrap();
|
||||
// let result = infer_expr(&mut ctx, &ast);
|
||||
// assert_eq!(result, Err("integer out of range".into()));
|
||||
|
||||
let ast = parse_expression("123.456").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(FLOAT_TYPE));
|
||||
|
||||
let ast = parse_expression("True").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
|
||||
|
||||
let ast = parse_expression("False").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identifier() {
|
||||
let ctx = basic_ctx();
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
|
||||
|
||||
let ast = parse_expression("abc").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("ab").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("unbounded identifier".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list() {
|
||||
let mut ctx = basic_ctx();
|
||||
ctx.add_fn(
|
||||
"foo",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
|
||||
// def is reserved...
|
||||
ctx.assign("efg", ctx.get_primitive(INT32_TYPE)).unwrap();
|
||||
ctx.assign("xyz", ctx.get_primitive(FLOAT_TYPE)).unwrap();
|
||||
|
||||
let ast = parse_expression("[]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(
|
||||
result.unwrap().unwrap(),
|
||||
ParametricType(LIST_TYPE, vec![BotType.into()]).into()
|
||||
);
|
||||
|
||||
let ast = parse_expression("[abc]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(
|
||||
result.unwrap().unwrap(),
|
||||
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
|
||||
);
|
||||
|
||||
let ast = parse_expression("[abc, efg]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(
|
||||
result.unwrap().unwrap(),
|
||||
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
|
||||
);
|
||||
|
||||
let ast = parse_expression("[abc, efg, xyz]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("inhomogeneous list is not allowed".into()));
|
||||
|
||||
let ast = parse_expression("[foo()]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("list elements must have some type".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tuple() {
|
||||
let mut ctx = basic_ctx();
|
||||
ctx.add_fn(
|
||||
"foo",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
|
||||
ctx.assign("efg", ctx.get_primitive(FLOAT_TYPE)).unwrap();
|
||||
|
||||
let ast = parse_expression("(abc, efg)").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(
|
||||
result.unwrap().unwrap(),
|
||||
ParametricType(
|
||||
TUPLE_TYPE,
|
||||
vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)]
|
||||
)
|
||||
.into()
|
||||
);
|
||||
|
||||
let ast = parse_expression("(abc, efg, foo())").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("tuple elements must have some type".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attribute() {
|
||||
let mut ctx = basic_ctx();
|
||||
ctx.add_fn(
|
||||
"none",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
||||
let float = ctx.get_primitive(FLOAT_TYPE);
|
||||
|
||||
let foo = ctx.add_class(ClassDef {
|
||||
base: TypeDef {
|
||||
name: "Foo",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
parents: vec![],
|
||||
});
|
||||
let foo_def = ctx.get_class_def_mut(foo);
|
||||
foo_def.base.fields.insert("a", int32.clone());
|
||||
foo_def.base.fields.insert("b", ClassType(foo).into());
|
||||
foo_def.base.fields.insert("c", int32.clone());
|
||||
|
||||
let bar = ctx.add_class(ClassDef {
|
||||
base: TypeDef {
|
||||
name: "Bar",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
parents: vec![],
|
||||
});
|
||||
let bar_def = ctx.get_class_def_mut(bar);
|
||||
bar_def.base.fields.insert("a", int32);
|
||||
bar_def.base.fields.insert("b", ClassType(bar).into());
|
||||
bar_def.base.fields.insert("c", float);
|
||||
|
||||
let v0 = ctx.add_variable(VarDef {
|
||||
name: "v0",
|
||||
bound: vec![],
|
||||
});
|
||||
|
||||
let v1 = ctx.add_variable(VarDef {
|
||||
name: "v1",
|
||||
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
|
||||
});
|
||||
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
|
||||
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
|
||||
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
|
||||
.unwrap();
|
||||
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
|
||||
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
|
||||
ctx.assign("bot", Rc::new(BotType)).unwrap();
|
||||
|
||||
let ast = parse_expression("foo.a").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("foo.d").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no such field".into()));
|
||||
|
||||
let ast = parse_expression("foobar.a").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("v0.a").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no fields for type variable".into()));
|
||||
|
||||
let ast = parse_expression("v1.a").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no fields for type variable".into()));
|
||||
|
||||
let ast = parse_expression("none().a").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no value".into()));
|
||||
|
||||
let ast = parse_expression("bot.a").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no such field".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bool_ops() {
|
||||
let mut ctx = basic_ctx();
|
||||
ctx.add_fn(
|
||||
"none",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
|
||||
let ast = parse_expression("True and False").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
|
||||
|
||||
let ast = parse_expression("True and none()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no value".into()));
|
||||
|
||||
let ast = parse_expression("True and 123").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("bool operands must be bool".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bin_ops() {
|
||||
let mut ctx = basic_ctx();
|
||||
let v0 = ctx.add_variable(VarDef {
|
||||
name: "v0",
|
||||
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
|
||||
});
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
ctx.assign("a", TypeVariable(v0).into()).unwrap();
|
||||
|
||||
let ast = parse_expression("1 + 2 + 3").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("a + a + a").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("not supported".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_ops() {
|
||||
let mut ctx = basic_ctx();
|
||||
let v0 = ctx.add_variable(VarDef {
|
||||
name: "v0",
|
||||
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
|
||||
});
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
ctx.assign("a", TypeVariable(v0).into()).unwrap();
|
||||
|
||||
let ast = parse_expression("-(123)").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("-a").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("not supported".into()));
|
||||
|
||||
let ast = parse_expression("not True").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
|
||||
|
||||
let ast = parse_expression("not (1)").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("logical not must be applied to bool".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compare() {
|
||||
let mut ctx = basic_ctx();
|
||||
let v0 = ctx.add_variable(VarDef {
|
||||
name: "v0",
|
||||
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
|
||||
});
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
ctx.assign("a", TypeVariable(v0).into()).unwrap();
|
||||
|
||||
let ast = parse_expression("a == a == a").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("not supported".into()));
|
||||
|
||||
let ast = parse_expression("a == a == 1").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("not supported".into()));
|
||||
|
||||
let ast = parse_expression("True > False").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no such function".into()));
|
||||
|
||||
let ast = parse_expression("True in False").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("unsupported comparison".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_call() {
|
||||
let mut ctx = basic_ctx();
|
||||
ctx.add_fn(
|
||||
"none",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
|
||||
let foo = ctx.add_class(ClassDef {
|
||||
base: TypeDef {
|
||||
name: "Foo",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
parents: vec![],
|
||||
});
|
||||
let foo_def = ctx.get_class_def_mut(foo);
|
||||
foo_def.base.methods.insert(
|
||||
"a",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: Some(Rc::new(ClassType(foo))),
|
||||
},
|
||||
);
|
||||
|
||||
let bar = ctx.add_class(ClassDef {
|
||||
base: TypeDef {
|
||||
name: "Bar",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
parents: vec![],
|
||||
});
|
||||
let bar_def = ctx.get_class_def_mut(bar);
|
||||
bar_def.base.methods.insert(
|
||||
"a",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: Some(Rc::new(ClassType(bar))),
|
||||
},
|
||||
);
|
||||
|
||||
let v0 = ctx.add_variable(VarDef {
|
||||
name: "v0",
|
||||
bound: vec![],
|
||||
});
|
||||
let v1 = ctx.add_variable(VarDef {
|
||||
name: "v1",
|
||||
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
|
||||
});
|
||||
let v2 = ctx.add_variable(VarDef {
|
||||
name: "v2",
|
||||
bound: vec![
|
||||
ClassType(foo).into(),
|
||||
ClassType(bar).into(),
|
||||
ctx.get_primitive(INT32_TYPE),
|
||||
],
|
||||
});
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
|
||||
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
|
||||
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
|
||||
.unwrap();
|
||||
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
|
||||
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
|
||||
ctx.assign("v2", ctx.get_variable(v2)).unwrap();
|
||||
ctx.assign("bot", Rc::new(BotType)).unwrap();
|
||||
|
||||
let ast = parse_expression("foo.a()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
|
||||
|
||||
let ast = parse_expression("v1.a()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("not supported".into()));
|
||||
|
||||
let ast = parse_expression("foobar.a()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
|
||||
|
||||
let ast = parse_expression("none().a()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no value".into()));
|
||||
|
||||
let ast = parse_expression("bot.a()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("not supported".into()));
|
||||
|
||||
let ast = parse_expression("[][0].a()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("not supported".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_subscript() {
|
||||
let mut ctx = basic_ctx();
|
||||
ctx.add_fn(
|
||||
"none",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
|
||||
let ast = parse_expression("[1, 2, 3][0]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("[[1]][0][0]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("[1, 2, 3][1:2]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(
|
||||
result.unwrap().unwrap(),
|
||||
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
|
||||
);
|
||||
|
||||
let ast = parse_expression("[1, 2, 3][1:2:2]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(
|
||||
result.unwrap().unwrap(),
|
||||
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
|
||||
);
|
||||
|
||||
let ast = parse_expression("[1, 2, 3][1:1.2]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("slice must be int32 type".into()));
|
||||
|
||||
let ast = parse_expression("[1, 2, 3][1:none()]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("slice must have type".into()));
|
||||
|
||||
let ast = parse_expression("[1, 2, 3][1.2]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("index must be either slice or int32".into()));
|
||||
|
||||
let ast = parse_expression("[1, 2, 3][none()]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no value".into()));
|
||||
|
||||
let ast = parse_expression("none()[1.2]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no value".into()));
|
||||
|
||||
let ast = parse_expression("123[1]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(
|
||||
result,
|
||||
Err("subscript is not supported for types other than list".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_if_expr() {
|
||||
let mut ctx = basic_ctx();
|
||||
ctx.add_fn(
|
||||
"none",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
|
||||
let ast = parse_expression("1 if True else 0").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
||||
|
||||
let ast = parse_expression("none() if True else none()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap(), None);
|
||||
|
||||
let ast = parse_expression("none() if 1 else none()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("test should be bool".into()));
|
||||
|
||||
let ast = parse_expression("1 if True else none()").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("divergent type".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_comp() {
|
||||
let mut ctx = basic_ctx();
|
||||
ctx.add_fn(
|
||||
"none",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
||||
let mut ctx = get_inference_context(ctx);
|
||||
ctx.assign("z", int32.clone()).unwrap();
|
||||
|
||||
let ast = parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(
|
||||
result.unwrap().unwrap(),
|
||||
ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into()
|
||||
);
|
||||
|
||||
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), int32);
|
||||
|
||||
let ast =
|
||||
parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result.unwrap().unwrap(), int32);
|
||||
|
||||
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("test must be bool".into()));
|
||||
|
||||
let ast = parse_expression("[y for x in []][0]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("unbounded identifier".into()));
|
||||
|
||||
let ast = parse_expression("[none() for x in []][0]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("no value".into()));
|
||||
|
||||
let ast = parse_expression("[z for z in []][0]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(result, Err("duplicated naming".into()));
|
||||
|
||||
let ast = parse_expression("[x for x in [] for y in []]").unwrap();
|
||||
let result = infer_expr(&mut ctx, &ast);
|
||||
assert_eq!(
|
||||
result,
|
||||
Err("only 1 generator statement is supported".into())
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,525 +0,0 @@
|
|||
use crate::context::InferenceContext;
|
||||
use crate::typedef::{TypeEnum::*, *};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn find_subst(
|
||||
ctx: &InferenceContext,
|
||||
valuation: &Option<(VariableId, Type)>,
|
||||
sub: &mut HashMap<VariableId, Type>,
|
||||
mut a: Type,
|
||||
mut b: Type,
|
||||
) -> Result<(), String> {
|
||||
// TODO: fix error messages later
|
||||
if let TypeVariable(id) = a.as_ref() {
|
||||
if let Some((assumption_id, t)) = valuation {
|
||||
if assumption_id == id {
|
||||
a = t.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut substituted = false;
|
||||
if let TypeVariable(id) = b.as_ref() {
|
||||
if let Some(c) = sub.get(&id) {
|
||||
b = c.clone();
|
||||
substituted = true;
|
||||
}
|
||||
}
|
||||
|
||||
match (a.as_ref(), b.as_ref()) {
|
||||
(BotType, _) => Ok(()),
|
||||
(TypeVariable(id_a), TypeVariable(id_b)) => {
|
||||
if substituted {
|
||||
return if id_a == id_b {
|
||||
Ok(())
|
||||
} else {
|
||||
Err("different variables".to_string())
|
||||
};
|
||||
}
|
||||
let v_a = ctx.get_variable_def(*id_a);
|
||||
let v_b = ctx.get_variable_def(*id_b);
|
||||
if !v_b.bound.is_empty() {
|
||||
if v_a.bound.is_empty() {
|
||||
return Err("unbounded a".to_string());
|
||||
} else {
|
||||
let diff: Vec<_> = v_a
|
||||
.bound
|
||||
.iter()
|
||||
.filter(|x| !v_b.bound.contains(x))
|
||||
.collect();
|
||||
if !diff.is_empty() {
|
||||
return Err("different domain".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
sub.insert(*id_b, a.clone());
|
||||
Ok(())
|
||||
}
|
||||
(TypeVariable(id_a), _) => {
|
||||
let v_a = ctx.get_variable_def(*id_a);
|
||||
if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err("different domain".to_string())
|
||||
}
|
||||
}
|
||||
(_, TypeVariable(id_b)) => {
|
||||
let v_b = ctx.get_variable_def(*id_b);
|
||||
if v_b.bound.is_empty() || v_b.bound.contains(&a) {
|
||||
sub.insert(*id_b, a.clone());
|
||||
Ok(())
|
||||
} else {
|
||||
Err("different domain".to_string())
|
||||
}
|
||||
}
|
||||
(_, VirtualClassType(id_b)) => {
|
||||
let mut parents;
|
||||
match a.as_ref() {
|
||||
ClassType(id_a) => {
|
||||
parents = [*id_a].to_vec();
|
||||
}
|
||||
VirtualClassType(id_a) => {
|
||||
parents = [*id_a].to_vec();
|
||||
}
|
||||
_ => {
|
||||
return Err("cannot substitute non-class type into virtual class".to_string());
|
||||
}
|
||||
};
|
||||
while !parents.is_empty() {
|
||||
if *id_b == parents[0] {
|
||||
return Ok(());
|
||||
}
|
||||
let c = ctx.get_class_def(parents.remove(0));
|
||||
parents.extend_from_slice(&c.parents);
|
||||
}
|
||||
Err("not subtype".to_string())
|
||||
}
|
||||
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
|
||||
if id_a != id_b || param_a.len() != param_b.len() {
|
||||
Err("different parametric types".to_string())
|
||||
} else {
|
||||
for (x, y) in param_a.iter().zip(param_b.iter()) {
|
||||
find_subst(ctx, valuation, sub, x.clone(), y.clone())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
(_, _) => {
|
||||
if a == b {
|
||||
Ok(())
|
||||
} else {
|
||||
Err("not equal".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_call_rec(
|
||||
ctx: &InferenceContext,
|
||||
valuation: &Option<(VariableId, Type)>,
|
||||
obj: Option<Type>,
|
||||
func: &str,
|
||||
args: &[Type],
|
||||
) -> Result<Option<Type>, String> {
|
||||
let mut subst = obj
|
||||
.as_ref()
|
||||
.map(|v| v.get_subst(ctx))
|
||||
.unwrap_or_else(HashMap::new);
|
||||
|
||||
let fun = match &obj {
|
||||
Some(obj) => {
|
||||
let base = match obj.as_ref() {
|
||||
PrimitiveType(id) => &ctx.get_primitive_def(*id),
|
||||
ClassType(id) | VirtualClassType(id) => &ctx.get_class_def(*id).base,
|
||||
ParametricType(id, _) => &ctx.get_parametric_def(*id).base,
|
||||
_ => return Err("not supported".to_string()),
|
||||
};
|
||||
base.methods.get(func)
|
||||
}
|
||||
None => ctx.get_fn_def(func),
|
||||
}
|
||||
.ok_or_else(|| "no such function".to_string())?;
|
||||
|
||||
if args.len() != fun.args.len() {
|
||||
return Err("incorrect parameter number".to_string());
|
||||
}
|
||||
for (a, b) in args.iter().zip(fun.args.iter()) {
|
||||
find_subst(ctx, valuation, &mut subst, a.clone(), b.clone())?;
|
||||
}
|
||||
let result = fun.result.as_ref().map(|v| v.subst(&subst));
|
||||
Ok(result.map(|result| {
|
||||
if let SelfType = result {
|
||||
obj.unwrap()
|
||||
} else {
|
||||
result.into()
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn resolve_call(
|
||||
ctx: &InferenceContext,
|
||||
obj: Option<Type>,
|
||||
func: &str,
|
||||
args: &[Type],
|
||||
) -> Result<Option<Type>, String> {
|
||||
resolve_call_rec(ctx, &None, obj, func, args)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::context::TopLevelContext;
|
||||
use crate::primitives::*;
|
||||
use std::rc::Rc;
|
||||
|
||||
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
|
||||
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_generic() {
|
||||
let mut ctx = basic_ctx();
|
||||
let v1 = ctx.add_variable(VarDef {
|
||||
name: "V1",
|
||||
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)],
|
||||
});
|
||||
let v1 = ctx.get_variable(v1);
|
||||
let v2 = ctx.add_variable(VarDef {
|
||||
name: "V2",
|
||||
bound: vec![
|
||||
ctx.get_primitive(BOOL_TYPE),
|
||||
ctx.get_primitive(INT32_TYPE),
|
||||
ctx.get_primitive(FLOAT_TYPE),
|
||||
],
|
||||
});
|
||||
let v2 = ctx.get_variable(v2);
|
||||
let ctx = get_inference_context(ctx);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(FLOAT_TYPE)]),
|
||||
Ok(Some(ctx.get_primitive(INT32_TYPE)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(INT32_TYPE)],),
|
||||
Ok(Some(ctx.get_primitive(INT32_TYPE)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "float", &[ctx.get_primitive(INT32_TYPE)]),
|
||||
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "float", &[ctx.get_primitive(BOOL_TYPE)]),
|
||||
Err("different domain".to_string())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "float", &[]),
|
||||
Err("incorrect parameter number".to_string())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "float", &[v1]),
|
||||
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "float", &[v2]),
|
||||
Err("different domain".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_methods() {
|
||||
let mut ctx = basic_ctx();
|
||||
|
||||
let v0 = ctx.add_variable(VarDef {
|
||||
name: "V0",
|
||||
bound: vec![],
|
||||
});
|
||||
let v0 = ctx.get_variable(v0);
|
||||
|
||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
||||
let int64 = ctx.get_primitive(INT64_TYPE);
|
||||
let ctx = get_inference_context(ctx);
|
||||
|
||||
// simple cases
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
|
||||
Ok(Some(int32.clone()))
|
||||
);
|
||||
|
||||
assert_ne!(
|
||||
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
|
||||
Ok(Some(int64.clone()))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, Some(int32), "__add__", &[int64]),
|
||||
Err("not equal".to_string())
|
||||
);
|
||||
|
||||
// with type variables
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, Some(v0.clone()), "__add__", &[v0.clone()]),
|
||||
Err("not supported".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_generic() {
|
||||
let mut ctx = basic_ctx();
|
||||
let v0 = ctx.add_variable(VarDef {
|
||||
name: "V0",
|
||||
bound: vec![],
|
||||
});
|
||||
let v0 = ctx.get_variable(v0);
|
||||
let v1 = ctx.add_variable(VarDef {
|
||||
name: "V1",
|
||||
bound: vec![],
|
||||
});
|
||||
let v1 = ctx.get_variable(v1);
|
||||
let v2 = ctx.add_variable(VarDef {
|
||||
name: "V2",
|
||||
bound: vec![],
|
||||
});
|
||||
let v2 = ctx.get_variable(v2);
|
||||
let v3 = ctx.add_variable(VarDef {
|
||||
name: "V3",
|
||||
bound: vec![],
|
||||
});
|
||||
let v3 = ctx.get_variable(v3);
|
||||
|
||||
ctx.add_fn(
|
||||
"foo",
|
||||
FnDef {
|
||||
args: vec![v0.clone(), v0.clone(), v1.clone()],
|
||||
result: Some(v0.clone()),
|
||||
},
|
||||
);
|
||||
|
||||
ctx.add_fn(
|
||||
"foo1",
|
||||
FnDef {
|
||||
args: vec![ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1]).into()],
|
||||
result: Some(v0),
|
||||
},
|
||||
);
|
||||
let ctx = get_inference_context(ctx);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]),
|
||||
Ok(Some(v2.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]),
|
||||
Ok(Some(v2.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]),
|
||||
Err("different variables".to_string())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(
|
||||
&ctx,
|
||||
None,
|
||||
"foo1",
|
||||
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()]
|
||||
),
|
||||
Ok(Some(v2.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_call(
|
||||
&ctx,
|
||||
None,
|
||||
"foo1",
|
||||
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()]
|
||||
),
|
||||
Ok(Some(v2.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_call(
|
||||
&ctx,
|
||||
None,
|
||||
"foo1",
|
||||
&[ParametricType(TUPLE_TYPE, vec![v2, v3.clone(), v3]).into()]
|
||||
),
|
||||
Err("different variables".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_class_generics() {
|
||||
let mut ctx = basic_ctx();
|
||||
|
||||
let list = ctx.get_parametric_def_mut(LIST_TYPE);
|
||||
let t = Rc::new(TypeVariable(list.params[0]));
|
||||
list.base.methods.insert(
|
||||
"head",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result: Some(t.clone()),
|
||||
},
|
||||
);
|
||||
list.base.methods.insert(
|
||||
"append",
|
||||
FnDef {
|
||||
args: vec![t],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
|
||||
let v0 = ctx.add_variable(VarDef {
|
||||
name: "V0",
|
||||
bound: vec![],
|
||||
});
|
||||
let v0 = ctx.get_variable(v0);
|
||||
let v1 = ctx.add_variable(VarDef {
|
||||
name: "V1",
|
||||
bound: vec![],
|
||||
});
|
||||
let v1 = ctx.get_variable(v1);
|
||||
let ctx = get_inference_context(ctx);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(
|
||||
&ctx,
|
||||
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
|
||||
"head",
|
||||
&[]
|
||||
),
|
||||
Ok(Some(v0.clone()))
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_call(
|
||||
&ctx,
|
||||
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
|
||||
"append",
|
||||
&[v0.clone()]
|
||||
),
|
||||
Ok(None)
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_call(
|
||||
&ctx,
|
||||
Some(ParametricType(LIST_TYPE, vec![v0]).into()),
|
||||
"append",
|
||||
&[v1]
|
||||
),
|
||||
Err("different variables".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_virtual_class() {
|
||||
let mut ctx = basic_ctx();
|
||||
|
||||
let foo = ctx.add_class(ClassDef {
|
||||
base: TypeDef {
|
||||
name: "Foo",
|
||||
methods: HashMap::new(),
|
||||
fields: HashMap::new(),
|
||||
},
|
||||
parents: vec![],
|
||||
});
|
||||
|
||||
let foo1 = ctx.add_class(ClassDef {
|
||||
base: TypeDef {
|
||||
name: "Foo1",
|
||||
methods: HashMap::new(),
|
||||
fields: HashMap::new(),
|
||||
},
|
||||
parents: vec![foo],
|
||||
});
|
||||
|
||||
let foo2 = ctx.add_class(ClassDef {
|
||||
base: TypeDef {
|
||||
name: "Foo2",
|
||||
methods: HashMap::new(),
|
||||
fields: HashMap::new(),
|
||||
},
|
||||
parents: vec![foo1],
|
||||
});
|
||||
|
||||
let bar = ctx.add_class(ClassDef {
|
||||
base: TypeDef {
|
||||
name: "bar",
|
||||
methods: HashMap::new(),
|
||||
fields: HashMap::new(),
|
||||
},
|
||||
parents: vec![],
|
||||
});
|
||||
|
||||
ctx.add_fn(
|
||||
"foo",
|
||||
FnDef {
|
||||
args: vec![VirtualClassType(foo).into()],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
ctx.add_fn(
|
||||
"foo1",
|
||||
FnDef {
|
||||
args: vec![VirtualClassType(foo1).into()],
|
||||
result: None,
|
||||
},
|
||||
);
|
||||
let ctx = get_inference_context(ctx);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]),
|
||||
Ok(None)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]),
|
||||
Ok(None)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]),
|
||||
Ok(None)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]),
|
||||
Err("not subtype".to_string())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]),
|
||||
Ok(None)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]),
|
||||
Ok(None)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]),
|
||||
Err("not subtype".to_string())
|
||||
);
|
||||
|
||||
// virtual class substitution
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]),
|
||||
Ok(None)
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]),
|
||||
Ok(None)
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]),
|
||||
Ok(None)
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]),
|
||||
Err("not subtype".to_string())
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,593 +1,8 @@
|
|||
#![warn(clippy::all)]
|
||||
#![allow(clippy::clone_double_ref)]
|
||||
#![allow(dead_code)]
|
||||
|
||||
extern crate num_bigint;
|
||||
extern crate inkwell;
|
||||
extern crate rustpython_parser;
|
||||
|
||||
pub mod expression_inference;
|
||||
pub mod inference_core;
|
||||
mod magic_methods;
|
||||
pub mod primitives;
|
||||
pub mod typedef;
|
||||
pub mod context;
|
||||
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::path::Path;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use num_traits::cast::ToPrimitive;
|
||||
|
||||
use rustpython_parser::ast;
|
||||
|
||||
use inkwell::OptimizationLevel;
|
||||
use inkwell::builder::Builder;
|
||||
use inkwell::context::Context;
|
||||
use inkwell::module::Module;
|
||||
use inkwell::targets::*;
|
||||
use inkwell::types;
|
||||
use inkwell::types::BasicType;
|
||||
use inkwell::values;
|
||||
use inkwell::{IntPredicate, FloatPredicate};
|
||||
use inkwell::basic_block;
|
||||
use inkwell::passes;
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CompileErrorKind {
|
||||
Unsupported(&'static str),
|
||||
MissingTypeAnnotation,
|
||||
UnknownTypeAnnotation,
|
||||
IncompatibleTypes,
|
||||
UnboundIdentifier,
|
||||
BreakOutsideLoop,
|
||||
Internal(&'static str)
|
||||
}
|
||||
|
||||
impl fmt::Display for CompileErrorKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
CompileErrorKind::Unsupported(feature)
|
||||
=> write!(f, "The following Python feature is not supported by NAC3: {}", feature),
|
||||
CompileErrorKind::MissingTypeAnnotation
|
||||
=> write!(f, "Missing type annotation"),
|
||||
CompileErrorKind::UnknownTypeAnnotation
|
||||
=> write!(f, "Unknown type annotation"),
|
||||
CompileErrorKind::IncompatibleTypes
|
||||
=> write!(f, "Incompatible types"),
|
||||
CompileErrorKind::UnboundIdentifier
|
||||
=> write!(f, "Unbound identifier"),
|
||||
CompileErrorKind::BreakOutsideLoop
|
||||
=> write!(f, "Break outside loop"),
|
||||
CompileErrorKind::Internal(details)
|
||||
=> write!(f, "Internal compiler error: {}", details),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CompileError {
|
||||
location: ast::Location,
|
||||
kind: CompileErrorKind,
|
||||
}
|
||||
|
||||
impl fmt::Display for CompileError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}, at {}", self.kind, self.location)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for CompileError {}
|
||||
|
||||
type CompileResult<T> = Result<T, CompileError>;
|
||||
|
||||
pub struct CodeGen<'ctx> {
|
||||
context: &'ctx Context,
|
||||
module: Module<'ctx>,
|
||||
pass_manager: passes::PassManager<values::FunctionValue<'ctx>>,
|
||||
builder: Builder<'ctx>,
|
||||
current_source_location: ast::Location,
|
||||
namespace: HashMap<String, values::PointerValue<'ctx>>,
|
||||
break_bb: Option<basic_block::BasicBlock<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> CodeGen<'ctx> {
|
||||
pub fn new(context: &'ctx Context) -> CodeGen<'ctx> {
|
||||
let module = context.create_module("kernel");
|
||||
|
||||
let pass_manager = passes::PassManager::create(&module);
|
||||
pass_manager.add_instruction_combining_pass();
|
||||
pass_manager.add_reassociate_pass();
|
||||
pass_manager.add_gvn_pass();
|
||||
pass_manager.add_cfg_simplification_pass();
|
||||
pass_manager.add_basic_alias_analysis_pass();
|
||||
pass_manager.add_promote_memory_to_register_pass();
|
||||
pass_manager.add_instruction_combining_pass();
|
||||
pass_manager.add_reassociate_pass();
|
||||
pass_manager.initialize();
|
||||
|
||||
let i32_type = context.i32_type();
|
||||
let fn_type = i32_type.fn_type(&[i32_type.into()], false);
|
||||
module.add_function("output", fn_type, None);
|
||||
|
||||
CodeGen {
|
||||
context, module, pass_manager,
|
||||
builder: context.create_builder(),
|
||||
current_source_location: ast::Location::default(),
|
||||
namespace: HashMap::new(),
|
||||
break_bb: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn set_source_location(&mut self, location: ast::Location) {
|
||||
self.current_source_location = location;
|
||||
}
|
||||
|
||||
fn compile_error(&self, kind: CompileErrorKind) -> CompileError {
|
||||
CompileError {
|
||||
location: self.current_source_location,
|
||||
kind
|
||||
}
|
||||
}
|
||||
|
||||
fn get_basic_type(&self, name: &str) -> CompileResult<types::BasicTypeEnum<'ctx>> {
|
||||
match name {
|
||||
"bool" => Ok(self.context.bool_type().into()),
|
||||
"int32" => Ok(self.context.i32_type().into()),
|
||||
"int64" => Ok(self.context.i64_type().into()),
|
||||
"float32" => Ok(self.context.f32_type().into()),
|
||||
"float64" => Ok(self.context.f64_type().into()),
|
||||
_ => Err(self.compile_error(CompileErrorKind::UnknownTypeAnnotation))
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_function_def(
|
||||
&mut self,
|
||||
name: &str,
|
||||
args: &ast::Parameters,
|
||||
body: &ast::Suite,
|
||||
decorator_list: &[ast::Expression],
|
||||
returns: &Option<ast::Expression>,
|
||||
is_async: bool,
|
||||
) -> CompileResult<values::FunctionValue<'ctx>> {
|
||||
if is_async {
|
||||
return Err(self.compile_error(CompileErrorKind::Unsupported("async functions")))
|
||||
}
|
||||
for decorator in decorator_list.iter() {
|
||||
self.set_source_location(decorator.location);
|
||||
if let ast::ExpressionType::Identifier { name } = &decorator.node {
|
||||
if name != "kernel" && name != "portable" {
|
||||
return Err(self.compile_error(CompileErrorKind::Unsupported("custom decorators")))
|
||||
}
|
||||
} else {
|
||||
return Err(self.compile_error(CompileErrorKind::Unsupported("decorator must be an identifier")))
|
||||
}
|
||||
}
|
||||
|
||||
let args_type = args.args.iter().map(|val| {
|
||||
self.set_source_location(val.location);
|
||||
if let Some(annotation) = &val.annotation {
|
||||
if let ast::ExpressionType::Identifier { name } = &annotation.node {
|
||||
Ok(self.get_basic_type(&name)?)
|
||||
} else {
|
||||
Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
|
||||
}
|
||||
} else {
|
||||
Err(self.compile_error(CompileErrorKind::MissingTypeAnnotation))
|
||||
}
|
||||
}).collect::<CompileResult<Vec<types::BasicTypeEnum>>>()?;
|
||||
let return_type = if let Some(returns) = returns {
|
||||
self.set_source_location(returns.location);
|
||||
if let ast::ExpressionType::Identifier { name } = &returns.node {
|
||||
if name == "None" { None } else { Some(self.get_basic_type(name)?) }
|
||||
} else {
|
||||
return Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let fn_type = match return_type {
|
||||
Some(ty) => ty.fn_type(&args_type, false),
|
||||
None => self.context.void_type().fn_type(&args_type, false)
|
||||
};
|
||||
|
||||
let function = self.module.add_function(name, fn_type, None);
|
||||
let basic_block = self.context.append_basic_block(function, "entry");
|
||||
self.builder.position_at_end(basic_block);
|
||||
|
||||
for (n, arg) in args.args.iter().enumerate() {
|
||||
let param = function.get_nth_param(n as u32).unwrap();
|
||||
let alloca = self.builder.build_alloca(param.get_type(), &arg.arg);
|
||||
self.builder.build_store(alloca, param);
|
||||
self.namespace.insert(arg.arg.clone(), alloca);
|
||||
}
|
||||
|
||||
self.compile_suite(body, return_type)?;
|
||||
|
||||
Ok(function)
|
||||
}
|
||||
|
||||
fn compile_expression(
|
||||
&mut self,
|
||||
expression: &ast::Expression
|
||||
) -> CompileResult<values::BasicValueEnum<'ctx>> {
|
||||
self.set_source_location(expression.location);
|
||||
|
||||
match &expression.node {
|
||||
ast::ExpressionType::True => Ok(self.context.bool_type().const_int(1, false).into()),
|
||||
ast::ExpressionType::False => Ok(self.context.bool_type().const_int(0, false).into()),
|
||||
ast::ExpressionType::Number { value: ast::Number::Integer { value } } => {
|
||||
let mut bits = value.bits();
|
||||
if value.sign() == num_bigint::Sign::Minus {
|
||||
bits += 1;
|
||||
}
|
||||
match bits {
|
||||
0..=32 => Ok(self.context.i32_type().const_int(value.to_i32().unwrap() as _, true).into()),
|
||||
33..=64 => Ok(self.context.i64_type().const_int(value.to_i64().unwrap() as _, true).into()),
|
||||
_ => Err(self.compile_error(CompileErrorKind::Unsupported("integers larger than 64 bits")))
|
||||
}
|
||||
},
|
||||
ast::ExpressionType::Number { value: ast::Number::Float { value } } => {
|
||||
Ok(self.context.f64_type().const_float(*value).into())
|
||||
},
|
||||
ast::ExpressionType::Identifier { name } => {
|
||||
match self.namespace.get(name) {
|
||||
Some(value) => Ok(self.builder.build_load(*value, name).into()),
|
||||
None => Err(self.compile_error(CompileErrorKind::UnboundIdentifier))
|
||||
}
|
||||
},
|
||||
ast::ExpressionType::Unop { op, a } => {
|
||||
let a = self.compile_expression(&a)?;
|
||||
match (op, a) {
|
||||
(ast::UnaryOperator::Pos, values::BasicValueEnum::IntValue(a))
|
||||
=> Ok(a.into()),
|
||||
(ast::UnaryOperator::Pos, values::BasicValueEnum::FloatValue(a))
|
||||
=> Ok(a.into()),
|
||||
(ast::UnaryOperator::Neg, values::BasicValueEnum::IntValue(a))
|
||||
=> Ok(self.builder.build_int_neg(a, "tmpneg").into()),
|
||||
(ast::UnaryOperator::Neg, values::BasicValueEnum::FloatValue(a))
|
||||
=> Ok(self.builder.build_float_neg(a, "tmpneg").into()),
|
||||
(ast::UnaryOperator::Inv, values::BasicValueEnum::IntValue(a))
|
||||
=> Ok(self.builder.build_not(a, "tmpnot").into()),
|
||||
(ast::UnaryOperator::Not, values::BasicValueEnum::IntValue(a)) => {
|
||||
// boolean "not"
|
||||
if a.get_type().get_bit_width() != 1 {
|
||||
Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation")))
|
||||
} else {
|
||||
Ok(self.builder.build_not(a, "tmpnot").into())
|
||||
}
|
||||
},
|
||||
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation"))),
|
||||
}
|
||||
},
|
||||
ast::ExpressionType::Binop { a, op, b } => {
|
||||
let a = self.compile_expression(&a)?;
|
||||
let b = self.compile_expression(&b)?;
|
||||
if a.get_type() != b.get_type() {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
use ast::Operator::*;
|
||||
match (op, a, b) {
|
||||
(Add, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
||||
=> Ok(self.builder.build_int_add(a, b, "tmpadd").into()),
|
||||
(Sub, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
||||
=> Ok(self.builder.build_int_sub(a, b, "tmpsub").into()),
|
||||
(Mult, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
||||
=> Ok(self.builder.build_int_mul(a, b, "tmpmul").into()),
|
||||
|
||||
(Add, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
||||
=> Ok(self.builder.build_float_add(a, b, "tmpadd").into()),
|
||||
(Sub, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
||||
=> Ok(self.builder.build_float_sub(a, b, "tmpsub").into()),
|
||||
(Mult, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
||||
=> Ok(self.builder.build_float_mul(a, b, "tmpmul").into()),
|
||||
|
||||
(Div, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
||||
=> Ok(self.builder.build_float_div(a, b, "tmpdiv").into()),
|
||||
(FloorDiv, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
||||
=> Ok(self.builder.build_int_signed_div(a, b, "tmpdiv").into()),
|
||||
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented binary operation"))),
|
||||
}
|
||||
},
|
||||
ast::ExpressionType::Compare { vals, ops } => {
|
||||
let mut vals = vals.iter();
|
||||
let mut ops = ops.iter();
|
||||
|
||||
let mut result = None;
|
||||
let mut a = self.compile_expression(vals.next().unwrap())?;
|
||||
loop {
|
||||
if let Some(op) = ops.next() {
|
||||
let b = self.compile_expression(vals.next().unwrap())?;
|
||||
if a.get_type() != b.get_type() {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
let this_result = match (a, b) {
|
||||
(values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) => {
|
||||
match op {
|
||||
ast::Comparison::Equal
|
||||
=> self.builder.build_int_compare(IntPredicate::EQ, a, b, "tmpeq"),
|
||||
ast::Comparison::NotEqual
|
||||
=> self.builder.build_int_compare(IntPredicate::NE, a, b, "tmpne"),
|
||||
ast::Comparison::Less
|
||||
=> self.builder.build_int_compare(IntPredicate::SLT, a, b, "tmpslt"),
|
||||
ast::Comparison::LessOrEqual
|
||||
=> self.builder.build_int_compare(IntPredicate::SLE, a, b, "tmpsle"),
|
||||
ast::Comparison::Greater
|
||||
=> self.builder.build_int_compare(IntPredicate::SGT, a, b, "tmpsgt"),
|
||||
ast::Comparison::GreaterOrEqual
|
||||
=> self.builder.build_int_compare(IntPredicate::SGE, a, b, "tmpsge"),
|
||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
|
||||
}
|
||||
},
|
||||
(values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) => {
|
||||
match op {
|
||||
ast::Comparison::Equal
|
||||
=> self.builder.build_float_compare(FloatPredicate::OEQ, a, b, "tmpoeq"),
|
||||
ast::Comparison::NotEqual
|
||||
=> self.builder.build_float_compare(FloatPredicate::UNE, a, b, "tmpune"),
|
||||
ast::Comparison::Less
|
||||
=> self.builder.build_float_compare(FloatPredicate::OLT, a, b, "tmpolt"),
|
||||
ast::Comparison::LessOrEqual
|
||||
=> self.builder.build_float_compare(FloatPredicate::OLE, a, b, "tmpole"),
|
||||
ast::Comparison::Greater
|
||||
=> self.builder.build_float_compare(FloatPredicate::OGT, a, b, "tmpogt"),
|
||||
ast::Comparison::GreaterOrEqual
|
||||
=> self.builder.build_float_compare(FloatPredicate::OGE, a, b, "tmpoge"),
|
||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
|
||||
}
|
||||
},
|
||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("comparison of non-numerical types"))),
|
||||
};
|
||||
match result {
|
||||
Some(last) => {
|
||||
result = Some(self.builder.build_and(last, this_result, "tmpand"));
|
||||
}
|
||||
None => {
|
||||
result = Some(this_result);
|
||||
}
|
||||
}
|
||||
a = b;
|
||||
} else {
|
||||
return Ok(result.unwrap().into())
|
||||
}
|
||||
}
|
||||
},
|
||||
ast::ExpressionType::Call { function, args, keywords } => {
|
||||
if !keywords.is_empty() {
|
||||
return Err(self.compile_error(CompileErrorKind::Unsupported("keyword arguments")))
|
||||
}
|
||||
let args = args.iter().map(|val| self.compile_expression(val))
|
||||
.collect::<CompileResult<Vec<values::BasicValueEnum>>>()?;
|
||||
self.set_source_location(expression.location);
|
||||
if let ast::ExpressionType::Identifier { name } = &function.node {
|
||||
match (name.as_str(), args[0]) {
|
||||
("int32", values::BasicValueEnum::IntValue(a)) => {
|
||||
let nbits = a.get_type().get_bit_width();
|
||||
if nbits < 32 {
|
||||
Ok(self.builder.build_int_s_extend(a, self.context.i32_type(), "tmpsext").into())
|
||||
} else if nbits > 32 {
|
||||
Ok(self.builder.build_int_truncate(a, self.context.i32_type(), "tmptrunc").into())
|
||||
} else {
|
||||
Ok(a.into())
|
||||
}
|
||||
},
|
||||
("int64", values::BasicValueEnum::IntValue(a)) => {
|
||||
let nbits = a.get_type().get_bit_width();
|
||||
if nbits < 64 {
|
||||
Ok(self.builder.build_int_s_extend(a, self.context.i64_type(), "tmpsext").into())
|
||||
} else {
|
||||
Ok(a.into())
|
||||
}
|
||||
},
|
||||
("int32", values::BasicValueEnum::FloatValue(a)) => {
|
||||
Ok(self.builder.build_float_to_signed_int(a, self.context.i32_type(), "tmpfptosi").into())
|
||||
},
|
||||
("int64", values::BasicValueEnum::FloatValue(a)) => {
|
||||
Ok(self.builder.build_float_to_signed_int(a, self.context.i64_type(), "tmpfptosi").into())
|
||||
},
|
||||
("float32", values::BasicValueEnum::IntValue(a)) => {
|
||||
Ok(self.builder.build_signed_int_to_float(a, self.context.f32_type(), "tmpsitofp").into())
|
||||
},
|
||||
("float64", values::BasicValueEnum::IntValue(a)) => {
|
||||
Ok(self.builder.build_signed_int_to_float(a, self.context.f64_type(), "tmpsitofp").into())
|
||||
},
|
||||
("float32", values::BasicValueEnum::FloatValue(a)) => {
|
||||
if a.get_type() == self.context.f64_type() {
|
||||
Ok(self.builder.build_float_trunc(a, self.context.f32_type(), "tmptrunc").into())
|
||||
} else {
|
||||
Ok(a.into())
|
||||
}
|
||||
},
|
||||
("float64", values::BasicValueEnum::FloatValue(a)) => {
|
||||
if a.get_type() == self.context.f32_type() {
|
||||
Ok(self.builder.build_float_ext(a, self.context.f64_type(), "tmpext").into())
|
||||
} else {
|
||||
Ok(a.into())
|
||||
}
|
||||
},
|
||||
|
||||
("output", values::BasicValueEnum::IntValue(a)) => {
|
||||
let fn_value = self.module.get_function("output").unwrap();
|
||||
Ok(self.builder.build_call(fn_value, &[a.into()], "call")
|
||||
.try_as_basic_value().left().unwrap())
|
||||
},
|
||||
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unrecognized call")))
|
||||
}
|
||||
} else {
|
||||
return Err(self.compile_error(CompileErrorKind::Unsupported("function must be an identifier")))
|
||||
}
|
||||
},
|
||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented expression"))),
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_statement(
|
||||
&mut self,
|
||||
statement: &ast::Statement,
|
||||
return_type: Option<types::BasicTypeEnum>
|
||||
) -> CompileResult<()> {
|
||||
self.set_source_location(statement.location);
|
||||
|
||||
use ast::StatementType::*;
|
||||
match &statement.node {
|
||||
Assign { targets, value } => {
|
||||
let value = self.compile_expression(value)?;
|
||||
for target in targets.iter() {
|
||||
self.set_source_location(target.location);
|
||||
if let ast::ExpressionType::Identifier { name } = &target.node {
|
||||
let builder = &self.builder;
|
||||
let target = self.namespace.entry(name.clone()).or_insert_with(
|
||||
|| builder.build_alloca(value.get_type(), name));
|
||||
if target.get_type() != value.get_type().ptr_type(inkwell::AddressSpace::Generic) {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
builder.build_store(*target, value);
|
||||
} else {
|
||||
return Err(self.compile_error(CompileErrorKind::Unsupported("assignment target must be an identifier")))
|
||||
}
|
||||
}
|
||||
},
|
||||
Expression { expression } => { self.compile_expression(expression)?; },
|
||||
If { test, body, orelse } => {
|
||||
let test = self.compile_expression(test)?;
|
||||
if test.get_type() != self.context.bool_type().into() {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
|
||||
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let then_bb = self.context.append_basic_block(parent, "then");
|
||||
let else_bb = self.context.append_basic_block(parent, "else");
|
||||
let cont_bb = self.context.append_basic_block(parent, "ifcont");
|
||||
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
|
||||
|
||||
self.builder.position_at_end(then_bb);
|
||||
self.compile_suite(body, return_type)?;
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
|
||||
self.builder.position_at_end(else_bb);
|
||||
if let Some(orelse) = orelse {
|
||||
self.compile_suite(orelse, return_type)?;
|
||||
}
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
self.builder.position_at_end(cont_bb);
|
||||
},
|
||||
While { test, body, orelse } => {
|
||||
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let test_bb = self.context.append_basic_block(parent, "test");
|
||||
self.builder.build_unconditional_branch(test_bb);
|
||||
self.builder.position_at_end(test_bb);
|
||||
let test = self.compile_expression(test)?;
|
||||
if test.get_type() != self.context.bool_type().into() {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
|
||||
let then_bb = self.context.append_basic_block(parent, "then");
|
||||
let else_bb = self.context.append_basic_block(parent, "else");
|
||||
let cont_bb = self.context.append_basic_block(parent, "ifcont");
|
||||
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
|
||||
|
||||
self.break_bb = Some(cont_bb);
|
||||
|
||||
self.builder.position_at_end(then_bb);
|
||||
self.compile_suite(body, return_type)?;
|
||||
self.builder.build_unconditional_branch(test_bb);
|
||||
|
||||
self.builder.position_at_end(else_bb);
|
||||
if let Some(orelse) = orelse {
|
||||
self.compile_suite(orelse, return_type)?;
|
||||
}
|
||||
self.builder.build_unconditional_branch(cont_bb);
|
||||
self.builder.position_at_end(cont_bb);
|
||||
|
||||
self.break_bb = None;
|
||||
},
|
||||
Break => {
|
||||
if let Some(bb) = self.break_bb {
|
||||
self.builder.build_unconditional_branch(bb);
|
||||
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let unreachable_bb = self.context.append_basic_block(parent, "unreachable");
|
||||
self.builder.position_at_end(unreachable_bb);
|
||||
} else {
|
||||
return Err(self.compile_error(CompileErrorKind::BreakOutsideLoop));
|
||||
}
|
||||
}
|
||||
Return { value: Some(value) } => {
|
||||
if let Some(return_type) = return_type {
|
||||
let value = self.compile_expression(value)?;
|
||||
if value.get_type() != return_type {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
self.builder.build_return(Some(&value));
|
||||
} else {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
},
|
||||
Return { value: None } => {
|
||||
if !return_type.is_none() {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
self.builder.build_return(None);
|
||||
},
|
||||
Pass => (),
|
||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special statement"))),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compile_suite(
|
||||
&mut self,
|
||||
suite: &ast::Suite,
|
||||
return_type: Option<types::BasicTypeEnum>
|
||||
) -> CompileResult<()> {
|
||||
for statement in suite.iter() {
|
||||
self.compile_statement(statement, return_type)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn compile_toplevel(&mut self, statement: &ast::Statement) -> CompileResult<()> {
|
||||
self.set_source_location(statement.location);
|
||||
if let ast::StatementType::FunctionDef {
|
||||
is_async,
|
||||
name,
|
||||
args,
|
||||
body,
|
||||
decorator_list,
|
||||
returns,
|
||||
} = &statement.node {
|
||||
let function = self.compile_function_def(name, args, body, decorator_list, returns, *is_async)?;
|
||||
self.pass_manager.run_on(&function);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(self.compile_error(CompileErrorKind::Internal("top-level is not a function definition")))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_ir(&self) {
|
||||
self.module.print_to_stderr();
|
||||
}
|
||||
|
||||
pub fn output(&self, filename: &str) {
|
||||
//let triple = TargetTriple::create("riscv32-none-linux-gnu");
|
||||
let triple = TargetMachine::get_default_triple();
|
||||
let target = Target::from_triple(&triple)
|
||||
.expect("couldn't create target from target triple");
|
||||
|
||||
let target_machine = target
|
||||
.create_target_machine(
|
||||
&triple,
|
||||
"",
|
||||
"",
|
||||
OptimizationLevel::Default,
|
||||
RelocMode::Default,
|
||||
CodeModel::Default,
|
||||
)
|
||||
.expect("couldn't create target machine");
|
||||
|
||||
target_machine
|
||||
.write_to_file(&self.module, FileType::Object, Path::new(filename))
|
||||
.expect("couldn't write module to file");
|
||||
}
|
||||
}
|
||||
mod codegen;
|
||||
mod location;
|
||||
mod symbol_resolver;
|
||||
mod top_level;
|
||||
mod typecheck;
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
use rustpython_parser::ast;
|
||||
use std::vec::Vec;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
pub struct FileID(u32);
|
||||
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
pub enum Location {
|
||||
CodeRange(FileID, ast::Location),
|
||||
Builtin,
|
||||
}
|
||||
|
||||
pub struct FileRegistry {
|
||||
files: Vec<String>,
|
||||
}
|
||||
|
||||
impl FileRegistry {
|
||||
pub fn new() -> FileRegistry {
|
||||
FileRegistry { files: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn add_file(&mut self, path: &str) -> FileID {
|
||||
let index = self.files.len() as u32;
|
||||
self.files.push(path.to_owned());
|
||||
FileID(index)
|
||||
}
|
||||
|
||||
pub fn query_file(&self, id: FileID) -> &str {
|
||||
&self.files[id.0 as usize]
|
||||
}
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
use rustpython_parser::ast::{Comparison, Operator, UnaryOperator};
|
||||
|
||||
pub fn binop_name(op: &Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::Add => "__add__",
|
||||
Operator::Sub => "__sub__",
|
||||
Operator::Div => "__truediv__",
|
||||
Operator::Mod => "__mod__",
|
||||
Operator::Mult => "__mul__",
|
||||
Operator::Pow => "__pow__",
|
||||
Operator::BitOr => "__or__",
|
||||
Operator::BitXor => "__xor__",
|
||||
Operator::BitAnd => "__and__",
|
||||
Operator::LShift => "__lshift__",
|
||||
Operator::RShift => "__rshift__",
|
||||
Operator::FloorDiv => "__floordiv__",
|
||||
Operator::MatMult => "__matmul__",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn binop_assign_name(op: &Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::Add => "__iadd__",
|
||||
Operator::Sub => "__isub__",
|
||||
Operator::Div => "__itruediv__",
|
||||
Operator::Mod => "__imod__",
|
||||
Operator::Mult => "__imul__",
|
||||
Operator::Pow => "__ipow__",
|
||||
Operator::BitOr => "__ior__",
|
||||
Operator::BitXor => "__ixor__",
|
||||
Operator::BitAnd => "__iand__",
|
||||
Operator::LShift => "__ilshift__",
|
||||
Operator::RShift => "__irshift__",
|
||||
Operator::FloorDiv => "__ifloordiv__",
|
||||
Operator::MatMult => "__imatmul__",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unaryop_name(op: &UnaryOperator) -> &'static str {
|
||||
match op {
|
||||
UnaryOperator::Pos => "__pos__",
|
||||
UnaryOperator::Neg => "__neg__",
|
||||
UnaryOperator::Not => "__not__",
|
||||
UnaryOperator::Inv => "__inv__",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn comparison_name(op: &Comparison) -> Option<&'static str> {
|
||||
match op {
|
||||
Comparison::Less => Some("__lt__"),
|
||||
Comparison::LessOrEqual => Some("__le__"),
|
||||
Comparison::Greater => Some("__gt__"),
|
||||
Comparison::GreaterOrEqual => Some("__ge__"),
|
||||
Comparison::Equal => Some("__eq__"),
|
||||
Comparison::NotEqual => Some("__ne__"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
|
@ -1,184 +0,0 @@
|
|||
use super::typedef::{TypeEnum::*, *};
|
||||
use crate::context::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const TUPLE_TYPE: ParamId = ParamId(0);
|
||||
pub const LIST_TYPE: ParamId = ParamId(1);
|
||||
|
||||
pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0);
|
||||
pub const INT32_TYPE: PrimitiveId = PrimitiveId(1);
|
||||
pub const INT64_TYPE: PrimitiveId = PrimitiveId(2);
|
||||
pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(3);
|
||||
|
||||
fn impl_math(def: &mut TypeDef, ty: &Type) {
|
||||
let result = Some(ty.clone());
|
||||
let fun = FnDef {
|
||||
args: vec![ty.clone()],
|
||||
result: result.clone(),
|
||||
};
|
||||
def.methods.insert("__add__", fun.clone());
|
||||
def.methods.insert("__sub__", fun.clone());
|
||||
def.methods.insert("__mul__", fun.clone());
|
||||
def.methods.insert(
|
||||
"__neg__",
|
||||
FnDef {
|
||||
args: vec![],
|
||||
result,
|
||||
},
|
||||
);
|
||||
def.methods.insert(
|
||||
"__truediv__",
|
||||
FnDef {
|
||||
args: vec![ty.clone()],
|
||||
result: Some(PrimitiveType(FLOAT_TYPE).into()),
|
||||
},
|
||||
);
|
||||
def.methods.insert("__floordiv__", fun.clone());
|
||||
def.methods.insert("__mod__", fun.clone());
|
||||
def.methods.insert("__pow__", fun);
|
||||
}
|
||||
|
||||
fn impl_bits(def: &mut TypeDef, ty: &Type) {
|
||||
let result = Some(ty.clone());
|
||||
let fun = FnDef {
|
||||
args: vec![PrimitiveType(INT32_TYPE).into()],
|
||||
result,
|
||||
};
|
||||
|
||||
def.methods.insert("__lshift__", fun.clone());
|
||||
def.methods.insert("__rshift__", fun);
|
||||
def.methods.insert(
|
||||
"__xor__",
|
||||
FnDef {
|
||||
args: vec![ty.clone()],
|
||||
result: Some(ty.clone()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn impl_eq(def: &mut TypeDef, ty: &Type) {
|
||||
let fun = FnDef {
|
||||
args: vec![ty.clone()],
|
||||
result: Some(PrimitiveType(BOOL_TYPE).into()),
|
||||
};
|
||||
|
||||
def.methods.insert("__eq__", fun.clone());
|
||||
def.methods.insert("__ne__", fun);
|
||||
}
|
||||
|
||||
fn impl_order(def: &mut TypeDef, ty: &Type) {
|
||||
let fun = FnDef {
|
||||
args: vec![ty.clone()],
|
||||
result: Some(PrimitiveType(BOOL_TYPE).into()),
|
||||
};
|
||||
|
||||
def.methods.insert("__lt__", fun.clone());
|
||||
def.methods.insert("__gt__", fun.clone());
|
||||
def.methods.insert("__le__", fun.clone());
|
||||
def.methods.insert("__ge__", fun);
|
||||
}
|
||||
|
||||
pub fn basic_ctx() -> TopLevelContext<'static> {
|
||||
let primitives = [
|
||||
TypeDef {
|
||||
name: "bool",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
TypeDef {
|
||||
name: "int32",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
TypeDef {
|
||||
name: "int64",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
TypeDef {
|
||||
name: "float",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
]
|
||||
.to_vec();
|
||||
let mut ctx = TopLevelContext::new(primitives);
|
||||
|
||||
let b = ctx.get_primitive(BOOL_TYPE);
|
||||
let b_def = ctx.get_primitive_def_mut(BOOL_TYPE);
|
||||
impl_eq(b_def, &b);
|
||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
||||
let int32_def = ctx.get_primitive_def_mut(INT32_TYPE);
|
||||
impl_math(int32_def, &int32);
|
||||
impl_bits(int32_def, &int32);
|
||||
impl_order(int32_def, &int32);
|
||||
impl_eq(int32_def, &int32);
|
||||
let int64 = ctx.get_primitive(INT64_TYPE);
|
||||
let int64_def = ctx.get_primitive_def_mut(INT64_TYPE);
|
||||
impl_math(int64_def, &int64);
|
||||
impl_bits(int64_def, &int64);
|
||||
impl_order(int64_def, &int64);
|
||||
impl_eq(int64_def, &int64);
|
||||
let float = ctx.get_primitive(FLOAT_TYPE);
|
||||
let float_def = ctx.get_primitive_def_mut(FLOAT_TYPE);
|
||||
impl_math(float_def, &float);
|
||||
impl_order(float_def, &float);
|
||||
impl_eq(float_def, &float);
|
||||
|
||||
let t = ctx.add_variable_private(VarDef {
|
||||
name: "T",
|
||||
bound: vec![],
|
||||
});
|
||||
|
||||
ctx.add_parametric(ParametricDef {
|
||||
base: TypeDef {
|
||||
name: "tuple",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
// we have nothing for tuple, so no param def
|
||||
params: vec![],
|
||||
});
|
||||
|
||||
ctx.add_parametric(ParametricDef {
|
||||
base: TypeDef {
|
||||
name: "list",
|
||||
fields: HashMap::new(),
|
||||
methods: HashMap::new(),
|
||||
},
|
||||
params: vec![t],
|
||||
});
|
||||
|
||||
let i = ctx.add_variable_private(VarDef {
|
||||
name: "I",
|
||||
bound: vec![
|
||||
PrimitiveType(INT32_TYPE).into(),
|
||||
PrimitiveType(INT64_TYPE).into(),
|
||||
PrimitiveType(FLOAT_TYPE).into(),
|
||||
],
|
||||
});
|
||||
let args = vec![TypeVariable(i).into()];
|
||||
ctx.add_fn(
|
||||
"int32",
|
||||
FnDef {
|
||||
args: args.clone(),
|
||||
result: Some(PrimitiveType(INT32_TYPE).into()),
|
||||
},
|
||||
);
|
||||
ctx.add_fn(
|
||||
"int64",
|
||||
FnDef {
|
||||
args: args.clone(),
|
||||
result: Some(PrimitiveType(INT64_TYPE).into()),
|
||||
},
|
||||
);
|
||||
ctx.add_fn(
|
||||
"float",
|
||||
FnDef {
|
||||
args,
|
||||
result: Some(PrimitiveType(FLOAT_TYPE).into()),
|
||||
},
|
||||
);
|
||||
|
||||
ctx
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::top_level::{DefinitionId, TopLevelContext, TopLevelDef};
|
||||
use crate::typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{Type, Unifier},
|
||||
};
|
||||
use crate::{location::Location, typecheck::typedef::TypeEnum};
|
||||
use itertools::{chain, izip};
|
||||
use rustpython_parser::ast::Expr;
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub enum SymbolValue {
|
||||
I32(i32),
|
||||
I64(i64),
|
||||
Double(f64),
|
||||
Bool(bool),
|
||||
Tuple(Vec<SymbolValue>),
|
||||
// we should think about how to implement bytes later...
|
||||
// Bytes(&'a [u8]),
|
||||
}
|
||||
|
||||
pub trait SymbolResolver {
|
||||
// get type of type variable identifier or top-level function type
|
||||
fn get_symbol_type(
|
||||
&self,
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
str: &str,
|
||||
) -> Option<Type>;
|
||||
// get the top-level definition of identifiers
|
||||
fn get_identifier_def(&self, str: &str) -> Option<DefinitionId>;
|
||||
fn get_symbol_value(&self, str: &str) -> Option<SymbolValue>;
|
||||
fn get_symbol_location(&self, str: &str) -> Option<Location>;
|
||||
// handle function call etc.
|
||||
}
|
||||
|
||||
// convert type annotation into type
|
||||
pub fn parse_type_annotation<T>(
|
||||
resolver: &dyn SymbolResolver,
|
||||
top_level: &TopLevelContext,
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
expr: &Expr<T>,
|
||||
) -> Result<Type, String> {
|
||||
use rustpython_parser::ast::ExprKind::*;
|
||||
match &expr.node {
|
||||
Name { id, .. } => match id.as_str() {
|
||||
"int32" => Ok(primitives.int32),
|
||||
"int64" => Ok(primitives.int64),
|
||||
"float" => Ok(primitives.float),
|
||||
"bool" => Ok(primitives.bool),
|
||||
"None" => Ok(primitives.none),
|
||||
x => {
|
||||
let obj_id = resolver.get_identifier_def(x);
|
||||
if let Some(obj_id) = obj_id {
|
||||
let defs = top_level.definitions.read();
|
||||
let def = defs[obj_id.0].read();
|
||||
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
|
||||
if !type_vars.is_empty() {
|
||||
return Err(format!(
|
||||
"Unexpected number of type parameters: expected {} but got 0",
|
||||
type_vars.len()
|
||||
));
|
||||
}
|
||||
let fields = RefCell::new(
|
||||
chain(
|
||||
fields.iter().map(|(k, v)| (k.clone(), *v)),
|
||||
methods.iter().map(|(k, v, _)| (k.clone(), *v)),
|
||||
)
|
||||
.collect(),
|
||||
);
|
||||
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id,
|
||||
fields,
|
||||
params: Default::default(),
|
||||
}))
|
||||
} else {
|
||||
Err("Cannot use function name as type".into())
|
||||
}
|
||||
} else {
|
||||
// it could be a type variable
|
||||
let ty = resolver
|
||||
.get_symbol_type(unifier, primitives, x)
|
||||
.ok_or_else(|| "Cannot use function name as type".to_owned())?;
|
||||
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
|
||||
Ok(ty)
|
||||
} else {
|
||||
Err(format!("Unknown type annotation {}", x))
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Subscript { value, slice, .. } => {
|
||||
if let Name { id, .. } = &value.node {
|
||||
if id == "virtual" {
|
||||
let ty =
|
||||
parse_type_annotation(resolver, top_level, unifier, primitives, slice)?;
|
||||
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
|
||||
} else {
|
||||
let types = if let Tuple { elts, .. } = &slice.node {
|
||||
elts.iter()
|
||||
.map(|v| {
|
||||
parse_type_annotation(resolver, top_level, unifier, primitives, v)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
} else {
|
||||
vec![parse_type_annotation(
|
||||
resolver, top_level, unifier, primitives, slice,
|
||||
)?]
|
||||
};
|
||||
|
||||
let obj_id = resolver
|
||||
.get_identifier_def(id)
|
||||
.ok_or_else(|| format!("Unknown type annotation {}", id))?;
|
||||
let defs = top_level.definitions.read();
|
||||
let def = defs[obj_id.0].read();
|
||||
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
|
||||
if types.len() != type_vars.len() {
|
||||
return Err(format!(
|
||||
"Unexpected number of type parameters: expected {} but got {}",
|
||||
type_vars.len(),
|
||||
types.len()
|
||||
));
|
||||
}
|
||||
let mut subst = HashMap::new();
|
||||
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
|
||||
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
|
||||
*id
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
subst.insert(id, *ty);
|
||||
}
|
||||
let mut fields = fields
|
||||
.iter()
|
||||
.map(|(attr, ty)| {
|
||||
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(attr.clone(), ty)
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
fields.extend(methods.iter().map(|(attr, ty, _)| {
|
||||
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(attr.clone(), ty)
|
||||
}));
|
||||
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id,
|
||||
fields: fields.into(),
|
||||
params: subst.into(),
|
||||
}))
|
||||
} else {
|
||||
Err("Cannot use function name as type".into())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err("unsupported type expression".into())
|
||||
}
|
||||
}
|
||||
_ => Err("unsupported type expression".into()),
|
||||
}
|
||||
}
|
||||
|
||||
impl dyn SymbolResolver + Send + Sync {
|
||||
pub fn parse_type_annotation<T>(
|
||||
&self,
|
||||
top_level: &TopLevelContext,
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
expr: &Expr<T>,
|
||||
) -> Result<Type, String> {
|
||||
parse_type_annotation(self, top_level, unifier, primitives, expr)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,778 @@
|
|||
use std::borrow::BorrowMut;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::{collections::HashMap, collections::HashSet, sync::Arc};
|
||||
|
||||
use super::typecheck::type_inferencer::PrimitiveStore;
|
||||
use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier};
|
||||
use crate::typecheck::typedef::{FunSignature, FuncArg};
|
||||
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Mapping};
|
||||
use itertools::Itertools;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use rustpython_parser::ast::{self, Stmt};
|
||||
|
||||
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
|
||||
pub struct DefinitionId(pub usize);
|
||||
|
||||
pub enum TopLevelDef {
|
||||
Class {
|
||||
// object ID used for TypeEnum
|
||||
object_id: DefinitionId,
|
||||
// type variables bounded to the class.
|
||||
type_vars: Vec<Type>,
|
||||
// class fields
|
||||
fields: Vec<(String, Type)>,
|
||||
// class methods, pointing to the corresponding function definition.
|
||||
methods: Vec<(String, Type, DefinitionId)>,
|
||||
// ancestor classes, including itself.
|
||||
ancestors: Vec<DefinitionId>,
|
||||
// symbol resolver of the module defined the class, none if it is built-in type
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
},
|
||||
Function {
|
||||
// prefix for symbol, should be unique globally, and not ending with numbers
|
||||
name: String,
|
||||
// function signature.
|
||||
signature: Type,
|
||||
/// Function instance to symbol mapping
|
||||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class.
|
||||
/// Value: function symbol name.
|
||||
instance_to_symbol: HashMap<String, String>,
|
||||
/// Function instances to annotated AST mapping
|
||||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class. Excluding rigid type
|
||||
/// variables.
|
||||
/// Value: AST annotated with types together with a unification table index. Could contain
|
||||
/// rigid type variables that would be substituted when the function is instantiated.
|
||||
instance_to_stmt: HashMap<String, (Stmt<Option<Type>>, usize)>,
|
||||
// symbol resolver of the module defined the class
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
},
|
||||
Initializer {
|
||||
class_id: DefinitionId,
|
||||
},
|
||||
}
|
||||
|
||||
impl TopLevelDef {
|
||||
fn get_function_type(&self) -> Result<Type, String> {
|
||||
if let Self::Function { signature, .. } = self {
|
||||
Ok(*signature)
|
||||
} else {
|
||||
Err("only expect function def here".into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TopLevelContext {
|
||||
pub definitions: Arc<RwLock<Vec<Arc<RwLock<TopLevelDef>>>>>,
|
||||
pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>,
|
||||
}
|
||||
|
||||
pub struct TopLevelComposer {
|
||||
// list of top level definitions, same as top level context
|
||||
pub definition_ast_list: Arc<RwLock<Vec<(Arc<RwLock<TopLevelDef>>, Option<ast::Stmt<()>>)>>>,
|
||||
// start as a primitive unifier, will add more top_level defs inside
|
||||
pub unifier: Unifier,
|
||||
// primitive store
|
||||
pub primitives: PrimitiveStore,
|
||||
// mangled class method name to def_id
|
||||
pub class_method_to_def_id: HashMap<String, DefinitionId>,
|
||||
// record the def id of the classes whoses fields and methods are to be analyzed
|
||||
pub to_be_analyzed_class: Vec<DefinitionId>,
|
||||
}
|
||||
|
||||
impl TopLevelComposer {
|
||||
pub fn to_top_level_context(&self) -> TopLevelContext {
|
||||
let def_list =
|
||||
self.definition_ast_list.read().iter().map(|(x, _)| x.clone()).collect::<Vec<_>>();
|
||||
TopLevelContext {
|
||||
definitions: RwLock::new(def_list).into(),
|
||||
// FIXME: all the big unifier or?
|
||||
unifiers: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn name_mangling(mut class_name: String, method_name: &str) -> String {
|
||||
class_name.push_str(method_name);
|
||||
class_name
|
||||
}
|
||||
|
||||
pub fn make_primitives() -> (PrimitiveStore, Unifier) {
|
||||
let mut unifier = Unifier::new();
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let float = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let none = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(4),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
|
||||
(primitives, unifier)
|
||||
}
|
||||
|
||||
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol
|
||||
/// resolver can later figure out primitive type definitions when passed a primitive type name
|
||||
pub fn new() -> (Vec<(String, DefinitionId, Type)>, Self) {
|
||||
let primitives = Self::make_primitives();
|
||||
|
||||
let top_level_def_list = vec![
|
||||
Arc::new(RwLock::new(Self::make_top_level_class_def(0, None))),
|
||||
Arc::new(RwLock::new(Self::make_top_level_class_def(1, None))),
|
||||
Arc::new(RwLock::new(Self::make_top_level_class_def(2, None))),
|
||||
Arc::new(RwLock::new(Self::make_top_level_class_def(3, None))),
|
||||
Arc::new(RwLock::new(Self::make_top_level_class_def(4, None))),
|
||||
];
|
||||
|
||||
let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None];
|
||||
|
||||
let composer = TopLevelComposer {
|
||||
definition_ast_list: RwLock::new(
|
||||
top_level_def_list.into_iter().zip(ast_list).collect_vec(),
|
||||
)
|
||||
.into(),
|
||||
primitives: primitives.0,
|
||||
unifier: primitives.1,
|
||||
class_method_to_def_id: Default::default(),
|
||||
to_be_analyzed_class: Default::default(),
|
||||
};
|
||||
(
|
||||
vec![
|
||||
("int32".into(), DefinitionId(0), composer.primitives.int32),
|
||||
("int64".into(), DefinitionId(1), composer.primitives.int64),
|
||||
("float".into(), DefinitionId(2), composer.primitives.float),
|
||||
("bool".into(), DefinitionId(3), composer.primitives.bool),
|
||||
("none".into(), DefinitionId(4), composer.primitives.none),
|
||||
],
|
||||
composer,
|
||||
)
|
||||
}
|
||||
|
||||
/// already include the definition_id of itself inside the ancestors vector
|
||||
/// when first regitering, the type_vars, fields, methods, ancestors are invalid
|
||||
pub fn make_top_level_class_def(
|
||||
index: usize,
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
) -> TopLevelDef {
|
||||
TopLevelDef::Class {
|
||||
object_id: DefinitionId(index),
|
||||
type_vars: Default::default(),
|
||||
fields: Default::default(),
|
||||
methods: Default::default(),
|
||||
ancestors: vec![DefinitionId(index)],
|
||||
resolver,
|
||||
}
|
||||
}
|
||||
|
||||
/// when first registering, the type is a invalid value
|
||||
pub fn make_top_level_function_def(
|
||||
name: String,
|
||||
ty: Type,
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
) -> TopLevelDef {
|
||||
TopLevelDef::Function {
|
||||
name,
|
||||
signature: ty,
|
||||
instance_to_symbol: Default::default(),
|
||||
instance_to_stmt: Default::default(),
|
||||
resolver,
|
||||
}
|
||||
}
|
||||
|
||||
/// step 0, register, just remeber the names of top level classes/function
|
||||
pub fn register_top_level(
|
||||
&mut self,
|
||||
ast: ast::Stmt<()>,
|
||||
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||
) -> Result<(String, DefinitionId), String> {
|
||||
let mut def_list = self.definition_ast_list.write();
|
||||
match &ast.node {
|
||||
ast::StmtKind::ClassDef { name, body, .. } => {
|
||||
let class_name = name.to_string();
|
||||
let class_def_id = def_list.len();
|
||||
|
||||
// add the class to the definition lists
|
||||
// since later when registering class method, ast will still be used,
|
||||
// here push None temporarly, later will move the ast inside
|
||||
let mut class_def_ast = (
|
||||
Arc::new(RwLock::new(Self::make_top_level_class_def(
|
||||
class_def_id,
|
||||
resolver.clone(),
|
||||
))),
|
||||
None,
|
||||
);
|
||||
|
||||
// parse class def body and register class methods into the def list.
|
||||
// module's symbol resolver would not know the name of the class methods,
|
||||
// thus cannot return their definition_id
|
||||
let mut class_method_name_def_ids: Vec<(
|
||||
String,
|
||||
Arc<RwLock<TopLevelDef>>,
|
||||
DefinitionId,
|
||||
)> = Vec::new();
|
||||
let mut class_method_index_offset = 0;
|
||||
for b in body {
|
||||
if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node {
|
||||
let method_name = Self::name_mangling(class_name.clone(), method_name);
|
||||
let method_def_id = def_list.len() + {
|
||||
class_method_index_offset += 1;
|
||||
class_method_index_offset
|
||||
};
|
||||
|
||||
// dummy method define here
|
||||
// the ast of class method is in the class, push None in to the list here
|
||||
class_method_name_def_ids.push((
|
||||
method_name.clone(),
|
||||
RwLock::new(Self::make_top_level_function_def(
|
||||
method_name.clone(),
|
||||
self.primitives.none,
|
||||
resolver.clone(),
|
||||
))
|
||||
.into(),
|
||||
DefinitionId(method_def_id),
|
||||
));
|
||||
}
|
||||
}
|
||||
// move the ast to the entry of the class in the ast_list
|
||||
class_def_ast.1 = Some(ast);
|
||||
|
||||
// now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order
|
||||
def_list.push(class_def_ast);
|
||||
for (name, def, id) in class_method_name_def_ids {
|
||||
def_list.push((def, None));
|
||||
self.class_method_to_def_id.insert(name, id);
|
||||
}
|
||||
|
||||
// put the constructor into the def_list
|
||||
def_list.push((
|
||||
RwLock::new(TopLevelDef::Initializer { class_id: DefinitionId(class_def_id) })
|
||||
.into(),
|
||||
None,
|
||||
));
|
||||
|
||||
// class, put its def_id into the to be analyzed set
|
||||
self.to_be_analyzed_class.push(DefinitionId(class_def_id));
|
||||
|
||||
Ok((class_name, DefinitionId(class_def_id)))
|
||||
}
|
||||
|
||||
ast::StmtKind::FunctionDef { name, .. } => {
|
||||
let fun_name = name.to_string();
|
||||
|
||||
// add to the definition list
|
||||
def_list.push((
|
||||
RwLock::new(Self::make_top_level_function_def(
|
||||
name.into(),
|
||||
self.primitives.none,
|
||||
resolver,
|
||||
))
|
||||
.into(),
|
||||
Some(ast),
|
||||
));
|
||||
|
||||
// return
|
||||
Ok((fun_name, DefinitionId(def_list.len() - 1)))
|
||||
}
|
||||
|
||||
_ => Err("only registrations of top level classes/functions are supprted".into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// step 1, analyze the type vars associated with top level class
|
||||
fn analyze_top_level_class_type_var(&mut self) -> Result<(), String> {
|
||||
let mut def_list = self.definition_ast_list.write();
|
||||
let converted_top_level = &self.to_top_level_context();
|
||||
let primitives = &self.primitives;
|
||||
let unifier = &mut self.unifier;
|
||||
|
||||
for (class_def, class_ast) in def_list.iter_mut() {
|
||||
// only deal with class def here
|
||||
let mut class_def = class_def.write();
|
||||
let (class_bases_ast, class_def_type_vars, class_resolver) = {
|
||||
if let TopLevelDef::Class { type_vars, resolver, .. } = class_def.deref_mut() {
|
||||
if let Some(ast::Located {
|
||||
node: ast::StmtKind::ClassDef { bases, .. }, ..
|
||||
}) = class_ast
|
||||
{
|
||||
(bases, type_vars, resolver)
|
||||
} else {
|
||||
unreachable!("must be both class")
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let class_resolver = class_resolver.as_ref().unwrap().lock();
|
||||
|
||||
let mut is_generic = false;
|
||||
for b in class_bases_ast {
|
||||
match &b.node {
|
||||
// analyze typevars bounded to the class,
|
||||
// only support things like `class A(Generic[T, V])`,
|
||||
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
|
||||
// i.e. only simple names are allowed in the subscript
|
||||
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
|
||||
ast::ExprKind::Subscript { value, slice, .. } if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Generic") =>
|
||||
{
|
||||
if !is_generic {
|
||||
is_generic = true;
|
||||
} else {
|
||||
return Err("Only single Generic[...] can be in bases".into());
|
||||
}
|
||||
|
||||
// if `class A(Generic[T, V, G])`
|
||||
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
|
||||
// parse the type vars
|
||||
let type_vars = elts
|
||||
.iter()
|
||||
.map(|e| {
|
||||
class_resolver.parse_type_annotation(
|
||||
converted_top_level,
|
||||
unifier.borrow_mut(),
|
||||
primitives,
|
||||
e,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// check if all are unique type vars
|
||||
let mut occured_type_var_id: HashSet<u32> = HashSet::new();
|
||||
let all_unique_type_var = type_vars.iter().all(|x| {
|
||||
let ty = unifier.get_ty(*x);
|
||||
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
|
||||
occured_type_var_id.insert(*id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
if !all_unique_type_var {
|
||||
return Err("expect unique type variables".into());
|
||||
}
|
||||
|
||||
// add to TopLevelDef
|
||||
class_def_type_vars.extend(type_vars);
|
||||
|
||||
// `class A(Generic[T])`
|
||||
} else {
|
||||
let ty = class_resolver.parse_type_annotation(
|
||||
converted_top_level,
|
||||
unifier.borrow_mut(),
|
||||
primitives,
|
||||
&slice,
|
||||
)?;
|
||||
// check if it is type var
|
||||
let is_type_var =
|
||||
matches!(unifier.get_ty(ty).as_ref(), &TypeEnum::TVar { .. });
|
||||
if !is_type_var {
|
||||
return Err("expect type variable here".into());
|
||||
}
|
||||
|
||||
// add to TopLevelDef
|
||||
class_def_type_vars.push(ty);
|
||||
}
|
||||
}
|
||||
|
||||
// if others, do nothing in this function
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// step 2, base classes. Need to separate step1 and step2 for this reason:
|
||||
/// `class B(Generic[T, V]);
|
||||
/// class A(B[int, bool])`
|
||||
/// if the type var associated with class `B` has not been handled properly,
|
||||
/// the parse of type annotation of `B[int, bool]` will fail
|
||||
fn analyze_top_level_class_bases(&mut self) -> Result<(), String> {
|
||||
let mut def_list = self.definition_ast_list.write();
|
||||
let converted_top_level = &self.to_top_level_context();
|
||||
let primitives = &self.primitives;
|
||||
let unifier = &mut self.unifier;
|
||||
|
||||
for (class_def, class_ast) in def_list.iter_mut() {
|
||||
let mut class_def = class_def.write();
|
||||
let (class_bases, class_ancestors, class_resolver) = {
|
||||
if let TopLevelDef::Class { ancestors, resolver, .. } = class_def.deref_mut() {
|
||||
if let Some(ast::Located {
|
||||
node: ast::StmtKind::ClassDef { bases, .. }, ..
|
||||
}) = class_ast
|
||||
{
|
||||
(bases, ancestors, resolver)
|
||||
} else {
|
||||
unreachable!("must be both class")
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let class_resolver = class_resolver.as_ref().unwrap().lock();
|
||||
for b in class_bases {
|
||||
// type vars have already been handled, so skip on `Generic[...]`
|
||||
if let ast::ExprKind::Subscript { value, .. } = &b.node {
|
||||
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||
if id == "Generic" {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// get the def id of the base class
|
||||
let base_ty = class_resolver.parse_type_annotation(
|
||||
converted_top_level,
|
||||
unifier.borrow_mut(),
|
||||
primitives,
|
||||
b,
|
||||
)?;
|
||||
let base_id =
|
||||
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(base_ty).as_ref() {
|
||||
*obj_id
|
||||
} else {
|
||||
return Err("expect concrete class/type to be base class".into());
|
||||
};
|
||||
|
||||
// write to the class ancestors, make sure the uniqueness
|
||||
if !class_ancestors.contains(&base_id) {
|
||||
class_ancestors.push(base_id);
|
||||
} else {
|
||||
return Err("cannot specify the same base class twice".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// step 3, class fields and methods
|
||||
// FIXME: analyze base classes here
|
||||
// FIXME: deal with self type
|
||||
// NOTE: prevent cycles only roughly done
|
||||
fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> {
|
||||
let mut def_ast_list = self.definition_ast_list.write();
|
||||
let converted_top_level = &self.to_top_level_context();
|
||||
let primitives = &self.primitives;
|
||||
let to_be_analyzed_class = &mut self.to_be_analyzed_class;
|
||||
let unifier = &mut self.unifier;
|
||||
|
||||
// NOTE: roughly prevent infinite loop
|
||||
let mut max_iter = to_be_analyzed_class.len() * 4;
|
||||
'class: loop {
|
||||
if to_be_analyzed_class.is_empty() && {
|
||||
max_iter -= 1;
|
||||
max_iter > 0
|
||||
} {
|
||||
break;
|
||||
}
|
||||
|
||||
let class_ind = to_be_analyzed_class.remove(0).0;
|
||||
let (class_name, class_body_ast, class_bases_ast, class_resolver, class_ancestors) = {
|
||||
let (class_def, class_ast) = &mut def_ast_list[class_ind];
|
||||
if let Some(ast::Located {
|
||||
node: ast::StmtKind::ClassDef { name, body, bases, .. },
|
||||
..
|
||||
}) = class_ast.as_ref()
|
||||
{
|
||||
if let TopLevelDef::Class { resolver, ancestors, .. } =
|
||||
class_def.write().deref()
|
||||
{
|
||||
(name, body, bases, resolver.as_ref().unwrap().clone(), ancestors.clone())
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
} else {
|
||||
unreachable!("should be class def ast")
|
||||
}
|
||||
};
|
||||
|
||||
let all_base_class_analyzed = {
|
||||
let not_yet_analyzed =
|
||||
to_be_analyzed_class.clone().into_iter().collect::<HashSet<_>>();
|
||||
let base = class_ancestors.clone().into_iter().collect::<HashSet<_>>();
|
||||
let intersection = not_yet_analyzed.intersection(&base).collect_vec();
|
||||
intersection.is_empty()
|
||||
};
|
||||
if !all_base_class_analyzed {
|
||||
to_be_analyzed_class.push(DefinitionId(class_ind));
|
||||
continue 'class;
|
||||
}
|
||||
|
||||
// get the bases type, can directly do this since it
|
||||
// already pass the check in the previous stages
|
||||
let class_bases_ty = class_bases_ast
|
||||
.iter()
|
||||
.filter_map(|x| {
|
||||
class_resolver
|
||||
.as_ref()
|
||||
.lock()
|
||||
.parse_type_annotation(
|
||||
converted_top_level,
|
||||
unifier.borrow_mut(),
|
||||
primitives,
|
||||
x,
|
||||
)
|
||||
.ok()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
// need these vectors to check re-defining methods, class fields
|
||||
// and store the parsed result in case some method cannot be typed for now
|
||||
let mut class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![];
|
||||
let mut class_fields_parsing_result: Vec<(String, Type)> = vec![];
|
||||
for b in class_body_ast {
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
args: method_args_ast,
|
||||
body: method_body_ast,
|
||||
name: method_name,
|
||||
returns: method_returns_ast,
|
||||
..
|
||||
} = &b.node
|
||||
{
|
||||
let arg_name_tys: Vec<(String, Type)> = {
|
||||
let mut result = vec![];
|
||||
for a in &method_args_ast.args {
|
||||
if a.node.arg != "self" {
|
||||
let annotation = a
|
||||
.node
|
||||
.annotation
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
"type annotation for function parameter is needed"
|
||||
.to_string()
|
||||
})?
|
||||
.as_ref();
|
||||
|
||||
let ty = class_resolver.as_ref().lock().parse_type_annotation(
|
||||
converted_top_level,
|
||||
unifier.borrow_mut(),
|
||||
primitives,
|
||||
annotation,
|
||||
)?;
|
||||
if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
|
||||
to_be_analyzed_class.push(DefinitionId(class_ind));
|
||||
continue 'class;
|
||||
}
|
||||
result.push((a.node.arg.to_string(), ty));
|
||||
} else {
|
||||
// TODO: handle self, how
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
result
|
||||
};
|
||||
|
||||
let method_type_var = arg_name_tys
|
||||
.iter()
|
||||
.filter_map(|(_, ty)| {
|
||||
let ty_enum = unifier.get_ty(*ty);
|
||||
if let TypeEnum::TVar { id, .. } = ty_enum.as_ref() {
|
||||
Some((*id, *ty))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Mapping<u32>>();
|
||||
|
||||
let ret_ty = {
|
||||
if method_name != "__init__" {
|
||||
let ty = method_returns_ast
|
||||
.as_ref()
|
||||
.map(|x| {
|
||||
class_resolver.as_ref().lock().parse_type_annotation(
|
||||
converted_top_level,
|
||||
unifier.borrow_mut(),
|
||||
primitives,
|
||||
x.as_ref(),
|
||||
)
|
||||
})
|
||||
.ok_or_else(|| "return type annotation error".to_string())??;
|
||||
if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
|
||||
to_be_analyzed_class.push(DefinitionId(class_ind));
|
||||
continue 'class;
|
||||
} else {
|
||||
ty
|
||||
}
|
||||
} else {
|
||||
// TODO: __init__ function, self type, how
|
||||
unimplemented!()
|
||||
}
|
||||
};
|
||||
|
||||
// handle fields
|
||||
let class_field_name_tys: Option<Vec<(String, Type)>> = if method_name
|
||||
== "__init__"
|
||||
{
|
||||
let mut result: Vec<(String, Type)> = vec![];
|
||||
for body in method_body_ast {
|
||||
match &body.node {
|
||||
ast::StmtKind::AnnAssign { target, annotation, .. }
|
||||
if {
|
||||
if let ast::ExprKind::Attribute { value, .. } = &target.node
|
||||
{
|
||||
matches!(
|
||||
&value.node,
|
||||
ast::ExprKind::Name { id, .. } if id == "self")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} =>
|
||||
{
|
||||
let field_ty =
|
||||
class_resolver.as_ref().lock().parse_type_annotation(
|
||||
converted_top_level,
|
||||
unifier.borrow_mut(),
|
||||
primitives,
|
||||
annotation.as_ref(),
|
||||
)?;
|
||||
if !Self::check_ty_analyzed(
|
||||
field_ty,
|
||||
unifier,
|
||||
to_be_analyzed_class,
|
||||
) {
|
||||
to_be_analyzed_class.push(DefinitionId(class_ind));
|
||||
continue 'class;
|
||||
} else {
|
||||
result.push((
|
||||
if let ast::ExprKind::Attribute { attr, .. } =
|
||||
&target.node
|
||||
{
|
||||
attr.to_string()
|
||||
} else {
|
||||
unreachable!()
|
||||
},
|
||||
field_ty,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// exclude those without type annotation
|
||||
ast::StmtKind::Assign { targets, .. }
|
||||
if {
|
||||
if let ast::ExprKind::Attribute { value, .. } =
|
||||
&targets[0].node
|
||||
{
|
||||
matches!(
|
||||
&value.node,
|
||||
ast::ExprKind::Name {id, ..} if id == "self")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} =>
|
||||
{
|
||||
return Err("class fields type annotation needed".into())
|
||||
}
|
||||
|
||||
// do nothing
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Some(result)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// current method all type ok, put the current method into the list
|
||||
if class_methods_parsing_result.iter().any(|(name, _, _)| name == method_name) {
|
||||
return Err("duplicate method definition".into());
|
||||
} else {
|
||||
class_methods_parsing_result.push((
|
||||
method_name.clone(),
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature {
|
||||
ret: ret_ty,
|
||||
args: arg_name_tys
|
||||
.into_iter()
|
||||
.map(|(name, ty)| FuncArg { name, ty, default_value: None })
|
||||
.collect_vec(),
|
||||
vars: method_type_var,
|
||||
}
|
||||
.into(),
|
||||
)),
|
||||
*self
|
||||
.class_method_to_def_id
|
||||
.get(&Self::name_mangling(class_name.clone(), method_name))
|
||||
.unwrap(),
|
||||
))
|
||||
}
|
||||
|
||||
// put the fiedlds inside
|
||||
if let Some(class_field_name_tys) = class_field_name_tys {
|
||||
assert!(class_fields_parsing_result.is_empty());
|
||||
class_fields_parsing_result.extend(class_field_name_tys);
|
||||
}
|
||||
} else {
|
||||
// what should we do with `class A: a = 3`?
|
||||
// do nothing, continue the for loop to iterate class ast
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// now it should be confirmed that every
|
||||
// methods and fields of the class can be correctly typed, put the results
|
||||
// into the actual class def method and fields field
|
||||
let (class_def, _) = &def_ast_list[class_ind];
|
||||
let mut class_def = class_def.write();
|
||||
if let TopLevelDef::Class { fields, methods, .. } = class_def.deref_mut() {
|
||||
for (ref n, ref t) in class_fields_parsing_result {
|
||||
fields.push((n.clone(), *t));
|
||||
}
|
||||
for (n, t, id) in &class_methods_parsing_result {
|
||||
methods.push((n.clone(), *t, *id));
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
// change the signature field of the class methods
|
||||
for (_, ty, id) in &class_methods_parsing_result {
|
||||
let (method_def, _) = &def_ast_list[id.0];
|
||||
let mut method_def = method_def.write();
|
||||
if let TopLevelDef::Function { signature, .. } = method_def.deref_mut() {
|
||||
*signature = *ty;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_top_level_function(&mut self) -> Result<(), String> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn check_ty_analyzed(ty: Type, unifier: &mut Unifier, to_be_analyzed: &[DefinitionId]) -> bool {
|
||||
let type_enum = unifier.get_ty(ty);
|
||||
match type_enum.as_ref() {
|
||||
TypeEnum::TObj { obj_id, .. } => !to_be_analyzed.contains(obj_id),
|
||||
TypeEnum::TVirtual { ty } => {
|
||||
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(*ty).as_ref() {
|
||||
!to_be_analyzed.contains(obj_id)
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
TypeEnum::TVar { .. } => true,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,216 @@
|
|||
use super::type_inferencer::Inferencer;
|
||||
use super::typedef::Type;
|
||||
use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind};
|
||||
use std::iter::once;
|
||||
|
||||
impl<'a> Inferencer<'a> {
|
||||
fn check_pattern(
|
||||
&mut self,
|
||||
pattern: &Expr<Option<Type>>,
|
||||
defined_identifiers: &mut Vec<String>,
|
||||
) -> Result<(), String> {
|
||||
match &pattern.node {
|
||||
ExprKind::Name { id, .. } => {
|
||||
if !defined_identifiers.contains(id) {
|
||||
defined_identifiers.push(id.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
ExprKind::Tuple { elts, .. } => {
|
||||
for elt in elts.iter() {
|
||||
self.check_pattern(elt, defined_identifiers)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
_ => self.check_expr(pattern, defined_identifiers),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_expr(
|
||||
&mut self,
|
||||
expr: &Expr<Option<Type>>,
|
||||
defined_identifiers: &[String],
|
||||
) -> Result<(), String> {
|
||||
// there are some cases where the custom field is None
|
||||
if let Some(ty) = &expr.custom {
|
||||
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
|
||||
return Err(format!(
|
||||
"expected concrete type at {} but got {}",
|
||||
expr.location,
|
||||
self.unifier.get_ty(*ty).get_type_name()
|
||||
));
|
||||
}
|
||||
}
|
||||
match &expr.node {
|
||||
ExprKind::Name { id, .. } => {
|
||||
if !defined_identifiers.contains(id) {
|
||||
return Err(format!(
|
||||
"unknown identifier {} (use before def?) at {}",
|
||||
id, expr.location
|
||||
));
|
||||
}
|
||||
}
|
||||
ExprKind::List { elts, .. }
|
||||
| ExprKind::Tuple { elts, .. }
|
||||
| ExprKind::BoolOp { values: elts, .. } => {
|
||||
for elt in elts.iter() {
|
||||
self.check_expr(elt, defined_identifiers)?;
|
||||
}
|
||||
}
|
||||
ExprKind::Attribute { value, .. } => {
|
||||
self.check_expr(value, defined_identifiers)?;
|
||||
}
|
||||
ExprKind::BinOp { left, right, .. } => {
|
||||
self.check_expr(left, defined_identifiers)?;
|
||||
self.check_expr(right, defined_identifiers)?;
|
||||
}
|
||||
ExprKind::UnaryOp { operand, .. } => {
|
||||
self.check_expr(operand, defined_identifiers)?;
|
||||
}
|
||||
ExprKind::Compare { left, comparators, .. } => {
|
||||
for elt in once(left.as_ref()).chain(comparators.iter()) {
|
||||
self.check_expr(elt, defined_identifiers)?;
|
||||
}
|
||||
}
|
||||
ExprKind::Subscript { value, slice, .. } => {
|
||||
self.check_expr(value, defined_identifiers)?;
|
||||
self.check_expr(slice, defined_identifiers)?;
|
||||
}
|
||||
ExprKind::IfExp { test, body, orelse } => {
|
||||
self.check_expr(test, defined_identifiers)?;
|
||||
self.check_expr(body, defined_identifiers)?;
|
||||
self.check_expr(orelse, defined_identifiers)?;
|
||||
}
|
||||
ExprKind::Slice { lower, upper, step } => {
|
||||
for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||
self.check_expr(elt, defined_identifiers)?;
|
||||
}
|
||||
}
|
||||
ExprKind::Lambda { args, body } => {
|
||||
let mut defined_identifiers = defined_identifiers.to_vec();
|
||||
for arg in args.args.iter() {
|
||||
if !defined_identifiers.contains(&arg.node.arg) {
|
||||
defined_identifiers.push(arg.node.arg.clone());
|
||||
}
|
||||
}
|
||||
self.check_expr(body, &defined_identifiers)?;
|
||||
}
|
||||
ExprKind::ListComp { elt, generators, .. } => {
|
||||
// in our type inference stage, we already make sure that there is only 1 generator
|
||||
let ast::Comprehension { target, iter, ifs, .. } = &generators[0];
|
||||
self.check_expr(iter, defined_identifiers)?;
|
||||
let mut defined_identifiers = defined_identifiers.to_vec();
|
||||
self.check_pattern(target, &mut defined_identifiers)?;
|
||||
for term in once(elt.as_ref()).chain(ifs.iter()) {
|
||||
self.check_expr(term, &defined_identifiers)?;
|
||||
}
|
||||
}
|
||||
ExprKind::Call { func, args, keywords } => {
|
||||
for expr in once(func.as_ref())
|
||||
.chain(args.iter())
|
||||
.chain(keywords.iter().map(|v| v.node.value.as_ref()))
|
||||
{
|
||||
self.check_expr(expr, defined_identifiers)?;
|
||||
}
|
||||
}
|
||||
ExprKind::Constant { .. } => {}
|
||||
_ => {
|
||||
println!("{:?}", expr.node);
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// check statements for proper identifier def-use and return on all paths
|
||||
fn check_stmt(
|
||||
&mut self,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
defined_identifiers: &mut Vec<String>,
|
||||
) -> Result<bool, String> {
|
||||
match &stmt.node {
|
||||
StmtKind::For { target, iter, body, orelse, .. } => {
|
||||
self.check_expr(iter, defined_identifiers)?;
|
||||
for stmt in orelse.iter() {
|
||||
self.check_stmt(stmt, defined_identifiers)?;
|
||||
}
|
||||
let mut defined_identifiers = defined_identifiers.clone();
|
||||
self.check_pattern(target, &mut defined_identifiers)?;
|
||||
for stmt in body.iter() {
|
||||
self.check_stmt(stmt, &mut defined_identifiers)?;
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
StmtKind::If { test, body, orelse } => {
|
||||
self.check_expr(test, defined_identifiers)?;
|
||||
let mut body_identifiers = defined_identifiers.clone();
|
||||
let mut orelse_identifiers = defined_identifiers.clone();
|
||||
let body_returned = self.check_block(body, &mut body_identifiers)?;
|
||||
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
|
||||
|
||||
for ident in body_identifiers.iter() {
|
||||
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
|
||||
defined_identifiers.push(ident.clone())
|
||||
}
|
||||
}
|
||||
Ok(body_returned && orelse_returned)
|
||||
}
|
||||
StmtKind::While { test, body, orelse } => {
|
||||
self.check_expr(test, defined_identifiers)?;
|
||||
let mut defined_identifiers = defined_identifiers.clone();
|
||||
self.check_block(body, &mut defined_identifiers)?;
|
||||
self.check_block(orelse, &mut defined_identifiers)?;
|
||||
Ok(false)
|
||||
}
|
||||
StmtKind::Expr { value } => {
|
||||
self.check_expr(value, defined_identifiers)?;
|
||||
Ok(false)
|
||||
}
|
||||
StmtKind::Assign { targets, value, .. } => {
|
||||
self.check_expr(value, defined_identifiers)?;
|
||||
for target in targets {
|
||||
self.check_pattern(target, defined_identifiers)?;
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
StmtKind::AnnAssign { target, value, .. } => {
|
||||
if let Some(value) = value {
|
||||
self.check_expr(value, defined_identifiers)?;
|
||||
self.check_pattern(target, defined_identifiers)?;
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
StmtKind::Return { value } => {
|
||||
if let Some(value) = value {
|
||||
self.check_expr(value, defined_identifiers)?;
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
StmtKind::Raise { exc, .. } => {
|
||||
if let Some(value) = exc {
|
||||
self.check_expr(value, defined_identifiers)?;
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
// break, raise, etc.
|
||||
_ => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_block(
|
||||
&mut self,
|
||||
block: &[Stmt<Option<Type>>],
|
||||
defined_identifiers: &mut Vec<String>,
|
||||
) -> Result<bool, String> {
|
||||
let mut ret = false;
|
||||
for stmt in block {
|
||||
if ret {
|
||||
return Err(format!("dead code at {:?}", stmt.location));
|
||||
}
|
||||
if self.check_stmt(stmt, defined_identifiers)? {
|
||||
ret = true;
|
||||
}
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,322 @@
|
|||
use crate::typecheck::{
|
||||
type_inferencer::*,
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
};
|
||||
use rustpython_parser::ast;
|
||||
use rustpython_parser::ast::{Cmpop, Operator, Unaryop};
|
||||
use std::borrow::Borrow;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub fn binop_name(op: &Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::Add => "__add__",
|
||||
Operator::Sub => "__sub__",
|
||||
Operator::Div => "__truediv__",
|
||||
Operator::Mod => "__mod__",
|
||||
Operator::Mult => "__mul__",
|
||||
Operator::Pow => "__pow__",
|
||||
Operator::BitOr => "__or__",
|
||||
Operator::BitXor => "__xor__",
|
||||
Operator::BitAnd => "__and__",
|
||||
Operator::LShift => "__lshift__",
|
||||
Operator::RShift => "__rshift__",
|
||||
Operator::FloorDiv => "__floordiv__",
|
||||
Operator::MatMult => "__matmul__",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn binop_assign_name(op: &Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::Add => "__iadd__",
|
||||
Operator::Sub => "__isub__",
|
||||
Operator::Div => "__itruediv__",
|
||||
Operator::Mod => "__imod__",
|
||||
Operator::Mult => "__imul__",
|
||||
Operator::Pow => "__ipow__",
|
||||
Operator::BitOr => "__ior__",
|
||||
Operator::BitXor => "__ixor__",
|
||||
Operator::BitAnd => "__iand__",
|
||||
Operator::LShift => "__ilshift__",
|
||||
Operator::RShift => "__irshift__",
|
||||
Operator::FloorDiv => "__ifloordiv__",
|
||||
Operator::MatMult => "__imatmul__",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unaryop_name(op: &Unaryop) -> &'static str {
|
||||
match op {
|
||||
Unaryop::UAdd => "__pos__",
|
||||
Unaryop::USub => "__neg__",
|
||||
Unaryop::Not => "__not__",
|
||||
Unaryop::Invert => "__inv__",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
|
||||
match op {
|
||||
Cmpop::Lt => Some("__lt__"),
|
||||
Cmpop::LtE => Some("__le__"),
|
||||
Cmpop::Gt => Some("__gt__"),
|
||||
Cmpop::GtE => Some("__ge__"),
|
||||
Cmpop::Eq => Some("__eq__"),
|
||||
Cmpop::NotEq => Some("__ne__"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn impl_binop(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
ops: &[ast::Operator],
|
||||
) {
|
||||
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
|
||||
let (other_ty, other_var_id) = if other_ty.len() == 1 {
|
||||
(other_ty[0], None)
|
||||
} else {
|
||||
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty);
|
||||
(ty, Some(var_id))
|
||||
};
|
||||
let function_vars = if let Some(var_id) = other_var_id {
|
||||
vec![(var_id, other_ty)].into_iter().collect::<HashMap<_, _>>()
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
for op in ops {
|
||||
fields.borrow_mut().insert(binop_name(op).into(), {
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature {
|
||||
ret: ret_ty,
|
||||
vars: function_vars.clone(),
|
||||
args: vec![FuncArg {
|
||||
ty: other_ty,
|
||||
default_value: None,
|
||||
name: "other".into(),
|
||||
}],
|
||||
}
|
||||
.into(),
|
||||
))
|
||||
});
|
||||
|
||||
fields.borrow_mut().insert(binop_assign_name(op).into(), {
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature {
|
||||
ret: store.none,
|
||||
vars: function_vars.clone(),
|
||||
args: vec![FuncArg {
|
||||
ty: other_ty,
|
||||
default_value: None,
|
||||
name: "other".into(),
|
||||
}],
|
||||
}
|
||||
.into(),
|
||||
))
|
||||
});
|
||||
}
|
||||
} else {
|
||||
unreachable!("")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn impl_unaryop(
|
||||
unifier: &mut Unifier,
|
||||
_store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
ret_ty: Type,
|
||||
ops: &[ast::Unaryop],
|
||||
) {
|
||||
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
|
||||
for op in ops {
|
||||
fields.borrow_mut().insert(
|
||||
unaryop_name(op).into(),
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(),
|
||||
)),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn impl_cmpop(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: Type,
|
||||
ops: &[ast::Cmpop],
|
||||
) {
|
||||
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
|
||||
for op in ops {
|
||||
fields.borrow_mut().insert(
|
||||
comparison_name(op).unwrap().into(),
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature {
|
||||
ret: store.bool,
|
||||
vars: HashMap::new(),
|
||||
args: vec![FuncArg {
|
||||
ty: other_ty,
|
||||
default_value: None,
|
||||
name: "other".into(),
|
||||
}],
|
||||
}
|
||||
.into(),
|
||||
)),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
/// Add, Sub, Mult, Pow
|
||||
pub fn impl_basic_arithmetic(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
) {
|
||||
impl_binop(
|
||||
unifier,
|
||||
store,
|
||||
ty,
|
||||
other_ty,
|
||||
ret_ty,
|
||||
&[ast::Operator::Add, ast::Operator::Sub, ast::Operator::Mult],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn impl_pow(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Pow])
|
||||
}
|
||||
|
||||
/// BitOr, BitXor, BitAnd
|
||||
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_binop(
|
||||
unifier,
|
||||
store,
|
||||
ty,
|
||||
&[ty],
|
||||
ty,
|
||||
&[ast::Operator::BitAnd, ast::Operator::BitOr, ast::Operator::BitXor],
|
||||
)
|
||||
}
|
||||
|
||||
/// LShift, RShift
|
||||
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift])
|
||||
}
|
||||
|
||||
/// Div
|
||||
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
|
||||
impl_binop(unifier, store, ty, other_ty, store.float, &[ast::Operator::Div])
|
||||
}
|
||||
|
||||
/// FloorDiv
|
||||
pub fn impl_floordiv(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::FloorDiv])
|
||||
}
|
||||
|
||||
/// Mod
|
||||
pub fn impl_mod(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Mod])
|
||||
}
|
||||
|
||||
/// UAdd, USub
|
||||
pub fn impl_sign(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::UAdd, ast::Unaryop::USub])
|
||||
}
|
||||
|
||||
/// Invert
|
||||
pub fn impl_invert(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::Invert])
|
||||
}
|
||||
|
||||
/// Not
|
||||
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_unaryop(unifier, store, ty, store.bool, &[ast::Unaryop::Not])
|
||||
}
|
||||
|
||||
/// Lt, LtE, Gt, GtE
|
||||
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
|
||||
impl_cmpop(
|
||||
unifier,
|
||||
store,
|
||||
ty,
|
||||
other_ty,
|
||||
&[ast::Cmpop::Lt, ast::Cmpop::Gt, ast::Cmpop::LtE, ast::Cmpop::GtE],
|
||||
)
|
||||
}
|
||||
|
||||
/// Eq, NotEq
|
||||
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_cmpop(unifier, store, ty, ty, &[ast::Cmpop::Eq, ast::Cmpop::NotEq])
|
||||
}
|
||||
|
||||
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||
let PrimitiveStore { int32: int32_t, int64: int64_t, float: float_t, bool: bool_t, .. } =
|
||||
*store;
|
||||
/* int32 ======== */
|
||||
impl_basic_arithmetic(unifier, store, int32_t, &[int32_t], int32_t);
|
||||
impl_pow(unifier, store, int32_t, &[int32_t], int32_t);
|
||||
impl_bitwise_arithmetic(unifier, store, int32_t);
|
||||
impl_bitwise_shift(unifier, store, int32_t);
|
||||
impl_div(unifier, store, int32_t, &[int32_t]);
|
||||
impl_floordiv(unifier, store, int32_t, &[int32_t], int32_t);
|
||||
impl_mod(unifier, store, int32_t, &[int32_t], int32_t);
|
||||
impl_sign(unifier, store, int32_t);
|
||||
impl_invert(unifier, store, int32_t);
|
||||
impl_not(unifier, store, int32_t);
|
||||
impl_comparison(unifier, store, int32_t, int32_t);
|
||||
impl_eq(unifier, store, int32_t);
|
||||
|
||||
/* int64 ======== */
|
||||
impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t);
|
||||
impl_pow(unifier, store, int64_t, &[int64_t], int64_t);
|
||||
impl_bitwise_arithmetic(unifier, store, int64_t);
|
||||
impl_bitwise_shift(unifier, store, int64_t);
|
||||
impl_div(unifier, store, int64_t, &[int64_t]);
|
||||
impl_floordiv(unifier, store, int64_t, &[int64_t], int64_t);
|
||||
impl_mod(unifier, store, int64_t, &[int64_t], int64_t);
|
||||
impl_sign(unifier, store, int64_t);
|
||||
impl_invert(unifier, store, int64_t);
|
||||
impl_not(unifier, store, int64_t);
|
||||
impl_comparison(unifier, store, int64_t, int64_t);
|
||||
impl_eq(unifier, store, int64_t);
|
||||
|
||||
/* float ======== */
|
||||
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
|
||||
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
|
||||
impl_div(unifier, store, float_t, &[float_t]);
|
||||
impl_floordiv(unifier, store, float_t, &[float_t], float_t);
|
||||
impl_mod(unifier, store, float_t, &[float_t], float_t);
|
||||
impl_sign(unifier, store, float_t);
|
||||
impl_not(unifier, store, float_t);
|
||||
impl_comparison(unifier, store, float_t, float_t);
|
||||
impl_eq(unifier, store, float_t);
|
||||
|
||||
/* bool ======== */
|
||||
impl_not(unifier, store, bool_t);
|
||||
impl_eq(unifier, store, bool_t);
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
mod function_check;
|
||||
pub mod magic_methods;
|
||||
pub mod type_inferencer;
|
||||
pub mod typedef;
|
||||
mod unification_table;
|
|
@ -0,0 +1,582 @@
|
|||
use std::collections::HashMap;
|
||||
use std::convert::{From, TryInto};
|
||||
use std::iter::once;
|
||||
use std::{cell::RefCell, sync::Arc};
|
||||
|
||||
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
|
||||
use super::{magic_methods::*, typedef::CallId};
|
||||
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
|
||||
use itertools::izip;
|
||||
use rustpython_parser::ast::{
|
||||
self,
|
||||
fold::{self, Fold},
|
||||
Arguments, Comprehension, ExprKind, Located, Location,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
|
||||
pub struct CodeLocation {
|
||||
row: usize,
|
||||
col: usize,
|
||||
}
|
||||
|
||||
impl From<Location> for CodeLocation {
|
||||
fn from(loc: Location) -> CodeLocation {
|
||||
CodeLocation { row: loc.row(), col: loc.column() }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct PrimitiveStore {
|
||||
pub int32: Type,
|
||||
pub int64: Type,
|
||||
pub float: Type,
|
||||
pub bool: Type,
|
||||
pub none: Type,
|
||||
}
|
||||
|
||||
pub struct FunctionData {
|
||||
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||
pub return_type: Option<Type>,
|
||||
pub bound_variables: Vec<Type>,
|
||||
}
|
||||
|
||||
pub struct Inferencer<'a> {
|
||||
pub top_level: &'a TopLevelContext,
|
||||
pub function_data: &'a mut FunctionData,
|
||||
pub unifier: &'a mut Unifier,
|
||||
pub primitives: &'a PrimitiveStore,
|
||||
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
||||
pub variable_mapping: HashMap<String, Type>,
|
||||
pub calls: &'a mut HashMap<CodeLocation, CallId>,
|
||||
}
|
||||
|
||||
struct NaiveFolder();
|
||||
impl fold::Fold<()> for NaiveFolder {
|
||||
type TargetU = Option<Type>;
|
||||
type Error = String;
|
||||
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||
type TargetU = Option<Type>;
|
||||
type Error = String;
|
||||
|
||||
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
|
||||
let stmt = match node.node {
|
||||
// we don't want fold over type annotation
|
||||
ast::StmtKind::AnnAssign { target, annotation, value, simple } => {
|
||||
let target = Box::new(self.fold_expr(*target)?);
|
||||
let value = if let Some(v) = value {
|
||||
let ty = Box::new(self.fold_expr(*v)?);
|
||||
self.unifier.unify(target.custom.unwrap(), ty.custom.unwrap())?;
|
||||
Some(ty)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let annotation_type = self.function_data.resolver.parse_type_annotation(
|
||||
self.top_level,
|
||||
self.unifier,
|
||||
&self.primitives,
|
||||
annotation.as_ref(),
|
||||
)?;
|
||||
self.unifier.unify(annotation_type, target.custom.unwrap())?;
|
||||
let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?);
|
||||
Located {
|
||||
location: node.location,
|
||||
custom: None,
|
||||
node: ast::StmtKind::AnnAssign { target, annotation, value, simple },
|
||||
}
|
||||
}
|
||||
_ => fold::fold_stmt(self, node)?,
|
||||
};
|
||||
match &stmt.node {
|
||||
ast::StmtKind::For { target, iter, .. } => {
|
||||
let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
||||
self.unifier.unify(list, iter.custom.unwrap())?;
|
||||
}
|
||||
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
|
||||
self.unifier.unify(test.custom.unwrap(), self.primitives.bool)?;
|
||||
}
|
||||
ast::StmtKind::Assign { targets, value, .. } => {
|
||||
for target in targets.iter() {
|
||||
self.unifier.unify(target.custom.unwrap(), value.custom.unwrap())?;
|
||||
}
|
||||
}
|
||||
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
||||
ast::StmtKind::Break | ast::StmtKind::Continue => {}
|
||||
ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
|
||||
(Some(v), Some(v1)) => {
|
||||
self.unifier.unify(v.custom.unwrap(), v1)?;
|
||||
}
|
||||
(Some(_), None) => {
|
||||
return Err("Unexpected return value".to_string());
|
||||
}
|
||||
(None, Some(_)) => {
|
||||
return Err("Expected return value".to_string());
|
||||
}
|
||||
(None, None) => {}
|
||||
},
|
||||
_ => return Err("Unsupported statement type".to_string()),
|
||||
};
|
||||
Ok(stmt)
|
||||
}
|
||||
|
||||
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
|
||||
let expr = match node.node {
|
||||
ast::ExprKind::Call { func, args, keywords } => {
|
||||
return self.fold_call(node.location, *func, args, keywords);
|
||||
}
|
||||
ast::ExprKind::Lambda { args, body } => {
|
||||
return self.fold_lambda(node.location, *args, *body);
|
||||
}
|
||||
ast::ExprKind::ListComp { elt, generators } => {
|
||||
return self.fold_listcomp(node.location, *elt, generators);
|
||||
}
|
||||
_ => fold::fold_expr(self, node)?,
|
||||
};
|
||||
let custom = match &expr.node {
|
||||
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
|
||||
ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?),
|
||||
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
||||
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
||||
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
|
||||
Some(self.infer_attribute(value, attr)?)
|
||||
}
|
||||
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
|
||||
ast::ExprKind::BinOp { left, op, right } => Some(self.infer_bin_ops(left, op, right)?),
|
||||
ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?),
|
||||
ast::ExprKind::Compare { left, ops, comparators } => {
|
||||
Some(self.infer_compare(left, ops, comparators)?)
|
||||
}
|
||||
ast::ExprKind::Subscript { value, slice, .. } => {
|
||||
Some(self.infer_subscript(value.as_ref(), slice.as_ref())?)
|
||||
}
|
||||
ast::ExprKind::IfExp { test, body, orelse } => {
|
||||
Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?)
|
||||
}
|
||||
ast::ExprKind::ListComp { .. }
|
||||
| ast::ExprKind::Lambda { .. }
|
||||
| ast::ExprKind::Call { .. } => expr.custom, // already computed
|
||||
ast::ExprKind::Slice { .. } => None, // we don't need it for slice
|
||||
_ => return Err("not supported yet".into()),
|
||||
};
|
||||
Ok(ast::Expr { custom, location: expr.location, node: expr.node })
|
||||
}
|
||||
}
|
||||
|
||||
type InferenceResult = Result<Type, String>;
|
||||
|
||||
impl<'a> Inferencer<'a> {
|
||||
/// Constrain a <: b
|
||||
/// Currently implemented as unification
|
||||
fn constrain(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||
self.unifier.unify(a, b)
|
||||
}
|
||||
|
||||
fn build_method_call(
|
||||
&mut self,
|
||||
location: Location,
|
||||
method: String,
|
||||
obj: Type,
|
||||
params: Vec<Type>,
|
||||
ret: Type,
|
||||
) -> InferenceResult {
|
||||
let call = self.unifier.add_call(Call {
|
||||
posargs: params,
|
||||
kwargs: HashMap::new(),
|
||||
ret,
|
||||
fun: RefCell::new(None),
|
||||
});
|
||||
self.calls.insert(location.into(), call);
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||
let fields = once((method, call)).collect();
|
||||
let record = self.unifier.add_record(fields);
|
||||
self.constrain(obj, record)?;
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
fn fold_lambda(
|
||||
&mut self,
|
||||
location: Location,
|
||||
args: Arguments,
|
||||
body: ast::Expr<()>,
|
||||
) -> Result<ast::Expr<Option<Type>>, String> {
|
||||
if !args.posonlyargs.is_empty()
|
||||
|| args.vararg.is_some()
|
||||
|| !args.kwonlyargs.is_empty()
|
||||
|| args.kwarg.is_some()
|
||||
|| !args.defaults.is_empty()
|
||||
{
|
||||
// actually I'm not sure whether programs violating this is a valid python program.
|
||||
return Err(
|
||||
"We only support positional or keyword arguments without defaults for lambdas."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let fn_args: Vec<_> = args
|
||||
.args
|
||||
.iter()
|
||||
.map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0))
|
||||
.collect();
|
||||
let mut variable_mapping = self.variable_mapping.clone();
|
||||
variable_mapping.extend(fn_args.iter().cloned());
|
||||
let ret = self.unifier.get_fresh_var().0;
|
||||
let mut new_context = Inferencer {
|
||||
function_data: self.function_data,
|
||||
unifier: self.unifier,
|
||||
primitives: self.primitives,
|
||||
virtual_checks: self.virtual_checks,
|
||||
calls: self.calls,
|
||||
top_level: self.top_level,
|
||||
variable_mapping,
|
||||
};
|
||||
let fun = FunSignature {
|
||||
args: fn_args
|
||||
.iter()
|
||||
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None })
|
||||
.collect(),
|
||||
ret,
|
||||
vars: Default::default(),
|
||||
};
|
||||
let body = new_context.fold_expr(body)?;
|
||||
new_context.unifier.unify(fun.ret, body.custom.unwrap())?;
|
||||
let mut args = new_context.fold_arguments(args)?;
|
||||
for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) {
|
||||
assert_eq!(&arg.node.arg, name);
|
||||
arg.custom = Some(*ty);
|
||||
}
|
||||
Ok(Located {
|
||||
location,
|
||||
node: ExprKind::Lambda { args: args.into(), body: body.into() },
|
||||
custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun.into()))),
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_listcomp(
|
||||
&mut self,
|
||||
location: Location,
|
||||
elt: ast::Expr<()>,
|
||||
mut generators: Vec<Comprehension>,
|
||||
) -> Result<ast::Expr<Option<Type>>, String> {
|
||||
if generators.len() != 1 {
|
||||
return Err(
|
||||
"Only 1 generator statement for list comprehension is supported.".to_string()
|
||||
);
|
||||
}
|
||||
let variable_mapping = self.variable_mapping.clone();
|
||||
let mut new_context = Inferencer {
|
||||
function_data: self.function_data,
|
||||
unifier: self.unifier,
|
||||
virtual_checks: self.virtual_checks,
|
||||
top_level: self.top_level,
|
||||
variable_mapping,
|
||||
primitives: self.primitives,
|
||||
calls: self.calls,
|
||||
};
|
||||
let elt = new_context.fold_expr(elt)?;
|
||||
let generator = generators.pop().unwrap();
|
||||
if generator.is_async {
|
||||
return Err("Async iterator not supported.".to_string());
|
||||
}
|
||||
let target = new_context.fold_expr(*generator.target)?;
|
||||
let iter = new_context.fold_expr(*generator.iter)?;
|
||||
let ifs: Vec<_> = generator
|
||||
.ifs
|
||||
.into_iter()
|
||||
.map(|v| new_context.fold_expr(v))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
// iter should be a list of targets...
|
||||
// actually it should be an iterator of targets, but we don't have iter type for now
|
||||
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
||||
new_context.unifier.unify(iter.custom.unwrap(), list)?;
|
||||
// if conditions should be bool
|
||||
for v in ifs.iter() {
|
||||
new_context.unifier.unify(v.custom.unwrap(), new_context.primitives.bool)?;
|
||||
}
|
||||
|
||||
Ok(Located {
|
||||
location,
|
||||
custom: Some(new_context.unifier.add_ty(TypeEnum::TList { ty: elt.custom.unwrap() })),
|
||||
node: ExprKind::ListComp {
|
||||
elt: Box::new(elt),
|
||||
generators: vec![ast::Comprehension {
|
||||
target: Box::new(target),
|
||||
iter: Box::new(iter),
|
||||
ifs,
|
||||
is_async: false,
|
||||
}],
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_call(
|
||||
&mut self,
|
||||
location: Location,
|
||||
func: ast::Expr<()>,
|
||||
mut args: Vec<ast::Expr<()>>,
|
||||
keywords: Vec<Located<ast::KeywordData>>,
|
||||
) -> Result<ast::Expr<Option<Type>>, String> {
|
||||
let func =
|
||||
if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } =
|
||||
func
|
||||
{
|
||||
// handle special functions that cannot be typed in the usual way...
|
||||
if id == "virtual" {
|
||||
if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
|
||||
return Err(
|
||||
"`virtual` can only accept 1/2 positional arguments.".to_string()
|
||||
);
|
||||
}
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let ty = if let Some(arg) = args.pop() {
|
||||
self.function_data.resolver.parse_type_annotation(
|
||||
self.top_level,
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
&arg,
|
||||
)?
|
||||
} else {
|
||||
self.unifier.get_fresh_var().0
|
||||
};
|
||||
self.virtual_checks.push((arg0.custom.unwrap(), ty));
|
||||
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
|
||||
return Ok(Located {
|
||||
location,
|
||||
custom,
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: None,
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id, ctx },
|
||||
}),
|
||||
args: vec![arg0],
|
||||
keywords: vec![],
|
||||
},
|
||||
});
|
||||
}
|
||||
// int64 is special because its argument can be a constant larger than int32
|
||||
if id == "int64" && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let int64: Result<i64, _> = val.try_into();
|
||||
let custom;
|
||||
if int64.is_ok() {
|
||||
custom = Some(self.primitives.int64);
|
||||
} else {
|
||||
return Err("Integer out of bound".into());
|
||||
}
|
||||
return Ok(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(val.clone()),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
|
||||
} else {
|
||||
func
|
||||
};
|
||||
let func = Box::new(self.fold_expr(func)?);
|
||||
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
|
||||
let keywords = keywords
|
||||
.into_iter()
|
||||
.map(|v| fold::fold_keyword(self, v))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let ret = self.unifier.get_fresh_var().0;
|
||||
let call = self.unifier.add_call(Call {
|
||||
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||
kwargs: keywords
|
||||
.iter()
|
||||
.map(|v| (v.node.arg.as_ref().unwrap().clone(), v.custom.unwrap()))
|
||||
.collect(),
|
||||
fun: RefCell::new(None),
|
||||
ret,
|
||||
});
|
||||
self.calls.insert(location.into(), call);
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||
self.unifier.unify(func.custom.unwrap(), call)?;
|
||||
|
||||
Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } })
|
||||
}
|
||||
|
||||
fn infer_identifier(&mut self, id: &str) -> InferenceResult {
|
||||
if let Some(ty) = self.variable_mapping.get(id) {
|
||||
Ok(*ty)
|
||||
} else {
|
||||
Ok(self
|
||||
.function_data
|
||||
.resolver
|
||||
.get_symbol_type(self.unifier, self.primitives, id)
|
||||
.unwrap_or_else(|| {
|
||||
let ty = self.unifier.get_fresh_var().0;
|
||||
self.variable_mapping.insert(id.to_string(), ty);
|
||||
ty
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_constant(&mut self, constant: &ast::Constant) -> InferenceResult {
|
||||
match constant {
|
||||
ast::Constant::Bool(_) => Ok(self.primitives.bool),
|
||||
ast::Constant::Int(val) => {
|
||||
let int32: Result<i32, _> = val.try_into();
|
||||
// int64 would be handled separately in functions
|
||||
if int32.is_ok() {
|
||||
Ok(self.primitives.int32)
|
||||
} else {
|
||||
Err("Integer out of bound".into())
|
||||
}
|
||||
}
|
||||
ast::Constant::Float(_) => Ok(self.primitives.float),
|
||||
ast::Constant::Tuple(vals) => {
|
||||
let ty: Result<Vec<_>, _> = vals.iter().map(|x| self.infer_constant(x)).collect();
|
||||
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? }))
|
||||
}
|
||||
_ => Err("not supported".into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_list(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
||||
let (ty, _) = self.unifier.get_fresh_var();
|
||||
for t in elts.iter() {
|
||||
self.unifier.unify(ty, t.custom.unwrap())?;
|
||||
}
|
||||
Ok(self.unifier.add_ty(TypeEnum::TList { ty }))
|
||||
}
|
||||
|
||||
fn infer_tuple(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
||||
let ty = elts.iter().map(|x| x.custom.unwrap()).collect();
|
||||
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
|
||||
}
|
||||
|
||||
fn infer_attribute(&mut self, value: &ast::Expr<Option<Type>>, attr: &str) -> InferenceResult {
|
||||
let (attr_ty, _) = self.unifier.get_fresh_var();
|
||||
let fields = once((attr.to_string(), attr_ty)).collect();
|
||||
let record = self.unifier.add_record(fields);
|
||||
self.constrain(value.custom.unwrap(), record)?;
|
||||
Ok(attr_ty)
|
||||
}
|
||||
|
||||
fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
||||
let b = self.primitives.bool;
|
||||
for v in values {
|
||||
self.constrain(v.custom.unwrap(), b)?;
|
||||
}
|
||||
Ok(b)
|
||||
}
|
||||
|
||||
fn infer_bin_ops(
|
||||
&mut self,
|
||||
left: &ast::Expr<Option<Type>>,
|
||||
op: &ast::Operator,
|
||||
right: &ast::Expr<Option<Type>>,
|
||||
) -> InferenceResult {
|
||||
let method = binop_name(op);
|
||||
let ret = self.unifier.get_fresh_var().0;
|
||||
self.build_method_call(
|
||||
left.location,
|
||||
method.to_string(),
|
||||
left.custom.unwrap(),
|
||||
vec![right.custom.unwrap()],
|
||||
ret,
|
||||
)
|
||||
}
|
||||
|
||||
fn infer_unary_ops(
|
||||
&mut self,
|
||||
op: &ast::Unaryop,
|
||||
operand: &ast::Expr<Option<Type>>,
|
||||
) -> InferenceResult {
|
||||
let method = unaryop_name(op);
|
||||
let ret = self.unifier.get_fresh_var().0;
|
||||
self.build_method_call(
|
||||
operand.location,
|
||||
method.to_string(),
|
||||
operand.custom.unwrap(),
|
||||
vec![],
|
||||
ret,
|
||||
)
|
||||
}
|
||||
|
||||
fn infer_compare(
|
||||
&mut self,
|
||||
left: &ast::Expr<Option<Type>>,
|
||||
ops: &[ast::Cmpop],
|
||||
comparators: &[ast::Expr<Option<Type>>],
|
||||
) -> InferenceResult {
|
||||
let boolean = self.primitives.bool;
|
||||
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
||||
let method =
|
||||
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string();
|
||||
self.build_method_call(
|
||||
a.location,
|
||||
method,
|
||||
a.custom.unwrap(),
|
||||
vec![b.custom.unwrap()],
|
||||
boolean,
|
||||
)?;
|
||||
}
|
||||
Ok(boolean)
|
||||
}
|
||||
|
||||
fn infer_subscript(
|
||||
&mut self,
|
||||
value: &ast::Expr<Option<Type>>,
|
||||
slice: &ast::Expr<Option<Type>>,
|
||||
) -> InferenceResult {
|
||||
let ty = self.unifier.get_fresh_var().0;
|
||||
match &slice.node {
|
||||
ast::ExprKind::Slice { lower, upper, step } => {
|
||||
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||
self.constrain(v.custom.unwrap(), self.primitives.int32)?;
|
||||
}
|
||||
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||
self.constrain(value.custom.unwrap(), list)?;
|
||||
Ok(list)
|
||||
}
|
||||
ast::ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||
// the index is a constant, so value can be a sequence.
|
||||
let ind: i32 = val.try_into().map_err(|_| "Index must be int32".to_string())?;
|
||||
let map = once((ind, ty)).collect();
|
||||
let seq = self.unifier.add_sequence(map);
|
||||
self.constrain(value.custom.unwrap(), seq)?;
|
||||
Ok(ty)
|
||||
}
|
||||
_ => {
|
||||
// the index is not a constant, so value can only be a list
|
||||
self.constrain(slice.custom.unwrap(), self.primitives.int32)?;
|
||||
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||
self.constrain(value.custom.unwrap(), list)?;
|
||||
Ok(ty)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_if_expr(
|
||||
&mut self,
|
||||
test: &ast::Expr<Option<Type>>,
|
||||
body: &ast::Expr<Option<Type>>,
|
||||
orelse: &ast::Expr<Option<Type>>,
|
||||
) -> InferenceResult {
|
||||
self.constrain(test.custom.unwrap(), self.primitives.bool)?;
|
||||
let ty = self.unifier.get_fresh_var().0;
|
||||
self.constrain(body.custom.unwrap(), ty)?;
|
||||
self.constrain(orelse.custom.unwrap(), ty)?;
|
||||
Ok(ty)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,546 @@
|
|||
use super::super::typedef::*;
|
||||
use super::*;
|
||||
use crate::symbol_resolver::*;
|
||||
use crate::top_level::DefinitionId;
|
||||
use crate::{location::Location, top_level::TopLevelDef};
|
||||
use indoc::indoc;
|
||||
use itertools::zip;
|
||||
use parking_lot::RwLock;
|
||||
use rustpython_parser::parser::parse_program;
|
||||
use test_case::test_case;
|
||||
|
||||
struct Resolver {
|
||||
id_to_type: HashMap<String, Type>,
|
||||
id_to_def: HashMap<String, DefinitionId>,
|
||||
class_names: HashMap<String, Type>,
|
||||
}
|
||||
|
||||
impl SymbolResolver for Resolver {
|
||||
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
||||
self.id_to_type.get(str).cloned()
|
||||
}
|
||||
|
||||
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_symbol_location(&self, _: &str) -> Option<Location> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
|
||||
self.id_to_def.get(id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
struct TestEnvironment {
|
||||
pub unifier: Unifier,
|
||||
pub function_data: FunctionData,
|
||||
pub primitives: PrimitiveStore,
|
||||
pub id_to_name: HashMap<usize, String>,
|
||||
pub identifier_mapping: HashMap<String, Type>,
|
||||
pub virtual_checks: Vec<(Type, Type)>,
|
||||
pub calls: HashMap<CodeLocation, CallId>,
|
||||
pub top_level: TopLevelContext,
|
||||
}
|
||||
|
||||
impl TestEnvironment {
|
||||
pub fn basic_test_env() -> TestEnvironment {
|
||||
let mut unifier = Unifier::new();
|
||||
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let float = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let none = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(4),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||
set_primitives_magic_methods(&primitives, &mut unifier);
|
||||
|
||||
let id_to_name = [
|
||||
(0, "int32".to_string()),
|
||||
(1, "int64".to_string()),
|
||||
(2, "float".to_string()),
|
||||
(3, "bool".to_string()),
|
||||
(4, "none".to_string()),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut identifier_mapping = HashMap::new();
|
||||
identifier_mapping.insert("None".into(), none);
|
||||
|
||||
let resolver = Arc::new(Resolver {
|
||||
id_to_type: identifier_mapping.clone(),
|
||||
id_to_def: Default::default(),
|
||||
class_names: Default::default(),
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
TestEnvironment {
|
||||
top_level: TopLevelContext {
|
||||
definitions: Default::default(),
|
||||
unifiers: Default::default(),
|
||||
},
|
||||
unifier,
|
||||
function_data: FunctionData {
|
||||
resolver,
|
||||
bound_variables: Vec::new(),
|
||||
return_type: None,
|
||||
},
|
||||
primitives,
|
||||
id_to_name,
|
||||
identifier_mapping,
|
||||
virtual_checks: Vec::new(),
|
||||
calls: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn new() -> TestEnvironment {
|
||||
let mut unifier = Unifier::new();
|
||||
let mut identifier_mapping = HashMap::new();
|
||||
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let float = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let none = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(4),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
identifier_mapping.insert("None".into(), none);
|
||||
for i in 0..5 {
|
||||
top_level_defs.push(
|
||||
RwLock::new(TopLevelDef::Class {
|
||||
object_id: DefinitionId(i),
|
||||
type_vars: Default::default(),
|
||||
fields: Default::default(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||
|
||||
let (v0, id) = unifier.get_fresh_var();
|
||||
|
||||
let foo_ty = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(5),
|
||||
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
});
|
||||
top_level_defs.push(
|
||||
RwLock::new(TopLevelDef::Class {
|
||||
object_id: DefinitionId(5),
|
||||
type_vars: vec![v0],
|
||||
fields: [("a".into(), v0)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
|
||||
identifier_mapping.insert(
|
||||
"Foo".into(),
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature {
|
||||
args: vec![],
|
||||
ret: foo_ty,
|
||||
vars: [(id, v0)].iter().cloned().collect(),
|
||||
}
|
||||
.into(),
|
||||
)),
|
||||
);
|
||||
|
||||
let fun = unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature { args: vec![], ret: int32, vars: Default::default() }.into(),
|
||||
));
|
||||
let bar = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(6),
|
||||
fields: [("a".into(), int32), ("b".into(), fun)]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<_, _>>()
|
||||
.into(),
|
||||
params: Default::default(),
|
||||
});
|
||||
top_level_defs.push(
|
||||
RwLock::new(TopLevelDef::Class {
|
||||
object_id: DefinitionId(6),
|
||||
type_vars: Default::default(),
|
||||
fields: [("a".into(), int32), ("b".into(), fun)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
identifier_mapping.insert(
|
||||
"Bar".into(),
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature { args: vec![], ret: bar, vars: Default::default() }.into(),
|
||||
)),
|
||||
);
|
||||
|
||||
let bar2 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(7),
|
||||
fields: [("a".into(), bool), ("b".into(), fun)]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<_, _>>()
|
||||
.into(),
|
||||
params: Default::default(),
|
||||
});
|
||||
top_level_defs.push(
|
||||
RwLock::new(TopLevelDef::Class {
|
||||
object_id: DefinitionId(7),
|
||||
type_vars: Default::default(),
|
||||
fields: [("a".into(), bool), ("b".into(), fun)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
identifier_mapping.insert(
|
||||
"Bar2".into(),
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature { args: vec![], ret: bar2, vars: Default::default() }.into(),
|
||||
)),
|
||||
);
|
||||
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
|
||||
|
||||
let id_to_name = [
|
||||
(0, "int32".to_string()),
|
||||
(1, "int64".to_string()),
|
||||
(2, "float".to_string()),
|
||||
(3, "bool".to_string()),
|
||||
(4, "none".to_string()),
|
||||
(5, "Foo".to_string()),
|
||||
(6, "Bar".to_string()),
|
||||
(7, "Bar2".to_string()),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let top_level = TopLevelContext {
|
||||
definitions: Arc::new(RwLock::new(top_level_defs)),
|
||||
unifiers: Default::default(),
|
||||
};
|
||||
|
||||
let resolver = Arc::new(Resolver {
|
||||
id_to_type: identifier_mapping.clone(),
|
||||
id_to_def: [
|
||||
("Foo".into(), DefinitionId(5)),
|
||||
("Bar".into(), DefinitionId(6)),
|
||||
("Bar2".into(), DefinitionId(7)),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect(),
|
||||
class_names,
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
top_level,
|
||||
function_data: FunctionData {
|
||||
resolver,
|
||||
bound_variables: Vec::new(),
|
||||
return_type: None,
|
||||
},
|
||||
primitives,
|
||||
id_to_name,
|
||||
identifier_mapping,
|
||||
virtual_checks: Vec::new(),
|
||||
calls: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_inferencer(&mut self) -> Inferencer {
|
||||
Inferencer {
|
||||
top_level: &self.top_level,
|
||||
function_data: &mut self.function_data,
|
||||
unifier: &mut self.unifier,
|
||||
variable_mapping: Default::default(),
|
||||
primitives: &mut self.primitives,
|
||||
virtual_checks: &mut self.virtual_checks,
|
||||
calls: &mut self.calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test_case(indoc! {"
|
||||
a = 1234
|
||||
b = int64(2147483648)
|
||||
c = 1.234
|
||||
d = True
|
||||
"},
|
||||
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(),
|
||||
&[]
|
||||
; "primitives test")]
|
||||
#[test_case(indoc! {"
|
||||
a = lambda x, y: x
|
||||
b = lambda x: a(x, x)
|
||||
c = 1.234
|
||||
d = b(c)
|
||||
"},
|
||||
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
|
||||
&[]
|
||||
; "lambda test")]
|
||||
#[test_case(indoc! {"
|
||||
a = lambda x: x
|
||||
b = lambda x: x
|
||||
|
||||
foo1 = Foo()
|
||||
foo2 = Foo()
|
||||
c = a(foo1.a)
|
||||
d = b(foo2.a)
|
||||
|
||||
a(True)
|
||||
b(123)
|
||||
|
||||
"},
|
||||
[("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"),
|
||||
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(),
|
||||
&[]
|
||||
; "obj test")]
|
||||
#[test_case(indoc! {"
|
||||
f = lambda x: True
|
||||
a = [1, 2, 3]
|
||||
b = [f(x) for x in a if f(x)]
|
||||
"},
|
||||
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(),
|
||||
&[]
|
||||
; "listcomp test")]
|
||||
#[test_case(indoc! {"
|
||||
a = virtual(Bar(), Bar)
|
||||
b = a.b()
|
||||
a = virtual(Bar2())
|
||||
"},
|
||||
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(),
|
||||
&[("Bar", "Bar"), ("Bar2", "Bar")]
|
||||
; "virtual test")]
|
||||
#[test_case(indoc! {"
|
||||
a = [virtual(Bar(), Bar), virtual(Bar2())]
|
||||
b = [x.b() for x in a]
|
||||
"},
|
||||
[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(),
|
||||
&[("Bar", "Bar"), ("Bar2", "Bar")]
|
||||
; "virtual list test")]
|
||||
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
|
||||
println!("source:\n{}", source);
|
||||
let mut env = TestEnvironment::new();
|
||||
let id_to_name = std::mem::take(&mut env.id_to_name);
|
||||
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
||||
defined_identifiers.push("virtual".to_string());
|
||||
let mut inferencer = env.get_inferencer();
|
||||
let statements = parse_program(source).unwrap();
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|v| inferencer.fold_stmt(v))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap();
|
||||
|
||||
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
|
||||
|
||||
for (k, v) in inferencer.variable_mapping.iter() {
|
||||
let name = inferencer.unifier.stringify(
|
||||
*v,
|
||||
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||
&mut |v| format!("v{}", v),
|
||||
);
|
||||
println!("{}: {}", k, name);
|
||||
}
|
||||
for (k, v) in mapping.iter() {
|
||||
let ty = inferencer.variable_mapping.get(*k).unwrap();
|
||||
let name = inferencer.unifier.stringify(
|
||||
*ty,
|
||||
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||
&mut |v| format!("v{}", v),
|
||||
);
|
||||
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
|
||||
}
|
||||
assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
|
||||
for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
|
||||
let a = inferencer.unifier.stringify(
|
||||
*a,
|
||||
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||
&mut |v| format!("v{}", v),
|
||||
);
|
||||
let b = inferencer.unifier.stringify(
|
||||
*b,
|
||||
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||
&mut |v| format!("v{}", v),
|
||||
);
|
||||
|
||||
assert_eq!(&a, x);
|
||||
assert_eq!(&b, y);
|
||||
}
|
||||
}
|
||||
|
||||
#[test_case(indoc! {"
|
||||
a = 2
|
||||
b = 2
|
||||
c = a + b
|
||||
d = a - b
|
||||
e = a * b
|
||||
f = a / b
|
||||
g = a // b
|
||||
h = a % b
|
||||
"},
|
||||
[("a", "int32"),
|
||||
("b", "int32"),
|
||||
("c", "int32"),
|
||||
("d", "int32"),
|
||||
("e", "int32"),
|
||||
("f", "float"),
|
||||
("g", "int32"),
|
||||
("h", "int32")].iter().cloned().collect()
|
||||
; "int32")]
|
||||
#[test_case(
|
||||
indoc! {"
|
||||
a = 2.4
|
||||
b = 3.6
|
||||
c = a + b
|
||||
d = a - b
|
||||
e = a * b
|
||||
f = a / b
|
||||
g = a // b
|
||||
h = a % b
|
||||
i = a ** b
|
||||
ii = 3
|
||||
j = a ** b
|
||||
"},
|
||||
[("a", "float"),
|
||||
("b", "float"),
|
||||
("c", "float"),
|
||||
("d", "float"),
|
||||
("e", "float"),
|
||||
("f", "float"),
|
||||
("g", "float"),
|
||||
("h", "float"),
|
||||
("i", "float"),
|
||||
("ii", "int32"),
|
||||
("j", "float")].iter().cloned().collect()
|
||||
; "float"
|
||||
)]
|
||||
#[test_case(
|
||||
indoc! {"
|
||||
a = int64(12312312312)
|
||||
b = int64(24242424424)
|
||||
c = a + b
|
||||
d = a - b
|
||||
e = a * b
|
||||
f = a / b
|
||||
g = a // b
|
||||
h = a % b
|
||||
i = a == b
|
||||
j = a > b
|
||||
k = a < b
|
||||
l = a != b
|
||||
"},
|
||||
[("a", "int64"),
|
||||
("b", "int64"),
|
||||
("c", "int64"),
|
||||
("d", "int64"),
|
||||
("e", "int64"),
|
||||
("f", "float"),
|
||||
("g", "int64"),
|
||||
("h", "int64"),
|
||||
("i", "bool"),
|
||||
("j", "bool"),
|
||||
("k", "bool"),
|
||||
("l", "bool")].iter().cloned().collect()
|
||||
; "int64"
|
||||
)]
|
||||
#[test_case(
|
||||
indoc! {"
|
||||
a = True
|
||||
b = False
|
||||
c = a == b
|
||||
d = not a
|
||||
e = a != b
|
||||
"},
|
||||
[("a", "bool"),
|
||||
("b", "bool"),
|
||||
("c", "bool"),
|
||||
("d", "bool"),
|
||||
("e", "bool")].iter().cloned().collect()
|
||||
; "boolean"
|
||||
)]
|
||||
fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
|
||||
println!("source:\n{}", source);
|
||||
let mut env = TestEnvironment::basic_test_env();
|
||||
let id_to_name = std::mem::take(&mut env.id_to_name);
|
||||
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
||||
defined_identifiers.push("virtual".to_string());
|
||||
let mut inferencer = env.get_inferencer();
|
||||
let statements = parse_program(source).unwrap();
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|v| inferencer.fold_stmt(v))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap();
|
||||
|
||||
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
|
||||
|
||||
for (k, v) in inferencer.variable_mapping.iter() {
|
||||
let name = inferencer.unifier.stringify(
|
||||
*v,
|
||||
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||
&mut |v| format!("v{}", v),
|
||||
);
|
||||
println!("{}: {}", k, name);
|
||||
}
|
||||
for (k, v) in mapping.iter() {
|
||||
let ty = inferencer.variable_mapping.get(*k).unwrap();
|
||||
let name = inferencer.unifier.stringify(
|
||||
*ty,
|
||||
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||
&mut |v| format!("v{}", v),
|
||||
);
|
||||
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,947 @@
|
|||
use itertools::{chain, zip, Itertools};
|
||||
use std::borrow::Cow;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::iter::once;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::unification_table::{UnificationKey, UnificationTable};
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::top_level::DefinitionId;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
/// Handle for a type, implementated as a key in the unification table.
|
||||
pub type Type = UnificationKey;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub struct CallId(usize);
|
||||
|
||||
pub type Mapping<K, V = Type> = HashMap<K, V>;
|
||||
type VarMap = Mapping<u32>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Call {
|
||||
pub posargs: Vec<Type>,
|
||||
pub kwargs: HashMap<String, Type>,
|
||||
pub ret: Type,
|
||||
pub fun: RefCell<Option<Type>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FuncArg {
|
||||
pub name: String,
|
||||
pub ty: Type,
|
||||
pub default_value: Option<SymbolValue>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FunSignature {
|
||||
pub args: Vec<FuncArg>,
|
||||
pub ret: Type,
|
||||
pub vars: VarMap,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum TypeVarMeta {
|
||||
Generic,
|
||||
Sequence(RefCell<Mapping<i32>>),
|
||||
Record(RefCell<Mapping<String>>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum TypeEnum {
|
||||
TRigidVar {
|
||||
id: u32,
|
||||
},
|
||||
TVar {
|
||||
id: u32,
|
||||
meta: TypeVarMeta,
|
||||
// empty indicates no restriction
|
||||
range: RefCell<Vec<Type>>,
|
||||
},
|
||||
TTuple {
|
||||
ty: Vec<Type>,
|
||||
},
|
||||
TList {
|
||||
ty: Type,
|
||||
},
|
||||
TObj {
|
||||
obj_id: DefinitionId,
|
||||
fields: RefCell<Mapping<String>>,
|
||||
params: RefCell<VarMap>,
|
||||
},
|
||||
TVirtual {
|
||||
ty: Type,
|
||||
},
|
||||
TCall(RefCell<Vec<CallId>>),
|
||||
TFunc(RefCell<FunSignature>),
|
||||
}
|
||||
|
||||
impl TypeEnum {
|
||||
pub fn get_type_name(&self) -> &'static str {
|
||||
match self {
|
||||
TypeEnum::TRigidVar { .. } => "TRigidVar",
|
||||
TypeEnum::TVar { .. } => "TVar",
|
||||
TypeEnum::TTuple { .. } => "TTuple",
|
||||
TypeEnum::TList { .. } => "TList",
|
||||
TypeEnum::TObj { .. } => "TObj",
|
||||
TypeEnum::TVirtual { .. } => "TVirtual",
|
||||
TypeEnum::TCall { .. } => "TCall",
|
||||
TypeEnum::TFunc { .. } => "TFunc",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
|
||||
|
||||
pub struct Unifier {
|
||||
unification_table: UnificationTable<Rc<TypeEnum>>,
|
||||
calls: Vec<Rc<Call>>,
|
||||
var_id: u32,
|
||||
}
|
||||
|
||||
impl Unifier {
|
||||
/// Get an empty unifier
|
||||
pub fn new() -> Unifier {
|
||||
Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() }
|
||||
}
|
||||
|
||||
/// Determine if the two types are the same
|
||||
pub fn unioned(&mut self, a: Type, b: Type) -> bool {
|
||||
self.unification_table.unioned(a, b)
|
||||
}
|
||||
|
||||
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
|
||||
let lock = unifier.lock().unwrap();
|
||||
Unifier {
|
||||
unification_table: UnificationTable::from_send(&lock.0),
|
||||
var_id: lock.1,
|
||||
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_shared_unifier(&self) -> SharedUnifier {
|
||||
Arc::new(Mutex::new((
|
||||
self.unification_table.get_send(),
|
||||
self.var_id,
|
||||
self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Register a type to the unifier.
|
||||
/// Returns a key in the unification_table.
|
||||
pub fn add_ty(&mut self, a: TypeEnum) -> Type {
|
||||
self.unification_table.new_key(Rc::new(a))
|
||||
}
|
||||
|
||||
pub fn add_record(&mut self, fields: Mapping<String>) -> Type {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
self.add_ty(TypeEnum::TVar {
|
||||
id,
|
||||
range: vec![].into(),
|
||||
meta: TypeVarMeta::Record(fields.into()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_call(&mut self, call: Call) -> CallId {
|
||||
let id = CallId(self.calls.len());
|
||||
self.calls.push(Rc::new(call));
|
||||
id
|
||||
}
|
||||
|
||||
pub fn get_representative(&mut self, ty: Type) -> Type {
|
||||
self.unification_table.get_representative(ty)
|
||||
}
|
||||
|
||||
pub fn add_sequence(&mut self, sequence: Mapping<i32>) -> Type {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
self.add_ty(TypeEnum::TVar {
|
||||
id,
|
||||
range: vec![].into(),
|
||||
meta: TypeVarMeta::Sequence(sequence.into()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the TypeEnum of a type.
|
||||
pub fn get_ty(&mut self, a: Type) -> Rc<TypeEnum> {
|
||||
self.unification_table.probe_value(a).clone()
|
||||
}
|
||||
|
||||
pub fn get_fresh_rigid_var(&mut self) -> (Type, u32) {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
(self.add_ty(TypeEnum::TRigidVar { id }), id)
|
||||
}
|
||||
|
||||
pub fn get_fresh_var(&mut self) -> (Type, u32) {
|
||||
self.get_fresh_var_with_range(&[])
|
||||
}
|
||||
|
||||
/// Get a fresh type variable.
|
||||
pub fn get_fresh_var_with_range(&mut self, range: &[Type]) -> (Type, u32) {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
let range = range.to_vec().into();
|
||||
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
|
||||
}
|
||||
|
||||
/// Unification would not unify rigid variables with other types, but we want to do this for
|
||||
/// function instantiations, so we make it explicit.
|
||||
pub fn replace_rigid_var(&mut self, rigid: Type, b: Type) {
|
||||
assert!(matches!(&*self.get_ty(rigid), TypeEnum::TRigidVar { .. }));
|
||||
self.set_a_to_b(rigid, b);
|
||||
}
|
||||
|
||||
pub fn get_instantiations(&mut self, ty: Type) -> Option<Vec<Type>> {
|
||||
match &*self.get_ty(ty) {
|
||||
TypeEnum::TVar { range, .. } => {
|
||||
let range = range.borrow();
|
||||
if range.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
range
|
||||
.iter()
|
||||
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||
.flatten()
|
||||
.collect_vec(),
|
||||
)
|
||||
}
|
||||
}
|
||||
TypeEnum::TList { ty } => self
|
||||
.get_instantiations(*ty)
|
||||
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
|
||||
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
|
||||
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
|
||||
}),
|
||||
TypeEnum::TTuple { ty } => {
|
||||
let tuples = ty
|
||||
.iter()
|
||||
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||
.multi_cartesian_product()
|
||||
.collect_vec();
|
||||
if tuples.len() == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
TypeEnum::TObj { params, .. } => {
|
||||
let params = params.borrow();
|
||||
let (keys, params): (Vec<&u32>, Vec<&Type>) = params.iter().unzip();
|
||||
let params = params
|
||||
.into_iter()
|
||||
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||
.multi_cartesian_product()
|
||||
.collect_vec();
|
||||
if params.len() <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
params
|
||||
.into_iter()
|
||||
.map(|params| {
|
||||
self.subst(
|
||||
ty,
|
||||
&zip(keys.iter().cloned().cloned(), params.iter().cloned())
|
||||
.collect(),
|
||||
)
|
||||
.unwrap_or(ty)
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
||||
use TypeEnum::*;
|
||||
match &*self.get_ty(a) {
|
||||
TRigidVar { .. } => true,
|
||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||
TCall { .. } => false,
|
||||
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||
TObj { params: vars, .. } => {
|
||||
vars.borrow().values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||
}
|
||||
// functions are instantiated for each call sites, so the function type can contain
|
||||
// type variables.
|
||||
TFunc { .. } => true,
|
||||
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||
if self.unification_table.unioned(a, b) {
|
||||
Ok(())
|
||||
} else {
|
||||
self.unify_impl(a, b, false)
|
||||
}
|
||||
}
|
||||
|
||||
fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> {
|
||||
use TypeEnum::*;
|
||||
use TypeVarMeta::*;
|
||||
let (ty_a, ty_b) = {
|
||||
(
|
||||
self.unification_table.probe_value(a).clone(),
|
||||
self.unification_table.probe_value(b).clone(),
|
||||
)
|
||||
};
|
||||
match (&*ty_a, &*ty_b) {
|
||||
(TVar { meta: meta1, range: range1, .. }, TVar { meta: meta2, range: range2, .. }) => {
|
||||
self.occur_check(a, b)?;
|
||||
self.occur_check(b, a)?;
|
||||
match (meta1, meta2) {
|
||||
(Generic, _) => {}
|
||||
(_, Generic) => {
|
||||
return self.unify_impl(b, a, true);
|
||||
}
|
||||
(Record(fields1), Record(fields2)) => {
|
||||
let mut fields2 = fields2.borrow_mut();
|
||||
for (key, value) in fields1.borrow().iter() {
|
||||
if let Some(ty) = fields2.get(key) {
|
||||
self.unify(*ty, *value)?;
|
||||
} else {
|
||||
fields2.insert(key.clone(), *value);
|
||||
}
|
||||
}
|
||||
}
|
||||
(Sequence(map1), Sequence(map2)) => {
|
||||
let mut map2 = map2.borrow_mut();
|
||||
for (key, value) in map1.borrow().iter() {
|
||||
if let Some(ty) = map2.get(key) {
|
||||
self.unify(*ty, *value)?;
|
||||
} else {
|
||||
map2.insert(*key, *value);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err("Incompatible".to_string());
|
||||
}
|
||||
}
|
||||
let range1 = range1.borrow();
|
||||
// new range is the intersection of them
|
||||
// empty range indicates no constraint
|
||||
if !range1.is_empty() {
|
||||
let old_range2 = range2.take();
|
||||
let mut range2 = range2.borrow_mut();
|
||||
if old_range2.is_empty() {
|
||||
range2.extend_from_slice(&range1);
|
||||
}
|
||||
for v1 in old_range2.iter() {
|
||||
for v2 in range1.iter() {
|
||||
if let Ok(result) = self.get_intersection(*v1, *v2) {
|
||||
range2.push(result.unwrap_or(*v2));
|
||||
}
|
||||
}
|
||||
}
|
||||
if range2.is_empty() {
|
||||
return Err(
|
||||
"cannot unify type variables with incompatible value range".to_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVar { meta: Generic, id, range, .. }, _) => {
|
||||
self.occur_check(a, b)?;
|
||||
// We check for the range of the type variable to see if unification is allowed.
|
||||
// Note that although b may be compatible with a, we may have to constrain type
|
||||
// variables in b to make sure that instantiations of b would always be compatible
|
||||
// with a.
|
||||
// The return value x of check_var_compatibility would be a new type that is
|
||||
// guaranteed to be compatible with a under all possible instantiations. So we
|
||||
// unify x with b to recursively apply the constrains, and then set a to x.
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => {
|
||||
self.occur_check(a, b)?;
|
||||
let len = ty.len() as i32;
|
||||
for (k, v) in map.borrow().iter() {
|
||||
// handle negative index
|
||||
let ind = if *k < 0 { len + *k } else { *k };
|
||||
if ind >= len || ind < 0 {
|
||||
return Err(format!(
|
||||
"Tuple index out of range. (Length: {}, Index: {})",
|
||||
len, k
|
||||
));
|
||||
}
|
||||
self.unify(*v, ty[ind as usize])?;
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => {
|
||||
self.occur_check(a, b)?;
|
||||
for v in map.borrow().values() {
|
||||
self.unify(*v, *ty)?;
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
||||
if ty1.len() != ty2.len() {
|
||||
return Err(format!(
|
||||
"Cannot unify tuples with length {} and {}",
|
||||
ty1.len(),
|
||||
ty2.len()
|
||||
));
|
||||
}
|
||||
for (x, y) in ty1.iter().zip(ty2.iter()) {
|
||||
self.unify(*x, *y)?;
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||
self.unify(*ty1, *ty2)?;
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
|
||||
self.occur_check(a, b)?;
|
||||
for (k, v) in map.borrow().iter() {
|
||||
let ty = fields
|
||||
.borrow()
|
||||
.get(k)
|
||||
.copied()
|
||||
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||
self.unify(ty, *v)?;
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
|
||||
self.occur_check(a, b)?;
|
||||
let ty = self.get_ty(*ty);
|
||||
if let TObj { fields, .. } = ty.as_ref() {
|
||||
for (k, v) in map.borrow().iter() {
|
||||
let ty = fields
|
||||
.borrow()
|
||||
.get(k)
|
||||
.copied()
|
||||
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
|
||||
return Err(format!("Cannot access field {} for virtual type", k));
|
||||
}
|
||||
self.unify(*v, ty)?;
|
||||
}
|
||||
} else {
|
||||
// require annotation...
|
||||
return Err("Requires type annotation for virtual".to_string());
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(
|
||||
TObj { obj_id: id1, params: params1, .. },
|
||||
TObj { obj_id: id2, params: params2, .. },
|
||||
) => {
|
||||
if id1 != id2 {
|
||||
return Err(format!("Cannot unify objects with ID {} and {}", id1.0, id2.0));
|
||||
}
|
||||
for (x, y) in zip(params1.borrow().values(), params2.borrow().values()) {
|
||||
self.unify(*x, *y)?;
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||
self.unify(*ty1, *ty2)?;
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TCall(calls1), TCall(calls2)) => {
|
||||
// we do not unify individual calls, instead we defer until the unification wtih a
|
||||
// function definition.
|
||||
calls2.borrow_mut().extend_from_slice(&calls1.borrow());
|
||||
}
|
||||
(TCall(calls), TFunc(signature)) => {
|
||||
self.occur_check(a, b)?;
|
||||
let required: Vec<String> = signature
|
||||
.borrow()
|
||||
.args
|
||||
.iter()
|
||||
.filter(|v| v.default_value.is_none())
|
||||
.map(|v| v.name.clone())
|
||||
.rev()
|
||||
.collect();
|
||||
// we unify every calls to the function signature.
|
||||
for c in calls.borrow().iter() {
|
||||
let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone();
|
||||
let instantiated = self.instantiate_fun(b, &*signature.borrow());
|
||||
let r = self.get_ty(instantiated);
|
||||
let r = r.as_ref();
|
||||
let signature;
|
||||
if let TypeEnum::TFunc(s) = &*r {
|
||||
signature = s;
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
// we check to make sure that all required arguments (those without default
|
||||
// arguments) are provided, and do not provide the same argument twice.
|
||||
let mut required = required.clone();
|
||||
let mut all_names: Vec<_> = signature
|
||||
.borrow()
|
||||
.args
|
||||
.iter()
|
||||
.map(|v| (v.name.clone(), v.ty))
|
||||
.rev()
|
||||
.collect();
|
||||
for (i, t) in posargs.iter().enumerate() {
|
||||
if signature.borrow().args.len() <= i {
|
||||
return Err("Too many arguments.".to_string());
|
||||
}
|
||||
if !required.is_empty() {
|
||||
required.pop();
|
||||
}
|
||||
self.unify(all_names.pop().unwrap().1, *t)?;
|
||||
}
|
||||
for (k, t) in kwargs.iter() {
|
||||
if let Some(i) = required.iter().position(|v| v == k) {
|
||||
required.remove(i);
|
||||
}
|
||||
let i = all_names
|
||||
.iter()
|
||||
.position(|v| &v.0 == k)
|
||||
.ok_or_else(|| format!("Unknown keyword argument {}", k))?;
|
||||
self.unify(all_names.remove(i).1, *t)?;
|
||||
}
|
||||
if !required.is_empty() {
|
||||
return Err("Expected more arguments".to_string());
|
||||
}
|
||||
self.unify(*ret, signature.borrow().ret)?;
|
||||
*fun.borrow_mut() = Some(instantiated);
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TFunc(sign1), TFunc(sign2)) => {
|
||||
let (sign1, sign2) = (&*sign1.borrow(), &*sign2.borrow());
|
||||
if !sign1.vars.is_empty() || !sign2.vars.is_empty() {
|
||||
return Err("Polymorphic function pointer is prohibited.".to_string());
|
||||
}
|
||||
if sign1.args.len() != sign2.args.len() {
|
||||
return Err("Functions differ in number of parameters.".to_string());
|
||||
}
|
||||
for (x, y) in sign1.args.iter().zip(sign2.args.iter()) {
|
||||
if x.name != y.name {
|
||||
return Err("Functions differ in parameter names.".to_string());
|
||||
}
|
||||
if x.default_value != y.default_value {
|
||||
return Err("Functions differ in optional parameters.".to_string());
|
||||
}
|
||||
self.unify(x.ty, y.ty)?;
|
||||
}
|
||||
self.unify(sign1.ret, sign2.ret)?;
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
_ => {
|
||||
if swapped {
|
||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
||||
} else {
|
||||
self.unify_impl(b, a, true)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get string representation of the type
|
||||
pub fn stringify<F, G>(&mut self, ty: Type, obj_to_name: &mut F, var_to_name: &mut G) -> String
|
||||
where
|
||||
F: FnMut(usize) -> String,
|
||||
G: FnMut(u32) -> String,
|
||||
{
|
||||
use TypeVarMeta::*;
|
||||
let ty = self.unification_table.probe_value(ty).clone();
|
||||
match ty.as_ref() {
|
||||
TypeEnum::TRigidVar { id } => var_to_name(*id),
|
||||
TypeEnum::TVar { id, meta: Generic, .. } => var_to_name(*id),
|
||||
TypeEnum::TVar { meta: Sequence(map), .. } => {
|
||||
let fields = map
|
||||
.borrow()
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
|
||||
.join(", ");
|
||||
format!("seq[{}]", fields)
|
||||
}
|
||||
TypeEnum::TVar { meta: Record(fields), .. } => {
|
||||
let fields = fields
|
||||
.borrow()
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
|
||||
.join(", ");
|
||||
format!("record[{}]", fields)
|
||||
}
|
||||
TypeEnum::TTuple { ty } => {
|
||||
let mut fields = ty.iter().map(|v| self.stringify(*v, obj_to_name, var_to_name));
|
||||
format!("tuple[{}]", fields.join(", "))
|
||||
}
|
||||
TypeEnum::TList { ty } => {
|
||||
format!("list[{}]", self.stringify(*ty, obj_to_name, var_to_name))
|
||||
}
|
||||
TypeEnum::TVirtual { ty } => {
|
||||
format!("virtual[{}]", self.stringify(*ty, obj_to_name, var_to_name))
|
||||
}
|
||||
TypeEnum::TObj { obj_id, params, .. } => {
|
||||
let name = obj_to_name(obj_id.0);
|
||||
let params = params.borrow();
|
||||
if !params.is_empty() {
|
||||
let mut params =
|
||||
params.values().map(|v| self.stringify(*v, obj_to_name, var_to_name));
|
||||
format!("{}[{}]", name, params.join(", "))
|
||||
} else {
|
||||
name
|
||||
}
|
||||
}
|
||||
TypeEnum::TCall { .. } => "call".to_owned(),
|
||||
TypeEnum::TFunc(signature) => {
|
||||
let params = signature
|
||||
.borrow()
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
format!("{}={}", arg.name, self.stringify(arg.ty, obj_to_name, var_to_name))
|
||||
})
|
||||
.join(", ");
|
||||
let ret = self.stringify(signature.borrow().ret, obj_to_name, var_to_name);
|
||||
format!("fn[[{}], {}]", params, ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_a_to_b(&mut self, a: Type, b: Type) {
|
||||
// unify a and b together, and set the value to b's value.
|
||||
let table = &mut self.unification_table;
|
||||
let ty_b = table.probe_value(b).clone();
|
||||
table.unify(a, b);
|
||||
table.set_value(a, ty_b)
|
||||
}
|
||||
|
||||
fn incompatible_types(&self, a: &TypeEnum, b: &TypeEnum) -> Result<(), String> {
|
||||
Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name()))
|
||||
}
|
||||
|
||||
/// Instantiate a function if it hasn't been instantiated.
|
||||
/// Returns Some(T) where T is the instantiated type.
|
||||
/// Returns None if the function is already instantiated.
|
||||
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
|
||||
let mut instantiated = false;
|
||||
let mut vars = Vec::new();
|
||||
for (k, v) in fun.vars.iter() {
|
||||
if let TypeEnum::TVar { id, range, .. } =
|
||||
self.unification_table.probe_value(*v).as_ref()
|
||||
{
|
||||
if k != id {
|
||||
instantiated = true;
|
||||
break;
|
||||
}
|
||||
// actually, if the first check succeeded, the function should be uninstatiated.
|
||||
// The cloned values must be used and would not be wasted.
|
||||
vars.push((*k, range.clone()));
|
||||
} else {
|
||||
instantiated = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if instantiated {
|
||||
ty
|
||||
} else {
|
||||
let mapping = vars
|
||||
.into_iter()
|
||||
.map(|(k, range)| (k, self.get_fresh_var_with_range(range.borrow().as_ref()).0))
|
||||
.collect();
|
||||
self.subst(ty, &mapping).unwrap_or(ty)
|
||||
}
|
||||
}
|
||||
|
||||
/// Substitute type variables within a type into other types.
|
||||
/// If this returns Some(T), T would be the substituted type.
|
||||
/// If this returns None, the result type would be the original type
|
||||
/// (no substitution has to be done).
|
||||
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||
use TypeVarMeta::*;
|
||||
let ty = self.unification_table.probe_value(a).clone();
|
||||
// this function would only be called when we instantiate functions.
|
||||
// function type signature should ONLY contain concrete types and type
|
||||
// variables, i.e. things like TRecord, TCall should not occur, and we
|
||||
// should be safe to not implement the substitution for those variants.
|
||||
match &*ty {
|
||||
TypeEnum::TRigidVar { .. } => None,
|
||||
TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(),
|
||||
TypeEnum::TTuple { ty } => {
|
||||
let mut new_ty = Cow::from(ty);
|
||||
for (i, t) in ty.iter().enumerate() {
|
||||
if let Some(t1) = self.subst(*t, mapping) {
|
||||
new_ty.to_mut()[i] = t1;
|
||||
}
|
||||
}
|
||||
if matches!(new_ty, Cow::Owned(_)) {
|
||||
Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() }))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
TypeEnum::TList { ty } => {
|
||||
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
||||
}
|
||||
TypeEnum::TVirtual { ty } => {
|
||||
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
|
||||
}
|
||||
TypeEnum::TObj { obj_id, fields, params } => {
|
||||
// Type variables in field types must be present in the type parameter.
|
||||
// If the mapping does not contain any type variables in the
|
||||
// parameter list, we don't need to substitute the fields.
|
||||
// This is also used to prevent infinite substitution...
|
||||
let params = params.borrow();
|
||||
let need_subst = params.values().any(|v| {
|
||||
let ty = self.unification_table.probe_value(*v);
|
||||
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
|
||||
mapping.contains_key(&id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
if need_subst {
|
||||
let obj_id = *obj_id;
|
||||
let params = self.subst_map(¶ms, mapping).unwrap_or_else(|| params.clone());
|
||||
let fields = self
|
||||
.subst_map(&fields.borrow(), mapping)
|
||||
.unwrap_or_else(|| fields.borrow().clone());
|
||||
Some(self.add_ty(TypeEnum::TObj {
|
||||
obj_id,
|
||||
params: params.into(),
|
||||
fields: fields.into(),
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
TypeEnum::TFunc(sig) => {
|
||||
let FunSignature { args, ret, vars: params } = &*sig.borrow();
|
||||
let new_params = self.subst_map(params, mapping);
|
||||
let new_ret = self.subst(*ret, mapping);
|
||||
let mut new_args = Cow::from(args);
|
||||
for (i, t) in args.iter().enumerate() {
|
||||
if let Some(t1) = self.subst(t.ty, mapping) {
|
||||
let mut t = t.clone();
|
||||
t.ty = t1;
|
||||
new_args.to_mut()[i] = t;
|
||||
}
|
||||
}
|
||||
if new_params.is_some() || new_ret.is_some() || matches!(new_args, Cow::Owned(..)) {
|
||||
let params = new_params.unwrap_or_else(|| params.clone());
|
||||
let ret = new_ret.unwrap_or_else(|| *ret);
|
||||
let args = new_args.into_owned();
|
||||
Some(
|
||||
self.add_ty(TypeEnum::TFunc(
|
||||
FunSignature { args, ret, vars: params }.into(),
|
||||
)),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
{
|
||||
let mut map2 = None;
|
||||
for (k, v) in map.iter() {
|
||||
if let Some(v1) = self.subst(*v, mapping) {
|
||||
if map2.is_none() {
|
||||
map2 = Some(map.clone());
|
||||
}
|
||||
*map2.as_mut().unwrap().get_mut(k).unwrap() = v1;
|
||||
}
|
||||
}
|
||||
map2
|
||||
}
|
||||
|
||||
fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||
use TypeVarMeta::*;
|
||||
if self.unification_table.unioned(a, b) {
|
||||
return Err("Recursive type is prohibited.".to_owned());
|
||||
}
|
||||
let ty = self.unification_table.probe_value(b).clone();
|
||||
|
||||
match ty.as_ref() {
|
||||
TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {}
|
||||
TypeEnum::TVar { meta: Sequence(map), .. } => {
|
||||
for t in map.borrow().values() {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TVar { meta: Record(map), .. } => {
|
||||
for t in map.borrow().values() {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TCall(calls) => {
|
||||
let call_store = self.calls.clone();
|
||||
for t in calls
|
||||
.borrow()
|
||||
.iter()
|
||||
.map(|call| {
|
||||
let call = call_store[call.0].as_ref();
|
||||
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
|
||||
})
|
||||
.flatten()
|
||||
{
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TTuple { ty } => {
|
||||
for t in ty.iter() {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => {
|
||||
self.occur_check(a, *ty)?;
|
||||
}
|
||||
TypeEnum::TObj { params: map, .. } => {
|
||||
for t in map.borrow().values() {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TFunc(sig) => {
|
||||
let FunSignature { args, ret, vars: params } = &*sig.borrow();
|
||||
for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
|
||||
use TypeEnum::*;
|
||||
let x = self.get_ty(a);
|
||||
let y = self.get_ty(b);
|
||||
match (x.as_ref(), y.as_ref()) {
|
||||
(TVar { range: range1, .. }, TVar { meta, range: range2, .. }) => {
|
||||
// we should restrict range2
|
||||
let range1 = range1.borrow();
|
||||
// new range is the intersection of them
|
||||
// empty range indicates no constraint
|
||||
if !range1.is_empty() {
|
||||
let range2 = range2.borrow();
|
||||
let mut range = Vec::new();
|
||||
if range2.is_empty() {
|
||||
range.extend_from_slice(&range1);
|
||||
}
|
||||
for v1 in range2.iter() {
|
||||
for v2 in range1.iter() {
|
||||
let result = self.get_intersection(*v1, *v2);
|
||||
if let Ok(result) = result {
|
||||
range.push(result.unwrap_or(*v2));
|
||||
}
|
||||
}
|
||||
}
|
||||
if range.is_empty() {
|
||||
Err(())
|
||||
} else {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
let ty = TVar { id, meta: meta.clone(), range: range.into() };
|
||||
Ok(Some(self.unification_table.new_key(ty.into())))
|
||||
}
|
||||
} else {
|
||||
Ok(Some(b))
|
||||
}
|
||||
}
|
||||
(_, TVar { range, .. }) => {
|
||||
// range should be restricted to the left hand side
|
||||
let range = range.borrow();
|
||||
if range.is_empty() {
|
||||
Ok(Some(a))
|
||||
} else {
|
||||
for v in range.iter() {
|
||||
let result = self.get_intersection(a, *v);
|
||||
if let Ok(result) = result {
|
||||
return Ok(result.or(Some(a)));
|
||||
}
|
||||
}
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
(TVar { id, range, .. }, _) => {
|
||||
self.check_var_compatibility(*id, b, &range.borrow()).or(Err(()))
|
||||
}
|
||||
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
||||
if ty1.len() != ty2.len() {
|
||||
return Err(());
|
||||
}
|
||||
let mut need_new = false;
|
||||
let mut ty = ty1.clone();
|
||||
for (a, b) in zip(ty1.iter(), ty2.iter()) {
|
||||
let result = self.get_intersection(*a, *b)?;
|
||||
ty.push(result.unwrap_or(*a));
|
||||
if result.is_some() {
|
||||
need_new = true;
|
||||
}
|
||||
}
|
||||
if need_new {
|
||||
Ok(Some(self.add_ty(TTuple { ty })))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
|
||||
}
|
||||
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
|
||||
}
|
||||
(TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => {
|
||||
if id1 == id2 {
|
||||
Ok(None)
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
// don't deal with function shape for now
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_var_compatibility(
|
||||
&mut self,
|
||||
id: u32,
|
||||
b: Type,
|
||||
range: &[Type],
|
||||
) -> Result<Option<Type>, String> {
|
||||
if range.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
for t in range.iter() {
|
||||
let result = self.get_intersection(*t, b);
|
||||
if let Ok(result) = result {
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
return Err(format!(
|
||||
"Cannot unify type variable {} with {} due to incompatible value range",
|
||||
id,
|
||||
self.get_ty(b).get_type_name()
|
||||
));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,534 @@
|
|||
use super::*;
|
||||
use indoc::indoc;
|
||||
use itertools::Itertools;
|
||||
use std::collections::HashMap;
|
||||
use test_case::test_case;
|
||||
|
||||
impl Unifier {
|
||||
/// Check whether two types are equal.
|
||||
fn eq(&mut self, a: Type, b: Type) -> bool {
|
||||
use TypeVarMeta::*;
|
||||
if a == b {
|
||||
return true;
|
||||
}
|
||||
let (ty_a, ty_b) = {
|
||||
let table = &mut self.unification_table;
|
||||
if table.unioned(a, b) {
|
||||
return true;
|
||||
}
|
||||
(table.probe_value(a).clone(), table.probe_value(b).clone())
|
||||
};
|
||||
|
||||
match (&*ty_a, &*ty_b) {
|
||||
(
|
||||
TypeEnum::TVar { meta: Generic, id: id1, .. },
|
||||
TypeEnum::TVar { meta: Generic, id: id2, .. },
|
||||
) => id1 == id2,
|
||||
(
|
||||
TypeEnum::TVar { meta: Sequence(map1), .. },
|
||||
TypeEnum::TVar { meta: Sequence(map2), .. },
|
||||
) => self.map_eq(&map1.borrow(), &map2.borrow()),
|
||||
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
|
||||
ty1.len() == ty2.len()
|
||||
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
|
||||
}
|
||||
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
|
||||
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
|
||||
self.eq(*ty1, *ty2)
|
||||
}
|
||||
(
|
||||
TypeEnum::TVar { meta: Record(fields1), .. },
|
||||
TypeEnum::TVar { meta: Record(fields2), .. },
|
||||
) => self.map_eq(&fields1.borrow(), &fields2.borrow()),
|
||||
(
|
||||
TypeEnum::TObj { obj_id: id1, params: params1, .. },
|
||||
TypeEnum::TObj { obj_id: id2, params: params2, .. },
|
||||
) => id1 == id2 && self.map_eq(¶ms1.borrow(), ¶ms2.borrow()),
|
||||
// TCall and TFunc are not yet implemented
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
{
|
||||
if map1.len() != map2.len() {
|
||||
return false;
|
||||
}
|
||||
for (k, v) in map1.iter() {
|
||||
if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
struct TestEnvironment {
|
||||
pub unifier: Unifier,
|
||||
type_mapping: HashMap<String, Type>,
|
||||
}
|
||||
|
||||
impl TestEnvironment {
|
||||
fn new() -> TestEnvironment {
|
||||
let mut unifier = Unifier::new();
|
||||
let mut type_mapping = HashMap::new();
|
||||
|
||||
type_mapping.insert(
|
||||
"int".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
}),
|
||||
);
|
||||
type_mapping.insert(
|
||||
"float".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
}),
|
||||
);
|
||||
type_mapping.insert(
|
||||
"bool".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
}),
|
||||
);
|
||||
let (v0, id) = unifier.get_fresh_var();
|
||||
type_mapping.insert(
|
||||
"Foo".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
}),
|
||||
);
|
||||
|
||||
TestEnvironment { unifier, type_mapping }
|
||||
}
|
||||
|
||||
fn parse(&mut self, typ: &str, mapping: &Mapping<String>) -> Type {
|
||||
let result = self.internal_parse(typ, mapping);
|
||||
assert!(result.1.is_empty());
|
||||
result.0
|
||||
}
|
||||
|
||||
fn internal_parse<'a, 'b>(
|
||||
&'a mut self,
|
||||
typ: &'b str,
|
||||
mapping: &Mapping<String>,
|
||||
) -> (Type, &'b str) {
|
||||
// for testing only, so we can just panic when the input is malformed
|
||||
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or_else(|| typ.len());
|
||||
match &typ[..end] {
|
||||
"Tuple" => {
|
||||
let mut s = &typ[end..];
|
||||
assert!(&s[0..1] == "[");
|
||||
let mut ty = Vec::new();
|
||||
while &s[0..1] != "]" {
|
||||
let result = self.internal_parse(&s[1..], mapping);
|
||||
ty.push(result.0);
|
||||
s = result.1;
|
||||
}
|
||||
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
|
||||
}
|
||||
"List" => {
|
||||
assert!(&typ[end..end + 1] == "[");
|
||||
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
|
||||
assert!(&s[0..1] == "]");
|
||||
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
|
||||
}
|
||||
"Record" => {
|
||||
let mut s = &typ[end..];
|
||||
assert!(&s[0..1] == "[");
|
||||
let mut fields = HashMap::new();
|
||||
while &s[0..1] != "]" {
|
||||
let eq = s.find('=').unwrap();
|
||||
let key = s[1..eq].to_string();
|
||||
let result = self.internal_parse(&s[eq + 1..], mapping);
|
||||
fields.insert(key, result.0);
|
||||
s = result.1;
|
||||
}
|
||||
(self.unifier.add_record(fields), &s[1..])
|
||||
}
|
||||
x => {
|
||||
let mut s = &typ[end..];
|
||||
let ty = mapping.get(x).cloned().unwrap_or_else(|| {
|
||||
// mapping should be type variables, type_mapping should be concrete types
|
||||
// we should not resolve the type of type variables.
|
||||
let mut ty = *self.type_mapping.get(x).unwrap();
|
||||
let te = self.unifier.get_ty(ty);
|
||||
if let TypeEnum::TObj { params, .. } = &*te.as_ref() {
|
||||
let params = params.borrow();
|
||||
if !params.is_empty() {
|
||||
assert!(&s[0..1] == "[");
|
||||
let mut p = Vec::new();
|
||||
while &s[0..1] != "]" {
|
||||
let result = self.internal_parse(&s[1..], mapping);
|
||||
p.push(result.0);
|
||||
s = result.1;
|
||||
}
|
||||
s = &s[1..];
|
||||
ty = self
|
||||
.unifier
|
||||
.subst(ty, ¶ms.keys().cloned().zip(p.into_iter()).collect())
|
||||
.unwrap_or(ty);
|
||||
}
|
||||
}
|
||||
ty
|
||||
});
|
||||
(ty, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test_case(2,
|
||||
&[("v1", "v2"), ("v2", "float")],
|
||||
&[("v1", "float"), ("v2", "float")]
|
||||
; "simple variable"
|
||||
)]
|
||||
#[test_case(2,
|
||||
&[("v1", "List[v2]"), ("v1", "List[float]")],
|
||||
&[("v1", "List[float]"), ("v2", "float")]
|
||||
; "list element"
|
||||
)]
|
||||
#[test_case(3,
|
||||
&[
|
||||
("v1", "Record[a=v3,b=v3]"),
|
||||
("v2", "Record[b=float,c=v3]"),
|
||||
("v1", "v2")
|
||||
],
|
||||
&[
|
||||
("v1", "Record[a=float,b=float,c=float]"),
|
||||
("v2", "Record[a=float,b=float,c=float]"),
|
||||
("v3", "float")
|
||||
]
|
||||
; "record merge"
|
||||
)]
|
||||
#[test_case(3,
|
||||
&[
|
||||
("v1", "Record[a=float]"),
|
||||
("v2", "Foo[v3]"),
|
||||
("v1", "v2")
|
||||
],
|
||||
&[
|
||||
("v1", "Foo[float]"),
|
||||
("v3", "float")
|
||||
]
|
||||
; "record obj merge"
|
||||
)]
|
||||
/// Test cases for valid unifications.
|
||||
fn test_unify(
|
||||
variable_count: u32,
|
||||
unify_pairs: &[(&'static str, &'static str)],
|
||||
verify_pairs: &[(&'static str, &'static str)],
|
||||
) {
|
||||
let unify_count = unify_pairs.len();
|
||||
// test all permutations...
|
||||
for perm in unify_pairs.iter().permutations(unify_count) {
|
||||
let mut env = TestEnvironment::new();
|
||||
let mut mapping = HashMap::new();
|
||||
for i in 1..=variable_count {
|
||||
let v = env.unifier.get_fresh_var();
|
||||
mapping.insert(format!("v{}", i), v.0);
|
||||
}
|
||||
// unification may have side effect when we do type resolution, so freeze the types
|
||||
// before doing unification.
|
||||
let mut pairs = Vec::new();
|
||||
for (a, b) in perm.iter() {
|
||||
let t1 = env.parse(a, &mapping);
|
||||
let t2 = env.parse(b, &mapping);
|
||||
pairs.push((t1, t2));
|
||||
}
|
||||
for (t1, t2) in pairs {
|
||||
env.unifier.unify(t1, t2).unwrap();
|
||||
}
|
||||
for (a, b) in verify_pairs.iter() {
|
||||
println!("{} = {}", a, b);
|
||||
let t1 = env.parse(a, &mapping);
|
||||
let t2 = env.parse(b, &mapping);
|
||||
assert!(env.unifier.eq(t1, t2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test_case(2,
|
||||
&[
|
||||
("v1", "Tuple[int]"),
|
||||
("v2", "List[int]"),
|
||||
],
|
||||
(("v1", "v2"), "Cannot unify TList with TTuple")
|
||||
; "type mismatch"
|
||||
)]
|
||||
#[test_case(2,
|
||||
&[
|
||||
("v1", "Tuple[int]"),
|
||||
("v2", "Tuple[float]"),
|
||||
],
|
||||
(("v1", "v2"), "Cannot unify objects with ID 0 and 1")
|
||||
; "tuple parameter mismatch"
|
||||
)]
|
||||
#[test_case(2,
|
||||
&[
|
||||
("v1", "Tuple[int,int]"),
|
||||
("v2", "Tuple[int]"),
|
||||
],
|
||||
(("v1", "v2"), "Cannot unify tuples with length 2 and 1")
|
||||
; "tuple length mismatch"
|
||||
)]
|
||||
#[test_case(3,
|
||||
&[
|
||||
("v1", "Record[a=float,b=int]"),
|
||||
("v2", "Foo[v3]"),
|
||||
],
|
||||
(("v1", "v2"), "No such attribute b")
|
||||
; "record obj merge"
|
||||
)]
|
||||
#[test_case(2,
|
||||
&[
|
||||
("v1", "List[v2]"),
|
||||
],
|
||||
(("v1", "v2"), "Recursive type is prohibited.")
|
||||
; "recursive type for lists"
|
||||
)]
|
||||
/// Test cases for invalid unifications.
|
||||
fn test_invalid_unification(
|
||||
variable_count: u32,
|
||||
unify_pairs: &[(&'static str, &'static str)],
|
||||
errornous_pair: ((&'static str, &'static str), &'static str),
|
||||
) {
|
||||
let mut env = TestEnvironment::new();
|
||||
let mut mapping = HashMap::new();
|
||||
for i in 1..=variable_count {
|
||||
let v = env.unifier.get_fresh_var();
|
||||
mapping.insert(format!("v{}", i), v.0);
|
||||
}
|
||||
// unification may have side effect when we do type resolution, so freeze the types
|
||||
// before doing unification.
|
||||
let mut pairs = Vec::new();
|
||||
for (a, b) in unify_pairs.iter() {
|
||||
let t1 = env.parse(a, &mapping);
|
||||
let t2 = env.parse(b, &mapping);
|
||||
pairs.push((t1, t2));
|
||||
}
|
||||
let (t1, t2) =
|
||||
(env.parse(errornous_pair.0 .0, &mapping), env.parse(errornous_pair.0 .1, &mapping));
|
||||
for (a, b) in pairs {
|
||||
env.unifier.unify(a, b).unwrap();
|
||||
}
|
||||
assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_virtual() {
|
||||
let mut env = TestEnvironment::new();
|
||||
let int = env.parse("int", &HashMap::new());
|
||||
let fun = env.unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature { args: vec![], ret: int, vars: HashMap::new() }.into(),
|
||||
));
|
||||
let bar = env.unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(5),
|
||||
fields: [("f".to_string(), fun), ("a".to_string(), int)]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<_, _>>()
|
||||
.into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
let v0 = env.unifier.get_fresh_var().0;
|
||||
let v1 = env.unifier.get_fresh_var().0;
|
||||
|
||||
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
|
||||
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
|
||||
let c = env.unifier.add_record([("f".to_string(), v1)].iter().cloned().collect());
|
||||
env.unifier.unify(a, b).unwrap();
|
||||
env.unifier.unify(b, c).unwrap();
|
||||
assert!(env.unifier.eq(v1, fun));
|
||||
|
||||
let d = env.unifier.add_record([("a".to_string(), v1)].iter().cloned().collect());
|
||||
assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string()));
|
||||
|
||||
let d = env.unifier.add_record([("b".to_string(), v1)].iter().cloned().collect());
|
||||
assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_typevar_range() {
|
||||
let mut env = TestEnvironment::new();
|
||||
let int = env.parse("int", &HashMap::new());
|
||||
let boolean = env.parse("bool", &HashMap::new());
|
||||
let float = env.parse("float", &HashMap::new());
|
||||
let int_list = env.parse("List[int]", &HashMap::new());
|
||||
let float_list = env.parse("List[float]", &HashMap::new());
|
||||
|
||||
// unification between v and int
|
||||
// where v in (int, bool)
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||
env.unifier.unify(int, v).unwrap();
|
||||
|
||||
// unification between v and List[int]
|
||||
// where v in (int, bool)
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||
assert_eq!(
|
||||
env.unifier.unify(int_list, v),
|
||||
Err("Cannot unify type variable 3 with TList due to incompatible value range".to_string())
|
||||
);
|
||||
|
||||
// unification between v and float
|
||||
// where v in (int, bool)
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||
assert_eq!(
|
||||
env.unifier.unify(float, v),
|
||||
Err("Cannot unify type variable 4 with TObj due to incompatible value range".to_string())
|
||||
);
|
||||
|
||||
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 });
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
|
||||
// unification between v and int
|
||||
// where v in (int, List[v1]), v1 in (int, bool)
|
||||
env.unifier.unify(int, v).unwrap();
|
||||
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
|
||||
// unification between v and List[int]
|
||||
// where v in (int, List[v1]), v1 in (int, bool)
|
||||
env.unifier.unify(int_list, v).unwrap();
|
||||
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
|
||||
// unification between v and List[float]
|
||||
// where v in (int, List[v1]), v1 in (int, bool)
|
||||
assert_eq!(
|
||||
env.unifier.unify(float_list, v),
|
||||
Err("Cannot unify type variable 8 with TList due to incompatible value range".to_string())
|
||||
);
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||
env.unifier.unify(a, b).unwrap();
|
||||
env.unifier.unify(a, float).unwrap();
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||
env.unifier.unify(a, b).unwrap();
|
||||
assert_eq!(
|
||||
env.unifier.unify(a, int),
|
||||
Err("Cannot unify type variable 12 with TObj due to incompatible value range".into())
|
||||
);
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
|
||||
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||
let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0;
|
||||
env.unifier.unify(a_list, b_list).unwrap();
|
||||
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
|
||||
env.unifier.unify(a_list, float_list).unwrap();
|
||||
// previous unifications should not affect a and b
|
||||
env.unifier.unify(a, int).unwrap();
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||
env.unifier.unify(a_list, b_list).unwrap();
|
||||
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int });
|
||||
assert_eq!(
|
||||
env.unifier.unify(a_list, int_list),
|
||||
Err("Cannot unify type variable 19 with TObj due to incompatible value range".into())
|
||||
);
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||
let b = env.unifier.get_fresh_var().0;
|
||||
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
|
||||
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||
env.unifier.unify(a_list, b_list).unwrap();
|
||||
assert_eq!(
|
||||
env.unifier.unify(b, boolean),
|
||||
Err("Cannot unify type variable 21 with TObj due to incompatible value range".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rigid_var() {
|
||||
let mut env = TestEnvironment::new();
|
||||
let a = env.unifier.get_fresh_rigid_var().0;
|
||||
let b = env.unifier.get_fresh_rigid_var().0;
|
||||
let x = env.unifier.get_fresh_var().0;
|
||||
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
|
||||
let int = env.parse("int", &HashMap::new());
|
||||
let list_int = env.parse("List[int]", &HashMap::new());
|
||||
|
||||
assert_eq!(env.unifier.unify(a, b), Err("Cannot unify TRigidVar with TRigidVar".to_string()));
|
||||
env.unifier.unify(list_a, list_x).unwrap();
|
||||
assert_eq!(
|
||||
env.unifier.unify(list_x, list_int),
|
||||
Err("Cannot unify TObj with TRigidVar".to_string())
|
||||
);
|
||||
|
||||
env.unifier.replace_rigid_var(a, int);
|
||||
env.unifier.unify(list_x, list_int).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_instantiation() {
|
||||
let mut env = TestEnvironment::new();
|
||||
let int = env.parse("int", &HashMap::new());
|
||||
let boolean = env.parse("bool", &HashMap::new());
|
||||
let float = env.parse("float", &HashMap::new());
|
||||
let list_int = env.parse("List[int]", &HashMap::new());
|
||||
|
||||
let obj_map: HashMap<_, _> =
|
||||
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
|
||||
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
|
||||
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int]).0;
|
||||
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float]).0;
|
||||
let t = env.unifier.get_fresh_rigid_var().0;
|
||||
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
|
||||
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t]).0;
|
||||
// t = TypeVar('t')
|
||||
// v = TypeVar('v', int, bool)
|
||||
// v1 = TypeVar('v1', 'list[v]', int)
|
||||
// v2 = TypeVar('v2', 'list[int]', float)
|
||||
// v3 = TypeVar('v3', tuple[v, v1, v2], t)
|
||||
// what values can v3 take?
|
||||
|
||||
let types = env.unifier.get_instantiations(v3).unwrap();
|
||||
let expected_types = indoc! {"
|
||||
tuple[bool, int, float]
|
||||
tuple[bool, int, list[int]]
|
||||
tuple[bool, list[bool], float]
|
||||
tuple[bool, list[bool], list[int]]
|
||||
tuple[bool, list[int], float]
|
||||
tuple[bool, list[int], list[int]]
|
||||
tuple[int, int, float]
|
||||
tuple[int, int, list[int]]
|
||||
tuple[int, list[bool], float]
|
||||
tuple[int, list[bool], list[int]]
|
||||
tuple[int, list[int], float]
|
||||
tuple[int, list[int], list[int]]
|
||||
v5"
|
||||
}
|
||||
.split('\n')
|
||||
.collect_vec();
|
||||
let types = types
|
||||
.iter()
|
||||
.map(|ty| {
|
||||
env.unifier.stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| {
|
||||
format!("v{}", i)
|
||||
})
|
||||
})
|
||||
.sorted()
|
||||
.collect_vec();
|
||||
assert_eq!(expected_types, types);
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
use std::rc::Rc;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
|
||||
pub struct UnificationKey(usize);
|
||||
|
||||
pub struct UnificationTable<V> {
|
||||
parents: Vec<usize>,
|
||||
ranks: Vec<u32>,
|
||||
values: Vec<V>,
|
||||
}
|
||||
|
||||
impl<V> UnificationTable<V> {
|
||||
pub fn new() -> UnificationTable<V> {
|
||||
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn new_key(&mut self, v: V) -> UnificationKey {
|
||||
let index = self.parents.len();
|
||||
self.parents.push(index);
|
||||
self.ranks.push(0);
|
||||
self.values.push(v);
|
||||
UnificationKey(index)
|
||||
}
|
||||
|
||||
pub fn unify(&mut self, a: UnificationKey, b: UnificationKey) {
|
||||
let mut a = self.find(a);
|
||||
let mut b = self.find(b);
|
||||
if a == b {
|
||||
return;
|
||||
}
|
||||
if self.ranks[a] < self.ranks[b] {
|
||||
std::mem::swap(&mut a, &mut b);
|
||||
}
|
||||
self.parents[b] = a;
|
||||
if self.ranks[a] == self.ranks[b] {
|
||||
self.ranks[a] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn probe_value(&mut self, a: UnificationKey) -> &V {
|
||||
let index = self.find(a);
|
||||
&self.values[index]
|
||||
}
|
||||
|
||||
pub fn set_value(&mut self, a: UnificationKey, v: V) {
|
||||
let index = self.find(a);
|
||||
self.values[index] = v;
|
||||
}
|
||||
|
||||
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
|
||||
self.find(a) == self.find(b)
|
||||
}
|
||||
|
||||
pub fn get_representative(&mut self, key: UnificationKey) -> UnificationKey {
|
||||
UnificationKey(self.find(key))
|
||||
}
|
||||
|
||||
fn find(&mut self, key: UnificationKey) -> usize {
|
||||
let mut root = key.0;
|
||||
let mut parent = self.parents[root];
|
||||
while root != parent {
|
||||
// a = parent.parent
|
||||
let a = self.parents[parent];
|
||||
// root.parent = parent.parent
|
||||
self.parents[root] = a;
|
||||
root = parent;
|
||||
// parent = root.parent
|
||||
parent = a;
|
||||
}
|
||||
parent
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> UnificationTable<Rc<V>>
|
||||
where
|
||||
V: Clone,
|
||||
{
|
||||
pub fn get_send(&self) -> UnificationTable<V> {
|
||||
let values = self.values.iter().map(|v| v.as_ref().clone()).collect();
|
||||
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values }
|
||||
}
|
||||
|
||||
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
|
||||
let values = table.values.iter().cloned().map(Rc::new).collect();
|
||||
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values }
|
||||
}
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||
pub struct PrimitiveId(pub(crate) usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||
pub struct ClassId(pub(crate) usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||
pub struct ParamId(pub(crate) usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||
pub struct VariableId(pub(crate) usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Hash, Debug)]
|
||||
pub enum TypeEnum {
|
||||
BotType,
|
||||
SelfType,
|
||||
PrimitiveType(PrimitiveId),
|
||||
ClassType(ClassId),
|
||||
VirtualClassType(ClassId),
|
||||
ParametricType(ParamId, Vec<Rc<TypeEnum>>),
|
||||
TypeVariable(VariableId),
|
||||
}
|
||||
|
||||
pub type Type = Rc<TypeEnum>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FnDef {
|
||||
// we assume methods first argument to be SelfType,
|
||||
// so the first argument is not contained here
|
||||
pub args: Vec<Type>,
|
||||
pub result: Option<Type>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TypeDef<'a> {
|
||||
pub name: &'a str,
|
||||
pub fields: HashMap<&'a str, Type>,
|
||||
pub methods: HashMap<&'a str, FnDef>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClassDef<'a> {
|
||||
pub base: TypeDef<'a>,
|
||||
pub parents: Vec<ClassId>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ParametricDef<'a> {
|
||||
pub base: TypeDef<'a>,
|
||||
pub params: Vec<VariableId>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VarDef<'a> {
|
||||
pub name: &'a str,
|
||||
pub bound: Vec<Type>,
|
||||
}
|
Loading…
Reference in New Issue