forked from M-Labs/nac3
Merge pull request 'hm-inference' (#6) from hm-inference into master
Reviewed-on: M-Labs/nac3#6
This commit is contained in:
commit
f205a8282a
668
Cargo.lock
generated
668
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
20
README.md
20
README.md
@ -15,20 +15,20 @@ caller to specify which methods should be compiled). After type checking, the
|
|||||||
compiler would analyse the set of functions/classes that are used and perform
|
compiler would analyse the set of functions/classes that are used and perform
|
||||||
code generation.
|
code generation.
|
||||||
|
|
||||||
|
|
||||||
Symbol resolver:
|
|
||||||
- Str -> Nac3Type
|
|
||||||
- Str -> Value
|
|
||||||
|
|
||||||
value could be integer values, boolean values, bytes (for memcpy), function ID
|
value could be integer values, boolean values, bytes (for memcpy), function ID
|
||||||
(full name + concrete type)
|
(full name + concrete type)
|
||||||
|
|
||||||
## Current Plan
|
## Current Plan
|
||||||
|
|
||||||
1. Write out the syntax-directed type checking/inferencing rules. Fix the rule
|
Type checking:
|
||||||
for type variable instantiation.
|
|
||||||
2. Update the library dependencies and rewrite some of the type checking code.
|
- [x] Basic interface for symbol resolver.
|
||||||
3. Design the symbol resolver API.
|
- [x] Track location information in context object (for diagnostics).
|
||||||
4. Move tests from code to external files to cleanup the code.
|
- [ ] Refactor old expression and statement type inference code. (anto)
|
||||||
|
- [ ] Error diagnostics utilities. (pca)
|
||||||
|
- [ ] Move tests to external files, write scripts for testing. (pca)
|
||||||
|
- [ ] Implement function type checking (instantiate bounded type parameters),
|
||||||
|
loop unrolling, type inference for lists with virtual objects. (pca)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,5 +7,13 @@ edition = "2018"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
num-bigint = "0.3"
|
num-bigint = "0.3"
|
||||||
num-traits = "0.2"
|
num-traits = "0.2"
|
||||||
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] }
|
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
|
||||||
rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" }
|
rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" }
|
||||||
|
itertools = "0.10.1"
|
||||||
|
crossbeam = "0.8.1"
|
||||||
|
parking_lot = "0.11.1"
|
||||||
|
rayon = "1.5.1"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
test-case = "1.2.0"
|
||||||
|
indoc = "1.0"
|
||||||
|
1
nac3core/rustfmt.toml
Normal file
1
nac3core/rustfmt.toml
Normal file
@ -0,0 +1 @@
|
|||||||
|
use_small_heuristics = "Max"
|
527
nac3core/src/codegen/expr.rs
Normal file
527
nac3core/src/codegen/expr.rs
Normal file
@ -0,0 +1,527 @@
|
|||||||
|
use std::{collections::HashMap, convert::TryInto, iter::once};
|
||||||
|
|
||||||
|
use super::{get_llvm_type, CodeGenContext};
|
||||||
|
use crate::{
|
||||||
|
symbol_resolver::SymbolValue,
|
||||||
|
top_level::{DefinitionId, TopLevelDef},
|
||||||
|
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||||
|
};
|
||||||
|
use inkwell::{
|
||||||
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
values::BasicValueEnum,
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use itertools::{chain, izip, zip, Itertools};
|
||||||
|
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
|
||||||
|
|
||||||
|
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
|
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
|
||||||
|
let mut vars = obj
|
||||||
|
.map(|ty| {
|
||||||
|
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
|
||||||
|
params.borrow().clone()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
vars.extend(fun.vars.iter());
|
||||||
|
let sorted = vars.keys().sorted();
|
||||||
|
sorted
|
||||||
|
.map(|id| {
|
||||||
|
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
|
||||||
|
})
|
||||||
|
.join(", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize {
|
||||||
|
let obj_id = match &*self.unifier.get_ty(ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. } => *obj_id,
|
||||||
|
// we cannot have other types, virtual type should be handled by function calls
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let def = &self.top_level.definitions.read()[obj_id.0];
|
||||||
|
let index = if let TopLevelDef::Class { fields, .. } = &*def.read() {
|
||||||
|
fields.iter().find_position(|x| x.0 == attr).unwrap().0
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
index
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> {
|
||||||
|
match val {
|
||||||
|
SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(),
|
||||||
|
SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(),
|
||||||
|
SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(),
|
||||||
|
SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(),
|
||||||
|
SymbolValue::Tuple(ls) => {
|
||||||
|
let vals = ls.iter().map(|v| self.gen_symbol_val(v)).collect_vec();
|
||||||
|
let fields = vals.iter().map(|v| v.get_type()).collect_vec();
|
||||||
|
let ty = self.ctx.struct_type(&fields, false);
|
||||||
|
let ptr = self.builder.build_alloca(ty, "tuple");
|
||||||
|
let zero = self.ctx.i32_type().const_zero();
|
||||||
|
unsafe {
|
||||||
|
for (i, val) in vals.into_iter().enumerate() {
|
||||||
|
let p = ptr.const_in_bounds_gep(&[
|
||||||
|
zero,
|
||||||
|
self.ctx.i32_type().const_int(i as u64, false),
|
||||||
|
]);
|
||||||
|
self.builder.build_store(p, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ptr.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> {
|
||||||
|
get_llvm_type(self.ctx, &mut self.unifier, self.top_level, &mut self.type_cache, ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_call(
|
||||||
|
&mut self,
|
||||||
|
obj: Option<(Type, BasicValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
|
||||||
|
ret: Type,
|
||||||
|
) -> Option<BasicValueEnum<'ctx>> {
|
||||||
|
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
|
||||||
|
let defs = self.top_level.definitions.read();
|
||||||
|
let definition = defs.get(fun.1 .0).unwrap();
|
||||||
|
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
|
||||||
|
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
|
||||||
|
// TODO: codegen for function that are not yet generated
|
||||||
|
unimplemented!()
|
||||||
|
});
|
||||||
|
let fun_val = self.module.get_function(symbol).unwrap_or_else(|| {
|
||||||
|
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
|
||||||
|
let fun_ty = if self.unifier.unioned(ret, self.primitives.none) {
|
||||||
|
self.ctx.void_type().fn_type(¶ms, false)
|
||||||
|
} else {
|
||||||
|
self.get_llvm_type(ret).fn_type(¶ms, false)
|
||||||
|
};
|
||||||
|
self.module.add_function(symbol, fun_ty, None)
|
||||||
|
});
|
||||||
|
let mut keys = fun.0.args.clone();
|
||||||
|
let mut mapping = HashMap::new();
|
||||||
|
for (key, value) in params.into_iter() {
|
||||||
|
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||||
|
}
|
||||||
|
// default value handling
|
||||||
|
for k in keys.into_iter() {
|
||||||
|
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
|
||||||
|
}
|
||||||
|
// reorder the parameters
|
||||||
|
let params =
|
||||||
|
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
|
||||||
|
self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
val
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> {
|
||||||
|
match value {
|
||||||
|
Constant::Bool(v) => {
|
||||||
|
assert!(self.unifier.unioned(ty, self.primitives.bool));
|
||||||
|
let ty = self.ctx.bool_type();
|
||||||
|
ty.const_int(if *v { 1 } else { 0 }, false).into()
|
||||||
|
}
|
||||||
|
Constant::Int(v) => {
|
||||||
|
let ty = if self.unifier.unioned(ty, self.primitives.int32) {
|
||||||
|
self.ctx.i32_type()
|
||||||
|
} else if self.unifier.unioned(ty, self.primitives.int64) {
|
||||||
|
self.ctx.i64_type()
|
||||||
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
};
|
||||||
|
ty.const_int(v.try_into().unwrap(), false).into()
|
||||||
|
}
|
||||||
|
Constant::Float(v) => {
|
||||||
|
assert!(self.unifier.unioned(ty, self.primitives.float));
|
||||||
|
let ty = self.ctx.f64_type();
|
||||||
|
ty.const_float(*v).into()
|
||||||
|
}
|
||||||
|
Constant::Tuple(v) => {
|
||||||
|
let ty = self.unifier.get_ty(ty);
|
||||||
|
let types =
|
||||||
|
if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() };
|
||||||
|
let values = zip(types.into_iter(), v.iter())
|
||||||
|
.map(|(ty, v)| self.gen_const(v, ty))
|
||||||
|
.collect_vec();
|
||||||
|
let types = values.iter().map(BasicValueEnum::get_type).collect_vec();
|
||||||
|
let ty = self.ctx.struct_type(&types, false);
|
||||||
|
ty.const_named_struct(&values).into()
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_int_ops(
|
||||||
|
&mut self,
|
||||||
|
op: &Operator,
|
||||||
|
lhs: BasicValueEnum<'ctx>,
|
||||||
|
rhs: BasicValueEnum<'ctx>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
let (lhs, rhs) =
|
||||||
|
if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) {
|
||||||
|
(lhs, rhs)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
match op {
|
||||||
|
Operator::Add => self.builder.build_int_add(lhs, rhs, "add").into(),
|
||||||
|
Operator::Sub => self.builder.build_int_sub(lhs, rhs, "sub").into(),
|
||||||
|
Operator::Mult => self.builder.build_int_mul(lhs, rhs, "mul").into(),
|
||||||
|
Operator::Div => {
|
||||||
|
let float = self.ctx.f64_type();
|
||||||
|
let left = self.builder.build_signed_int_to_float(lhs, float, "i2f");
|
||||||
|
let right = self.builder.build_signed_int_to_float(rhs, float, "i2f");
|
||||||
|
self.builder.build_float_div(left, right, "fdiv").into()
|
||||||
|
}
|
||||||
|
Operator::Mod => self.builder.build_int_signed_rem(lhs, rhs, "mod").into(),
|
||||||
|
Operator::BitOr => self.builder.build_or(lhs, rhs, "or").into(),
|
||||||
|
Operator::BitXor => self.builder.build_xor(lhs, rhs, "xor").into(),
|
||||||
|
Operator::BitAnd => self.builder.build_and(lhs, rhs, "and").into(),
|
||||||
|
Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(),
|
||||||
|
Operator::RShift => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(),
|
||||||
|
Operator::FloorDiv => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(),
|
||||||
|
// special implementation?
|
||||||
|
Operator::Pow => unimplemented!(),
|
||||||
|
Operator::MatMult => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_float_ops(
|
||||||
|
&mut self,
|
||||||
|
op: &Operator,
|
||||||
|
lhs: BasicValueEnum<'ctx>,
|
||||||
|
rhs: BasicValueEnum<'ctx>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
let (lhs, rhs) = if let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) =
|
||||||
|
(lhs, rhs)
|
||||||
|
{
|
||||||
|
(lhs, rhs)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
match op {
|
||||||
|
Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").into(),
|
||||||
|
Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").into(),
|
||||||
|
Operator::Mult => self.builder.build_float_mul(lhs, rhs, "fmul").into(),
|
||||||
|
Operator::Div => self.builder.build_float_div(lhs, rhs, "fdiv").into(),
|
||||||
|
Operator::Mod => self.builder.build_float_rem(lhs, rhs, "fmod").into(),
|
||||||
|
Operator::FloorDiv => {
|
||||||
|
let div = self.builder.build_float_div(lhs, rhs, "fdiv");
|
||||||
|
let floor_intrinsic =
|
||||||
|
self.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
|
||||||
|
let float = self.ctx.f64_type();
|
||||||
|
let fn_type = float.fn_type(&[float.into()], false);
|
||||||
|
self.module.add_function("llvm.floor.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
self.builder
|
||||||
|
.build_call(floor_intrinsic, &[div.into()], "floor")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
// special implementation?
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gen_expr(&mut self, expr: &Expr<Option<Type>>) -> BasicValueEnum<'ctx> {
|
||||||
|
let zero = self.ctx.i32_type().const_int(0, false);
|
||||||
|
match &expr.node {
|
||||||
|
ExprKind::Constant { value, .. } => {
|
||||||
|
let ty = expr.custom.unwrap();
|
||||||
|
self.gen_const(value, ty)
|
||||||
|
}
|
||||||
|
ExprKind::Name { id, .. } => {
|
||||||
|
let ptr = self.var_assignment.get(id).unwrap();
|
||||||
|
let primitives = &self.primitives;
|
||||||
|
// we should only dereference primitive types
|
||||||
|
if [primitives.int32, primitives.int64, primitives.float, primitives.bool]
|
||||||
|
.contains(&self.unifier.get_representative(expr.custom.unwrap()))
|
||||||
|
{
|
||||||
|
self.builder.build_load(*ptr, "load")
|
||||||
|
} else {
|
||||||
|
(*ptr).into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::List { elts, .. } => {
|
||||||
|
// this shall be optimized later for constant primitive lists...
|
||||||
|
// we should use memcpy for that instead of generating thousands of stores
|
||||||
|
let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
||||||
|
let ty = if elements.is_empty() {
|
||||||
|
self.ctx.i32_type().into()
|
||||||
|
} else {
|
||||||
|
elements[0].get_type()
|
||||||
|
};
|
||||||
|
let arr_ptr = self.builder.build_array_alloca(
|
||||||
|
ty,
|
||||||
|
self.ctx.i32_type().const_int(elements.len() as u64, false),
|
||||||
|
"tmparr",
|
||||||
|
);
|
||||||
|
let arr_ty = self.ctx.struct_type(
|
||||||
|
&[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
|
||||||
|
unsafe {
|
||||||
|
self.builder.build_store(
|
||||||
|
arr_str_ptr.const_in_bounds_gep(&[zero, zero]),
|
||||||
|
self.ctx.i32_type().const_int(elements.len() as u64, false),
|
||||||
|
);
|
||||||
|
self.builder.build_store(
|
||||||
|
arr_str_ptr
|
||||||
|
.const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(1, false)]),
|
||||||
|
arr_ptr,
|
||||||
|
);
|
||||||
|
let arr_offset = self.ctx.i32_type().const_int(1, false);
|
||||||
|
for (i, v) in elements.iter().enumerate() {
|
||||||
|
let ptr = self.builder.build_in_bounds_gep(
|
||||||
|
arr_ptr,
|
||||||
|
&[zero, arr_offset, self.ctx.i32_type().const_int(i as u64, false)],
|
||||||
|
"arr_element",
|
||||||
|
);
|
||||||
|
self.builder.build_store(ptr, *v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
arr_str_ptr.into()
|
||||||
|
}
|
||||||
|
ExprKind::Tuple { elts, .. } => {
|
||||||
|
let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec();
|
||||||
|
let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec();
|
||||||
|
let tuple_ty = self.ctx.struct_type(&element_ty, false);
|
||||||
|
let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple");
|
||||||
|
for (i, v) in element_val.into_iter().enumerate() {
|
||||||
|
unsafe {
|
||||||
|
let ptr = tuple_ptr.const_in_bounds_gep(&[
|
||||||
|
zero,
|
||||||
|
self.ctx.i32_type().const_int(i as u64, false),
|
||||||
|
]);
|
||||||
|
self.builder.build_store(ptr, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tuple_ptr.into()
|
||||||
|
}
|
||||||
|
ExprKind::Attribute { value, attr, .. } => {
|
||||||
|
// note that we would handle class methods directly in calls
|
||||||
|
let index = self.get_attr_index(value.custom.unwrap(), attr);
|
||||||
|
let val = self.gen_expr(value);
|
||||||
|
let ptr = if let BasicValueEnum::PointerValue(v) = val {
|
||||||
|
v
|
||||||
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
};
|
||||||
|
unsafe {
|
||||||
|
let ptr = ptr.const_in_bounds_gep(&[
|
||||||
|
zero,
|
||||||
|
self.ctx.i32_type().const_int(index as u64, false),
|
||||||
|
]);
|
||||||
|
self.builder.build_load(ptr, "field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::BoolOp { op, values } => {
|
||||||
|
// requires conditional branches for short-circuiting...
|
||||||
|
let left = if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]) {
|
||||||
|
left
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||||
|
let a_bb = self.ctx.append_basic_block(current, "a");
|
||||||
|
let b_bb = self.ctx.append_basic_block(current, "b");
|
||||||
|
let cont_bb = self.ctx.append_basic_block(current, "cont");
|
||||||
|
self.builder.build_conditional_branch(left, a_bb, b_bb);
|
||||||
|
let (a, b) = match op {
|
||||||
|
Boolop::Or => {
|
||||||
|
self.builder.position_at_end(a_bb);
|
||||||
|
let a = self.ctx.bool_type().const_int(1, false);
|
||||||
|
self.builder.build_unconditional_branch(cont_bb);
|
||||||
|
self.builder.position_at_end(b_bb);
|
||||||
|
let b = if let BasicValueEnum::IntValue(b) = self.gen_expr(&values[1]) {
|
||||||
|
b
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
self.builder.build_unconditional_branch(cont_bb);
|
||||||
|
(a, b)
|
||||||
|
}
|
||||||
|
Boolop::And => {
|
||||||
|
self.builder.position_at_end(a_bb);
|
||||||
|
let a = if let BasicValueEnum::IntValue(a) = self.gen_expr(&values[1]) {
|
||||||
|
a
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
self.builder.build_unconditional_branch(cont_bb);
|
||||||
|
self.builder.position_at_end(b_bb);
|
||||||
|
let b = self.ctx.bool_type().const_int(0, false);
|
||||||
|
self.builder.build_unconditional_branch(cont_bb);
|
||||||
|
(a, b)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.builder.position_at_end(cont_bb);
|
||||||
|
let phi = self.builder.build_phi(self.ctx.bool_type(), "phi");
|
||||||
|
phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]);
|
||||||
|
phi.as_basic_value()
|
||||||
|
}
|
||||||
|
ExprKind::BinOp { op, left, right } => {
|
||||||
|
let ty1 = self.unifier.get_representative(left.custom.unwrap());
|
||||||
|
let ty2 = self.unifier.get_representative(right.custom.unwrap());
|
||||||
|
let left = self.gen_expr(left);
|
||||||
|
let right = self.gen_expr(right);
|
||||||
|
|
||||||
|
// we can directly compare the types, because we've got their representatives
|
||||||
|
// which would be unchanged until further unification, which we would never do
|
||||||
|
// when doing code generation for function instances
|
||||||
|
if ty1 == ty2 && [self.primitives.int32, self.primitives.int64].contains(&ty1) {
|
||||||
|
self.gen_int_ops(op, left, right)
|
||||||
|
} else if ty1 == ty2 && self.primitives.float == ty1 {
|
||||||
|
self.gen_float_ops(op, left, right)
|
||||||
|
} else {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::UnaryOp { op, operand } => {
|
||||||
|
let ty = self.unifier.get_representative(operand.custom.unwrap());
|
||||||
|
let val = self.gen_expr(operand);
|
||||||
|
if ty == self.primitives.bool {
|
||||||
|
let val =
|
||||||
|
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
||||||
|
match op {
|
||||||
|
ast::Unaryop::Invert | ast::Unaryop::Not => {
|
||||||
|
self.builder.build_not(val, "not").into()
|
||||||
|
}
|
||||||
|
_ => val.into(),
|
||||||
|
}
|
||||||
|
} else if [self.primitives.int32, self.primitives.int64].contains(&ty) {
|
||||||
|
let val =
|
||||||
|
if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() };
|
||||||
|
match op {
|
||||||
|
ast::Unaryop::USub => self.builder.build_int_neg(val, "neg").into(),
|
||||||
|
ast::Unaryop::Invert => self.builder.build_not(val, "not").into(),
|
||||||
|
ast::Unaryop::Not => self
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
inkwell::IntPredicate::EQ,
|
||||||
|
val,
|
||||||
|
val.get_type().const_zero(),
|
||||||
|
"not",
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
_ => val.into(),
|
||||||
|
}
|
||||||
|
} else if ty == self.primitives.float {
|
||||||
|
let val = if let BasicValueEnum::FloatValue(val) = val {
|
||||||
|
val
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
match op {
|
||||||
|
ast::Unaryop::USub => self.builder.build_float_neg(val, "neg").into(),
|
||||||
|
ast::Unaryop::Not => self
|
||||||
|
.builder
|
||||||
|
.build_float_compare(
|
||||||
|
inkwell::FloatPredicate::OEQ,
|
||||||
|
val,
|
||||||
|
val.get_type().const_zero(),
|
||||||
|
"not",
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
_ => val.into(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Compare { left, ops, comparators } => {
|
||||||
|
izip!(
|
||||||
|
chain(once(left.as_ref()), comparators.iter()),
|
||||||
|
comparators.iter(),
|
||||||
|
ops.iter(),
|
||||||
|
)
|
||||||
|
.fold(None, |prev, (lhs, rhs, op)| {
|
||||||
|
let ty = self.unifier.get_representative(lhs.custom.unwrap());
|
||||||
|
let current =
|
||||||
|
if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
|
||||||
|
.contains(&ty)
|
||||||
|
{
|
||||||
|
let (lhs, rhs) = if let (
|
||||||
|
BasicValueEnum::IntValue(lhs),
|
||||||
|
BasicValueEnum::IntValue(rhs),
|
||||||
|
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
||||||
|
{
|
||||||
|
(lhs, rhs)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let op = match op {
|
||||||
|
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
|
||||||
|
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
|
||||||
|
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
|
||||||
|
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
|
||||||
|
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
|
||||||
|
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
self.builder.build_int_compare(op, lhs, rhs, "cmp")
|
||||||
|
} else if ty == self.primitives.float {
|
||||||
|
let (lhs, rhs) = if let (
|
||||||
|
BasicValueEnum::FloatValue(lhs),
|
||||||
|
BasicValueEnum::FloatValue(rhs),
|
||||||
|
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
||||||
|
{
|
||||||
|
(lhs, rhs)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let op = match op {
|
||||||
|
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
|
||||||
|
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
|
||||||
|
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
|
||||||
|
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
|
||||||
|
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
|
||||||
|
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
self.builder.build_float_compare(op, lhs, rhs, "cmp")
|
||||||
|
} else {
|
||||||
|
unimplemented!()
|
||||||
|
};
|
||||||
|
prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current))
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.into() // as there should be at least 1 element, it should never be none
|
||||||
|
}
|
||||||
|
ExprKind::IfExp { test, body, orelse } => {
|
||||||
|
let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test) {
|
||||||
|
test
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||||
|
let then_bb = self.ctx.append_basic_block(current, "then");
|
||||||
|
let else_bb = self.ctx.append_basic_block(current, "else");
|
||||||
|
let cont_bb = self.ctx.append_basic_block(current, "cont");
|
||||||
|
self.builder.build_conditional_branch(test, then_bb, else_bb);
|
||||||
|
self.builder.position_at_end(then_bb);
|
||||||
|
let a = self.gen_expr(body);
|
||||||
|
self.builder.build_unconditional_branch(cont_bb);
|
||||||
|
self.builder.position_at_end(else_bb);
|
||||||
|
let b = self.gen_expr(orelse);
|
||||||
|
self.builder.build_unconditional_branch(cont_bb);
|
||||||
|
self.builder.position_at_end(cont_bb);
|
||||||
|
let phi = self.builder.build_phi(a.get_type(), "ifexpr");
|
||||||
|
phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]);
|
||||||
|
phi.as_basic_value()
|
||||||
|
}
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
343
nac3core/src/codegen/mod.rs
Normal file
343
nac3core/src/codegen/mod.rs
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
use crate::{
|
||||||
|
symbol_resolver::SymbolResolver,
|
||||||
|
top_level::{TopLevelContext, TopLevelDef},
|
||||||
|
typecheck::{
|
||||||
|
type_inferencer::PrimitiveStore,
|
||||||
|
typedef::{FunSignature, Type, TypeEnum, Unifier},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use crossbeam::channel::{unbounded, Receiver, Sender};
|
||||||
|
use inkwell::{
|
||||||
|
basic_block::BasicBlock,
|
||||||
|
builder::Builder,
|
||||||
|
context::Context,
|
||||||
|
module::Module,
|
||||||
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
values::PointerValue,
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
use parking_lot::{Condvar, Mutex};
|
||||||
|
use rustpython_parser::ast::Stmt;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
mod expr;
|
||||||
|
mod stmt;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test;
|
||||||
|
|
||||||
|
pub struct CodeGenContext<'ctx, 'a> {
|
||||||
|
pub ctx: &'ctx Context,
|
||||||
|
pub builder: Builder<'ctx>,
|
||||||
|
pub module: Module<'ctx>,
|
||||||
|
pub top_level: &'a TopLevelContext,
|
||||||
|
pub unifier: Unifier,
|
||||||
|
pub resolver: Arc<dyn SymbolResolver>,
|
||||||
|
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
|
||||||
|
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||||
|
pub primitives: PrimitiveStore,
|
||||||
|
// stores the alloca for variables
|
||||||
|
pub init_bb: BasicBlock<'ctx>,
|
||||||
|
// where continue and break should go to respectively
|
||||||
|
// the first one is the test_bb, and the second one is bb after the loop
|
||||||
|
pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
type Fp = Box<dyn Fn(&Module) + Send + Sync>;
|
||||||
|
|
||||||
|
pub struct WithCall {
|
||||||
|
fp: Fp,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WithCall {
|
||||||
|
pub fn new(fp: Fp) -> WithCall {
|
||||||
|
WithCall { fp }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run<'ctx>(&self, m: &Module<'ctx>) {
|
||||||
|
(self.fp)(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WorkerRegistry {
|
||||||
|
sender: Arc<Sender<Option<CodeGenTask>>>,
|
||||||
|
receiver: Arc<Receiver<Option<CodeGenTask>>>,
|
||||||
|
panicked: AtomicBool,
|
||||||
|
task_count: Mutex<usize>,
|
||||||
|
thread_count: usize,
|
||||||
|
wait_condvar: Condvar,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerRegistry {
|
||||||
|
pub fn create_workers(
|
||||||
|
names: &[&str],
|
||||||
|
top_level_ctx: Arc<TopLevelContext>,
|
||||||
|
f: Arc<WithCall>,
|
||||||
|
) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
|
||||||
|
let (sender, receiver) = unbounded();
|
||||||
|
let task_count = Mutex::new(0);
|
||||||
|
let wait_condvar = Condvar::new();
|
||||||
|
|
||||||
|
let registry = Arc::new(WorkerRegistry {
|
||||||
|
sender: Arc::new(sender),
|
||||||
|
receiver: Arc::new(receiver),
|
||||||
|
thread_count: names.len(),
|
||||||
|
panicked: AtomicBool::new(false),
|
||||||
|
task_count,
|
||||||
|
wait_condvar,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for name in names.iter() {
|
||||||
|
let top_level_ctx = top_level_ctx.clone();
|
||||||
|
let registry = registry.clone();
|
||||||
|
let registry2 = registry.clone();
|
||||||
|
let name = name.to_string();
|
||||||
|
let f = f.clone();
|
||||||
|
let handle = thread::spawn(move || {
|
||||||
|
registry.worker_thread(name, top_level_ctx, f);
|
||||||
|
});
|
||||||
|
let handle = thread::spawn(move || {
|
||||||
|
if let Err(e) = handle.join() {
|
||||||
|
if let Some(e) = e.downcast_ref::<&'static str>() {
|
||||||
|
eprintln!("Got an error: {}", e);
|
||||||
|
} else {
|
||||||
|
eprintln!("Got an unknown error: {:?}", e);
|
||||||
|
}
|
||||||
|
registry2.panicked.store(true, Ordering::SeqCst);
|
||||||
|
registry2.wait_condvar.notify_all();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
(registry, handles)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait_tasks_complete(&self, handles: Vec<thread::JoinHandle<()>>) {
|
||||||
|
{
|
||||||
|
let mut count = self.task_count.lock();
|
||||||
|
while *count != 0 {
|
||||||
|
if self.panicked.load(Ordering::SeqCst) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
self.wait_condvar.wait(&mut count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _ in 0..self.thread_count {
|
||||||
|
self.sender.send(None).unwrap();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let mut count = self.task_count.lock();
|
||||||
|
while *count != self.thread_count {
|
||||||
|
if self.panicked.load(Ordering::SeqCst) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
self.wait_condvar.wait(&mut count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for handle in handles {
|
||||||
|
handle.join().unwrap();
|
||||||
|
}
|
||||||
|
if self.panicked.load(Ordering::SeqCst) {
|
||||||
|
panic!("tasks panicked");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_task(&self, task: CodeGenTask) {
|
||||||
|
*self.task_count.lock() += 1;
|
||||||
|
self.sender.send(Some(task)).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn worker_thread(
|
||||||
|
&self,
|
||||||
|
module_name: String,
|
||||||
|
top_level_ctx: Arc<TopLevelContext>,
|
||||||
|
f: Arc<WithCall>,
|
||||||
|
) {
|
||||||
|
let context = Context::create();
|
||||||
|
let mut builder = context.create_builder();
|
||||||
|
let mut module = context.create_module(&module_name);
|
||||||
|
|
||||||
|
while let Some(task) = self.receiver.recv().unwrap() {
|
||||||
|
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
|
||||||
|
builder = result.0;
|
||||||
|
module = result.1;
|
||||||
|
*self.task_count.lock() -= 1;
|
||||||
|
self.wait_condvar.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
// do whatever...
|
||||||
|
let mut lock = self.task_count.lock();
|
||||||
|
module.verify().unwrap();
|
||||||
|
f.run(&module);
|
||||||
|
*lock += 1;
|
||||||
|
self.wait_condvar.notify_all();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CodeGenTask {
|
||||||
|
pub subst: Vec<(Type, Type)>,
|
||||||
|
pub symbol_name: String,
|
||||||
|
pub signature: FunSignature,
|
||||||
|
pub body: Vec<Stmt<Option<Type>>>,
|
||||||
|
pub unifier_index: usize,
|
||||||
|
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_llvm_type<'ctx>(
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
top_level: &TopLevelContext,
|
||||||
|
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||||
|
ty: Type,
|
||||||
|
) -> BasicTypeEnum<'ctx> {
|
||||||
|
use TypeEnum::*;
|
||||||
|
// we assume the type cache should already contain primitive types,
|
||||||
|
// and they should be passed by value instead of passing as pointer.
|
||||||
|
type_cache.get(&unifier.get_representative(ty)).cloned().unwrap_or_else(|| {
|
||||||
|
match &*unifier.get_ty(ty) {
|
||||||
|
TObj { obj_id, fields, .. } => {
|
||||||
|
// a struct with fields in the order of declaration
|
||||||
|
let defs = top_level.definitions.read();
|
||||||
|
let definition = defs.get(obj_id.0).unwrap();
|
||||||
|
let ty = if let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read()
|
||||||
|
{
|
||||||
|
let fields = fields.borrow();
|
||||||
|
let fields = fields_list
|
||||||
|
.iter()
|
||||||
|
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0]))
|
||||||
|
.collect_vec();
|
||||||
|
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
ty
|
||||||
|
}
|
||||||
|
TTuple { ty } => {
|
||||||
|
// a struct with fields in the order present in the tuple
|
||||||
|
let fields = ty
|
||||||
|
.iter()
|
||||||
|
.map(|ty| get_llvm_type(ctx, unifier, top_level, type_cache, *ty))
|
||||||
|
.collect_vec();
|
||||||
|
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||||
|
}
|
||||||
|
TList { ty } => {
|
||||||
|
// a struct with an integer and a pointer to an array
|
||||||
|
let element_type = get_llvm_type(ctx, unifier, top_level, type_cache, *ty);
|
||||||
|
let fields =
|
||||||
|
[ctx.i32_type().into(), element_type.ptr_type(AddressSpace::Generic).into()];
|
||||||
|
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||||
|
}
|
||||||
|
TVirtual { .. } => unimplemented!(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gen_func<'ctx>(
|
||||||
|
context: &'ctx Context,
|
||||||
|
builder: Builder<'ctx>,
|
||||||
|
module: Module<'ctx>,
|
||||||
|
task: CodeGenTask,
|
||||||
|
top_level_ctx: Arc<TopLevelContext>,
|
||||||
|
) -> (Builder<'ctx>, Module<'ctx>) {
|
||||||
|
// unwrap_or(0) is for unit tests without using rayon
|
||||||
|
let (mut unifier, primitives) = {
|
||||||
|
let unifiers = top_level_ctx.unifiers.read();
|
||||||
|
let (unifier, primitives) = &unifiers[task.unifier_index];
|
||||||
|
(Unifier::from_shared_unifier(unifier), *primitives)
|
||||||
|
};
|
||||||
|
|
||||||
|
for (a, b) in task.subst.iter() {
|
||||||
|
// this should be unification between variables and concrete types
|
||||||
|
// and should not cause any problem...
|
||||||
|
unifier.unify(*a, *b).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// rebuild primitive store with unique representatives
|
||||||
|
let primitives = PrimitiveStore {
|
||||||
|
int32: unifier.get_representative(primitives.int32),
|
||||||
|
int64: unifier.get_representative(primitives.int64),
|
||||||
|
float: unifier.get_representative(primitives.float),
|
||||||
|
bool: unifier.get_representative(primitives.bool),
|
||||||
|
none: unifier.get_representative(primitives.none),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut type_cache: HashMap<_, _> = [
|
||||||
|
(unifier.get_representative(primitives.int32), context.i32_type().into()),
|
||||||
|
(unifier.get_representative(primitives.int64), context.i64_type().into()),
|
||||||
|
(unifier.get_representative(primitives.float), context.f64_type().into()),
|
||||||
|
(unifier.get_representative(primitives.bool), context.bool_type().into()),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let params = task
|
||||||
|
.signature
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|arg| {
|
||||||
|
get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty)
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
let fn_type = if unifier.unioned(task.signature.ret, primitives.none) {
|
||||||
|
context.void_type().fn_type(¶ms, false)
|
||||||
|
} else {
|
||||||
|
get_llvm_type(
|
||||||
|
&context,
|
||||||
|
&mut unifier,
|
||||||
|
top_level_ctx.as_ref(),
|
||||||
|
&mut type_cache,
|
||||||
|
task.signature.ret,
|
||||||
|
)
|
||||||
|
.fn_type(¶ms, false)
|
||||||
|
};
|
||||||
|
|
||||||
|
let fn_val = module.add_function(&task.symbol_name, fn_type, None);
|
||||||
|
let init_bb = context.append_basic_block(fn_val, "init");
|
||||||
|
builder.position_at_end(init_bb);
|
||||||
|
let body_bb = context.append_basic_block(fn_val, "body");
|
||||||
|
|
||||||
|
let mut var_assignment = HashMap::new();
|
||||||
|
for (n, arg) in task.signature.args.iter().enumerate() {
|
||||||
|
let param = fn_val.get_nth_param(n as u32).unwrap();
|
||||||
|
let alloca = builder.build_alloca(
|
||||||
|
get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty),
|
||||||
|
&arg.name,
|
||||||
|
);
|
||||||
|
builder.build_store(alloca, param);
|
||||||
|
var_assignment.insert(arg.name.clone(), alloca);
|
||||||
|
}
|
||||||
|
builder.build_unconditional_branch(body_bb);
|
||||||
|
builder.position_at_end(body_bb);
|
||||||
|
|
||||||
|
let mut code_gen_context = CodeGenContext {
|
||||||
|
ctx: &context,
|
||||||
|
resolver: task.resolver,
|
||||||
|
top_level: top_level_ctx.as_ref(),
|
||||||
|
loop_bb: None,
|
||||||
|
var_assignment,
|
||||||
|
type_cache,
|
||||||
|
primitives,
|
||||||
|
init_bb,
|
||||||
|
builder,
|
||||||
|
module,
|
||||||
|
unifier,
|
||||||
|
};
|
||||||
|
|
||||||
|
for stmt in task.body.iter() {
|
||||||
|
code_gen_context.gen_stmt(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
let CodeGenContext { builder, module, .. } = code_gen_context;
|
||||||
|
|
||||||
|
(builder, module)
|
||||||
|
}
|
138
nac3core/src/codegen/stmt.rs
Normal file
138
nac3core/src/codegen/stmt.rs
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
use super::CodeGenContext;
|
||||||
|
use crate::typecheck::typedef::Type;
|
||||||
|
use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
|
||||||
|
use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind};
|
||||||
|
|
||||||
|
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
|
fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> {
|
||||||
|
// put the alloca in init block
|
||||||
|
let current = self.builder.get_insert_block().unwrap();
|
||||||
|
// position before the last branching instruction...
|
||||||
|
self.builder.position_before(&self.init_bb.get_last_instruction().unwrap());
|
||||||
|
let ty = self.get_llvm_type(ty);
|
||||||
|
let ptr = self.builder.build_alloca(ty, "tmp");
|
||||||
|
self.builder.position_at_end(current);
|
||||||
|
ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_pattern(&mut self, pattern: &Expr<Option<Type>>) -> PointerValue<'ctx> {
|
||||||
|
// very similar to gen_expr, but we don't do an extra load at the end
|
||||||
|
// and we flatten nested tuples
|
||||||
|
match &pattern.node {
|
||||||
|
ExprKind::Name { id, .. } => {
|
||||||
|
self.var_assignment.get(id).cloned().unwrap_or_else(|| {
|
||||||
|
let ptr = self.gen_var(pattern.custom.unwrap());
|
||||||
|
self.var_assignment.insert(id.clone(), ptr);
|
||||||
|
ptr
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ExprKind::Attribute { value, attr, .. } => {
|
||||||
|
let index = self.get_attr_index(value.custom.unwrap(), attr);
|
||||||
|
let val = self.gen_expr(value);
|
||||||
|
let ptr = if let BasicValueEnum::PointerValue(v) = val {
|
||||||
|
v
|
||||||
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
};
|
||||||
|
unsafe {
|
||||||
|
ptr.const_in_bounds_gep(&[
|
||||||
|
self.ctx.i32_type().const_zero(),
|
||||||
|
self.ctx.i32_type().const_int(index as u64, false),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Subscript { .. } => unimplemented!(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_assignment(&mut self, target: &Expr<Option<Type>>, value: BasicValueEnum<'ctx>) {
|
||||||
|
if let ExprKind::Tuple { elts, .. } = &target.node {
|
||||||
|
if let BasicValueEnum::PointerValue(ptr) = value {
|
||||||
|
for (i, elt) in elts.iter().enumerate() {
|
||||||
|
unsafe {
|
||||||
|
let t = ptr.const_in_bounds_gep(&[
|
||||||
|
self.ctx.i32_type().const_zero(),
|
||||||
|
self.ctx.i32_type().const_int(i as u64, false),
|
||||||
|
]);
|
||||||
|
let v = self.builder.build_load(t, "tmpload");
|
||||||
|
self.gen_assignment(elt, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let ptr = self.parse_pattern(target);
|
||||||
|
self.builder.build_store(ptr, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gen_stmt(&mut self, stmt: &Stmt<Option<Type>>) {
|
||||||
|
match &stmt.node {
|
||||||
|
StmtKind::Expr { value } => {
|
||||||
|
self.gen_expr(&value);
|
||||||
|
}
|
||||||
|
StmtKind::Return { value } => {
|
||||||
|
let value = value.as_ref().map(|v| self.gen_expr(&v));
|
||||||
|
let value = value.as_ref().map(|v| v as &dyn BasicValue);
|
||||||
|
self.builder.build_return(value);
|
||||||
|
}
|
||||||
|
StmtKind::AnnAssign { target, value, .. } => {
|
||||||
|
if let Some(value) = value {
|
||||||
|
let value = self.gen_expr(&value);
|
||||||
|
self.gen_assignment(target, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StmtKind::Assign { targets, value, .. } => {
|
||||||
|
let value = self.gen_expr(&value);
|
||||||
|
for target in targets.iter() {
|
||||||
|
self.gen_assignment(target, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StmtKind::Continue => {
|
||||||
|
self.builder.build_unconditional_branch(self.loop_bb.unwrap().0);
|
||||||
|
}
|
||||||
|
StmtKind::Break => {
|
||||||
|
self.builder.build_unconditional_branch(self.loop_bb.unwrap().1);
|
||||||
|
}
|
||||||
|
StmtKind::While { test, body, orelse } => {
|
||||||
|
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||||
|
let test_bb = self.ctx.append_basic_block(current, "test");
|
||||||
|
let body_bb = self.ctx.append_basic_block(current, "body");
|
||||||
|
let cont_bb = self.ctx.append_basic_block(current, "cont");
|
||||||
|
// if there is no orelse, we just go to cont_bb
|
||||||
|
let orelse_bb = if orelse.is_empty() {
|
||||||
|
cont_bb
|
||||||
|
} else {
|
||||||
|
self.ctx.append_basic_block(current, "orelse")
|
||||||
|
};
|
||||||
|
// store loop bb information and restore it later
|
||||||
|
let loop_bb = self.loop_bb.replace((test_bb, cont_bb));
|
||||||
|
self.builder.build_unconditional_branch(test_bb);
|
||||||
|
self.builder.position_at_end(test_bb);
|
||||||
|
let test = self.gen_expr(test);
|
||||||
|
if let BasicValueEnum::IntValue(test) = test {
|
||||||
|
self.builder.build_conditional_branch(test, body_bb, orelse_bb);
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
self.builder.position_at_end(body_bb);
|
||||||
|
for stmt in body.iter() {
|
||||||
|
self.gen_stmt(stmt);
|
||||||
|
}
|
||||||
|
self.builder.build_unconditional_branch(test_bb);
|
||||||
|
if !orelse.is_empty() {
|
||||||
|
self.builder.position_at_end(orelse_bb);
|
||||||
|
for stmt in orelse.iter() {
|
||||||
|
self.gen_stmt(stmt);
|
||||||
|
}
|
||||||
|
self.builder.build_unconditional_branch(cont_bb);
|
||||||
|
}
|
||||||
|
self.builder.position_at_end(cont_bb);
|
||||||
|
self.loop_bb = loop_bb;
|
||||||
|
}
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
247
nac3core/src/codegen/test.rs
Normal file
247
nac3core/src/codegen/test.rs
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
use super::{CodeGenTask, WorkerRegistry};
|
||||||
|
use crate::{
|
||||||
|
codegen::WithCall,
|
||||||
|
location::Location,
|
||||||
|
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||||
|
top_level::{DefinitionId, TopLevelContext},
|
||||||
|
typecheck::{
|
||||||
|
magic_methods::set_primitives_magic_methods,
|
||||||
|
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
|
||||||
|
typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use indoc::indoc;
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use rustpython_parser::{ast::fold::Fold, parser::parse_program};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Resolver {
|
||||||
|
id_to_type: HashMap<String, Type>,
|
||||||
|
id_to_def: HashMap<String, DefinitionId>,
|
||||||
|
class_names: HashMap<String, Type>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SymbolResolver for Resolver {
|
||||||
|
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
||||||
|
self.id_to_type.get(str).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_symbol_location(&self, _: &str) -> Option<Location> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
|
||||||
|
self.id_to_def.get(id).cloned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestEnvironment {
|
||||||
|
pub unifier: Unifier,
|
||||||
|
pub function_data: FunctionData,
|
||||||
|
pub primitives: PrimitiveStore,
|
||||||
|
pub id_to_name: HashMap<usize, String>,
|
||||||
|
pub identifier_mapping: HashMap<String, Type>,
|
||||||
|
pub virtual_checks: Vec<(Type, Type)>,
|
||||||
|
pub calls: HashMap<CodeLocation, CallId>,
|
||||||
|
pub top_level: TopLevelContext,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestEnvironment {
|
||||||
|
pub fn basic_test_env() -> TestEnvironment {
|
||||||
|
let mut unifier = Unifier::new();
|
||||||
|
|
||||||
|
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(0),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(1),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let float = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(2),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(3),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let none = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(4),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||||
|
set_primitives_magic_methods(&primitives, &mut unifier);
|
||||||
|
|
||||||
|
let id_to_name = [
|
||||||
|
(0, "int32".to_string()),
|
||||||
|
(1, "int64".to_string()),
|
||||||
|
(2, "float".to_string()),
|
||||||
|
(3, "bool".to_string()),
|
||||||
|
(4, "none".to_string()),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut identifier_mapping = HashMap::new();
|
||||||
|
identifier_mapping.insert("None".into(), none);
|
||||||
|
|
||||||
|
let resolver = Arc::new(Resolver {
|
||||||
|
id_to_type: identifier_mapping.clone(),
|
||||||
|
id_to_def: Default::default(),
|
||||||
|
class_names: Default::default(),
|
||||||
|
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||||
|
|
||||||
|
TestEnvironment {
|
||||||
|
unifier,
|
||||||
|
top_level: TopLevelContext {
|
||||||
|
definitions: Default::default(),
|
||||||
|
unifiers: Default::default(),
|
||||||
|
// conetexts: Default::default(),
|
||||||
|
},
|
||||||
|
function_data: FunctionData {
|
||||||
|
resolver,
|
||||||
|
bound_variables: Vec::new(),
|
||||||
|
return_type: Some(primitives.int32),
|
||||||
|
},
|
||||||
|
primitives,
|
||||||
|
id_to_name,
|
||||||
|
identifier_mapping,
|
||||||
|
virtual_checks: Vec::new(),
|
||||||
|
calls: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_inferencer(&mut self) -> Inferencer {
|
||||||
|
Inferencer {
|
||||||
|
top_level: &self.top_level,
|
||||||
|
function_data: &mut self.function_data,
|
||||||
|
unifier: &mut self.unifier,
|
||||||
|
variable_mapping: Default::default(),
|
||||||
|
primitives: &mut self.primitives,
|
||||||
|
virtual_checks: &mut self.virtual_checks,
|
||||||
|
calls: &mut self.calls,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_primitives() {
|
||||||
|
let mut env = TestEnvironment::basic_test_env();
|
||||||
|
let threads = ["test"];
|
||||||
|
let signature = FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None },
|
||||||
|
FuncArg { name: "b".to_string(), ty: env.primitives.int32, default_value: None },
|
||||||
|
],
|
||||||
|
ret: env.primitives.int32,
|
||||||
|
vars: HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut inferencer = env.get_inferencer();
|
||||||
|
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
|
||||||
|
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
|
||||||
|
let source = indoc! { "
|
||||||
|
c = a + b
|
||||||
|
d = a if c == 1 else 0
|
||||||
|
return d
|
||||||
|
"};
|
||||||
|
let statements = parse_program(source).unwrap();
|
||||||
|
|
||||||
|
let statements = statements
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| inferencer.fold_stmt(v))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.unwrap();
|
||||||
|
let mut identifiers = vec!["a".to_string(), "b".to_string()];
|
||||||
|
inferencer.check_block(&statements, &mut identifiers).unwrap();
|
||||||
|
|
||||||
|
let top_level = Arc::new(TopLevelContext {
|
||||||
|
definitions: Default::default(),
|
||||||
|
unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])),
|
||||||
|
// conetexts: Default::default(),
|
||||||
|
});
|
||||||
|
let task = CodeGenTask {
|
||||||
|
subst: Default::default(),
|
||||||
|
symbol_name: "testing".to_string(),
|
||||||
|
body: statements,
|
||||||
|
unifier_index: 0,
|
||||||
|
resolver: env.function_data.resolver.clone(),
|
||||||
|
signature,
|
||||||
|
};
|
||||||
|
|
||||||
|
let f = Arc::new(WithCall::new(Box::new(|module| {
|
||||||
|
// the following IR is equivalent to
|
||||||
|
// ```
|
||||||
|
// ; ModuleID = 'test.ll'
|
||||||
|
// source_filename = "test"
|
||||||
|
//
|
||||||
|
// ; Function Attrs: norecurse nounwind readnone
|
||||||
|
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
|
||||||
|
// init:
|
||||||
|
// %add = add i32 %1, %0
|
||||||
|
// %cmp = icmp eq i32 %add, 1
|
||||||
|
// %ifexpr = select i1 %cmp, i32 %0, i32 0
|
||||||
|
// ret i32 %ifexpr
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// attributes #0 = { norecurse nounwind readnone }
|
||||||
|
// ```
|
||||||
|
// after O2 optimization
|
||||||
|
|
||||||
|
let expected = indoc! {"
|
||||||
|
; ModuleID = 'test'
|
||||||
|
source_filename = \"test\"
|
||||||
|
|
||||||
|
define i32 @testing(i32 %0, i32 %1) {
|
||||||
|
init:
|
||||||
|
%a = alloca i32, align 4
|
||||||
|
store i32 %0, i32* %a, align 4
|
||||||
|
%b = alloca i32, align 4
|
||||||
|
store i32 %1, i32* %b, align 4
|
||||||
|
%tmp = alloca i32, align 4
|
||||||
|
%tmp4 = alloca i32, align 4
|
||||||
|
br label %body
|
||||||
|
|
||||||
|
body: ; preds = %init
|
||||||
|
%load = load i32, i32* %a, align 4
|
||||||
|
%load1 = load i32, i32* %b, align 4
|
||||||
|
%add = add i32 %load, %load1
|
||||||
|
store i32 %add, i32* %tmp, align 4
|
||||||
|
%load2 = load i32, i32* %tmp, align 4
|
||||||
|
%cmp = icmp eq i32 %load2, 1
|
||||||
|
br i1 %cmp, label %then, label %else
|
||||||
|
|
||||||
|
then: ; preds = %body
|
||||||
|
%load3 = load i32, i32* %a, align 4
|
||||||
|
br label %cont
|
||||||
|
|
||||||
|
else: ; preds = %body
|
||||||
|
br label %cont
|
||||||
|
|
||||||
|
cont: ; preds = %else, %then
|
||||||
|
%ifexpr = phi i32 [ %load3, %then ], [ 0, %else ]
|
||||||
|
store i32 %ifexpr, i32* %tmp4, align 4
|
||||||
|
%load5 = load i32, i32* %tmp4, align 4
|
||||||
|
ret i32 %load5
|
||||||
|
}
|
||||||
|
"}
|
||||||
|
.trim();
|
||||||
|
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
|
||||||
|
})));
|
||||||
|
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
|
||||||
|
registry.add_task(task);
|
||||||
|
registry.wait_tasks_complete(handles);
|
||||||
|
}
|
@ -1,212 +0,0 @@
|
|||||||
use super::TopLevelContext;
|
|
||||||
use crate::typedef::*;
|
|
||||||
use std::boxed::Box;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
struct ContextStack<'a> {
|
|
||||||
/// stack level, starts from 0
|
|
||||||
level: u32,
|
|
||||||
/// stack of variable definitions containing (id, def, level) where `def` is the original
|
|
||||||
/// definition in `level-1`.
|
|
||||||
var_defs: Vec<(usize, VarDef<'a>, u32)>,
|
|
||||||
/// stack of symbol definitions containing (name, level) where `level` is the smallest level
|
|
||||||
/// where the name is assigned a value
|
|
||||||
sym_def: Vec<(&'a str, u32)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct InferenceContext<'a> {
|
|
||||||
/// top level context
|
|
||||||
top_level: TopLevelContext<'a>,
|
|
||||||
|
|
||||||
/// list of primitive instances
|
|
||||||
primitives: Vec<Type>,
|
|
||||||
/// list of variable instances
|
|
||||||
variables: Vec<Type>,
|
|
||||||
/// identifier to (type, readable) mapping.
|
|
||||||
/// an identifier might be defined earlier but has no value (for some code path), thus not
|
|
||||||
/// readable.
|
|
||||||
sym_table: HashMap<&'a str, (Type, bool)>,
|
|
||||||
/// resolution function reference, that may resolve unbounded identifiers to some type
|
|
||||||
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
|
|
||||||
/// stack
|
|
||||||
stack: ContextStack<'a>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// non-trivial implementations here
|
|
||||||
impl<'a> InferenceContext<'a> {
|
|
||||||
/// return a new `InferenceContext` from `TopLevelContext` and resolution function.
|
|
||||||
pub fn new(
|
|
||||||
top_level: TopLevelContext,
|
|
||||||
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
|
|
||||||
) -> InferenceContext {
|
|
||||||
let primitives = (0..top_level.primitive_defs.len())
|
|
||||||
.map(|v| TypeEnum::PrimitiveType(PrimitiveId(v)).into())
|
|
||||||
.collect();
|
|
||||||
let variables = (0..top_level.var_defs.len())
|
|
||||||
.map(|v| TypeEnum::TypeVariable(VariableId(v)).into())
|
|
||||||
.collect();
|
|
||||||
InferenceContext {
|
|
||||||
top_level,
|
|
||||||
primitives,
|
|
||||||
variables,
|
|
||||||
sym_table: HashMap::new(),
|
|
||||||
resolution_fn,
|
|
||||||
stack: ContextStack {
|
|
||||||
level: 0,
|
|
||||||
var_defs: Vec::new(),
|
|
||||||
sym_def: Vec::new(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// execute the function with new scope.
|
|
||||||
/// variable assignment would be limited within the scope (not readable outside), and type
|
|
||||||
/// variable type guard would be limited within the scope.
|
|
||||||
/// returns the list of variables assigned within the scope, and the result of the function
|
|
||||||
pub fn with_scope<F, R>(&mut self, f: F) -> (Vec<&'a str>, R)
|
|
||||||
where
|
|
||||||
F: FnOnce(&mut Self) -> R,
|
|
||||||
{
|
|
||||||
self.stack.level += 1;
|
|
||||||
let result = f(self);
|
|
||||||
self.stack.level -= 1;
|
|
||||||
while !self.stack.var_defs.is_empty() {
|
|
||||||
let (_, _, level) = self.stack.var_defs.last().unwrap();
|
|
||||||
if *level > self.stack.level {
|
|
||||||
let (id, def, _) = self.stack.var_defs.pop().unwrap();
|
|
||||||
self.top_level.var_defs[id] = def;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mut poped_names = Vec::new();
|
|
||||||
while !self.stack.sym_def.is_empty() {
|
|
||||||
let (_, level) = self.stack.sym_def.last().unwrap();
|
|
||||||
if *level > self.stack.level {
|
|
||||||
let (name, _) = self.stack.sym_def.pop().unwrap();
|
|
||||||
self.sym_table.remove(name).unwrap();
|
|
||||||
poped_names.push(name);
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(poped_names, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// assign a type to an identifier.
|
|
||||||
/// may return error if the identifier was defined but with different type
|
|
||||||
pub fn assign(&mut self, name: &'a str, ty: Type) -> Result<Type, String> {
|
|
||||||
if let Some((t, x)) = self.sym_table.get_mut(name) {
|
|
||||||
if t == &ty {
|
|
||||||
if !*x {
|
|
||||||
self.stack.sym_def.push((name, self.stack.level));
|
|
||||||
}
|
|
||||||
*x = true;
|
|
||||||
Ok(ty)
|
|
||||||
} else {
|
|
||||||
Err("different types".into())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.stack.sym_def.push((name, self.stack.level));
|
|
||||||
self.sym_table.insert(name, (ty.clone(), true));
|
|
||||||
Ok(ty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// check if an identifier is already defined
|
|
||||||
pub fn defined(&self, name: &str) -> bool {
|
|
||||||
self.sym_table.get(name).is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get the type of an identifier
|
|
||||||
/// may return error if the identifier is not defined, and cannot be resolved with the
|
|
||||||
/// resolution function.
|
|
||||||
pub fn resolve(&mut self, name: &str) -> Result<Type, String> {
|
|
||||||
if let Some((t, x)) = self.sym_table.get(name) {
|
|
||||||
if *x {
|
|
||||||
Ok(t.clone())
|
|
||||||
} else {
|
|
||||||
Err("may not have value".into())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.resolution_fn.as_mut()(name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// restrict the bound of a type variable by replacing its definition.
|
|
||||||
/// used for implementing type guard
|
|
||||||
pub fn restrict(&mut self, id: VariableId, mut def: VarDef<'a>) {
|
|
||||||
std::mem::swap(self.top_level.var_defs.get_mut(id.0).unwrap(), &mut def);
|
|
||||||
self.stack.var_defs.push((id.0, def, self.stack.level));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// trivial getters:
|
|
||||||
impl<'a> InferenceContext<'a> {
|
|
||||||
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
|
|
||||||
self.primitives.get(id.0).unwrap().clone()
|
|
||||||
}
|
|
||||||
pub fn get_variable(&self, id: VariableId) -> Type {
|
|
||||||
self.variables.get(id.0).unwrap().clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
|
|
||||||
self.top_level.fn_table.get(name)
|
|
||||||
}
|
|
||||||
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
|
|
||||||
self.top_level.primitive_defs.get(id.0).unwrap()
|
|
||||||
}
|
|
||||||
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
|
|
||||||
self.top_level.class_defs.get(id.0).unwrap()
|
|
||||||
}
|
|
||||||
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
|
|
||||||
self.top_level.parametric_defs.get(id.0).unwrap()
|
|
||||||
}
|
|
||||||
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
|
|
||||||
self.top_level.var_defs.get(id.0).unwrap()
|
|
||||||
}
|
|
||||||
pub fn get_type(&self, name: &str) -> Option<Type> {
|
|
||||||
self.top_level.get_type(name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TypeEnum {
|
|
||||||
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> TypeEnum {
|
|
||||||
match self {
|
|
||||||
TypeEnum::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(),
|
|
||||||
TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType(
|
|
||||||
*id,
|
|
||||||
params
|
|
||||||
.iter()
|
|
||||||
.map(|v| v.as_ref().subst(map).into())
|
|
||||||
.collect(),
|
|
||||||
),
|
|
||||||
_ => self.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_subst(&self, ctx: &InferenceContext) -> HashMap<VariableId, Type> {
|
|
||||||
match self {
|
|
||||||
TypeEnum::ParametricType(id, params) => {
|
|
||||||
let vars = &ctx.get_parametric_def(*id).params;
|
|
||||||
vars.iter()
|
|
||||||
.zip(params)
|
|
||||||
.map(|(v, p)| (*v, p.as_ref().clone().into()))
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
// if this proves to be slow, we can use option type
|
|
||||||
_ => HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b InferenceContext) -> Option<&'b TypeDef> {
|
|
||||||
match self {
|
|
||||||
TypeEnum::PrimitiveType(id) => Some(ctx.get_primitive_def(*id)),
|
|
||||||
TypeEnum::ClassType(id) | TypeEnum::VirtualClassType(id) => {
|
|
||||||
Some(&ctx.get_class_def(*id).base)
|
|
||||||
}
|
|
||||||
TypeEnum::ParametricType(id, _) => Some(&ctx.get_parametric_def(*id).base),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,4 +0,0 @@
|
|||||||
mod inference_context;
|
|
||||||
mod top_level_context;
|
|
||||||
pub use inference_context::InferenceContext;
|
|
||||||
pub use top_level_context::TopLevelContext;
|
|
@ -1,136 +0,0 @@
|
|||||||
use crate::typedef::*;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::rc::Rc;
|
|
||||||
|
|
||||||
/// Structure for storing top-level type definitions.
|
|
||||||
/// Used for collecting type signature from source code.
|
|
||||||
/// Can be converted to `InferenceContext` for type inference in functions.
|
|
||||||
pub struct TopLevelContext<'a> {
|
|
||||||
/// List of primitive definitions.
|
|
||||||
pub(super) primitive_defs: Vec<TypeDef<'a>>,
|
|
||||||
/// List of class definitions.
|
|
||||||
pub(super) class_defs: Vec<ClassDef<'a>>,
|
|
||||||
/// List of parametric type definitions.
|
|
||||||
pub(super) parametric_defs: Vec<ParametricDef<'a>>,
|
|
||||||
/// List of type variable definitions.
|
|
||||||
pub(super) var_defs: Vec<VarDef<'a>>,
|
|
||||||
/// Function name to signature mapping.
|
|
||||||
pub(super) fn_table: HashMap<&'a str, FnDef>,
|
|
||||||
/// Type name to type mapping.
|
|
||||||
pub(super) sym_table: HashMap<&'a str, Type>,
|
|
||||||
|
|
||||||
primitives: Vec<Type>,
|
|
||||||
variables: Vec<Type>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> TopLevelContext<'a> {
|
|
||||||
pub fn new(primitive_defs: Vec<TypeDef<'a>>) -> TopLevelContext {
|
|
||||||
let mut sym_table = HashMap::new();
|
|
||||||
let mut primitives = Vec::new();
|
|
||||||
for (i, t) in primitive_defs.iter().enumerate() {
|
|
||||||
primitives.push(TypeEnum::PrimitiveType(PrimitiveId(i)).into());
|
|
||||||
sym_table.insert(t.name, TypeEnum::PrimitiveType(PrimitiveId(i)).into());
|
|
||||||
}
|
|
||||||
TopLevelContext {
|
|
||||||
primitive_defs,
|
|
||||||
class_defs: Vec::new(),
|
|
||||||
parametric_defs: Vec::new(),
|
|
||||||
var_defs: Vec::new(),
|
|
||||||
fn_table: HashMap::new(),
|
|
||||||
sym_table,
|
|
||||||
primitives,
|
|
||||||
variables: Vec::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId {
|
|
||||||
self.sym_table.insert(
|
|
||||||
def.base.name,
|
|
||||||
TypeEnum::ClassType(ClassId(self.class_defs.len())).into(),
|
|
||||||
);
|
|
||||||
self.class_defs.push(def);
|
|
||||||
ClassId(self.class_defs.len() - 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId {
|
|
||||||
let params = def
|
|
||||||
.params
|
|
||||||
.iter()
|
|
||||||
.map(|&v| Rc::new(TypeEnum::TypeVariable(v)))
|
|
||||||
.collect();
|
|
||||||
self.sym_table.insert(
|
|
||||||
def.base.name,
|
|
||||||
TypeEnum::ParametricType(ParamId(self.parametric_defs.len()), params).into(),
|
|
||||||
);
|
|
||||||
self.parametric_defs.push(def);
|
|
||||||
ParamId(self.parametric_defs.len() - 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId {
|
|
||||||
self.sym_table.insert(
|
|
||||||
def.name,
|
|
||||||
TypeEnum::TypeVariable(VariableId(self.var_defs.len())).into(),
|
|
||||||
);
|
|
||||||
self.add_variable_private(def)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId {
|
|
||||||
self.var_defs.push(def);
|
|
||||||
self.variables
|
|
||||||
.push(TypeEnum::TypeVariable(VariableId(self.var_defs.len() - 1)).into());
|
|
||||||
VariableId(self.var_defs.len() - 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_fn(&mut self, name: &'a str, def: FnDef) {
|
|
||||||
self.fn_table.insert(name, def);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
|
|
||||||
self.fn_table.get(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_primitive_def_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
|
|
||||||
self.primitive_defs.get_mut(id.0).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
|
|
||||||
self.primitive_defs.get(id.0).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_class_def_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> {
|
|
||||||
self.class_defs.get_mut(id.0).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
|
|
||||||
self.class_defs.get(id.0).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_parametric_def_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> {
|
|
||||||
self.parametric_defs.get_mut(id.0).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
|
|
||||||
self.parametric_defs.get(id.0).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_variable_def_mut(&mut self, id: VariableId) -> &mut VarDef<'a> {
|
|
||||||
self.var_defs.get_mut(id.0).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
|
|
||||||
self.var_defs.get(id.0).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
|
|
||||||
self.primitives.get(id.0).unwrap().clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_variable(&self, id: VariableId) -> Type {
|
|
||||||
self.variables.get(id.0).unwrap().clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_type(&self, name: &str) -> Option<Type> {
|
|
||||||
// TODO: handle parametric types
|
|
||||||
self.sym_table.get(name).cloned()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,922 +0,0 @@
|
|||||||
use crate::context::InferenceContext;
|
|
||||||
use crate::inference_core::resolve_call;
|
|
||||||
use crate::magic_methods::*;
|
|
||||||
use crate::primitives::*;
|
|
||||||
use crate::typedef::{Type, TypeEnum::*};
|
|
||||||
use rustpython_parser::ast::{
|
|
||||||
Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator,
|
|
||||||
UnaryOperator,
|
|
||||||
};
|
|
||||||
use std::convert::TryInto;
|
|
||||||
|
|
||||||
type ParserResult = Result<Option<Type>, String>;
|
|
||||||
|
|
||||||
pub fn infer_expr<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
expr: &'b Expression,
|
|
||||||
) -> ParserResult {
|
|
||||||
match &expr.node {
|
|
||||||
ExpressionType::Number { value } => infer_constant(ctx, value),
|
|
||||||
ExpressionType::Identifier { name } => infer_identifier(ctx, name),
|
|
||||||
ExpressionType::List { elements } => infer_list(ctx, elements),
|
|
||||||
ExpressionType::Tuple { elements } => infer_tuple(ctx, elements),
|
|
||||||
ExpressionType::Attribute { value, name } => infer_attribute(ctx, value, name),
|
|
||||||
ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, values),
|
|
||||||
ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, op, a, b),
|
|
||||||
ExpressionType::Unop { op, a } => infer_unary_ops(ctx, op, a),
|
|
||||||
ExpressionType::Compare { vals, ops } => infer_compare(ctx, vals, ops),
|
|
||||||
ExpressionType::Call {
|
|
||||||
args,
|
|
||||||
function,
|
|
||||||
keywords,
|
|
||||||
} => {
|
|
||||||
if !keywords.is_empty() {
|
|
||||||
Err("keyword is not supported".into())
|
|
||||||
} else {
|
|
||||||
infer_call(ctx, &args, &function)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ExpressionType::Subscript { a, b } => infer_subscript(ctx, a, b),
|
|
||||||
ExpressionType::IfExpression { test, body, orelse } => {
|
|
||||||
infer_if_expr(ctx, &test, &body, orelse)
|
|
||||||
}
|
|
||||||
ExpressionType::Comprehension { kind, generators } => match kind.as_ref() {
|
|
||||||
ComprehensionKind::List { element } => {
|
|
||||||
if generators.len() == 1 {
|
|
||||||
infer_list_comprehension(ctx, element, &generators[0])
|
|
||||||
} else {
|
|
||||||
Err("only 1 generator statement is supported".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => Err("only list comprehension is supported".into()),
|
|
||||||
},
|
|
||||||
ExpressionType::True | ExpressionType::False => Ok(Some(ctx.get_primitive(BOOL_TYPE))),
|
|
||||||
_ => Err("not supported".into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_constant(
|
|
||||||
ctx: &mut InferenceContext,
|
|
||||||
value: &rustpython_parser::ast::Number,
|
|
||||||
) -> ParserResult {
|
|
||||||
use rustpython_parser::ast::Number;
|
|
||||||
match value {
|
|
||||||
Number::Integer { value } => {
|
|
||||||
let int32: Result<i32, _> = value.try_into();
|
|
||||||
if int32.is_ok() {
|
|
||||||
Ok(Some(ctx.get_primitive(INT32_TYPE)))
|
|
||||||
} else {
|
|
||||||
Err("integer out of range".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Number::Float { .. } => Ok(Some(ctx.get_primitive(FLOAT_TYPE))),
|
|
||||||
_ => Err("not supported".into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_identifier(ctx: &mut InferenceContext, name: &str) -> ParserResult {
|
|
||||||
Ok(Some(ctx.resolve(name)?))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_list<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
elements: &'b [Expression],
|
|
||||||
) -> ParserResult {
|
|
||||||
if elements.is_empty() {
|
|
||||||
return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut types = elements.iter().map(|v| infer_expr(ctx, v));
|
|
||||||
|
|
||||||
let head = types.next().unwrap()?;
|
|
||||||
if head.is_none() {
|
|
||||||
return Err("list elements must have some type".into());
|
|
||||||
}
|
|
||||||
for v in types {
|
|
||||||
// TODO: try virtual type...
|
|
||||||
if v? != head {
|
|
||||||
return Err("inhomogeneous list is not allowed".into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_tuple<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
elements: &'b [Expression],
|
|
||||||
) -> ParserResult {
|
|
||||||
let types: Result<Option<Vec<_>>, String> =
|
|
||||||
elements.iter().map(|v| infer_expr(ctx, v)).collect();
|
|
||||||
if let Some(t) = types? {
|
|
||||||
Ok(Some(ParametricType(TUPLE_TYPE, t).into()))
|
|
||||||
} else {
|
|
||||||
Err("tuple elements must have some type".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_attribute<'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
value: &'a Expression,
|
|
||||||
name: &str,
|
|
||||||
) -> ParserResult {
|
|
||||||
let value = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
if let TypeVariable(_) = value.as_ref() {
|
|
||||||
return Err("no fields for type variable".into());
|
|
||||||
}
|
|
||||||
|
|
||||||
value
|
|
||||||
.get_base(ctx)
|
|
||||||
.and_then(|b| b.fields.get(name).cloned())
|
|
||||||
.map_or_else(|| Err("no such field".to_string()), |v| Ok(Some(v)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_bool_ops<'a>(ctx: &mut InferenceContext<'a>, values: &'a [Expression]) -> ParserResult {
|
|
||||||
assert_eq!(values.len(), 2);
|
|
||||||
let left = infer_expr(ctx, &values[0])?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
let right = infer_expr(ctx, &values[1])?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
|
|
||||||
let b = ctx.get_primitive(BOOL_TYPE);
|
|
||||||
if left == b && right == b {
|
|
||||||
Ok(Some(b))
|
|
||||||
} else {
|
|
||||||
Err("bool operands must be bool".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_bin_ops<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
op: &Operator,
|
|
||||||
left: &'b Expression,
|
|
||||||
right: &'b Expression,
|
|
||||||
) -> ParserResult {
|
|
||||||
let left = infer_expr(ctx, left)?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
let right = infer_expr(ctx, right)?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
let fun = binop_name(op);
|
|
||||||
resolve_call(ctx, Some(left), fun, &[right])
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_unary_ops<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
op: &UnaryOperator,
|
|
||||||
obj: &'b Expression,
|
|
||||||
) -> ParserResult {
|
|
||||||
let ty = infer_expr(ctx, obj)?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
if let UnaryOperator::Not = op {
|
|
||||||
if ty == ctx.get_primitive(BOOL_TYPE) {
|
|
||||||
Ok(Some(ty))
|
|
||||||
} else {
|
|
||||||
Err("logical not must be applied to bool".into())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
resolve_call(ctx, Some(ty), unaryop_name(op), &[])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_compare<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
vals: &'b [Expression],
|
|
||||||
ops: &'b [Comparison],
|
|
||||||
) -> ParserResult {
|
|
||||||
let types: Result<Option<Vec<_>>, _> = vals.iter().map(|v| infer_expr(ctx, v)).collect();
|
|
||||||
let types = types?;
|
|
||||||
if types.is_none() {
|
|
||||||
return Err("comparison operands must have type".into());
|
|
||||||
}
|
|
||||||
let types = types.unwrap();
|
|
||||||
let boolean = ctx.get_primitive(BOOL_TYPE);
|
|
||||||
let left = &types[..types.len() - 1];
|
|
||||||
let right = &types[1..];
|
|
||||||
|
|
||||||
for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) {
|
|
||||||
let fun = comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?;
|
|
||||||
let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?;
|
|
||||||
if ty.is_none() || ty.unwrap() != boolean {
|
|
||||||
return Err("comparison result must be boolean".into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(Some(boolean))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_call<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
args: &'b [Expression],
|
|
||||||
function: &'b Expression,
|
|
||||||
) -> ParserResult {
|
|
||||||
// TODO: special handling for int64 constant
|
|
||||||
let types: Result<Option<Vec<_>>, _> = args.iter().map(|v| infer_expr(ctx, v)).collect();
|
|
||||||
let types = types?;
|
|
||||||
if types.is_none() {
|
|
||||||
return Err("function params must have type".into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let (obj, fun) = match &function.node {
|
|
||||||
ExpressionType::Identifier { name } => (None, name),
|
|
||||||
ExpressionType::Attribute { value, name } => (
|
|
||||||
Some(infer_expr(ctx, &value)?.ok_or_else(|| "no value".to_string())?),
|
|
||||||
name,
|
|
||||||
),
|
|
||||||
_ => return Err("not supported".into()),
|
|
||||||
};
|
|
||||||
resolve_call(ctx, obj, fun.as_str(), &types.unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_subscript<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
a: &'b Expression,
|
|
||||||
b: &'b Expression,
|
|
||||||
) -> ParserResult {
|
|
||||||
let a = infer_expr(ctx, a)?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() {
|
|
||||||
ls[0].clone()
|
|
||||||
} else {
|
|
||||||
return Err("subscript is not supported for types other than list".into());
|
|
||||||
};
|
|
||||||
|
|
||||||
match &b.node {
|
|
||||||
ExpressionType::Slice { elements } => {
|
|
||||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
|
||||||
let types: Result<Option<Vec<_>>, _> = elements
|
|
||||||
.iter()
|
|
||||||
.map(|v| {
|
|
||||||
if let ExpressionType::None = v.node {
|
|
||||||
Ok(Some(int32.clone()))
|
|
||||||
} else {
|
|
||||||
infer_expr(ctx, v)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let types = types?.ok_or_else(|| "slice must have type".to_string())?;
|
|
||||||
if types.iter().all(|v| v == &int32) {
|
|
||||||
Ok(Some(a))
|
|
||||||
} else {
|
|
||||||
Err("slice must be int32 type".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let b = infer_expr(ctx, b)?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
if b == ctx.get_primitive(INT32_TYPE) {
|
|
||||||
Ok(Some(t))
|
|
||||||
} else {
|
|
||||||
Err("index must be either slice or int32".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_if_expr<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
test: &'b Expression,
|
|
||||||
body: &'b Expression,
|
|
||||||
orelse: &'b Expression,
|
|
||||||
) -> ParserResult {
|
|
||||||
let test = infer_expr(ctx, test)?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
if test != ctx.get_primitive(BOOL_TYPE) {
|
|
||||||
return Err("test should be bool".into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let body = infer_expr(ctx, body)?;
|
|
||||||
let orelse = infer_expr(ctx, orelse)?;
|
|
||||||
if body.as_ref() == orelse.as_ref() {
|
|
||||||
Ok(body)
|
|
||||||
} else {
|
|
||||||
Err("divergent type".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_simple_binding<'a: 'b, 'b>(
|
|
||||||
ctx: &mut InferenceContext<'b>,
|
|
||||||
name: &'a Expression,
|
|
||||||
ty: Type,
|
|
||||||
) -> Result<(), String> {
|
|
||||||
match &name.node {
|
|
||||||
ExpressionType::Identifier { name } => {
|
|
||||||
if name == "_" {
|
|
||||||
Ok(())
|
|
||||||
} else if ctx.defined(name.as_str()) {
|
|
||||||
Err("duplicated naming".into())
|
|
||||||
} else {
|
|
||||||
ctx.assign(name.as_str(), ty)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ExpressionType::Tuple { elements } => {
|
|
||||||
if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() {
|
|
||||||
if elements.len() == ls.len() {
|
|
||||||
for (a, b) in elements.iter().zip(ls.iter()) {
|
|
||||||
infer_simple_binding(ctx, a, b.clone())?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err("different length".into())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err("not supported".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => Err("not supported".into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_list_comprehension<'b: 'a, 'a>(
|
|
||||||
ctx: &mut InferenceContext<'a>,
|
|
||||||
element: &'b Expression,
|
|
||||||
comprehension: &'b Comprehension,
|
|
||||||
) -> ParserResult {
|
|
||||||
if comprehension.is_async {
|
|
||||||
return Err("async is not supported".into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let iter = infer_expr(ctx, &comprehension.iter)?.ok_or_else(|| "no value".to_string())?;
|
|
||||||
if let ParametricType(LIST_TYPE, ls) = iter.as_ref() {
|
|
||||||
ctx.with_scope(|ctx| {
|
|
||||||
infer_simple_binding(ctx, &comprehension.target, ls[0].clone())?;
|
|
||||||
|
|
||||||
let boolean = ctx.get_primitive(BOOL_TYPE);
|
|
||||||
for test in comprehension.ifs.iter() {
|
|
||||||
let result =
|
|
||||||
infer_expr(ctx, test)?.ok_or_else(|| "no value in test".to_string())?;
|
|
||||||
if result != boolean {
|
|
||||||
return Err("test must be bool".into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let result = infer_expr(ctx, element)?.ok_or_else(|| "no value")?;
|
|
||||||
Ok(Some(ParametricType(LIST_TYPE, vec![result]).into()))
|
|
||||||
})
|
|
||||||
.1
|
|
||||||
} else {
|
|
||||||
Err("iteration is supported for list only".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod test {
|
|
||||||
use super::*;
|
|
||||||
use crate::context::*;
|
|
||||||
use crate::typedef::*;
|
|
||||||
use rustpython_parser::parser::parse_expression;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::rc::Rc;
|
|
||||||
|
|
||||||
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
|
|
||||||
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_constants() {
|
|
||||||
let ctx = basic_ctx();
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
|
|
||||||
let ast = parse_expression("123").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("2147483647").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("2147483648").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("integer out of range".into()));
|
|
||||||
//
|
|
||||||
// let ast = parse_expression("2147483648").unwrap();
|
|
||||||
// let result = infer_expr(&mut ctx, &ast);
|
|
||||||
// assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
|
|
||||||
|
|
||||||
// let ast = parse_expression("9223372036854775807").unwrap();
|
|
||||||
// let result = infer_expr(&mut ctx, &ast);
|
|
||||||
// assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
|
|
||||||
|
|
||||||
// let ast = parse_expression("9223372036854775808").unwrap();
|
|
||||||
// let result = infer_expr(&mut ctx, &ast);
|
|
||||||
// assert_eq!(result, Err("integer out of range".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("123.456").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(FLOAT_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("True").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("False").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_identifier() {
|
|
||||||
let ctx = basic_ctx();
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
|
|
||||||
|
|
||||||
let ast = parse_expression("abc").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("ab").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("unbounded identifier".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_list() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
ctx.add_fn(
|
|
||||||
"foo",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
|
|
||||||
// def is reserved...
|
|
||||||
ctx.assign("efg", ctx.get_primitive(INT32_TYPE)).unwrap();
|
|
||||||
ctx.assign("xyz", ctx.get_primitive(FLOAT_TYPE)).unwrap();
|
|
||||||
|
|
||||||
let ast = parse_expression("[]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(
|
|
||||||
result.unwrap().unwrap(),
|
|
||||||
ParametricType(LIST_TYPE, vec![BotType.into()]).into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let ast = parse_expression("[abc]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(
|
|
||||||
result.unwrap().unwrap(),
|
|
||||||
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let ast = parse_expression("[abc, efg]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(
|
|
||||||
result.unwrap().unwrap(),
|
|
||||||
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let ast = parse_expression("[abc, efg, xyz]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("inhomogeneous list is not allowed".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("[foo()]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("list elements must have some type".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tuple() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
ctx.add_fn(
|
|
||||||
"foo",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
|
|
||||||
ctx.assign("efg", ctx.get_primitive(FLOAT_TYPE)).unwrap();
|
|
||||||
|
|
||||||
let ast = parse_expression("(abc, efg)").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(
|
|
||||||
result.unwrap().unwrap(),
|
|
||||||
ParametricType(
|
|
||||||
TUPLE_TYPE,
|
|
||||||
vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)]
|
|
||||||
)
|
|
||||||
.into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let ast = parse_expression("(abc, efg, foo())").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("tuple elements must have some type".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_attribute() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
ctx.add_fn(
|
|
||||||
"none",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
|
||||||
let float = ctx.get_primitive(FLOAT_TYPE);
|
|
||||||
|
|
||||||
let foo = ctx.add_class(ClassDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "Foo",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
parents: vec![],
|
|
||||||
});
|
|
||||||
let foo_def = ctx.get_class_def_mut(foo);
|
|
||||||
foo_def.base.fields.insert("a", int32.clone());
|
|
||||||
foo_def.base.fields.insert("b", ClassType(foo).into());
|
|
||||||
foo_def.base.fields.insert("c", int32.clone());
|
|
||||||
|
|
||||||
let bar = ctx.add_class(ClassDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "Bar",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
parents: vec![],
|
|
||||||
});
|
|
||||||
let bar_def = ctx.get_class_def_mut(bar);
|
|
||||||
bar_def.base.fields.insert("a", int32);
|
|
||||||
bar_def.base.fields.insert("b", ClassType(bar).into());
|
|
||||||
bar_def.base.fields.insert("c", float);
|
|
||||||
|
|
||||||
let v0 = ctx.add_variable(VarDef {
|
|
||||||
name: "v0",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
|
|
||||||
let v1 = ctx.add_variable(VarDef {
|
|
||||||
name: "v1",
|
|
||||||
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
|
|
||||||
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
|
|
||||||
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
|
|
||||||
.unwrap();
|
|
||||||
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
|
|
||||||
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
|
|
||||||
ctx.assign("bot", Rc::new(BotType)).unwrap();
|
|
||||||
|
|
||||||
let ast = parse_expression("foo.a").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("foo.d").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no such field".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("foobar.a").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("v0.a").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no fields for type variable".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("v1.a").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no fields for type variable".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("none().a").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no value".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("bot.a").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no such field".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_bool_ops() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
ctx.add_fn(
|
|
||||||
"none",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
|
|
||||||
let ast = parse_expression("True and False").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("True and none()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no value".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("True and 123").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("bool operands must be bool".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_bin_ops() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
let v0 = ctx.add_variable(VarDef {
|
|
||||||
name: "v0",
|
|
||||||
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
|
|
||||||
});
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
ctx.assign("a", TypeVariable(v0).into()).unwrap();
|
|
||||||
|
|
||||||
let ast = parse_expression("1 + 2 + 3").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("a + a + a").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("not supported".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_unary_ops() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
let v0 = ctx.add_variable(VarDef {
|
|
||||||
name: "v0",
|
|
||||||
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
|
|
||||||
});
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
ctx.assign("a", TypeVariable(v0).into()).unwrap();
|
|
||||||
|
|
||||||
let ast = parse_expression("-(123)").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("-a").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("not supported".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("not True").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("not (1)").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("logical not must be applied to bool".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_compare() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
let v0 = ctx.add_variable(VarDef {
|
|
||||||
name: "v0",
|
|
||||||
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
|
|
||||||
});
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
ctx.assign("a", TypeVariable(v0).into()).unwrap();
|
|
||||||
|
|
||||||
let ast = parse_expression("a == a == a").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("not supported".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("a == a == 1").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("not supported".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("True > False").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no such function".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("True in False").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("unsupported comparison".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_call() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
ctx.add_fn(
|
|
||||||
"none",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let foo = ctx.add_class(ClassDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "Foo",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
parents: vec![],
|
|
||||||
});
|
|
||||||
let foo_def = ctx.get_class_def_mut(foo);
|
|
||||||
foo_def.base.methods.insert(
|
|
||||||
"a",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: Some(Rc::new(ClassType(foo))),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let bar = ctx.add_class(ClassDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "Bar",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
parents: vec![],
|
|
||||||
});
|
|
||||||
let bar_def = ctx.get_class_def_mut(bar);
|
|
||||||
bar_def.base.methods.insert(
|
|
||||||
"a",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: Some(Rc::new(ClassType(bar))),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let v0 = ctx.add_variable(VarDef {
|
|
||||||
name: "v0",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
let v1 = ctx.add_variable(VarDef {
|
|
||||||
name: "v1",
|
|
||||||
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
|
|
||||||
});
|
|
||||||
let v2 = ctx.add_variable(VarDef {
|
|
||||||
name: "v2",
|
|
||||||
bound: vec![
|
|
||||||
ClassType(foo).into(),
|
|
||||||
ClassType(bar).into(),
|
|
||||||
ctx.get_primitive(INT32_TYPE),
|
|
||||||
],
|
|
||||||
});
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
|
|
||||||
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
|
|
||||||
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
|
|
||||||
.unwrap();
|
|
||||||
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
|
|
||||||
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
|
|
||||||
ctx.assign("v2", ctx.get_variable(v2)).unwrap();
|
|
||||||
ctx.assign("bot", Rc::new(BotType)).unwrap();
|
|
||||||
|
|
||||||
let ast = parse_expression("foo.a()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
|
|
||||||
|
|
||||||
let ast = parse_expression("v1.a()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("not supported".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("foobar.a()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
|
|
||||||
|
|
||||||
let ast = parse_expression("none().a()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no value".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("bot.a()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("not supported".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("[][0].a()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("not supported".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn infer_subscript() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
ctx.add_fn(
|
|
||||||
"none",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
|
|
||||||
let ast = parse_expression("[1, 2, 3][0]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("[[1]][0][0]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("[1, 2, 3][1:2]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(
|
|
||||||
result.unwrap().unwrap(),
|
|
||||||
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let ast = parse_expression("[1, 2, 3][1:2:2]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(
|
|
||||||
result.unwrap().unwrap(),
|
|
||||||
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let ast = parse_expression("[1, 2, 3][1:1.2]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("slice must be int32 type".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("[1, 2, 3][1:none()]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("slice must have type".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("[1, 2, 3][1.2]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("index must be either slice or int32".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("[1, 2, 3][none()]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no value".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("none()[1.2]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no value".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("123[1]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(
|
|
||||||
result,
|
|
||||||
Err("subscript is not supported for types other than list".into())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_if_expr() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
ctx.add_fn(
|
|
||||||
"none",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
|
|
||||||
let ast = parse_expression("1 if True else 0").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
|
|
||||||
|
|
||||||
let ast = parse_expression("none() if True else none()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap(), None);
|
|
||||||
|
|
||||||
let ast = parse_expression("none() if 1 else none()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("test should be bool".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("1 if True else none()").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("divergent type".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_list_comp() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
ctx.add_fn(
|
|
||||||
"none",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
|
||||||
let mut ctx = get_inference_context(ctx);
|
|
||||||
ctx.assign("z", int32.clone()).unwrap();
|
|
||||||
|
|
||||||
let ast = parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(
|
|
||||||
result.unwrap().unwrap(),
|
|
||||||
ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), int32);
|
|
||||||
|
|
||||||
let ast =
|
|
||||||
parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result.unwrap().unwrap(), int32);
|
|
||||||
|
|
||||||
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("test must be bool".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("[y for x in []][0]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("unbounded identifier".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("[none() for x in []][0]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("no value".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("[z for z in []][0]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(result, Err("duplicated naming".into()));
|
|
||||||
|
|
||||||
let ast = parse_expression("[x for x in [] for y in []]").unwrap();
|
|
||||||
let result = infer_expr(&mut ctx, &ast);
|
|
||||||
assert_eq!(
|
|
||||||
result,
|
|
||||||
Err("only 1 generator statement is supported".into())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,525 +0,0 @@
|
|||||||
use crate::context::InferenceContext;
|
|
||||||
use crate::typedef::{TypeEnum::*, *};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
fn find_subst(
|
|
||||||
ctx: &InferenceContext,
|
|
||||||
valuation: &Option<(VariableId, Type)>,
|
|
||||||
sub: &mut HashMap<VariableId, Type>,
|
|
||||||
mut a: Type,
|
|
||||||
mut b: Type,
|
|
||||||
) -> Result<(), String> {
|
|
||||||
// TODO: fix error messages later
|
|
||||||
if let TypeVariable(id) = a.as_ref() {
|
|
||||||
if let Some((assumption_id, t)) = valuation {
|
|
||||||
if assumption_id == id {
|
|
||||||
a = t.clone();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut substituted = false;
|
|
||||||
if let TypeVariable(id) = b.as_ref() {
|
|
||||||
if let Some(c) = sub.get(&id) {
|
|
||||||
b = c.clone();
|
|
||||||
substituted = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match (a.as_ref(), b.as_ref()) {
|
|
||||||
(BotType, _) => Ok(()),
|
|
||||||
(TypeVariable(id_a), TypeVariable(id_b)) => {
|
|
||||||
if substituted {
|
|
||||||
return if id_a == id_b {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err("different variables".to_string())
|
|
||||||
};
|
|
||||||
}
|
|
||||||
let v_a = ctx.get_variable_def(*id_a);
|
|
||||||
let v_b = ctx.get_variable_def(*id_b);
|
|
||||||
if !v_b.bound.is_empty() {
|
|
||||||
if v_a.bound.is_empty() {
|
|
||||||
return Err("unbounded a".to_string());
|
|
||||||
} else {
|
|
||||||
let diff: Vec<_> = v_a
|
|
||||||
.bound
|
|
||||||
.iter()
|
|
||||||
.filter(|x| !v_b.bound.contains(x))
|
|
||||||
.collect();
|
|
||||||
if !diff.is_empty() {
|
|
||||||
return Err("different domain".to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sub.insert(*id_b, a.clone());
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
(TypeVariable(id_a), _) => {
|
|
||||||
let v_a = ctx.get_variable_def(*id_a);
|
|
||||||
if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err("different domain".to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(_, TypeVariable(id_b)) => {
|
|
||||||
let v_b = ctx.get_variable_def(*id_b);
|
|
||||||
if v_b.bound.is_empty() || v_b.bound.contains(&a) {
|
|
||||||
sub.insert(*id_b, a.clone());
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err("different domain".to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(_, VirtualClassType(id_b)) => {
|
|
||||||
let mut parents;
|
|
||||||
match a.as_ref() {
|
|
||||||
ClassType(id_a) => {
|
|
||||||
parents = [*id_a].to_vec();
|
|
||||||
}
|
|
||||||
VirtualClassType(id_a) => {
|
|
||||||
parents = [*id_a].to_vec();
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
return Err("cannot substitute non-class type into virtual class".to_string());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
while !parents.is_empty() {
|
|
||||||
if *id_b == parents[0] {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let c = ctx.get_class_def(parents.remove(0));
|
|
||||||
parents.extend_from_slice(&c.parents);
|
|
||||||
}
|
|
||||||
Err("not subtype".to_string())
|
|
||||||
}
|
|
||||||
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
|
|
||||||
if id_a != id_b || param_a.len() != param_b.len() {
|
|
||||||
Err("different parametric types".to_string())
|
|
||||||
} else {
|
|
||||||
for (x, y) in param_a.iter().zip(param_b.iter()) {
|
|
||||||
find_subst(ctx, valuation, sub, x.clone(), y.clone())?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(_, _) => {
|
|
||||||
if a == b {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err("not equal".to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_call_rec(
|
|
||||||
ctx: &InferenceContext,
|
|
||||||
valuation: &Option<(VariableId, Type)>,
|
|
||||||
obj: Option<Type>,
|
|
||||||
func: &str,
|
|
||||||
args: &[Type],
|
|
||||||
) -> Result<Option<Type>, String> {
|
|
||||||
let mut subst = obj
|
|
||||||
.as_ref()
|
|
||||||
.map(|v| v.get_subst(ctx))
|
|
||||||
.unwrap_or_else(HashMap::new);
|
|
||||||
|
|
||||||
let fun = match &obj {
|
|
||||||
Some(obj) => {
|
|
||||||
let base = match obj.as_ref() {
|
|
||||||
PrimitiveType(id) => &ctx.get_primitive_def(*id),
|
|
||||||
ClassType(id) | VirtualClassType(id) => &ctx.get_class_def(*id).base,
|
|
||||||
ParametricType(id, _) => &ctx.get_parametric_def(*id).base,
|
|
||||||
_ => return Err("not supported".to_string()),
|
|
||||||
};
|
|
||||||
base.methods.get(func)
|
|
||||||
}
|
|
||||||
None => ctx.get_fn_def(func),
|
|
||||||
}
|
|
||||||
.ok_or_else(|| "no such function".to_string())?;
|
|
||||||
|
|
||||||
if args.len() != fun.args.len() {
|
|
||||||
return Err("incorrect parameter number".to_string());
|
|
||||||
}
|
|
||||||
for (a, b) in args.iter().zip(fun.args.iter()) {
|
|
||||||
find_subst(ctx, valuation, &mut subst, a.clone(), b.clone())?;
|
|
||||||
}
|
|
||||||
let result = fun.result.as_ref().map(|v| v.subst(&subst));
|
|
||||||
Ok(result.map(|result| {
|
|
||||||
if let SelfType = result {
|
|
||||||
obj.unwrap()
|
|
||||||
} else {
|
|
||||||
result.into()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn resolve_call(
|
|
||||||
ctx: &InferenceContext,
|
|
||||||
obj: Option<Type>,
|
|
||||||
func: &str,
|
|
||||||
args: &[Type],
|
|
||||||
) -> Result<Option<Type>, String> {
|
|
||||||
resolve_call_rec(ctx, &None, obj, func, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::context::TopLevelContext;
|
|
||||||
use crate::primitives::*;
|
|
||||||
use std::rc::Rc;
|
|
||||||
|
|
||||||
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
|
|
||||||
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_simple_generic() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
let v1 = ctx.add_variable(VarDef {
|
|
||||||
name: "V1",
|
|
||||||
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)],
|
|
||||||
});
|
|
||||||
let v1 = ctx.get_variable(v1);
|
|
||||||
let v2 = ctx.add_variable(VarDef {
|
|
||||||
name: "V2",
|
|
||||||
bound: vec![
|
|
||||||
ctx.get_primitive(BOOL_TYPE),
|
|
||||||
ctx.get_primitive(INT32_TYPE),
|
|
||||||
ctx.get_primitive(FLOAT_TYPE),
|
|
||||||
],
|
|
||||||
});
|
|
||||||
let v2 = ctx.get_variable(v2);
|
|
||||||
let ctx = get_inference_context(ctx);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(FLOAT_TYPE)]),
|
|
||||||
Ok(Some(ctx.get_primitive(INT32_TYPE)))
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(INT32_TYPE)],),
|
|
||||||
Ok(Some(ctx.get_primitive(INT32_TYPE)))
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "float", &[ctx.get_primitive(INT32_TYPE)]),
|
|
||||||
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "float", &[ctx.get_primitive(BOOL_TYPE)]),
|
|
||||||
Err("different domain".to_string())
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "float", &[]),
|
|
||||||
Err("incorrect parameter number".to_string())
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "float", &[v1]),
|
|
||||||
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "float", &[v2]),
|
|
||||||
Err("different domain".to_string())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_methods() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
|
|
||||||
let v0 = ctx.add_variable(VarDef {
|
|
||||||
name: "V0",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
let v0 = ctx.get_variable(v0);
|
|
||||||
|
|
||||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
|
||||||
let int64 = ctx.get_primitive(INT64_TYPE);
|
|
||||||
let ctx = get_inference_context(ctx);
|
|
||||||
|
|
||||||
// simple cases
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
|
|
||||||
Ok(Some(int32.clone()))
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_ne!(
|
|
||||||
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
|
|
||||||
Ok(Some(int64.clone()))
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, Some(int32), "__add__", &[int64]),
|
|
||||||
Err("not equal".to_string())
|
|
||||||
);
|
|
||||||
|
|
||||||
// with type variables
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, Some(v0.clone()), "__add__", &[v0.clone()]),
|
|
||||||
Err("not supported".into())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_multi_generic() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
let v0 = ctx.add_variable(VarDef {
|
|
||||||
name: "V0",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
let v0 = ctx.get_variable(v0);
|
|
||||||
let v1 = ctx.add_variable(VarDef {
|
|
||||||
name: "V1",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
let v1 = ctx.get_variable(v1);
|
|
||||||
let v2 = ctx.add_variable(VarDef {
|
|
||||||
name: "V2",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
let v2 = ctx.get_variable(v2);
|
|
||||||
let v3 = ctx.add_variable(VarDef {
|
|
||||||
name: "V3",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
let v3 = ctx.get_variable(v3);
|
|
||||||
|
|
||||||
ctx.add_fn(
|
|
||||||
"foo",
|
|
||||||
FnDef {
|
|
||||||
args: vec![v0.clone(), v0.clone(), v1.clone()],
|
|
||||||
result: Some(v0.clone()),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.add_fn(
|
|
||||||
"foo1",
|
|
||||||
FnDef {
|
|
||||||
args: vec![ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1]).into()],
|
|
||||||
result: Some(v0),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
let ctx = get_inference_context(ctx);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]),
|
|
||||||
Ok(Some(v2.clone()))
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]),
|
|
||||||
Ok(Some(v2.clone()))
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]),
|
|
||||||
Err("different variables".to_string())
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(
|
|
||||||
&ctx,
|
|
||||||
None,
|
|
||||||
"foo1",
|
|
||||||
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()]
|
|
||||||
),
|
|
||||||
Ok(Some(v2.clone()))
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(
|
|
||||||
&ctx,
|
|
||||||
None,
|
|
||||||
"foo1",
|
|
||||||
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()]
|
|
||||||
),
|
|
||||||
Ok(Some(v2.clone()))
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(
|
|
||||||
&ctx,
|
|
||||||
None,
|
|
||||||
"foo1",
|
|
||||||
&[ParametricType(TUPLE_TYPE, vec![v2, v3.clone(), v3]).into()]
|
|
||||||
),
|
|
||||||
Err("different variables".to_string())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_class_generics() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
|
|
||||||
let list = ctx.get_parametric_def_mut(LIST_TYPE);
|
|
||||||
let t = Rc::new(TypeVariable(list.params[0]));
|
|
||||||
list.base.methods.insert(
|
|
||||||
"head",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result: Some(t.clone()),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
list.base.methods.insert(
|
|
||||||
"append",
|
|
||||||
FnDef {
|
|
||||||
args: vec![t],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let v0 = ctx.add_variable(VarDef {
|
|
||||||
name: "V0",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
let v0 = ctx.get_variable(v0);
|
|
||||||
let v1 = ctx.add_variable(VarDef {
|
|
||||||
name: "V1",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
let v1 = ctx.get_variable(v1);
|
|
||||||
let ctx = get_inference_context(ctx);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(
|
|
||||||
&ctx,
|
|
||||||
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
|
|
||||||
"head",
|
|
||||||
&[]
|
|
||||||
),
|
|
||||||
Ok(Some(v0.clone()))
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(
|
|
||||||
&ctx,
|
|
||||||
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
|
|
||||||
"append",
|
|
||||||
&[v0.clone()]
|
|
||||||
),
|
|
||||||
Ok(None)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(
|
|
||||||
&ctx,
|
|
||||||
Some(ParametricType(LIST_TYPE, vec![v0]).into()),
|
|
||||||
"append",
|
|
||||||
&[v1]
|
|
||||||
),
|
|
||||||
Err("different variables".to_string())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_virtual_class() {
|
|
||||||
let mut ctx = basic_ctx();
|
|
||||||
|
|
||||||
let foo = ctx.add_class(ClassDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "Foo",
|
|
||||||
methods: HashMap::new(),
|
|
||||||
fields: HashMap::new(),
|
|
||||||
},
|
|
||||||
parents: vec![],
|
|
||||||
});
|
|
||||||
|
|
||||||
let foo1 = ctx.add_class(ClassDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "Foo1",
|
|
||||||
methods: HashMap::new(),
|
|
||||||
fields: HashMap::new(),
|
|
||||||
},
|
|
||||||
parents: vec![foo],
|
|
||||||
});
|
|
||||||
|
|
||||||
let foo2 = ctx.add_class(ClassDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "Foo2",
|
|
||||||
methods: HashMap::new(),
|
|
||||||
fields: HashMap::new(),
|
|
||||||
},
|
|
||||||
parents: vec![foo1],
|
|
||||||
});
|
|
||||||
|
|
||||||
let bar = ctx.add_class(ClassDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "bar",
|
|
||||||
methods: HashMap::new(),
|
|
||||||
fields: HashMap::new(),
|
|
||||||
},
|
|
||||||
parents: vec![],
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.add_fn(
|
|
||||||
"foo",
|
|
||||||
FnDef {
|
|
||||||
args: vec![VirtualClassType(foo).into()],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
ctx.add_fn(
|
|
||||||
"foo1",
|
|
||||||
FnDef {
|
|
||||||
args: vec![VirtualClassType(foo1).into()],
|
|
||||||
result: None,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
let ctx = get_inference_context(ctx);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]),
|
|
||||||
Ok(None)
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]),
|
|
||||||
Ok(None)
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]),
|
|
||||||
Ok(None)
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]),
|
|
||||||
Err("not subtype".to_string())
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]),
|
|
||||||
Ok(None)
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]),
|
|
||||||
Ok(None)
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]),
|
|
||||||
Err("not subtype".to_string())
|
|
||||||
);
|
|
||||||
|
|
||||||
// virtual class substitution
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]),
|
|
||||||
Ok(None)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]),
|
|
||||||
Ok(None)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]),
|
|
||||||
Ok(None)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]),
|
|
||||||
Err("not subtype".to_string())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,593 +1,8 @@
|
|||||||
#![warn(clippy::all)]
|
#![warn(clippy::all)]
|
||||||
#![allow(clippy::clone_double_ref)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
extern crate num_bigint;
|
mod codegen;
|
||||||
extern crate inkwell;
|
mod location;
|
||||||
extern crate rustpython_parser;
|
mod symbol_resolver;
|
||||||
|
mod top_level;
|
||||||
pub mod expression_inference;
|
mod typecheck;
|
||||||
pub mod inference_core;
|
|
||||||
mod magic_methods;
|
|
||||||
pub mod primitives;
|
|
||||||
pub mod typedef;
|
|
||||||
pub mod context;
|
|
||||||
|
|
||||||
use std::error::Error;
|
|
||||||
use std::fmt;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use num_traits::cast::ToPrimitive;
|
|
||||||
|
|
||||||
use rustpython_parser::ast;
|
|
||||||
|
|
||||||
use inkwell::OptimizationLevel;
|
|
||||||
use inkwell::builder::Builder;
|
|
||||||
use inkwell::context::Context;
|
|
||||||
use inkwell::module::Module;
|
|
||||||
use inkwell::targets::*;
|
|
||||||
use inkwell::types;
|
|
||||||
use inkwell::types::BasicType;
|
|
||||||
use inkwell::values;
|
|
||||||
use inkwell::{IntPredicate, FloatPredicate};
|
|
||||||
use inkwell::basic_block;
|
|
||||||
use inkwell::passes;
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum CompileErrorKind {
|
|
||||||
Unsupported(&'static str),
|
|
||||||
MissingTypeAnnotation,
|
|
||||||
UnknownTypeAnnotation,
|
|
||||||
IncompatibleTypes,
|
|
||||||
UnboundIdentifier,
|
|
||||||
BreakOutsideLoop,
|
|
||||||
Internal(&'static str)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for CompileErrorKind {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
match self {
|
|
||||||
CompileErrorKind::Unsupported(feature)
|
|
||||||
=> write!(f, "The following Python feature is not supported by NAC3: {}", feature),
|
|
||||||
CompileErrorKind::MissingTypeAnnotation
|
|
||||||
=> write!(f, "Missing type annotation"),
|
|
||||||
CompileErrorKind::UnknownTypeAnnotation
|
|
||||||
=> write!(f, "Unknown type annotation"),
|
|
||||||
CompileErrorKind::IncompatibleTypes
|
|
||||||
=> write!(f, "Incompatible types"),
|
|
||||||
CompileErrorKind::UnboundIdentifier
|
|
||||||
=> write!(f, "Unbound identifier"),
|
|
||||||
CompileErrorKind::BreakOutsideLoop
|
|
||||||
=> write!(f, "Break outside loop"),
|
|
||||||
CompileErrorKind::Internal(details)
|
|
||||||
=> write!(f, "Internal compiler error: {}", details),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct CompileError {
|
|
||||||
location: ast::Location,
|
|
||||||
kind: CompileErrorKind,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for CompileError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "{}, at {}", self.kind, self.location)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Error for CompileError {}
|
|
||||||
|
|
||||||
type CompileResult<T> = Result<T, CompileError>;
|
|
||||||
|
|
||||||
pub struct CodeGen<'ctx> {
|
|
||||||
context: &'ctx Context,
|
|
||||||
module: Module<'ctx>,
|
|
||||||
pass_manager: passes::PassManager<values::FunctionValue<'ctx>>,
|
|
||||||
builder: Builder<'ctx>,
|
|
||||||
current_source_location: ast::Location,
|
|
||||||
namespace: HashMap<String, values::PointerValue<'ctx>>,
|
|
||||||
break_bb: Option<basic_block::BasicBlock<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> CodeGen<'ctx> {
|
|
||||||
pub fn new(context: &'ctx Context) -> CodeGen<'ctx> {
|
|
||||||
let module = context.create_module("kernel");
|
|
||||||
|
|
||||||
let pass_manager = passes::PassManager::create(&module);
|
|
||||||
pass_manager.add_instruction_combining_pass();
|
|
||||||
pass_manager.add_reassociate_pass();
|
|
||||||
pass_manager.add_gvn_pass();
|
|
||||||
pass_manager.add_cfg_simplification_pass();
|
|
||||||
pass_manager.add_basic_alias_analysis_pass();
|
|
||||||
pass_manager.add_promote_memory_to_register_pass();
|
|
||||||
pass_manager.add_instruction_combining_pass();
|
|
||||||
pass_manager.add_reassociate_pass();
|
|
||||||
pass_manager.initialize();
|
|
||||||
|
|
||||||
let i32_type = context.i32_type();
|
|
||||||
let fn_type = i32_type.fn_type(&[i32_type.into()], false);
|
|
||||||
module.add_function("output", fn_type, None);
|
|
||||||
|
|
||||||
CodeGen {
|
|
||||||
context, module, pass_manager,
|
|
||||||
builder: context.create_builder(),
|
|
||||||
current_source_location: ast::Location::default(),
|
|
||||||
namespace: HashMap::new(),
|
|
||||||
break_bb: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_source_location(&mut self, location: ast::Location) {
|
|
||||||
self.current_source_location = location;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compile_error(&self, kind: CompileErrorKind) -> CompileError {
|
|
||||||
CompileError {
|
|
||||||
location: self.current_source_location,
|
|
||||||
kind
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_basic_type(&self, name: &str) -> CompileResult<types::BasicTypeEnum<'ctx>> {
|
|
||||||
match name {
|
|
||||||
"bool" => Ok(self.context.bool_type().into()),
|
|
||||||
"int32" => Ok(self.context.i32_type().into()),
|
|
||||||
"int64" => Ok(self.context.i64_type().into()),
|
|
||||||
"float32" => Ok(self.context.f32_type().into()),
|
|
||||||
"float64" => Ok(self.context.f64_type().into()),
|
|
||||||
_ => Err(self.compile_error(CompileErrorKind::UnknownTypeAnnotation))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compile_function_def(
|
|
||||||
&mut self,
|
|
||||||
name: &str,
|
|
||||||
args: &ast::Parameters,
|
|
||||||
body: &ast::Suite,
|
|
||||||
decorator_list: &[ast::Expression],
|
|
||||||
returns: &Option<ast::Expression>,
|
|
||||||
is_async: bool,
|
|
||||||
) -> CompileResult<values::FunctionValue<'ctx>> {
|
|
||||||
if is_async {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::Unsupported("async functions")))
|
|
||||||
}
|
|
||||||
for decorator in decorator_list.iter() {
|
|
||||||
self.set_source_location(decorator.location);
|
|
||||||
if let ast::ExpressionType::Identifier { name } = &decorator.node {
|
|
||||||
if name != "kernel" && name != "portable" {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::Unsupported("custom decorators")))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::Unsupported("decorator must be an identifier")))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let args_type = args.args.iter().map(|val| {
|
|
||||||
self.set_source_location(val.location);
|
|
||||||
if let Some(annotation) = &val.annotation {
|
|
||||||
if let ast::ExpressionType::Identifier { name } = &annotation.node {
|
|
||||||
Ok(self.get_basic_type(&name)?)
|
|
||||||
} else {
|
|
||||||
Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(self.compile_error(CompileErrorKind::MissingTypeAnnotation))
|
|
||||||
}
|
|
||||||
}).collect::<CompileResult<Vec<types::BasicTypeEnum>>>()?;
|
|
||||||
let return_type = if let Some(returns) = returns {
|
|
||||||
self.set_source_location(returns.location);
|
|
||||||
if let ast::ExpressionType::Identifier { name } = &returns.node {
|
|
||||||
if name == "None" { None } else { Some(self.get_basic_type(name)?) }
|
|
||||||
} else {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let fn_type = match return_type {
|
|
||||||
Some(ty) => ty.fn_type(&args_type, false),
|
|
||||||
None => self.context.void_type().fn_type(&args_type, false)
|
|
||||||
};
|
|
||||||
|
|
||||||
let function = self.module.add_function(name, fn_type, None);
|
|
||||||
let basic_block = self.context.append_basic_block(function, "entry");
|
|
||||||
self.builder.position_at_end(basic_block);
|
|
||||||
|
|
||||||
for (n, arg) in args.args.iter().enumerate() {
|
|
||||||
let param = function.get_nth_param(n as u32).unwrap();
|
|
||||||
let alloca = self.builder.build_alloca(param.get_type(), &arg.arg);
|
|
||||||
self.builder.build_store(alloca, param);
|
|
||||||
self.namespace.insert(arg.arg.clone(), alloca);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.compile_suite(body, return_type)?;
|
|
||||||
|
|
||||||
Ok(function)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compile_expression(
|
|
||||||
&mut self,
|
|
||||||
expression: &ast::Expression
|
|
||||||
) -> CompileResult<values::BasicValueEnum<'ctx>> {
|
|
||||||
self.set_source_location(expression.location);
|
|
||||||
|
|
||||||
match &expression.node {
|
|
||||||
ast::ExpressionType::True => Ok(self.context.bool_type().const_int(1, false).into()),
|
|
||||||
ast::ExpressionType::False => Ok(self.context.bool_type().const_int(0, false).into()),
|
|
||||||
ast::ExpressionType::Number { value: ast::Number::Integer { value } } => {
|
|
||||||
let mut bits = value.bits();
|
|
||||||
if value.sign() == num_bigint::Sign::Minus {
|
|
||||||
bits += 1;
|
|
||||||
}
|
|
||||||
match bits {
|
|
||||||
0..=32 => Ok(self.context.i32_type().const_int(value.to_i32().unwrap() as _, true).into()),
|
|
||||||
33..=64 => Ok(self.context.i64_type().const_int(value.to_i64().unwrap() as _, true).into()),
|
|
||||||
_ => Err(self.compile_error(CompileErrorKind::Unsupported("integers larger than 64 bits")))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ast::ExpressionType::Number { value: ast::Number::Float { value } } => {
|
|
||||||
Ok(self.context.f64_type().const_float(*value).into())
|
|
||||||
},
|
|
||||||
ast::ExpressionType::Identifier { name } => {
|
|
||||||
match self.namespace.get(name) {
|
|
||||||
Some(value) => Ok(self.builder.build_load(*value, name).into()),
|
|
||||||
None => Err(self.compile_error(CompileErrorKind::UnboundIdentifier))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ast::ExpressionType::Unop { op, a } => {
|
|
||||||
let a = self.compile_expression(&a)?;
|
|
||||||
match (op, a) {
|
|
||||||
(ast::UnaryOperator::Pos, values::BasicValueEnum::IntValue(a))
|
|
||||||
=> Ok(a.into()),
|
|
||||||
(ast::UnaryOperator::Pos, values::BasicValueEnum::FloatValue(a))
|
|
||||||
=> Ok(a.into()),
|
|
||||||
(ast::UnaryOperator::Neg, values::BasicValueEnum::IntValue(a))
|
|
||||||
=> Ok(self.builder.build_int_neg(a, "tmpneg").into()),
|
|
||||||
(ast::UnaryOperator::Neg, values::BasicValueEnum::FloatValue(a))
|
|
||||||
=> Ok(self.builder.build_float_neg(a, "tmpneg").into()),
|
|
||||||
(ast::UnaryOperator::Inv, values::BasicValueEnum::IntValue(a))
|
|
||||||
=> Ok(self.builder.build_not(a, "tmpnot").into()),
|
|
||||||
(ast::UnaryOperator::Not, values::BasicValueEnum::IntValue(a)) => {
|
|
||||||
// boolean "not"
|
|
||||||
if a.get_type().get_bit_width() != 1 {
|
|
||||||
Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation")))
|
|
||||||
} else {
|
|
||||||
Ok(self.builder.build_not(a, "tmpnot").into())
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation"))),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ast::ExpressionType::Binop { a, op, b } => {
|
|
||||||
let a = self.compile_expression(&a)?;
|
|
||||||
let b = self.compile_expression(&b)?;
|
|
||||||
if a.get_type() != b.get_type() {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
|
||||||
}
|
|
||||||
use ast::Operator::*;
|
|
||||||
match (op, a, b) {
|
|
||||||
(Add, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
|
||||||
=> Ok(self.builder.build_int_add(a, b, "tmpadd").into()),
|
|
||||||
(Sub, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
|
||||||
=> Ok(self.builder.build_int_sub(a, b, "tmpsub").into()),
|
|
||||||
(Mult, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
|
||||||
=> Ok(self.builder.build_int_mul(a, b, "tmpmul").into()),
|
|
||||||
|
|
||||||
(Add, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
|
||||||
=> Ok(self.builder.build_float_add(a, b, "tmpadd").into()),
|
|
||||||
(Sub, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
|
||||||
=> Ok(self.builder.build_float_sub(a, b, "tmpsub").into()),
|
|
||||||
(Mult, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
|
||||||
=> Ok(self.builder.build_float_mul(a, b, "tmpmul").into()),
|
|
||||||
|
|
||||||
(Div, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
|
||||||
=> Ok(self.builder.build_float_div(a, b, "tmpdiv").into()),
|
|
||||||
(FloorDiv, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
|
||||||
=> Ok(self.builder.build_int_signed_div(a, b, "tmpdiv").into()),
|
|
||||||
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented binary operation"))),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ast::ExpressionType::Compare { vals, ops } => {
|
|
||||||
let mut vals = vals.iter();
|
|
||||||
let mut ops = ops.iter();
|
|
||||||
|
|
||||||
let mut result = None;
|
|
||||||
let mut a = self.compile_expression(vals.next().unwrap())?;
|
|
||||||
loop {
|
|
||||||
if let Some(op) = ops.next() {
|
|
||||||
let b = self.compile_expression(vals.next().unwrap())?;
|
|
||||||
if a.get_type() != b.get_type() {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
|
||||||
}
|
|
||||||
let this_result = match (a, b) {
|
|
||||||
(values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) => {
|
|
||||||
match op {
|
|
||||||
ast::Comparison::Equal
|
|
||||||
=> self.builder.build_int_compare(IntPredicate::EQ, a, b, "tmpeq"),
|
|
||||||
ast::Comparison::NotEqual
|
|
||||||
=> self.builder.build_int_compare(IntPredicate::NE, a, b, "tmpne"),
|
|
||||||
ast::Comparison::Less
|
|
||||||
=> self.builder.build_int_compare(IntPredicate::SLT, a, b, "tmpslt"),
|
|
||||||
ast::Comparison::LessOrEqual
|
|
||||||
=> self.builder.build_int_compare(IntPredicate::SLE, a, b, "tmpsle"),
|
|
||||||
ast::Comparison::Greater
|
|
||||||
=> self.builder.build_int_compare(IntPredicate::SGT, a, b, "tmpsgt"),
|
|
||||||
ast::Comparison::GreaterOrEqual
|
|
||||||
=> self.builder.build_int_compare(IntPredicate::SGE, a, b, "tmpsge"),
|
|
||||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
(values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) => {
|
|
||||||
match op {
|
|
||||||
ast::Comparison::Equal
|
|
||||||
=> self.builder.build_float_compare(FloatPredicate::OEQ, a, b, "tmpoeq"),
|
|
||||||
ast::Comparison::NotEqual
|
|
||||||
=> self.builder.build_float_compare(FloatPredicate::UNE, a, b, "tmpune"),
|
|
||||||
ast::Comparison::Less
|
|
||||||
=> self.builder.build_float_compare(FloatPredicate::OLT, a, b, "tmpolt"),
|
|
||||||
ast::Comparison::LessOrEqual
|
|
||||||
=> self.builder.build_float_compare(FloatPredicate::OLE, a, b, "tmpole"),
|
|
||||||
ast::Comparison::Greater
|
|
||||||
=> self.builder.build_float_compare(FloatPredicate::OGT, a, b, "tmpogt"),
|
|
||||||
ast::Comparison::GreaterOrEqual
|
|
||||||
=> self.builder.build_float_compare(FloatPredicate::OGE, a, b, "tmpoge"),
|
|
||||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("comparison of non-numerical types"))),
|
|
||||||
};
|
|
||||||
match result {
|
|
||||||
Some(last) => {
|
|
||||||
result = Some(self.builder.build_and(last, this_result, "tmpand"));
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
result = Some(this_result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
a = b;
|
|
||||||
} else {
|
|
||||||
return Ok(result.unwrap().into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ast::ExpressionType::Call { function, args, keywords } => {
|
|
||||||
if !keywords.is_empty() {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::Unsupported("keyword arguments")))
|
|
||||||
}
|
|
||||||
let args = args.iter().map(|val| self.compile_expression(val))
|
|
||||||
.collect::<CompileResult<Vec<values::BasicValueEnum>>>()?;
|
|
||||||
self.set_source_location(expression.location);
|
|
||||||
if let ast::ExpressionType::Identifier { name } = &function.node {
|
|
||||||
match (name.as_str(), args[0]) {
|
|
||||||
("int32", values::BasicValueEnum::IntValue(a)) => {
|
|
||||||
let nbits = a.get_type().get_bit_width();
|
|
||||||
if nbits < 32 {
|
|
||||||
Ok(self.builder.build_int_s_extend(a, self.context.i32_type(), "tmpsext").into())
|
|
||||||
} else if nbits > 32 {
|
|
||||||
Ok(self.builder.build_int_truncate(a, self.context.i32_type(), "tmptrunc").into())
|
|
||||||
} else {
|
|
||||||
Ok(a.into())
|
|
||||||
}
|
|
||||||
},
|
|
||||||
("int64", values::BasicValueEnum::IntValue(a)) => {
|
|
||||||
let nbits = a.get_type().get_bit_width();
|
|
||||||
if nbits < 64 {
|
|
||||||
Ok(self.builder.build_int_s_extend(a, self.context.i64_type(), "tmpsext").into())
|
|
||||||
} else {
|
|
||||||
Ok(a.into())
|
|
||||||
}
|
|
||||||
},
|
|
||||||
("int32", values::BasicValueEnum::FloatValue(a)) => {
|
|
||||||
Ok(self.builder.build_float_to_signed_int(a, self.context.i32_type(), "tmpfptosi").into())
|
|
||||||
},
|
|
||||||
("int64", values::BasicValueEnum::FloatValue(a)) => {
|
|
||||||
Ok(self.builder.build_float_to_signed_int(a, self.context.i64_type(), "tmpfptosi").into())
|
|
||||||
},
|
|
||||||
("float32", values::BasicValueEnum::IntValue(a)) => {
|
|
||||||
Ok(self.builder.build_signed_int_to_float(a, self.context.f32_type(), "tmpsitofp").into())
|
|
||||||
},
|
|
||||||
("float64", values::BasicValueEnum::IntValue(a)) => {
|
|
||||||
Ok(self.builder.build_signed_int_to_float(a, self.context.f64_type(), "tmpsitofp").into())
|
|
||||||
},
|
|
||||||
("float32", values::BasicValueEnum::FloatValue(a)) => {
|
|
||||||
if a.get_type() == self.context.f64_type() {
|
|
||||||
Ok(self.builder.build_float_trunc(a, self.context.f32_type(), "tmptrunc").into())
|
|
||||||
} else {
|
|
||||||
Ok(a.into())
|
|
||||||
}
|
|
||||||
},
|
|
||||||
("float64", values::BasicValueEnum::FloatValue(a)) => {
|
|
||||||
if a.get_type() == self.context.f32_type() {
|
|
||||||
Ok(self.builder.build_float_ext(a, self.context.f64_type(), "tmpext").into())
|
|
||||||
} else {
|
|
||||||
Ok(a.into())
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
("output", values::BasicValueEnum::IntValue(a)) => {
|
|
||||||
let fn_value = self.module.get_function("output").unwrap();
|
|
||||||
Ok(self.builder.build_call(fn_value, &[a.into()], "call")
|
|
||||||
.try_as_basic_value().left().unwrap())
|
|
||||||
},
|
|
||||||
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unrecognized call")))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::Unsupported("function must be an identifier")))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented expression"))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compile_statement(
|
|
||||||
&mut self,
|
|
||||||
statement: &ast::Statement,
|
|
||||||
return_type: Option<types::BasicTypeEnum>
|
|
||||||
) -> CompileResult<()> {
|
|
||||||
self.set_source_location(statement.location);
|
|
||||||
|
|
||||||
use ast::StatementType::*;
|
|
||||||
match &statement.node {
|
|
||||||
Assign { targets, value } => {
|
|
||||||
let value = self.compile_expression(value)?;
|
|
||||||
for target in targets.iter() {
|
|
||||||
self.set_source_location(target.location);
|
|
||||||
if let ast::ExpressionType::Identifier { name } = &target.node {
|
|
||||||
let builder = &self.builder;
|
|
||||||
let target = self.namespace.entry(name.clone()).or_insert_with(
|
|
||||||
|| builder.build_alloca(value.get_type(), name));
|
|
||||||
if target.get_type() != value.get_type().ptr_type(inkwell::AddressSpace::Generic) {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
|
||||||
}
|
|
||||||
builder.build_store(*target, value);
|
|
||||||
} else {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::Unsupported("assignment target must be an identifier")))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Expression { expression } => { self.compile_expression(expression)?; },
|
|
||||||
If { test, body, orelse } => {
|
|
||||||
let test = self.compile_expression(test)?;
|
|
||||||
if test.get_type() != self.context.bool_type().into() {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
|
||||||
}
|
|
||||||
|
|
||||||
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
|
||||||
let then_bb = self.context.append_basic_block(parent, "then");
|
|
||||||
let else_bb = self.context.append_basic_block(parent, "else");
|
|
||||||
let cont_bb = self.context.append_basic_block(parent, "ifcont");
|
|
||||||
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
|
|
||||||
|
|
||||||
self.builder.position_at_end(then_bb);
|
|
||||||
self.compile_suite(body, return_type)?;
|
|
||||||
self.builder.build_unconditional_branch(cont_bb);
|
|
||||||
|
|
||||||
self.builder.position_at_end(else_bb);
|
|
||||||
if let Some(orelse) = orelse {
|
|
||||||
self.compile_suite(orelse, return_type)?;
|
|
||||||
}
|
|
||||||
self.builder.build_unconditional_branch(cont_bb);
|
|
||||||
self.builder.position_at_end(cont_bb);
|
|
||||||
},
|
|
||||||
While { test, body, orelse } => {
|
|
||||||
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
|
||||||
let test_bb = self.context.append_basic_block(parent, "test");
|
|
||||||
self.builder.build_unconditional_branch(test_bb);
|
|
||||||
self.builder.position_at_end(test_bb);
|
|
||||||
let test = self.compile_expression(test)?;
|
|
||||||
if test.get_type() != self.context.bool_type().into() {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
|
||||||
}
|
|
||||||
|
|
||||||
let then_bb = self.context.append_basic_block(parent, "then");
|
|
||||||
let else_bb = self.context.append_basic_block(parent, "else");
|
|
||||||
let cont_bb = self.context.append_basic_block(parent, "ifcont");
|
|
||||||
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
|
|
||||||
|
|
||||||
self.break_bb = Some(cont_bb);
|
|
||||||
|
|
||||||
self.builder.position_at_end(then_bb);
|
|
||||||
self.compile_suite(body, return_type)?;
|
|
||||||
self.builder.build_unconditional_branch(test_bb);
|
|
||||||
|
|
||||||
self.builder.position_at_end(else_bb);
|
|
||||||
if let Some(orelse) = orelse {
|
|
||||||
self.compile_suite(orelse, return_type)?;
|
|
||||||
}
|
|
||||||
self.builder.build_unconditional_branch(cont_bb);
|
|
||||||
self.builder.position_at_end(cont_bb);
|
|
||||||
|
|
||||||
self.break_bb = None;
|
|
||||||
},
|
|
||||||
Break => {
|
|
||||||
if let Some(bb) = self.break_bb {
|
|
||||||
self.builder.build_unconditional_branch(bb);
|
|
||||||
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
|
||||||
let unreachable_bb = self.context.append_basic_block(parent, "unreachable");
|
|
||||||
self.builder.position_at_end(unreachable_bb);
|
|
||||||
} else {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::BreakOutsideLoop));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Return { value: Some(value) } => {
|
|
||||||
if let Some(return_type) = return_type {
|
|
||||||
let value = self.compile_expression(value)?;
|
|
||||||
if value.get_type() != return_type {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
|
||||||
}
|
|
||||||
self.builder.build_return(Some(&value));
|
|
||||||
} else {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Return { value: None } => {
|
|
||||||
if !return_type.is_none() {
|
|
||||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
|
||||||
}
|
|
||||||
self.builder.build_return(None);
|
|
||||||
},
|
|
||||||
Pass => (),
|
|
||||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special statement"))),
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compile_suite(
|
|
||||||
&mut self,
|
|
||||||
suite: &ast::Suite,
|
|
||||||
return_type: Option<types::BasicTypeEnum>
|
|
||||||
) -> CompileResult<()> {
|
|
||||||
for statement in suite.iter() {
|
|
||||||
self.compile_statement(statement, return_type)?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn compile_toplevel(&mut self, statement: &ast::Statement) -> CompileResult<()> {
|
|
||||||
self.set_source_location(statement.location);
|
|
||||||
if let ast::StatementType::FunctionDef {
|
|
||||||
is_async,
|
|
||||||
name,
|
|
||||||
args,
|
|
||||||
body,
|
|
||||||
decorator_list,
|
|
||||||
returns,
|
|
||||||
} = &statement.node {
|
|
||||||
let function = self.compile_function_def(name, args, body, decorator_list, returns, *is_async)?;
|
|
||||||
self.pass_manager.run_on(&function);
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(self.compile_error(CompileErrorKind::Internal("top-level is not a function definition")))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn print_ir(&self) {
|
|
||||||
self.module.print_to_stderr();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn output(&self, filename: &str) {
|
|
||||||
//let triple = TargetTriple::create("riscv32-none-linux-gnu");
|
|
||||||
let triple = TargetMachine::get_default_triple();
|
|
||||||
let target = Target::from_triple(&triple)
|
|
||||||
.expect("couldn't create target from target triple");
|
|
||||||
|
|
||||||
let target_machine = target
|
|
||||||
.create_target_machine(
|
|
||||||
&triple,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
OptimizationLevel::Default,
|
|
||||||
RelocMode::Default,
|
|
||||||
CodeModel::Default,
|
|
||||||
)
|
|
||||||
.expect("couldn't create target machine");
|
|
||||||
|
|
||||||
target_machine
|
|
||||||
.write_to_file(&self.module, FileType::Object, Path::new(filename))
|
|
||||||
.expect("couldn't write module to file");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
31
nac3core/src/location.rs
Normal file
31
nac3core/src/location.rs
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
use rustpython_parser::ast;
|
||||||
|
use std::vec::Vec;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq)]
|
||||||
|
pub struct FileID(u32);
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq)]
|
||||||
|
pub enum Location {
|
||||||
|
CodeRange(FileID, ast::Location),
|
||||||
|
Builtin,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FileRegistry {
|
||||||
|
files: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FileRegistry {
|
||||||
|
pub fn new() -> FileRegistry {
|
||||||
|
FileRegistry { files: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_file(&mut self, path: &str) -> FileID {
|
||||||
|
let index = self.files.len() as u32;
|
||||||
|
self.files.push(path.to_owned());
|
||||||
|
FileID(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn query_file(&self, id: FileID) -> &str {
|
||||||
|
&self.files[id.0 as usize]
|
||||||
|
}
|
||||||
|
}
|
@ -1,58 +0,0 @@
|
|||||||
use rustpython_parser::ast::{Comparison, Operator, UnaryOperator};
|
|
||||||
|
|
||||||
pub fn binop_name(op: &Operator) -> &'static str {
|
|
||||||
match op {
|
|
||||||
Operator::Add => "__add__",
|
|
||||||
Operator::Sub => "__sub__",
|
|
||||||
Operator::Div => "__truediv__",
|
|
||||||
Operator::Mod => "__mod__",
|
|
||||||
Operator::Mult => "__mul__",
|
|
||||||
Operator::Pow => "__pow__",
|
|
||||||
Operator::BitOr => "__or__",
|
|
||||||
Operator::BitXor => "__xor__",
|
|
||||||
Operator::BitAnd => "__and__",
|
|
||||||
Operator::LShift => "__lshift__",
|
|
||||||
Operator::RShift => "__rshift__",
|
|
||||||
Operator::FloorDiv => "__floordiv__",
|
|
||||||
Operator::MatMult => "__matmul__",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn binop_assign_name(op: &Operator) -> &'static str {
|
|
||||||
match op {
|
|
||||||
Operator::Add => "__iadd__",
|
|
||||||
Operator::Sub => "__isub__",
|
|
||||||
Operator::Div => "__itruediv__",
|
|
||||||
Operator::Mod => "__imod__",
|
|
||||||
Operator::Mult => "__imul__",
|
|
||||||
Operator::Pow => "__ipow__",
|
|
||||||
Operator::BitOr => "__ior__",
|
|
||||||
Operator::BitXor => "__ixor__",
|
|
||||||
Operator::BitAnd => "__iand__",
|
|
||||||
Operator::LShift => "__ilshift__",
|
|
||||||
Operator::RShift => "__irshift__",
|
|
||||||
Operator::FloorDiv => "__ifloordiv__",
|
|
||||||
Operator::MatMult => "__imatmul__",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unaryop_name(op: &UnaryOperator) -> &'static str {
|
|
||||||
match op {
|
|
||||||
UnaryOperator::Pos => "__pos__",
|
|
||||||
UnaryOperator::Neg => "__neg__",
|
|
||||||
UnaryOperator::Not => "__not__",
|
|
||||||
UnaryOperator::Inv => "__inv__",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn comparison_name(op: &Comparison) -> Option<&'static str> {
|
|
||||||
match op {
|
|
||||||
Comparison::Less => Some("__lt__"),
|
|
||||||
Comparison::LessOrEqual => Some("__le__"),
|
|
||||||
Comparison::Greater => Some("__gt__"),
|
|
||||||
Comparison::GreaterOrEqual => Some("__ge__"),
|
|
||||||
Comparison::Equal => Some("__eq__"),
|
|
||||||
Comparison::NotEqual => Some("__ne__"),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,184 +0,0 @@
|
|||||||
use super::typedef::{TypeEnum::*, *};
|
|
||||||
use crate::context::*;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
pub const TUPLE_TYPE: ParamId = ParamId(0);
|
|
||||||
pub const LIST_TYPE: ParamId = ParamId(1);
|
|
||||||
|
|
||||||
pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0);
|
|
||||||
pub const INT32_TYPE: PrimitiveId = PrimitiveId(1);
|
|
||||||
pub const INT64_TYPE: PrimitiveId = PrimitiveId(2);
|
|
||||||
pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(3);
|
|
||||||
|
|
||||||
fn impl_math(def: &mut TypeDef, ty: &Type) {
|
|
||||||
let result = Some(ty.clone());
|
|
||||||
let fun = FnDef {
|
|
||||||
args: vec![ty.clone()],
|
|
||||||
result: result.clone(),
|
|
||||||
};
|
|
||||||
def.methods.insert("__add__", fun.clone());
|
|
||||||
def.methods.insert("__sub__", fun.clone());
|
|
||||||
def.methods.insert("__mul__", fun.clone());
|
|
||||||
def.methods.insert(
|
|
||||||
"__neg__",
|
|
||||||
FnDef {
|
|
||||||
args: vec![],
|
|
||||||
result,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
def.methods.insert(
|
|
||||||
"__truediv__",
|
|
||||||
FnDef {
|
|
||||||
args: vec![ty.clone()],
|
|
||||||
result: Some(PrimitiveType(FLOAT_TYPE).into()),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
def.methods.insert("__floordiv__", fun.clone());
|
|
||||||
def.methods.insert("__mod__", fun.clone());
|
|
||||||
def.methods.insert("__pow__", fun);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn impl_bits(def: &mut TypeDef, ty: &Type) {
|
|
||||||
let result = Some(ty.clone());
|
|
||||||
let fun = FnDef {
|
|
||||||
args: vec![PrimitiveType(INT32_TYPE).into()],
|
|
||||||
result,
|
|
||||||
};
|
|
||||||
|
|
||||||
def.methods.insert("__lshift__", fun.clone());
|
|
||||||
def.methods.insert("__rshift__", fun);
|
|
||||||
def.methods.insert(
|
|
||||||
"__xor__",
|
|
||||||
FnDef {
|
|
||||||
args: vec![ty.clone()],
|
|
||||||
result: Some(ty.clone()),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn impl_eq(def: &mut TypeDef, ty: &Type) {
|
|
||||||
let fun = FnDef {
|
|
||||||
args: vec![ty.clone()],
|
|
||||||
result: Some(PrimitiveType(BOOL_TYPE).into()),
|
|
||||||
};
|
|
||||||
|
|
||||||
def.methods.insert("__eq__", fun.clone());
|
|
||||||
def.methods.insert("__ne__", fun);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn impl_order(def: &mut TypeDef, ty: &Type) {
|
|
||||||
let fun = FnDef {
|
|
||||||
args: vec![ty.clone()],
|
|
||||||
result: Some(PrimitiveType(BOOL_TYPE).into()),
|
|
||||||
};
|
|
||||||
|
|
||||||
def.methods.insert("__lt__", fun.clone());
|
|
||||||
def.methods.insert("__gt__", fun.clone());
|
|
||||||
def.methods.insert("__le__", fun.clone());
|
|
||||||
def.methods.insert("__ge__", fun);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn basic_ctx() -> TopLevelContext<'static> {
|
|
||||||
let primitives = [
|
|
||||||
TypeDef {
|
|
||||||
name: "bool",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
TypeDef {
|
|
||||||
name: "int32",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
TypeDef {
|
|
||||||
name: "int64",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
TypeDef {
|
|
||||||
name: "float",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
.to_vec();
|
|
||||||
let mut ctx = TopLevelContext::new(primitives);
|
|
||||||
|
|
||||||
let b = ctx.get_primitive(BOOL_TYPE);
|
|
||||||
let b_def = ctx.get_primitive_def_mut(BOOL_TYPE);
|
|
||||||
impl_eq(b_def, &b);
|
|
||||||
let int32 = ctx.get_primitive(INT32_TYPE);
|
|
||||||
let int32_def = ctx.get_primitive_def_mut(INT32_TYPE);
|
|
||||||
impl_math(int32_def, &int32);
|
|
||||||
impl_bits(int32_def, &int32);
|
|
||||||
impl_order(int32_def, &int32);
|
|
||||||
impl_eq(int32_def, &int32);
|
|
||||||
let int64 = ctx.get_primitive(INT64_TYPE);
|
|
||||||
let int64_def = ctx.get_primitive_def_mut(INT64_TYPE);
|
|
||||||
impl_math(int64_def, &int64);
|
|
||||||
impl_bits(int64_def, &int64);
|
|
||||||
impl_order(int64_def, &int64);
|
|
||||||
impl_eq(int64_def, &int64);
|
|
||||||
let float = ctx.get_primitive(FLOAT_TYPE);
|
|
||||||
let float_def = ctx.get_primitive_def_mut(FLOAT_TYPE);
|
|
||||||
impl_math(float_def, &float);
|
|
||||||
impl_order(float_def, &float);
|
|
||||||
impl_eq(float_def, &float);
|
|
||||||
|
|
||||||
let t = ctx.add_variable_private(VarDef {
|
|
||||||
name: "T",
|
|
||||||
bound: vec![],
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.add_parametric(ParametricDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "tuple",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
// we have nothing for tuple, so no param def
|
|
||||||
params: vec![],
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.add_parametric(ParametricDef {
|
|
||||||
base: TypeDef {
|
|
||||||
name: "list",
|
|
||||||
fields: HashMap::new(),
|
|
||||||
methods: HashMap::new(),
|
|
||||||
},
|
|
||||||
params: vec![t],
|
|
||||||
});
|
|
||||||
|
|
||||||
let i = ctx.add_variable_private(VarDef {
|
|
||||||
name: "I",
|
|
||||||
bound: vec![
|
|
||||||
PrimitiveType(INT32_TYPE).into(),
|
|
||||||
PrimitiveType(INT64_TYPE).into(),
|
|
||||||
PrimitiveType(FLOAT_TYPE).into(),
|
|
||||||
],
|
|
||||||
});
|
|
||||||
let args = vec![TypeVariable(i).into()];
|
|
||||||
ctx.add_fn(
|
|
||||||
"int32",
|
|
||||||
FnDef {
|
|
||||||
args: args.clone(),
|
|
||||||
result: Some(PrimitiveType(INT32_TYPE).into()),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
ctx.add_fn(
|
|
||||||
"int64",
|
|
||||||
FnDef {
|
|
||||||
args: args.clone(),
|
|
||||||
result: Some(PrimitiveType(INT64_TYPE).into()),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
ctx.add_fn(
|
|
||||||
"float",
|
|
||||||
FnDef {
|
|
||||||
args,
|
|
||||||
result: Some(PrimitiveType(FLOAT_TYPE).into()),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx
|
|
||||||
}
|
|
174
nac3core/src/symbol_resolver.rs
Normal file
174
nac3core/src/symbol_resolver.rs
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
use std::cell::RefCell;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::top_level::{DefinitionId, TopLevelContext, TopLevelDef};
|
||||||
|
use crate::typecheck::{
|
||||||
|
type_inferencer::PrimitiveStore,
|
||||||
|
typedef::{Type, Unifier},
|
||||||
|
};
|
||||||
|
use crate::{location::Location, typecheck::typedef::TypeEnum};
|
||||||
|
use itertools::{chain, izip};
|
||||||
|
use rustpython_parser::ast::Expr;
|
||||||
|
|
||||||
|
#[derive(Clone, PartialEq)]
|
||||||
|
pub enum SymbolValue {
|
||||||
|
I32(i32),
|
||||||
|
I64(i64),
|
||||||
|
Double(f64),
|
||||||
|
Bool(bool),
|
||||||
|
Tuple(Vec<SymbolValue>),
|
||||||
|
// we should think about how to implement bytes later...
|
||||||
|
// Bytes(&'a [u8]),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait SymbolResolver {
|
||||||
|
// get type of type variable identifier or top-level function type
|
||||||
|
fn get_symbol_type(
|
||||||
|
&self,
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
str: &str,
|
||||||
|
) -> Option<Type>;
|
||||||
|
// get the top-level definition of identifiers
|
||||||
|
fn get_identifier_def(&self, str: &str) -> Option<DefinitionId>;
|
||||||
|
fn get_symbol_value(&self, str: &str) -> Option<SymbolValue>;
|
||||||
|
fn get_symbol_location(&self, str: &str) -> Option<Location>;
|
||||||
|
// handle function call etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert type annotation into type
|
||||||
|
pub fn parse_type_annotation<T>(
|
||||||
|
resolver: &dyn SymbolResolver,
|
||||||
|
top_level: &TopLevelContext,
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
expr: &Expr<T>,
|
||||||
|
) -> Result<Type, String> {
|
||||||
|
use rustpython_parser::ast::ExprKind::*;
|
||||||
|
match &expr.node {
|
||||||
|
Name { id, .. } => match id.as_str() {
|
||||||
|
"int32" => Ok(primitives.int32),
|
||||||
|
"int64" => Ok(primitives.int64),
|
||||||
|
"float" => Ok(primitives.float),
|
||||||
|
"bool" => Ok(primitives.bool),
|
||||||
|
"None" => Ok(primitives.none),
|
||||||
|
x => {
|
||||||
|
let obj_id = resolver.get_identifier_def(x);
|
||||||
|
if let Some(obj_id) = obj_id {
|
||||||
|
let defs = top_level.definitions.read();
|
||||||
|
let def = defs[obj_id.0].read();
|
||||||
|
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
|
||||||
|
if !type_vars.is_empty() {
|
||||||
|
return Err(format!(
|
||||||
|
"Unexpected number of type parameters: expected {} but got 0",
|
||||||
|
type_vars.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let fields = RefCell::new(
|
||||||
|
chain(
|
||||||
|
fields.iter().map(|(k, v)| (k.clone(), *v)),
|
||||||
|
methods.iter().map(|(k, v, _)| (k.clone(), *v)),
|
||||||
|
)
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
|
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id,
|
||||||
|
fields,
|
||||||
|
params: Default::default(),
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
Err("Cannot use function name as type".into())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// it could be a type variable
|
||||||
|
let ty = resolver
|
||||||
|
.get_symbol_type(unifier, primitives, x)
|
||||||
|
.ok_or_else(|| "Cannot use function name as type".to_owned())?;
|
||||||
|
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
|
||||||
|
Ok(ty)
|
||||||
|
} else {
|
||||||
|
Err(format!("Unknown type annotation {}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Subscript { value, slice, .. } => {
|
||||||
|
if let Name { id, .. } = &value.node {
|
||||||
|
if id == "virtual" {
|
||||||
|
let ty =
|
||||||
|
parse_type_annotation(resolver, top_level, unifier, primitives, slice)?;
|
||||||
|
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
|
||||||
|
} else {
|
||||||
|
let types = if let Tuple { elts, .. } = &slice.node {
|
||||||
|
elts.iter()
|
||||||
|
.map(|v| {
|
||||||
|
parse_type_annotation(resolver, top_level, unifier, primitives, v)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?
|
||||||
|
} else {
|
||||||
|
vec![parse_type_annotation(
|
||||||
|
resolver, top_level, unifier, primitives, slice,
|
||||||
|
)?]
|
||||||
|
};
|
||||||
|
|
||||||
|
let obj_id = resolver
|
||||||
|
.get_identifier_def(id)
|
||||||
|
.ok_or_else(|| format!("Unknown type annotation {}", id))?;
|
||||||
|
let defs = top_level.definitions.read();
|
||||||
|
let def = defs[obj_id.0].read();
|
||||||
|
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
|
||||||
|
if types.len() != type_vars.len() {
|
||||||
|
return Err(format!(
|
||||||
|
"Unexpected number of type parameters: expected {} but got {}",
|
||||||
|
type_vars.len(),
|
||||||
|
types.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let mut subst = HashMap::new();
|
||||||
|
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
|
||||||
|
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
|
||||||
|
*id
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
subst.insert(id, *ty);
|
||||||
|
}
|
||||||
|
let mut fields = fields
|
||||||
|
.iter()
|
||||||
|
.map(|(attr, ty)| {
|
||||||
|
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||||
|
(attr.clone(), ty)
|
||||||
|
})
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
fields.extend(methods.iter().map(|(attr, ty, _)| {
|
||||||
|
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||||
|
(attr.clone(), ty)
|
||||||
|
}));
|
||||||
|
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id,
|
||||||
|
fields: fields.into(),
|
||||||
|
params: subst.into(),
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
Err("Cannot use function name as type".into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err("unsupported type expression".into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err("unsupported type expression".into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl dyn SymbolResolver + Send + Sync {
|
||||||
|
pub fn parse_type_annotation<T>(
|
||||||
|
&self,
|
||||||
|
top_level: &TopLevelContext,
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
expr: &Expr<T>,
|
||||||
|
) -> Result<Type, String> {
|
||||||
|
parse_type_annotation(self, top_level, unifier, primitives, expr)
|
||||||
|
}
|
||||||
|
}
|
778
nac3core/src/top_level.rs
Normal file
778
nac3core/src/top_level.rs
Normal file
@ -0,0 +1,778 @@
|
|||||||
|
use std::borrow::BorrowMut;
|
||||||
|
use std::ops::{Deref, DerefMut};
|
||||||
|
use std::{collections::HashMap, collections::HashSet, sync::Arc};
|
||||||
|
|
||||||
|
use super::typecheck::type_inferencer::PrimitiveStore;
|
||||||
|
use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier};
|
||||||
|
use crate::typecheck::typedef::{FunSignature, FuncArg};
|
||||||
|
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Mapping};
|
||||||
|
use itertools::Itertools;
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
use rustpython_parser::ast::{self, Stmt};
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
|
||||||
|
pub struct DefinitionId(pub usize);
|
||||||
|
|
||||||
|
pub enum TopLevelDef {
|
||||||
|
Class {
|
||||||
|
// object ID used for TypeEnum
|
||||||
|
object_id: DefinitionId,
|
||||||
|
// type variables bounded to the class.
|
||||||
|
type_vars: Vec<Type>,
|
||||||
|
// class fields
|
||||||
|
fields: Vec<(String, Type)>,
|
||||||
|
// class methods, pointing to the corresponding function definition.
|
||||||
|
methods: Vec<(String, Type, DefinitionId)>,
|
||||||
|
// ancestor classes, including itself.
|
||||||
|
ancestors: Vec<DefinitionId>,
|
||||||
|
// symbol resolver of the module defined the class, none if it is built-in type
|
||||||
|
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||||
|
},
|
||||||
|
Function {
|
||||||
|
// prefix for symbol, should be unique globally, and not ending with numbers
|
||||||
|
name: String,
|
||||||
|
// function signature.
|
||||||
|
signature: Type,
|
||||||
|
/// Function instance to symbol mapping
|
||||||
|
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||||
|
/// order, including type variables associated with the class.
|
||||||
|
/// Value: function symbol name.
|
||||||
|
instance_to_symbol: HashMap<String, String>,
|
||||||
|
/// Function instances to annotated AST mapping
|
||||||
|
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||||
|
/// order, including type variables associated with the class. Excluding rigid type
|
||||||
|
/// variables.
|
||||||
|
/// Value: AST annotated with types together with a unification table index. Could contain
|
||||||
|
/// rigid type variables that would be substituted when the function is instantiated.
|
||||||
|
instance_to_stmt: HashMap<String, (Stmt<Option<Type>>, usize)>,
|
||||||
|
// symbol resolver of the module defined the class
|
||||||
|
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||||
|
},
|
||||||
|
Initializer {
|
||||||
|
class_id: DefinitionId,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TopLevelDef {
|
||||||
|
fn get_function_type(&self) -> Result<Type, String> {
|
||||||
|
if let Self::Function { signature, .. } = self {
|
||||||
|
Ok(*signature)
|
||||||
|
} else {
|
||||||
|
Err("only expect function def here".into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TopLevelContext {
|
||||||
|
pub definitions: Arc<RwLock<Vec<Arc<RwLock<TopLevelDef>>>>>,
|
||||||
|
pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TopLevelComposer {
|
||||||
|
// list of top level definitions, same as top level context
|
||||||
|
pub definition_ast_list: Arc<RwLock<Vec<(Arc<RwLock<TopLevelDef>>, Option<ast::Stmt<()>>)>>>,
|
||||||
|
// start as a primitive unifier, will add more top_level defs inside
|
||||||
|
pub unifier: Unifier,
|
||||||
|
// primitive store
|
||||||
|
pub primitives: PrimitiveStore,
|
||||||
|
// mangled class method name to def_id
|
||||||
|
pub class_method_to_def_id: HashMap<String, DefinitionId>,
|
||||||
|
// record the def id of the classes whoses fields and methods are to be analyzed
|
||||||
|
pub to_be_analyzed_class: Vec<DefinitionId>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TopLevelComposer {
|
||||||
|
pub fn to_top_level_context(&self) -> TopLevelContext {
|
||||||
|
let def_list =
|
||||||
|
self.definition_ast_list.read().iter().map(|(x, _)| x.clone()).collect::<Vec<_>>();
|
||||||
|
TopLevelContext {
|
||||||
|
definitions: RwLock::new(def_list).into(),
|
||||||
|
// FIXME: all the big unifier or?
|
||||||
|
unifiers: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name_mangling(mut class_name: String, method_name: &str) -> String {
|
||||||
|
class_name.push_str(method_name);
|
||||||
|
class_name
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn make_primitives() -> (PrimitiveStore, Unifier) {
|
||||||
|
let mut unifier = Unifier::new();
|
||||||
|
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(0),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(1),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let float = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(2),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(3),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let none = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(4),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||||
|
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
|
||||||
|
(primitives, unifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol
|
||||||
|
/// resolver can later figure out primitive type definitions when passed a primitive type name
|
||||||
|
pub fn new() -> (Vec<(String, DefinitionId, Type)>, Self) {
|
||||||
|
let primitives = Self::make_primitives();
|
||||||
|
|
||||||
|
let top_level_def_list = vec![
|
||||||
|
Arc::new(RwLock::new(Self::make_top_level_class_def(0, None))),
|
||||||
|
Arc::new(RwLock::new(Self::make_top_level_class_def(1, None))),
|
||||||
|
Arc::new(RwLock::new(Self::make_top_level_class_def(2, None))),
|
||||||
|
Arc::new(RwLock::new(Self::make_top_level_class_def(3, None))),
|
||||||
|
Arc::new(RwLock::new(Self::make_top_level_class_def(4, None))),
|
||||||
|
];
|
||||||
|
|
||||||
|
let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None];
|
||||||
|
|
||||||
|
let composer = TopLevelComposer {
|
||||||
|
definition_ast_list: RwLock::new(
|
||||||
|
top_level_def_list.into_iter().zip(ast_list).collect_vec(),
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
primitives: primitives.0,
|
||||||
|
unifier: primitives.1,
|
||||||
|
class_method_to_def_id: Default::default(),
|
||||||
|
to_be_analyzed_class: Default::default(),
|
||||||
|
};
|
||||||
|
(
|
||||||
|
vec![
|
||||||
|
("int32".into(), DefinitionId(0), composer.primitives.int32),
|
||||||
|
("int64".into(), DefinitionId(1), composer.primitives.int64),
|
||||||
|
("float".into(), DefinitionId(2), composer.primitives.float),
|
||||||
|
("bool".into(), DefinitionId(3), composer.primitives.bool),
|
||||||
|
("none".into(), DefinitionId(4), composer.primitives.none),
|
||||||
|
],
|
||||||
|
composer,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// already include the definition_id of itself inside the ancestors vector
|
||||||
|
/// when first regitering, the type_vars, fields, methods, ancestors are invalid
|
||||||
|
pub fn make_top_level_class_def(
|
||||||
|
index: usize,
|
||||||
|
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||||
|
) -> TopLevelDef {
|
||||||
|
TopLevelDef::Class {
|
||||||
|
object_id: DefinitionId(index),
|
||||||
|
type_vars: Default::default(),
|
||||||
|
fields: Default::default(),
|
||||||
|
methods: Default::default(),
|
||||||
|
ancestors: vec![DefinitionId(index)],
|
||||||
|
resolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// when first registering, the type is a invalid value
|
||||||
|
pub fn make_top_level_function_def(
|
||||||
|
name: String,
|
||||||
|
ty: Type,
|
||||||
|
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||||
|
) -> TopLevelDef {
|
||||||
|
TopLevelDef::Function {
|
||||||
|
name,
|
||||||
|
signature: ty,
|
||||||
|
instance_to_symbol: Default::default(),
|
||||||
|
instance_to_stmt: Default::default(),
|
||||||
|
resolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// step 0, register, just remeber the names of top level classes/function
|
||||||
|
pub fn register_top_level(
|
||||||
|
&mut self,
|
||||||
|
ast: ast::Stmt<()>,
|
||||||
|
resolver: Option<Arc<Mutex<dyn SymbolResolver + Send + Sync>>>,
|
||||||
|
) -> Result<(String, DefinitionId), String> {
|
||||||
|
let mut def_list = self.definition_ast_list.write();
|
||||||
|
match &ast.node {
|
||||||
|
ast::StmtKind::ClassDef { name, body, .. } => {
|
||||||
|
let class_name = name.to_string();
|
||||||
|
let class_def_id = def_list.len();
|
||||||
|
|
||||||
|
// add the class to the definition lists
|
||||||
|
// since later when registering class method, ast will still be used,
|
||||||
|
// here push None temporarly, later will move the ast inside
|
||||||
|
let mut class_def_ast = (
|
||||||
|
Arc::new(RwLock::new(Self::make_top_level_class_def(
|
||||||
|
class_def_id,
|
||||||
|
resolver.clone(),
|
||||||
|
))),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
// parse class def body and register class methods into the def list.
|
||||||
|
// module's symbol resolver would not know the name of the class methods,
|
||||||
|
// thus cannot return their definition_id
|
||||||
|
let mut class_method_name_def_ids: Vec<(
|
||||||
|
String,
|
||||||
|
Arc<RwLock<TopLevelDef>>,
|
||||||
|
DefinitionId,
|
||||||
|
)> = Vec::new();
|
||||||
|
let mut class_method_index_offset = 0;
|
||||||
|
for b in body {
|
||||||
|
if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node {
|
||||||
|
let method_name = Self::name_mangling(class_name.clone(), method_name);
|
||||||
|
let method_def_id = def_list.len() + {
|
||||||
|
class_method_index_offset += 1;
|
||||||
|
class_method_index_offset
|
||||||
|
};
|
||||||
|
|
||||||
|
// dummy method define here
|
||||||
|
// the ast of class method is in the class, push None in to the list here
|
||||||
|
class_method_name_def_ids.push((
|
||||||
|
method_name.clone(),
|
||||||
|
RwLock::new(Self::make_top_level_function_def(
|
||||||
|
method_name.clone(),
|
||||||
|
self.primitives.none,
|
||||||
|
resolver.clone(),
|
||||||
|
))
|
||||||
|
.into(),
|
||||||
|
DefinitionId(method_def_id),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// move the ast to the entry of the class in the ast_list
|
||||||
|
class_def_ast.1 = Some(ast);
|
||||||
|
|
||||||
|
// now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order
|
||||||
|
def_list.push(class_def_ast);
|
||||||
|
for (name, def, id) in class_method_name_def_ids {
|
||||||
|
def_list.push((def, None));
|
||||||
|
self.class_method_to_def_id.insert(name, id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// put the constructor into the def_list
|
||||||
|
def_list.push((
|
||||||
|
RwLock::new(TopLevelDef::Initializer { class_id: DefinitionId(class_def_id) })
|
||||||
|
.into(),
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
|
||||||
|
// class, put its def_id into the to be analyzed set
|
||||||
|
self.to_be_analyzed_class.push(DefinitionId(class_def_id));
|
||||||
|
|
||||||
|
Ok((class_name, DefinitionId(class_def_id)))
|
||||||
|
}
|
||||||
|
|
||||||
|
ast::StmtKind::FunctionDef { name, .. } => {
|
||||||
|
let fun_name = name.to_string();
|
||||||
|
|
||||||
|
// add to the definition list
|
||||||
|
def_list.push((
|
||||||
|
RwLock::new(Self::make_top_level_function_def(
|
||||||
|
name.into(),
|
||||||
|
self.primitives.none,
|
||||||
|
resolver,
|
||||||
|
))
|
||||||
|
.into(),
|
||||||
|
Some(ast),
|
||||||
|
));
|
||||||
|
|
||||||
|
// return
|
||||||
|
Ok((fun_name, DefinitionId(def_list.len() - 1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => Err("only registrations of top level classes/functions are supprted".into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// step 1, analyze the type vars associated with top level class
|
||||||
|
fn analyze_top_level_class_type_var(&mut self) -> Result<(), String> {
|
||||||
|
let mut def_list = self.definition_ast_list.write();
|
||||||
|
let converted_top_level = &self.to_top_level_context();
|
||||||
|
let primitives = &self.primitives;
|
||||||
|
let unifier = &mut self.unifier;
|
||||||
|
|
||||||
|
for (class_def, class_ast) in def_list.iter_mut() {
|
||||||
|
// only deal with class def here
|
||||||
|
let mut class_def = class_def.write();
|
||||||
|
let (class_bases_ast, class_def_type_vars, class_resolver) = {
|
||||||
|
if let TopLevelDef::Class { type_vars, resolver, .. } = class_def.deref_mut() {
|
||||||
|
if let Some(ast::Located {
|
||||||
|
node: ast::StmtKind::ClassDef { bases, .. }, ..
|
||||||
|
}) = class_ast
|
||||||
|
{
|
||||||
|
(bases, type_vars, resolver)
|
||||||
|
} else {
|
||||||
|
unreachable!("must be both class")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let class_resolver = class_resolver.as_ref().unwrap().lock();
|
||||||
|
|
||||||
|
let mut is_generic = false;
|
||||||
|
for b in class_bases_ast {
|
||||||
|
match &b.node {
|
||||||
|
// analyze typevars bounded to the class,
|
||||||
|
// only support things like `class A(Generic[T, V])`,
|
||||||
|
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
|
||||||
|
// i.e. only simple names are allowed in the subscript
|
||||||
|
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
|
||||||
|
ast::ExprKind::Subscript { value, slice, .. } if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Generic") =>
|
||||||
|
{
|
||||||
|
if !is_generic {
|
||||||
|
is_generic = true;
|
||||||
|
} else {
|
||||||
|
return Err("Only single Generic[...] can be in bases".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
// if `class A(Generic[T, V, G])`
|
||||||
|
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
|
||||||
|
// parse the type vars
|
||||||
|
let type_vars = elts
|
||||||
|
.iter()
|
||||||
|
.map(|e| {
|
||||||
|
class_resolver.parse_type_annotation(
|
||||||
|
converted_top_level,
|
||||||
|
unifier.borrow_mut(),
|
||||||
|
primitives,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
// check if all are unique type vars
|
||||||
|
let mut occured_type_var_id: HashSet<u32> = HashSet::new();
|
||||||
|
let all_unique_type_var = type_vars.iter().all(|x| {
|
||||||
|
let ty = unifier.get_ty(*x);
|
||||||
|
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
|
||||||
|
occured_type_var_id.insert(*id)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if !all_unique_type_var {
|
||||||
|
return Err("expect unique type variables".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
// add to TopLevelDef
|
||||||
|
class_def_type_vars.extend(type_vars);
|
||||||
|
|
||||||
|
// `class A(Generic[T])`
|
||||||
|
} else {
|
||||||
|
let ty = class_resolver.parse_type_annotation(
|
||||||
|
converted_top_level,
|
||||||
|
unifier.borrow_mut(),
|
||||||
|
primitives,
|
||||||
|
&slice,
|
||||||
|
)?;
|
||||||
|
// check if it is type var
|
||||||
|
let is_type_var =
|
||||||
|
matches!(unifier.get_ty(ty).as_ref(), &TypeEnum::TVar { .. });
|
||||||
|
if !is_type_var {
|
||||||
|
return Err("expect type variable here".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
// add to TopLevelDef
|
||||||
|
class_def_type_vars.push(ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if others, do nothing in this function
|
||||||
|
_ => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// step 2, base classes. Need to separate step1 and step2 for this reason:
|
||||||
|
/// `class B(Generic[T, V]);
|
||||||
|
/// class A(B[int, bool])`
|
||||||
|
/// if the type var associated with class `B` has not been handled properly,
|
||||||
|
/// the parse of type annotation of `B[int, bool]` will fail
|
||||||
|
fn analyze_top_level_class_bases(&mut self) -> Result<(), String> {
|
||||||
|
let mut def_list = self.definition_ast_list.write();
|
||||||
|
let converted_top_level = &self.to_top_level_context();
|
||||||
|
let primitives = &self.primitives;
|
||||||
|
let unifier = &mut self.unifier;
|
||||||
|
|
||||||
|
for (class_def, class_ast) in def_list.iter_mut() {
|
||||||
|
let mut class_def = class_def.write();
|
||||||
|
let (class_bases, class_ancestors, class_resolver) = {
|
||||||
|
if let TopLevelDef::Class { ancestors, resolver, .. } = class_def.deref_mut() {
|
||||||
|
if let Some(ast::Located {
|
||||||
|
node: ast::StmtKind::ClassDef { bases, .. }, ..
|
||||||
|
}) = class_ast
|
||||||
|
{
|
||||||
|
(bases, ancestors, resolver)
|
||||||
|
} else {
|
||||||
|
unreachable!("must be both class")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let class_resolver = class_resolver.as_ref().unwrap().lock();
|
||||||
|
for b in class_bases {
|
||||||
|
// type vars have already been handled, so skip on `Generic[...]`
|
||||||
|
if let ast::ExprKind::Subscript { value, .. } = &b.node {
|
||||||
|
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||||
|
if id == "Generic" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// get the def id of the base class
|
||||||
|
let base_ty = class_resolver.parse_type_annotation(
|
||||||
|
converted_top_level,
|
||||||
|
unifier.borrow_mut(),
|
||||||
|
primitives,
|
||||||
|
b,
|
||||||
|
)?;
|
||||||
|
let base_id =
|
||||||
|
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(base_ty).as_ref() {
|
||||||
|
*obj_id
|
||||||
|
} else {
|
||||||
|
return Err("expect concrete class/type to be base class".into());
|
||||||
|
};
|
||||||
|
|
||||||
|
// write to the class ancestors, make sure the uniqueness
|
||||||
|
if !class_ancestors.contains(&base_id) {
|
||||||
|
class_ancestors.push(base_id);
|
||||||
|
} else {
|
||||||
|
return Err("cannot specify the same base class twice".into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// step 3, class fields and methods
|
||||||
|
// FIXME: analyze base classes here
|
||||||
|
// FIXME: deal with self type
|
||||||
|
// NOTE: prevent cycles only roughly done
|
||||||
|
fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> {
|
||||||
|
let mut def_ast_list = self.definition_ast_list.write();
|
||||||
|
let converted_top_level = &self.to_top_level_context();
|
||||||
|
let primitives = &self.primitives;
|
||||||
|
let to_be_analyzed_class = &mut self.to_be_analyzed_class;
|
||||||
|
let unifier = &mut self.unifier;
|
||||||
|
|
||||||
|
// NOTE: roughly prevent infinite loop
|
||||||
|
let mut max_iter = to_be_analyzed_class.len() * 4;
|
||||||
|
'class: loop {
|
||||||
|
if to_be_analyzed_class.is_empty() && {
|
||||||
|
max_iter -= 1;
|
||||||
|
max_iter > 0
|
||||||
|
} {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let class_ind = to_be_analyzed_class.remove(0).0;
|
||||||
|
let (class_name, class_body_ast, class_bases_ast, class_resolver, class_ancestors) = {
|
||||||
|
let (class_def, class_ast) = &mut def_ast_list[class_ind];
|
||||||
|
if let Some(ast::Located {
|
||||||
|
node: ast::StmtKind::ClassDef { name, body, bases, .. },
|
||||||
|
..
|
||||||
|
}) = class_ast.as_ref()
|
||||||
|
{
|
||||||
|
if let TopLevelDef::Class { resolver, ancestors, .. } =
|
||||||
|
class_def.write().deref()
|
||||||
|
{
|
||||||
|
(name, body, bases, resolver.as_ref().unwrap().clone(), ancestors.clone())
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unreachable!("should be class def ast")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let all_base_class_analyzed = {
|
||||||
|
let not_yet_analyzed =
|
||||||
|
to_be_analyzed_class.clone().into_iter().collect::<HashSet<_>>();
|
||||||
|
let base = class_ancestors.clone().into_iter().collect::<HashSet<_>>();
|
||||||
|
let intersection = not_yet_analyzed.intersection(&base).collect_vec();
|
||||||
|
intersection.is_empty()
|
||||||
|
};
|
||||||
|
if !all_base_class_analyzed {
|
||||||
|
to_be_analyzed_class.push(DefinitionId(class_ind));
|
||||||
|
continue 'class;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the bases type, can directly do this since it
|
||||||
|
// already pass the check in the previous stages
|
||||||
|
let class_bases_ty = class_bases_ast
|
||||||
|
.iter()
|
||||||
|
.filter_map(|x| {
|
||||||
|
class_resolver
|
||||||
|
.as_ref()
|
||||||
|
.lock()
|
||||||
|
.parse_type_annotation(
|
||||||
|
converted_top_level,
|
||||||
|
unifier.borrow_mut(),
|
||||||
|
primitives,
|
||||||
|
x,
|
||||||
|
)
|
||||||
|
.ok()
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
// need these vectors to check re-defining methods, class fields
|
||||||
|
// and store the parsed result in case some method cannot be typed for now
|
||||||
|
let mut class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![];
|
||||||
|
let mut class_fields_parsing_result: Vec<(String, Type)> = vec![];
|
||||||
|
for b in class_body_ast {
|
||||||
|
if let ast::StmtKind::FunctionDef {
|
||||||
|
args: method_args_ast,
|
||||||
|
body: method_body_ast,
|
||||||
|
name: method_name,
|
||||||
|
returns: method_returns_ast,
|
||||||
|
..
|
||||||
|
} = &b.node
|
||||||
|
{
|
||||||
|
let arg_name_tys: Vec<(String, Type)> = {
|
||||||
|
let mut result = vec![];
|
||||||
|
for a in &method_args_ast.args {
|
||||||
|
if a.node.arg != "self" {
|
||||||
|
let annotation = a
|
||||||
|
.node
|
||||||
|
.annotation
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
"type annotation for function parameter is needed"
|
||||||
|
.to_string()
|
||||||
|
})?
|
||||||
|
.as_ref();
|
||||||
|
|
||||||
|
let ty = class_resolver.as_ref().lock().parse_type_annotation(
|
||||||
|
converted_top_level,
|
||||||
|
unifier.borrow_mut(),
|
||||||
|
primitives,
|
||||||
|
annotation,
|
||||||
|
)?;
|
||||||
|
if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
|
||||||
|
to_be_analyzed_class.push(DefinitionId(class_ind));
|
||||||
|
continue 'class;
|
||||||
|
}
|
||||||
|
result.push((a.node.arg.to_string(), ty));
|
||||||
|
} else {
|
||||||
|
// TODO: handle self, how
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
};
|
||||||
|
|
||||||
|
let method_type_var = arg_name_tys
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(_, ty)| {
|
||||||
|
let ty_enum = unifier.get_ty(*ty);
|
||||||
|
if let TypeEnum::TVar { id, .. } = ty_enum.as_ref() {
|
||||||
|
Some((*id, *ty))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Mapping<u32>>();
|
||||||
|
|
||||||
|
let ret_ty = {
|
||||||
|
if method_name != "__init__" {
|
||||||
|
let ty = method_returns_ast
|
||||||
|
.as_ref()
|
||||||
|
.map(|x| {
|
||||||
|
class_resolver.as_ref().lock().parse_type_annotation(
|
||||||
|
converted_top_level,
|
||||||
|
unifier.borrow_mut(),
|
||||||
|
primitives,
|
||||||
|
x.as_ref(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.ok_or_else(|| "return type annotation error".to_string())??;
|
||||||
|
if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
|
||||||
|
to_be_analyzed_class.push(DefinitionId(class_ind));
|
||||||
|
continue 'class;
|
||||||
|
} else {
|
||||||
|
ty
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// TODO: __init__ function, self type, how
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// handle fields
|
||||||
|
let class_field_name_tys: Option<Vec<(String, Type)>> = if method_name
|
||||||
|
== "__init__"
|
||||||
|
{
|
||||||
|
let mut result: Vec<(String, Type)> = vec![];
|
||||||
|
for body in method_body_ast {
|
||||||
|
match &body.node {
|
||||||
|
ast::StmtKind::AnnAssign { target, annotation, .. }
|
||||||
|
if {
|
||||||
|
if let ast::ExprKind::Attribute { value, .. } = &target.node
|
||||||
|
{
|
||||||
|
matches!(
|
||||||
|
&value.node,
|
||||||
|
ast::ExprKind::Name { id, .. } if id == "self")
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} =>
|
||||||
|
{
|
||||||
|
let field_ty =
|
||||||
|
class_resolver.as_ref().lock().parse_type_annotation(
|
||||||
|
converted_top_level,
|
||||||
|
unifier.borrow_mut(),
|
||||||
|
primitives,
|
||||||
|
annotation.as_ref(),
|
||||||
|
)?;
|
||||||
|
if !Self::check_ty_analyzed(
|
||||||
|
field_ty,
|
||||||
|
unifier,
|
||||||
|
to_be_analyzed_class,
|
||||||
|
) {
|
||||||
|
to_be_analyzed_class.push(DefinitionId(class_ind));
|
||||||
|
continue 'class;
|
||||||
|
} else {
|
||||||
|
result.push((
|
||||||
|
if let ast::ExprKind::Attribute { attr, .. } =
|
||||||
|
&target.node
|
||||||
|
{
|
||||||
|
attr.to_string()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
},
|
||||||
|
field_ty,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// exclude those without type annotation
|
||||||
|
ast::StmtKind::Assign { targets, .. }
|
||||||
|
if {
|
||||||
|
if let ast::ExprKind::Attribute { value, .. } =
|
||||||
|
&targets[0].node
|
||||||
|
{
|
||||||
|
matches!(
|
||||||
|
&value.node,
|
||||||
|
ast::ExprKind::Name {id, ..} if id == "self")
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} =>
|
||||||
|
{
|
||||||
|
return Err("class fields type annotation needed".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
// do nothing
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(result)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// current method all type ok, put the current method into the list
|
||||||
|
if class_methods_parsing_result.iter().any(|(name, _, _)| name == method_name) {
|
||||||
|
return Err("duplicate method definition".into());
|
||||||
|
} else {
|
||||||
|
class_methods_parsing_result.push((
|
||||||
|
method_name.clone(),
|
||||||
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature {
|
||||||
|
ret: ret_ty,
|
||||||
|
args: arg_name_tys
|
||||||
|
.into_iter()
|
||||||
|
.map(|(name, ty)| FuncArg { name, ty, default_value: None })
|
||||||
|
.collect_vec(),
|
||||||
|
vars: method_type_var,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
)),
|
||||||
|
*self
|
||||||
|
.class_method_to_def_id
|
||||||
|
.get(&Self::name_mangling(class_name.clone(), method_name))
|
||||||
|
.unwrap(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// put the fiedlds inside
|
||||||
|
if let Some(class_field_name_tys) = class_field_name_tys {
|
||||||
|
assert!(class_fields_parsing_result.is_empty());
|
||||||
|
class_fields_parsing_result.extend(class_field_name_tys);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// what should we do with `class A: a = 3`?
|
||||||
|
// do nothing, continue the for loop to iterate class ast
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// now it should be confirmed that every
|
||||||
|
// methods and fields of the class can be correctly typed, put the results
|
||||||
|
// into the actual class def method and fields field
|
||||||
|
let (class_def, _) = &def_ast_list[class_ind];
|
||||||
|
let mut class_def = class_def.write();
|
||||||
|
if let TopLevelDef::Class { fields, methods, .. } = class_def.deref_mut() {
|
||||||
|
for (ref n, ref t) in class_fields_parsing_result {
|
||||||
|
fields.push((n.clone(), *t));
|
||||||
|
}
|
||||||
|
for (n, t, id) in &class_methods_parsing_result {
|
||||||
|
methods.push((n.clone(), *t, *id));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// change the signature field of the class methods
|
||||||
|
for (_, ty, id) in &class_methods_parsing_result {
|
||||||
|
let (method_def, _) = &def_ast_list[id.0];
|
||||||
|
let mut method_def = method_def.write();
|
||||||
|
if let TopLevelDef::Function { signature, .. } = method_def.deref_mut() {
|
||||||
|
*signature = *ty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn analyze_top_level_function(&mut self) -> Result<(), String> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_ty_analyzed(ty: Type, unifier: &mut Unifier, to_be_analyzed: &[DefinitionId]) -> bool {
|
||||||
|
let type_enum = unifier.get_ty(ty);
|
||||||
|
match type_enum.as_ref() {
|
||||||
|
TypeEnum::TObj { obj_id, .. } => !to_be_analyzed.contains(obj_id),
|
||||||
|
TypeEnum::TVirtual { ty } => {
|
||||||
|
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(*ty).as_ref() {
|
||||||
|
!to_be_analyzed.contains(obj_id)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TVar { .. } => true,
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
216
nac3core/src/typecheck/function_check.rs
Normal file
216
nac3core/src/typecheck/function_check.rs
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
use super::type_inferencer::Inferencer;
|
||||||
|
use super::typedef::Type;
|
||||||
|
use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind};
|
||||||
|
use std::iter::once;
|
||||||
|
|
||||||
|
impl<'a> Inferencer<'a> {
|
||||||
|
fn check_pattern(
|
||||||
|
&mut self,
|
||||||
|
pattern: &Expr<Option<Type>>,
|
||||||
|
defined_identifiers: &mut Vec<String>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
match &pattern.node {
|
||||||
|
ExprKind::Name { id, .. } => {
|
||||||
|
if !defined_identifiers.contains(id) {
|
||||||
|
defined_identifiers.push(id.clone());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
ExprKind::Tuple { elts, .. } => {
|
||||||
|
for elt in elts.iter() {
|
||||||
|
self.check_pattern(elt, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
_ => self.check_expr(pattern, defined_identifiers),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_expr(
|
||||||
|
&mut self,
|
||||||
|
expr: &Expr<Option<Type>>,
|
||||||
|
defined_identifiers: &[String],
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// there are some cases where the custom field is None
|
||||||
|
if let Some(ty) = &expr.custom {
|
||||||
|
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
|
||||||
|
return Err(format!(
|
||||||
|
"expected concrete type at {} but got {}",
|
||||||
|
expr.location,
|
||||||
|
self.unifier.get_ty(*ty).get_type_name()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
match &expr.node {
|
||||||
|
ExprKind::Name { id, .. } => {
|
||||||
|
if !defined_identifiers.contains(id) {
|
||||||
|
return Err(format!(
|
||||||
|
"unknown identifier {} (use before def?) at {}",
|
||||||
|
id, expr.location
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::List { elts, .. }
|
||||||
|
| ExprKind::Tuple { elts, .. }
|
||||||
|
| ExprKind::BoolOp { values: elts, .. } => {
|
||||||
|
for elt in elts.iter() {
|
||||||
|
self.check_expr(elt, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Attribute { value, .. } => {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
ExprKind::BinOp { left, right, .. } => {
|
||||||
|
self.check_expr(left, defined_identifiers)?;
|
||||||
|
self.check_expr(right, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
ExprKind::UnaryOp { operand, .. } => {
|
||||||
|
self.check_expr(operand, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
ExprKind::Compare { left, comparators, .. } => {
|
||||||
|
for elt in once(left.as_ref()).chain(comparators.iter()) {
|
||||||
|
self.check_expr(elt, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Subscript { value, slice, .. } => {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
self.check_expr(slice, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
ExprKind::IfExp { test, body, orelse } => {
|
||||||
|
self.check_expr(test, defined_identifiers)?;
|
||||||
|
self.check_expr(body, defined_identifiers)?;
|
||||||
|
self.check_expr(orelse, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
ExprKind::Slice { lower, upper, step } => {
|
||||||
|
for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||||
|
self.check_expr(elt, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Lambda { args, body } => {
|
||||||
|
let mut defined_identifiers = defined_identifiers.to_vec();
|
||||||
|
for arg in args.args.iter() {
|
||||||
|
if !defined_identifiers.contains(&arg.node.arg) {
|
||||||
|
defined_identifiers.push(arg.node.arg.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.check_expr(body, &defined_identifiers)?;
|
||||||
|
}
|
||||||
|
ExprKind::ListComp { elt, generators, .. } => {
|
||||||
|
// in our type inference stage, we already make sure that there is only 1 generator
|
||||||
|
let ast::Comprehension { target, iter, ifs, .. } = &generators[0];
|
||||||
|
self.check_expr(iter, defined_identifiers)?;
|
||||||
|
let mut defined_identifiers = defined_identifiers.to_vec();
|
||||||
|
self.check_pattern(target, &mut defined_identifiers)?;
|
||||||
|
for term in once(elt.as_ref()).chain(ifs.iter()) {
|
||||||
|
self.check_expr(term, &defined_identifiers)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Call { func, args, keywords } => {
|
||||||
|
for expr in once(func.as_ref())
|
||||||
|
.chain(args.iter())
|
||||||
|
.chain(keywords.iter().map(|v| v.node.value.as_ref()))
|
||||||
|
{
|
||||||
|
self.check_expr(expr, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Constant { .. } => {}
|
||||||
|
_ => {
|
||||||
|
println!("{:?}", expr.node);
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// check statements for proper identifier def-use and return on all paths
|
||||||
|
fn check_stmt(
|
||||||
|
&mut self,
|
||||||
|
stmt: &Stmt<Option<Type>>,
|
||||||
|
defined_identifiers: &mut Vec<String>,
|
||||||
|
) -> Result<bool, String> {
|
||||||
|
match &stmt.node {
|
||||||
|
StmtKind::For { target, iter, body, orelse, .. } => {
|
||||||
|
self.check_expr(iter, defined_identifiers)?;
|
||||||
|
for stmt in orelse.iter() {
|
||||||
|
self.check_stmt(stmt, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
let mut defined_identifiers = defined_identifiers.clone();
|
||||||
|
self.check_pattern(target, &mut defined_identifiers)?;
|
||||||
|
for stmt in body.iter() {
|
||||||
|
self.check_stmt(stmt, &mut defined_identifiers)?;
|
||||||
|
}
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
StmtKind::If { test, body, orelse } => {
|
||||||
|
self.check_expr(test, defined_identifiers)?;
|
||||||
|
let mut body_identifiers = defined_identifiers.clone();
|
||||||
|
let mut orelse_identifiers = defined_identifiers.clone();
|
||||||
|
let body_returned = self.check_block(body, &mut body_identifiers)?;
|
||||||
|
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
|
||||||
|
|
||||||
|
for ident in body_identifiers.iter() {
|
||||||
|
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
|
||||||
|
defined_identifiers.push(ident.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(body_returned && orelse_returned)
|
||||||
|
}
|
||||||
|
StmtKind::While { test, body, orelse } => {
|
||||||
|
self.check_expr(test, defined_identifiers)?;
|
||||||
|
let mut defined_identifiers = defined_identifiers.clone();
|
||||||
|
self.check_block(body, &mut defined_identifiers)?;
|
||||||
|
self.check_block(orelse, &mut defined_identifiers)?;
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
StmtKind::Expr { value } => {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
StmtKind::Assign { targets, value, .. } => {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
for target in targets {
|
||||||
|
self.check_pattern(target, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
StmtKind::AnnAssign { target, value, .. } => {
|
||||||
|
if let Some(value) = value {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
self.check_pattern(target, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
StmtKind::Return { value } => {
|
||||||
|
if let Some(value) = value {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
StmtKind::Raise { exc, .. } => {
|
||||||
|
if let Some(value) = exc {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
// break, raise, etc.
|
||||||
|
_ => Ok(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_block(
|
||||||
|
&mut self,
|
||||||
|
block: &[Stmt<Option<Type>>],
|
||||||
|
defined_identifiers: &mut Vec<String>,
|
||||||
|
) -> Result<bool, String> {
|
||||||
|
let mut ret = false;
|
||||||
|
for stmt in block {
|
||||||
|
if ret {
|
||||||
|
return Err(format!("dead code at {:?}", stmt.location));
|
||||||
|
}
|
||||||
|
if self.check_stmt(stmt, defined_identifiers)? {
|
||||||
|
ret = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
}
|
322
nac3core/src/typecheck/magic_methods.rs
Normal file
322
nac3core/src/typecheck/magic_methods.rs
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
use crate::typecheck::{
|
||||||
|
type_inferencer::*,
|
||||||
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||||
|
};
|
||||||
|
use rustpython_parser::ast;
|
||||||
|
use rustpython_parser::ast::{Cmpop, Operator, Unaryop};
|
||||||
|
use std::borrow::Borrow;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
pub fn binop_name(op: &Operator) -> &'static str {
|
||||||
|
match op {
|
||||||
|
Operator::Add => "__add__",
|
||||||
|
Operator::Sub => "__sub__",
|
||||||
|
Operator::Div => "__truediv__",
|
||||||
|
Operator::Mod => "__mod__",
|
||||||
|
Operator::Mult => "__mul__",
|
||||||
|
Operator::Pow => "__pow__",
|
||||||
|
Operator::BitOr => "__or__",
|
||||||
|
Operator::BitXor => "__xor__",
|
||||||
|
Operator::BitAnd => "__and__",
|
||||||
|
Operator::LShift => "__lshift__",
|
||||||
|
Operator::RShift => "__rshift__",
|
||||||
|
Operator::FloorDiv => "__floordiv__",
|
||||||
|
Operator::MatMult => "__matmul__",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binop_assign_name(op: &Operator) -> &'static str {
|
||||||
|
match op {
|
||||||
|
Operator::Add => "__iadd__",
|
||||||
|
Operator::Sub => "__isub__",
|
||||||
|
Operator::Div => "__itruediv__",
|
||||||
|
Operator::Mod => "__imod__",
|
||||||
|
Operator::Mult => "__imul__",
|
||||||
|
Operator::Pow => "__ipow__",
|
||||||
|
Operator::BitOr => "__ior__",
|
||||||
|
Operator::BitXor => "__ixor__",
|
||||||
|
Operator::BitAnd => "__iand__",
|
||||||
|
Operator::LShift => "__ilshift__",
|
||||||
|
Operator::RShift => "__irshift__",
|
||||||
|
Operator::FloorDiv => "__ifloordiv__",
|
||||||
|
Operator::MatMult => "__imatmul__",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unaryop_name(op: &Unaryop) -> &'static str {
|
||||||
|
match op {
|
||||||
|
Unaryop::UAdd => "__pos__",
|
||||||
|
Unaryop::USub => "__neg__",
|
||||||
|
Unaryop::Not => "__not__",
|
||||||
|
Unaryop::Invert => "__inv__",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
|
||||||
|
match op {
|
||||||
|
Cmpop::Lt => Some("__lt__"),
|
||||||
|
Cmpop::LtE => Some("__le__"),
|
||||||
|
Cmpop::Gt => Some("__gt__"),
|
||||||
|
Cmpop::GtE => Some("__ge__"),
|
||||||
|
Cmpop::Eq => Some("__eq__"),
|
||||||
|
Cmpop::NotEq => Some("__ne__"),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn impl_binop(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
|
ty: Type,
|
||||||
|
other_ty: &[Type],
|
||||||
|
ret_ty: Type,
|
||||||
|
ops: &[ast::Operator],
|
||||||
|
) {
|
||||||
|
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
|
||||||
|
let (other_ty, other_var_id) = if other_ty.len() == 1 {
|
||||||
|
(other_ty[0], None)
|
||||||
|
} else {
|
||||||
|
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty);
|
||||||
|
(ty, Some(var_id))
|
||||||
|
};
|
||||||
|
let function_vars = if let Some(var_id) = other_var_id {
|
||||||
|
vec![(var_id, other_ty)].into_iter().collect::<HashMap<_, _>>()
|
||||||
|
} else {
|
||||||
|
HashMap::new()
|
||||||
|
};
|
||||||
|
for op in ops {
|
||||||
|
fields.borrow_mut().insert(binop_name(op).into(), {
|
||||||
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature {
|
||||||
|
ret: ret_ty,
|
||||||
|
vars: function_vars.clone(),
|
||||||
|
args: vec![FuncArg {
|
||||||
|
ty: other_ty,
|
||||||
|
default_value: None,
|
||||||
|
name: "other".into(),
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
|
||||||
|
fields.borrow_mut().insert(binop_assign_name(op).into(), {
|
||||||
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature {
|
||||||
|
ret: store.none,
|
||||||
|
vars: function_vars.clone(),
|
||||||
|
args: vec![FuncArg {
|
||||||
|
ty: other_ty,
|
||||||
|
default_value: None,
|
||||||
|
name: "other".into(),
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unreachable!("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn impl_unaryop(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
_store: &PrimitiveStore,
|
||||||
|
ty: Type,
|
||||||
|
ret_ty: Type,
|
||||||
|
ops: &[ast::Unaryop],
|
||||||
|
) {
|
||||||
|
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
|
||||||
|
for op in ops {
|
||||||
|
fields.borrow_mut().insert(
|
||||||
|
unaryop_name(op).into(),
|
||||||
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(),
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn impl_cmpop(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
|
ty: Type,
|
||||||
|
other_ty: Type,
|
||||||
|
ops: &[ast::Cmpop],
|
||||||
|
) {
|
||||||
|
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() {
|
||||||
|
for op in ops {
|
||||||
|
fields.borrow_mut().insert(
|
||||||
|
comparison_name(op).unwrap().into(),
|
||||||
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature {
|
||||||
|
ret: store.bool,
|
||||||
|
vars: HashMap::new(),
|
||||||
|
args: vec![FuncArg {
|
||||||
|
ty: other_ty,
|
||||||
|
default_value: None,
|
||||||
|
name: "other".into(),
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add, Sub, Mult, Pow
|
||||||
|
pub fn impl_basic_arithmetic(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
|
ty: Type,
|
||||||
|
other_ty: &[Type],
|
||||||
|
ret_ty: Type,
|
||||||
|
) {
|
||||||
|
impl_binop(
|
||||||
|
unifier,
|
||||||
|
store,
|
||||||
|
ty,
|
||||||
|
other_ty,
|
||||||
|
ret_ty,
|
||||||
|
&[ast::Operator::Add, ast::Operator::Sub, ast::Operator::Mult],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn impl_pow(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
|
ty: Type,
|
||||||
|
other_ty: &[Type],
|
||||||
|
ret_ty: Type,
|
||||||
|
) {
|
||||||
|
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Pow])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// BitOr, BitXor, BitAnd
|
||||||
|
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||||
|
impl_binop(
|
||||||
|
unifier,
|
||||||
|
store,
|
||||||
|
ty,
|
||||||
|
&[ty],
|
||||||
|
ty,
|
||||||
|
&[ast::Operator::BitAnd, ast::Operator::BitOr, ast::Operator::BitXor],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LShift, RShift
|
||||||
|
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||||
|
impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Div
|
||||||
|
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
|
||||||
|
impl_binop(unifier, store, ty, other_ty, store.float, &[ast::Operator::Div])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// FloorDiv
|
||||||
|
pub fn impl_floordiv(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
|
ty: Type,
|
||||||
|
other_ty: &[Type],
|
||||||
|
ret_ty: Type,
|
||||||
|
) {
|
||||||
|
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::FloorDiv])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mod
|
||||||
|
pub fn impl_mod(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
|
ty: Type,
|
||||||
|
other_ty: &[Type],
|
||||||
|
ret_ty: Type,
|
||||||
|
) {
|
||||||
|
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Mod])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// UAdd, USub
|
||||||
|
pub fn impl_sign(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||||
|
impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::UAdd, ast::Unaryop::USub])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invert
|
||||||
|
pub fn impl_invert(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||||
|
impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::Invert])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Not
|
||||||
|
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||||
|
impl_unaryop(unifier, store, ty, store.bool, &[ast::Unaryop::Not])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lt, LtE, Gt, GtE
|
||||||
|
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
|
||||||
|
impl_cmpop(
|
||||||
|
unifier,
|
||||||
|
store,
|
||||||
|
ty,
|
||||||
|
other_ty,
|
||||||
|
&[ast::Cmpop::Lt, ast::Cmpop::Gt, ast::Cmpop::LtE, ast::Cmpop::GtE],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Eq, NotEq
|
||||||
|
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||||
|
impl_cmpop(unifier, store, ty, ty, &[ast::Cmpop::Eq, ast::Cmpop::NotEq])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||||
|
let PrimitiveStore { int32: int32_t, int64: int64_t, float: float_t, bool: bool_t, .. } =
|
||||||
|
*store;
|
||||||
|
/* int32 ======== */
|
||||||
|
impl_basic_arithmetic(unifier, store, int32_t, &[int32_t], int32_t);
|
||||||
|
impl_pow(unifier, store, int32_t, &[int32_t], int32_t);
|
||||||
|
impl_bitwise_arithmetic(unifier, store, int32_t);
|
||||||
|
impl_bitwise_shift(unifier, store, int32_t);
|
||||||
|
impl_div(unifier, store, int32_t, &[int32_t]);
|
||||||
|
impl_floordiv(unifier, store, int32_t, &[int32_t], int32_t);
|
||||||
|
impl_mod(unifier, store, int32_t, &[int32_t], int32_t);
|
||||||
|
impl_sign(unifier, store, int32_t);
|
||||||
|
impl_invert(unifier, store, int32_t);
|
||||||
|
impl_not(unifier, store, int32_t);
|
||||||
|
impl_comparison(unifier, store, int32_t, int32_t);
|
||||||
|
impl_eq(unifier, store, int32_t);
|
||||||
|
|
||||||
|
/* int64 ======== */
|
||||||
|
impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t);
|
||||||
|
impl_pow(unifier, store, int64_t, &[int64_t], int64_t);
|
||||||
|
impl_bitwise_arithmetic(unifier, store, int64_t);
|
||||||
|
impl_bitwise_shift(unifier, store, int64_t);
|
||||||
|
impl_div(unifier, store, int64_t, &[int64_t]);
|
||||||
|
impl_floordiv(unifier, store, int64_t, &[int64_t], int64_t);
|
||||||
|
impl_mod(unifier, store, int64_t, &[int64_t], int64_t);
|
||||||
|
impl_sign(unifier, store, int64_t);
|
||||||
|
impl_invert(unifier, store, int64_t);
|
||||||
|
impl_not(unifier, store, int64_t);
|
||||||
|
impl_comparison(unifier, store, int64_t, int64_t);
|
||||||
|
impl_eq(unifier, store, int64_t);
|
||||||
|
|
||||||
|
/* float ======== */
|
||||||
|
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
|
||||||
|
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
|
||||||
|
impl_div(unifier, store, float_t, &[float_t]);
|
||||||
|
impl_floordiv(unifier, store, float_t, &[float_t], float_t);
|
||||||
|
impl_mod(unifier, store, float_t, &[float_t], float_t);
|
||||||
|
impl_sign(unifier, store, float_t);
|
||||||
|
impl_not(unifier, store, float_t);
|
||||||
|
impl_comparison(unifier, store, float_t, float_t);
|
||||||
|
impl_eq(unifier, store, float_t);
|
||||||
|
|
||||||
|
/* bool ======== */
|
||||||
|
impl_not(unifier, store, bool_t);
|
||||||
|
impl_eq(unifier, store, bool_t);
|
||||||
|
}
|
5
nac3core/src/typecheck/mod.rs
Normal file
5
nac3core/src/typecheck/mod.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
mod function_check;
|
||||||
|
pub mod magic_methods;
|
||||||
|
pub mod type_inferencer;
|
||||||
|
pub mod typedef;
|
||||||
|
mod unification_table;
|
582
nac3core/src/typecheck/type_inferencer/mod.rs
Normal file
582
nac3core/src/typecheck/type_inferencer/mod.rs
Normal file
@ -0,0 +1,582 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::convert::{From, TryInto};
|
||||||
|
use std::iter::once;
|
||||||
|
use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
|
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
|
||||||
|
use super::{magic_methods::*, typedef::CallId};
|
||||||
|
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
|
||||||
|
use itertools::izip;
|
||||||
|
use rustpython_parser::ast::{
|
||||||
|
self,
|
||||||
|
fold::{self, Fold},
|
||||||
|
Arguments, Comprehension, ExprKind, Located, Location,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
|
||||||
|
pub struct CodeLocation {
|
||||||
|
row: usize,
|
||||||
|
col: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Location> for CodeLocation {
|
||||||
|
fn from(loc: Location) -> CodeLocation {
|
||||||
|
CodeLocation { row: loc.row(), col: loc.column() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct PrimitiveStore {
|
||||||
|
pub int32: Type,
|
||||||
|
pub int64: Type,
|
||||||
|
pub float: Type,
|
||||||
|
pub bool: Type,
|
||||||
|
pub none: Type,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FunctionData {
|
||||||
|
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||||
|
pub return_type: Option<Type>,
|
||||||
|
pub bound_variables: Vec<Type>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Inferencer<'a> {
|
||||||
|
pub top_level: &'a TopLevelContext,
|
||||||
|
pub function_data: &'a mut FunctionData,
|
||||||
|
pub unifier: &'a mut Unifier,
|
||||||
|
pub primitives: &'a PrimitiveStore,
|
||||||
|
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
||||||
|
pub variable_mapping: HashMap<String, Type>,
|
||||||
|
pub calls: &'a mut HashMap<CodeLocation, CallId>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NaiveFolder();
|
||||||
|
impl fold::Fold<()> for NaiveFolder {
|
||||||
|
type TargetU = Option<Type>;
|
||||||
|
type Error = String;
|
||||||
|
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
|
type TargetU = Option<Type>;
|
||||||
|
type Error = String;
|
||||||
|
|
||||||
|
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
|
||||||
|
let stmt = match node.node {
|
||||||
|
// we don't want fold over type annotation
|
||||||
|
ast::StmtKind::AnnAssign { target, annotation, value, simple } => {
|
||||||
|
let target = Box::new(self.fold_expr(*target)?);
|
||||||
|
let value = if let Some(v) = value {
|
||||||
|
let ty = Box::new(self.fold_expr(*v)?);
|
||||||
|
self.unifier.unify(target.custom.unwrap(), ty.custom.unwrap())?;
|
||||||
|
Some(ty)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let annotation_type = self.function_data.resolver.parse_type_annotation(
|
||||||
|
self.top_level,
|
||||||
|
self.unifier,
|
||||||
|
&self.primitives,
|
||||||
|
annotation.as_ref(),
|
||||||
|
)?;
|
||||||
|
self.unifier.unify(annotation_type, target.custom.unwrap())?;
|
||||||
|
let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?);
|
||||||
|
Located {
|
||||||
|
location: node.location,
|
||||||
|
custom: None,
|
||||||
|
node: ast::StmtKind::AnnAssign { target, annotation, value, simple },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => fold::fold_stmt(self, node)?,
|
||||||
|
};
|
||||||
|
match &stmt.node {
|
||||||
|
ast::StmtKind::For { target, iter, .. } => {
|
||||||
|
let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
||||||
|
self.unifier.unify(list, iter.custom.unwrap())?;
|
||||||
|
}
|
||||||
|
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
|
||||||
|
self.unifier.unify(test.custom.unwrap(), self.primitives.bool)?;
|
||||||
|
}
|
||||||
|
ast::StmtKind::Assign { targets, value, .. } => {
|
||||||
|
for target in targets.iter() {
|
||||||
|
self.unifier.unify(target.custom.unwrap(), value.custom.unwrap())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
||||||
|
ast::StmtKind::Break | ast::StmtKind::Continue => {}
|
||||||
|
ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
|
||||||
|
(Some(v), Some(v1)) => {
|
||||||
|
self.unifier.unify(v.custom.unwrap(), v1)?;
|
||||||
|
}
|
||||||
|
(Some(_), None) => {
|
||||||
|
return Err("Unexpected return value".to_string());
|
||||||
|
}
|
||||||
|
(None, Some(_)) => {
|
||||||
|
return Err("Expected return value".to_string());
|
||||||
|
}
|
||||||
|
(None, None) => {}
|
||||||
|
},
|
||||||
|
_ => return Err("Unsupported statement type".to_string()),
|
||||||
|
};
|
||||||
|
Ok(stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
|
||||||
|
let expr = match node.node {
|
||||||
|
ast::ExprKind::Call { func, args, keywords } => {
|
||||||
|
return self.fold_call(node.location, *func, args, keywords);
|
||||||
|
}
|
||||||
|
ast::ExprKind::Lambda { args, body } => {
|
||||||
|
return self.fold_lambda(node.location, *args, *body);
|
||||||
|
}
|
||||||
|
ast::ExprKind::ListComp { elt, generators } => {
|
||||||
|
return self.fold_listcomp(node.location, *elt, generators);
|
||||||
|
}
|
||||||
|
_ => fold::fold_expr(self, node)?,
|
||||||
|
};
|
||||||
|
let custom = match &expr.node {
|
||||||
|
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
|
||||||
|
ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?),
|
||||||
|
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
||||||
|
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
||||||
|
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
|
||||||
|
Some(self.infer_attribute(value, attr)?)
|
||||||
|
}
|
||||||
|
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
|
||||||
|
ast::ExprKind::BinOp { left, op, right } => Some(self.infer_bin_ops(left, op, right)?),
|
||||||
|
ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?),
|
||||||
|
ast::ExprKind::Compare { left, ops, comparators } => {
|
||||||
|
Some(self.infer_compare(left, ops, comparators)?)
|
||||||
|
}
|
||||||
|
ast::ExprKind::Subscript { value, slice, .. } => {
|
||||||
|
Some(self.infer_subscript(value.as_ref(), slice.as_ref())?)
|
||||||
|
}
|
||||||
|
ast::ExprKind::IfExp { test, body, orelse } => {
|
||||||
|
Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?)
|
||||||
|
}
|
||||||
|
ast::ExprKind::ListComp { .. }
|
||||||
|
| ast::ExprKind::Lambda { .. }
|
||||||
|
| ast::ExprKind::Call { .. } => expr.custom, // already computed
|
||||||
|
ast::ExprKind::Slice { .. } => None, // we don't need it for slice
|
||||||
|
_ => return Err("not supported yet".into()),
|
||||||
|
};
|
||||||
|
Ok(ast::Expr { custom, location: expr.location, node: expr.node })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type InferenceResult = Result<Type, String>;
|
||||||
|
|
||||||
|
impl<'a> Inferencer<'a> {
|
||||||
|
/// Constrain a <: b
|
||||||
|
/// Currently implemented as unification
|
||||||
|
fn constrain(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||||
|
self.unifier.unify(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_method_call(
|
||||||
|
&mut self,
|
||||||
|
location: Location,
|
||||||
|
method: String,
|
||||||
|
obj: Type,
|
||||||
|
params: Vec<Type>,
|
||||||
|
ret: Type,
|
||||||
|
) -> InferenceResult {
|
||||||
|
let call = self.unifier.add_call(Call {
|
||||||
|
posargs: params,
|
||||||
|
kwargs: HashMap::new(),
|
||||||
|
ret,
|
||||||
|
fun: RefCell::new(None),
|
||||||
|
});
|
||||||
|
self.calls.insert(location.into(), call);
|
||||||
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||||
|
let fields = once((method, call)).collect();
|
||||||
|
let record = self.unifier.add_record(fields);
|
||||||
|
self.constrain(obj, record)?;
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_lambda(
|
||||||
|
&mut self,
|
||||||
|
location: Location,
|
||||||
|
args: Arguments,
|
||||||
|
body: ast::Expr<()>,
|
||||||
|
) -> Result<ast::Expr<Option<Type>>, String> {
|
||||||
|
if !args.posonlyargs.is_empty()
|
||||||
|
|| args.vararg.is_some()
|
||||||
|
|| !args.kwonlyargs.is_empty()
|
||||||
|
|| args.kwarg.is_some()
|
||||||
|
|| !args.defaults.is_empty()
|
||||||
|
{
|
||||||
|
// actually I'm not sure whether programs violating this is a valid python program.
|
||||||
|
return Err(
|
||||||
|
"We only support positional or keyword arguments without defaults for lambdas."
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let fn_args: Vec<_> = args
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0))
|
||||||
|
.collect();
|
||||||
|
let mut variable_mapping = self.variable_mapping.clone();
|
||||||
|
variable_mapping.extend(fn_args.iter().cloned());
|
||||||
|
let ret = self.unifier.get_fresh_var().0;
|
||||||
|
let mut new_context = Inferencer {
|
||||||
|
function_data: self.function_data,
|
||||||
|
unifier: self.unifier,
|
||||||
|
primitives: self.primitives,
|
||||||
|
virtual_checks: self.virtual_checks,
|
||||||
|
calls: self.calls,
|
||||||
|
top_level: self.top_level,
|
||||||
|
variable_mapping,
|
||||||
|
};
|
||||||
|
let fun = FunSignature {
|
||||||
|
args: fn_args
|
||||||
|
.iter()
|
||||||
|
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None })
|
||||||
|
.collect(),
|
||||||
|
ret,
|
||||||
|
vars: Default::default(),
|
||||||
|
};
|
||||||
|
let body = new_context.fold_expr(body)?;
|
||||||
|
new_context.unifier.unify(fun.ret, body.custom.unwrap())?;
|
||||||
|
let mut args = new_context.fold_arguments(args)?;
|
||||||
|
for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) {
|
||||||
|
assert_eq!(&arg.node.arg, name);
|
||||||
|
arg.custom = Some(*ty);
|
||||||
|
}
|
||||||
|
Ok(Located {
|
||||||
|
location,
|
||||||
|
node: ExprKind::Lambda { args: args.into(), body: body.into() },
|
||||||
|
custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun.into()))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_listcomp(
|
||||||
|
&mut self,
|
||||||
|
location: Location,
|
||||||
|
elt: ast::Expr<()>,
|
||||||
|
mut generators: Vec<Comprehension>,
|
||||||
|
) -> Result<ast::Expr<Option<Type>>, String> {
|
||||||
|
if generators.len() != 1 {
|
||||||
|
return Err(
|
||||||
|
"Only 1 generator statement for list comprehension is supported.".to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let variable_mapping = self.variable_mapping.clone();
|
||||||
|
let mut new_context = Inferencer {
|
||||||
|
function_data: self.function_data,
|
||||||
|
unifier: self.unifier,
|
||||||
|
virtual_checks: self.virtual_checks,
|
||||||
|
top_level: self.top_level,
|
||||||
|
variable_mapping,
|
||||||
|
primitives: self.primitives,
|
||||||
|
calls: self.calls,
|
||||||
|
};
|
||||||
|
let elt = new_context.fold_expr(elt)?;
|
||||||
|
let generator = generators.pop().unwrap();
|
||||||
|
if generator.is_async {
|
||||||
|
return Err("Async iterator not supported.".to_string());
|
||||||
|
}
|
||||||
|
let target = new_context.fold_expr(*generator.target)?;
|
||||||
|
let iter = new_context.fold_expr(*generator.iter)?;
|
||||||
|
let ifs: Vec<_> = generator
|
||||||
|
.ifs
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| new_context.fold_expr(v))
|
||||||
|
.collect::<Result<_, _>>()?;
|
||||||
|
|
||||||
|
// iter should be a list of targets...
|
||||||
|
// actually it should be an iterator of targets, but we don't have iter type for now
|
||||||
|
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
||||||
|
new_context.unifier.unify(iter.custom.unwrap(), list)?;
|
||||||
|
// if conditions should be bool
|
||||||
|
for v in ifs.iter() {
|
||||||
|
new_context.unifier.unify(v.custom.unwrap(), new_context.primitives.bool)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(new_context.unifier.add_ty(TypeEnum::TList { ty: elt.custom.unwrap() })),
|
||||||
|
node: ExprKind::ListComp {
|
||||||
|
elt: Box::new(elt),
|
||||||
|
generators: vec![ast::Comprehension {
|
||||||
|
target: Box::new(target),
|
||||||
|
iter: Box::new(iter),
|
||||||
|
ifs,
|
||||||
|
is_async: false,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_call(
|
||||||
|
&mut self,
|
||||||
|
location: Location,
|
||||||
|
func: ast::Expr<()>,
|
||||||
|
mut args: Vec<ast::Expr<()>>,
|
||||||
|
keywords: Vec<Located<ast::KeywordData>>,
|
||||||
|
) -> Result<ast::Expr<Option<Type>>, String> {
|
||||||
|
let func =
|
||||||
|
if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } =
|
||||||
|
func
|
||||||
|
{
|
||||||
|
// handle special functions that cannot be typed in the usual way...
|
||||||
|
if id == "virtual" {
|
||||||
|
if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
|
||||||
|
return Err(
|
||||||
|
"`virtual` can only accept 1/2 positional arguments.".to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let ty = if let Some(arg) = args.pop() {
|
||||||
|
self.function_data.resolver.parse_type_annotation(
|
||||||
|
self.top_level,
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
&arg,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
self.unifier.get_fresh_var().0
|
||||||
|
};
|
||||||
|
self.virtual_checks.push((arg0.custom.unwrap(), ty));
|
||||||
|
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
|
||||||
|
return Ok(Located {
|
||||||
|
location,
|
||||||
|
custom,
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: None,
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id, ctx },
|
||||||
|
}),
|
||||||
|
args: vec![arg0],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// int64 is special because its argument can be a constant larger than int32
|
||||||
|
if id == "int64" && args.len() == 1 {
|
||||||
|
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||||
|
&args[0].node
|
||||||
|
{
|
||||||
|
let int64: Result<i64, _> = val.try_into();
|
||||||
|
let custom;
|
||||||
|
if int64.is_ok() {
|
||||||
|
custom = Some(self.primitives.int64);
|
||||||
|
} else {
|
||||||
|
return Err("Integer out of bound".into());
|
||||||
|
}
|
||||||
|
return Ok(Located {
|
||||||
|
location: args[0].location,
|
||||||
|
custom,
|
||||||
|
node: ExprKind::Constant {
|
||||||
|
value: ast::Constant::Int(val.clone()),
|
||||||
|
kind: kind.clone(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
|
||||||
|
} else {
|
||||||
|
func
|
||||||
|
};
|
||||||
|
let func = Box::new(self.fold_expr(func)?);
|
||||||
|
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
|
||||||
|
let keywords = keywords
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| fold::fold_keyword(self, v))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
let ret = self.unifier.get_fresh_var().0;
|
||||||
|
let call = self.unifier.add_call(Call {
|
||||||
|
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||||
|
kwargs: keywords
|
||||||
|
.iter()
|
||||||
|
.map(|v| (v.node.arg.as_ref().unwrap().clone(), v.custom.unwrap()))
|
||||||
|
.collect(),
|
||||||
|
fun: RefCell::new(None),
|
||||||
|
ret,
|
||||||
|
});
|
||||||
|
self.calls.insert(location.into(), call);
|
||||||
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||||
|
self.unifier.unify(func.custom.unwrap(), call)?;
|
||||||
|
|
||||||
|
Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_identifier(&mut self, id: &str) -> InferenceResult {
|
||||||
|
if let Some(ty) = self.variable_mapping.get(id) {
|
||||||
|
Ok(*ty)
|
||||||
|
} else {
|
||||||
|
Ok(self
|
||||||
|
.function_data
|
||||||
|
.resolver
|
||||||
|
.get_symbol_type(self.unifier, self.primitives, id)
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
let ty = self.unifier.get_fresh_var().0;
|
||||||
|
self.variable_mapping.insert(id.to_string(), ty);
|
||||||
|
ty
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_constant(&mut self, constant: &ast::Constant) -> InferenceResult {
|
||||||
|
match constant {
|
||||||
|
ast::Constant::Bool(_) => Ok(self.primitives.bool),
|
||||||
|
ast::Constant::Int(val) => {
|
||||||
|
let int32: Result<i32, _> = val.try_into();
|
||||||
|
// int64 would be handled separately in functions
|
||||||
|
if int32.is_ok() {
|
||||||
|
Ok(self.primitives.int32)
|
||||||
|
} else {
|
||||||
|
Err("Integer out of bound".into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::Constant::Float(_) => Ok(self.primitives.float),
|
||||||
|
ast::Constant::Tuple(vals) => {
|
||||||
|
let ty: Result<Vec<_>, _> = vals.iter().map(|x| self.infer_constant(x)).collect();
|
||||||
|
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? }))
|
||||||
|
}
|
||||||
|
_ => Err("not supported".into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_list(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
||||||
|
let (ty, _) = self.unifier.get_fresh_var();
|
||||||
|
for t in elts.iter() {
|
||||||
|
self.unifier.unify(ty, t.custom.unwrap())?;
|
||||||
|
}
|
||||||
|
Ok(self.unifier.add_ty(TypeEnum::TList { ty }))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_tuple(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
||||||
|
let ty = elts.iter().map(|x| x.custom.unwrap()).collect();
|
||||||
|
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_attribute(&mut self, value: &ast::Expr<Option<Type>>, attr: &str) -> InferenceResult {
|
||||||
|
let (attr_ty, _) = self.unifier.get_fresh_var();
|
||||||
|
let fields = once((attr.to_string(), attr_ty)).collect();
|
||||||
|
let record = self.unifier.add_record(fields);
|
||||||
|
self.constrain(value.custom.unwrap(), record)?;
|
||||||
|
Ok(attr_ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
||||||
|
let b = self.primitives.bool;
|
||||||
|
for v in values {
|
||||||
|
self.constrain(v.custom.unwrap(), b)?;
|
||||||
|
}
|
||||||
|
Ok(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_bin_ops(
|
||||||
|
&mut self,
|
||||||
|
left: &ast::Expr<Option<Type>>,
|
||||||
|
op: &ast::Operator,
|
||||||
|
right: &ast::Expr<Option<Type>>,
|
||||||
|
) -> InferenceResult {
|
||||||
|
let method = binop_name(op);
|
||||||
|
let ret = self.unifier.get_fresh_var().0;
|
||||||
|
self.build_method_call(
|
||||||
|
left.location,
|
||||||
|
method.to_string(),
|
||||||
|
left.custom.unwrap(),
|
||||||
|
vec![right.custom.unwrap()],
|
||||||
|
ret,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_unary_ops(
|
||||||
|
&mut self,
|
||||||
|
op: &ast::Unaryop,
|
||||||
|
operand: &ast::Expr<Option<Type>>,
|
||||||
|
) -> InferenceResult {
|
||||||
|
let method = unaryop_name(op);
|
||||||
|
let ret = self.unifier.get_fresh_var().0;
|
||||||
|
self.build_method_call(
|
||||||
|
operand.location,
|
||||||
|
method.to_string(),
|
||||||
|
operand.custom.unwrap(),
|
||||||
|
vec![],
|
||||||
|
ret,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_compare(
|
||||||
|
&mut self,
|
||||||
|
left: &ast::Expr<Option<Type>>,
|
||||||
|
ops: &[ast::Cmpop],
|
||||||
|
comparators: &[ast::Expr<Option<Type>>],
|
||||||
|
) -> InferenceResult {
|
||||||
|
let boolean = self.primitives.bool;
|
||||||
|
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
||||||
|
let method =
|
||||||
|
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string();
|
||||||
|
self.build_method_call(
|
||||||
|
a.location,
|
||||||
|
method,
|
||||||
|
a.custom.unwrap(),
|
||||||
|
vec![b.custom.unwrap()],
|
||||||
|
boolean,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
Ok(boolean)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_subscript(
|
||||||
|
&mut self,
|
||||||
|
value: &ast::Expr<Option<Type>>,
|
||||||
|
slice: &ast::Expr<Option<Type>>,
|
||||||
|
) -> InferenceResult {
|
||||||
|
let ty = self.unifier.get_fresh_var().0;
|
||||||
|
match &slice.node {
|
||||||
|
ast::ExprKind::Slice { lower, upper, step } => {
|
||||||
|
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||||
|
self.constrain(v.custom.unwrap(), self.primitives.int32)?;
|
||||||
|
}
|
||||||
|
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||||
|
self.constrain(value.custom.unwrap(), list)?;
|
||||||
|
Ok(list)
|
||||||
|
}
|
||||||
|
ast::ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||||
|
// the index is a constant, so value can be a sequence.
|
||||||
|
let ind: i32 = val.try_into().map_err(|_| "Index must be int32".to_string())?;
|
||||||
|
let map = once((ind, ty)).collect();
|
||||||
|
let seq = self.unifier.add_sequence(map);
|
||||||
|
self.constrain(value.custom.unwrap(), seq)?;
|
||||||
|
Ok(ty)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// the index is not a constant, so value can only be a list
|
||||||
|
self.constrain(slice.custom.unwrap(), self.primitives.int32)?;
|
||||||
|
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||||
|
self.constrain(value.custom.unwrap(), list)?;
|
||||||
|
Ok(ty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_if_expr(
|
||||||
|
&mut self,
|
||||||
|
test: &ast::Expr<Option<Type>>,
|
||||||
|
body: &ast::Expr<Option<Type>>,
|
||||||
|
orelse: &ast::Expr<Option<Type>>,
|
||||||
|
) -> InferenceResult {
|
||||||
|
self.constrain(test.custom.unwrap(), self.primitives.bool)?;
|
||||||
|
let ty = self.unifier.get_fresh_var().0;
|
||||||
|
self.constrain(body.custom.unwrap(), ty)?;
|
||||||
|
self.constrain(orelse.custom.unwrap(), ty)?;
|
||||||
|
Ok(ty)
|
||||||
|
}
|
||||||
|
}
|
546
nac3core/src/typecheck/type_inferencer/test.rs
Normal file
546
nac3core/src/typecheck/type_inferencer/test.rs
Normal file
@ -0,0 +1,546 @@
|
|||||||
|
use super::super::typedef::*;
|
||||||
|
use super::*;
|
||||||
|
use crate::symbol_resolver::*;
|
||||||
|
use crate::top_level::DefinitionId;
|
||||||
|
use crate::{location::Location, top_level::TopLevelDef};
|
||||||
|
use indoc::indoc;
|
||||||
|
use itertools::zip;
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use rustpython_parser::parser::parse_program;
|
||||||
|
use test_case::test_case;
|
||||||
|
|
||||||
|
struct Resolver {
|
||||||
|
id_to_type: HashMap<String, Type>,
|
||||||
|
id_to_def: HashMap<String, DefinitionId>,
|
||||||
|
class_names: HashMap<String, Type>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SymbolResolver for Resolver {
|
||||||
|
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
||||||
|
self.id_to_type.get(str).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_symbol_location(&self, _: &str) -> Option<Location> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
|
||||||
|
self.id_to_def.get(id).cloned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestEnvironment {
|
||||||
|
pub unifier: Unifier,
|
||||||
|
pub function_data: FunctionData,
|
||||||
|
pub primitives: PrimitiveStore,
|
||||||
|
pub id_to_name: HashMap<usize, String>,
|
||||||
|
pub identifier_mapping: HashMap<String, Type>,
|
||||||
|
pub virtual_checks: Vec<(Type, Type)>,
|
||||||
|
pub calls: HashMap<CodeLocation, CallId>,
|
||||||
|
pub top_level: TopLevelContext,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestEnvironment {
|
||||||
|
pub fn basic_test_env() -> TestEnvironment {
|
||||||
|
let mut unifier = Unifier::new();
|
||||||
|
|
||||||
|
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(0),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(1),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let float = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(2),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(3),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let none = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(4),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||||
|
set_primitives_magic_methods(&primitives, &mut unifier);
|
||||||
|
|
||||||
|
let id_to_name = [
|
||||||
|
(0, "int32".to_string()),
|
||||||
|
(1, "int64".to_string()),
|
||||||
|
(2, "float".to_string()),
|
||||||
|
(3, "bool".to_string()),
|
||||||
|
(4, "none".to_string()),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut identifier_mapping = HashMap::new();
|
||||||
|
identifier_mapping.insert("None".into(), none);
|
||||||
|
|
||||||
|
let resolver = Arc::new(Resolver {
|
||||||
|
id_to_type: identifier_mapping.clone(),
|
||||||
|
id_to_def: Default::default(),
|
||||||
|
class_names: Default::default(),
|
||||||
|
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||||
|
|
||||||
|
TestEnvironment {
|
||||||
|
top_level: TopLevelContext {
|
||||||
|
definitions: Default::default(),
|
||||||
|
unifiers: Default::default(),
|
||||||
|
},
|
||||||
|
unifier,
|
||||||
|
function_data: FunctionData {
|
||||||
|
resolver,
|
||||||
|
bound_variables: Vec::new(),
|
||||||
|
return_type: None,
|
||||||
|
},
|
||||||
|
primitives,
|
||||||
|
id_to_name,
|
||||||
|
identifier_mapping,
|
||||||
|
virtual_checks: Vec::new(),
|
||||||
|
calls: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new() -> TestEnvironment {
|
||||||
|
let mut unifier = Unifier::new();
|
||||||
|
let mut identifier_mapping = HashMap::new();
|
||||||
|
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
|
||||||
|
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(0),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(1),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let float = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(2),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(3),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let none = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(4),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
identifier_mapping.insert("None".into(), none);
|
||||||
|
for i in 0..5 {
|
||||||
|
top_level_defs.push(
|
||||||
|
RwLock::new(TopLevelDef::Class {
|
||||||
|
object_id: DefinitionId(i),
|
||||||
|
type_vars: Default::default(),
|
||||||
|
fields: Default::default(),
|
||||||
|
methods: Default::default(),
|
||||||
|
ancestors: Default::default(),
|
||||||
|
resolver: None,
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||||
|
|
||||||
|
let (v0, id) = unifier.get_fresh_var();
|
||||||
|
|
||||||
|
let foo_ty = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(5),
|
||||||
|
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||||
|
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||||
|
});
|
||||||
|
top_level_defs.push(
|
||||||
|
RwLock::new(TopLevelDef::Class {
|
||||||
|
object_id: DefinitionId(5),
|
||||||
|
type_vars: vec![v0],
|
||||||
|
fields: [("a".into(), v0)].into(),
|
||||||
|
methods: Default::default(),
|
||||||
|
ancestors: Default::default(),
|
||||||
|
resolver: None,
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
identifier_mapping.insert(
|
||||||
|
"Foo".into(),
|
||||||
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature {
|
||||||
|
args: vec![],
|
||||||
|
ret: foo_ty,
|
||||||
|
vars: [(id, v0)].iter().cloned().collect(),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
|
||||||
|
let fun = unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature { args: vec![], ret: int32, vars: Default::default() }.into(),
|
||||||
|
));
|
||||||
|
let bar = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(6),
|
||||||
|
fields: [("a".into(), int32), ("b".into(), fun)]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect::<HashMap<_, _>>()
|
||||||
|
.into(),
|
||||||
|
params: Default::default(),
|
||||||
|
});
|
||||||
|
top_level_defs.push(
|
||||||
|
RwLock::new(TopLevelDef::Class {
|
||||||
|
object_id: DefinitionId(6),
|
||||||
|
type_vars: Default::default(),
|
||||||
|
fields: [("a".into(), int32), ("b".into(), fun)].into(),
|
||||||
|
methods: Default::default(),
|
||||||
|
ancestors: Default::default(),
|
||||||
|
resolver: None,
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
identifier_mapping.insert(
|
||||||
|
"Bar".into(),
|
||||||
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature { args: vec![], ret: bar, vars: Default::default() }.into(),
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
|
||||||
|
let bar2 = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(7),
|
||||||
|
fields: [("a".into(), bool), ("b".into(), fun)]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect::<HashMap<_, _>>()
|
||||||
|
.into(),
|
||||||
|
params: Default::default(),
|
||||||
|
});
|
||||||
|
top_level_defs.push(
|
||||||
|
RwLock::new(TopLevelDef::Class {
|
||||||
|
object_id: DefinitionId(7),
|
||||||
|
type_vars: Default::default(),
|
||||||
|
fields: [("a".into(), bool), ("b".into(), fun)].into(),
|
||||||
|
methods: Default::default(),
|
||||||
|
ancestors: Default::default(),
|
||||||
|
resolver: None,
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
identifier_mapping.insert(
|
||||||
|
"Bar2".into(),
|
||||||
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature { args: vec![], ret: bar2, vars: Default::default() }.into(),
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
|
||||||
|
|
||||||
|
let id_to_name = [
|
||||||
|
(0, "int32".to_string()),
|
||||||
|
(1, "int64".to_string()),
|
||||||
|
(2, "float".to_string()),
|
||||||
|
(3, "bool".to_string()),
|
||||||
|
(4, "none".to_string()),
|
||||||
|
(5, "Foo".to_string()),
|
||||||
|
(6, "Bar".to_string()),
|
||||||
|
(7, "Bar2".to_string()),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let top_level = TopLevelContext {
|
||||||
|
definitions: Arc::new(RwLock::new(top_level_defs)),
|
||||||
|
unifiers: Default::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let resolver = Arc::new(Resolver {
|
||||||
|
id_to_type: identifier_mapping.clone(),
|
||||||
|
id_to_def: [
|
||||||
|
("Foo".into(), DefinitionId(5)),
|
||||||
|
("Bar".into(), DefinitionId(6)),
|
||||||
|
("Bar2".into(), DefinitionId(7)),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect(),
|
||||||
|
class_names,
|
||||||
|
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||||
|
|
||||||
|
TestEnvironment {
|
||||||
|
unifier,
|
||||||
|
top_level,
|
||||||
|
function_data: FunctionData {
|
||||||
|
resolver,
|
||||||
|
bound_variables: Vec::new(),
|
||||||
|
return_type: None,
|
||||||
|
},
|
||||||
|
primitives,
|
||||||
|
id_to_name,
|
||||||
|
identifier_mapping,
|
||||||
|
virtual_checks: Vec::new(),
|
||||||
|
calls: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_inferencer(&mut self) -> Inferencer {
|
||||||
|
Inferencer {
|
||||||
|
top_level: &self.top_level,
|
||||||
|
function_data: &mut self.function_data,
|
||||||
|
unifier: &mut self.unifier,
|
||||||
|
variable_mapping: Default::default(),
|
||||||
|
primitives: &mut self.primitives,
|
||||||
|
virtual_checks: &mut self.virtual_checks,
|
||||||
|
calls: &mut self.calls,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test_case(indoc! {"
|
||||||
|
a = 1234
|
||||||
|
b = int64(2147483648)
|
||||||
|
c = 1.234
|
||||||
|
d = True
|
||||||
|
"},
|
||||||
|
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(),
|
||||||
|
&[]
|
||||||
|
; "primitives test")]
|
||||||
|
#[test_case(indoc! {"
|
||||||
|
a = lambda x, y: x
|
||||||
|
b = lambda x: a(x, x)
|
||||||
|
c = 1.234
|
||||||
|
d = b(c)
|
||||||
|
"},
|
||||||
|
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
|
||||||
|
&[]
|
||||||
|
; "lambda test")]
|
||||||
|
#[test_case(indoc! {"
|
||||||
|
a = lambda x: x
|
||||||
|
b = lambda x: x
|
||||||
|
|
||||||
|
foo1 = Foo()
|
||||||
|
foo2 = Foo()
|
||||||
|
c = a(foo1.a)
|
||||||
|
d = b(foo2.a)
|
||||||
|
|
||||||
|
a(True)
|
||||||
|
b(123)
|
||||||
|
|
||||||
|
"},
|
||||||
|
[("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"),
|
||||||
|
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(),
|
||||||
|
&[]
|
||||||
|
; "obj test")]
|
||||||
|
#[test_case(indoc! {"
|
||||||
|
f = lambda x: True
|
||||||
|
a = [1, 2, 3]
|
||||||
|
b = [f(x) for x in a if f(x)]
|
||||||
|
"},
|
||||||
|
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(),
|
||||||
|
&[]
|
||||||
|
; "listcomp test")]
|
||||||
|
#[test_case(indoc! {"
|
||||||
|
a = virtual(Bar(), Bar)
|
||||||
|
b = a.b()
|
||||||
|
a = virtual(Bar2())
|
||||||
|
"},
|
||||||
|
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(),
|
||||||
|
&[("Bar", "Bar"), ("Bar2", "Bar")]
|
||||||
|
; "virtual test")]
|
||||||
|
#[test_case(indoc! {"
|
||||||
|
a = [virtual(Bar(), Bar), virtual(Bar2())]
|
||||||
|
b = [x.b() for x in a]
|
||||||
|
"},
|
||||||
|
[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(),
|
||||||
|
&[("Bar", "Bar"), ("Bar2", "Bar")]
|
||||||
|
; "virtual list test")]
|
||||||
|
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
|
||||||
|
println!("source:\n{}", source);
|
||||||
|
let mut env = TestEnvironment::new();
|
||||||
|
let id_to_name = std::mem::take(&mut env.id_to_name);
|
||||||
|
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
||||||
|
defined_identifiers.push("virtual".to_string());
|
||||||
|
let mut inferencer = env.get_inferencer();
|
||||||
|
let statements = parse_program(source).unwrap();
|
||||||
|
let statements = statements
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| inferencer.fold_stmt(v))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
|
||||||
|
|
||||||
|
for (k, v) in inferencer.variable_mapping.iter() {
|
||||||
|
let name = inferencer.unifier.stringify(
|
||||||
|
*v,
|
||||||
|
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||||
|
&mut |v| format!("v{}", v),
|
||||||
|
);
|
||||||
|
println!("{}: {}", k, name);
|
||||||
|
}
|
||||||
|
for (k, v) in mapping.iter() {
|
||||||
|
let ty = inferencer.variable_mapping.get(*k).unwrap();
|
||||||
|
let name = inferencer.unifier.stringify(
|
||||||
|
*ty,
|
||||||
|
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||||
|
&mut |v| format!("v{}", v),
|
||||||
|
);
|
||||||
|
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
|
||||||
|
}
|
||||||
|
assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
|
||||||
|
for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
|
||||||
|
let a = inferencer.unifier.stringify(
|
||||||
|
*a,
|
||||||
|
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||||
|
&mut |v| format!("v{}", v),
|
||||||
|
);
|
||||||
|
let b = inferencer.unifier.stringify(
|
||||||
|
*b,
|
||||||
|
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||||
|
&mut |v| format!("v{}", v),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(&a, x);
|
||||||
|
assert_eq!(&b, y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test_case(indoc! {"
|
||||||
|
a = 2
|
||||||
|
b = 2
|
||||||
|
c = a + b
|
||||||
|
d = a - b
|
||||||
|
e = a * b
|
||||||
|
f = a / b
|
||||||
|
g = a // b
|
||||||
|
h = a % b
|
||||||
|
"},
|
||||||
|
[("a", "int32"),
|
||||||
|
("b", "int32"),
|
||||||
|
("c", "int32"),
|
||||||
|
("d", "int32"),
|
||||||
|
("e", "int32"),
|
||||||
|
("f", "float"),
|
||||||
|
("g", "int32"),
|
||||||
|
("h", "int32")].iter().cloned().collect()
|
||||||
|
; "int32")]
|
||||||
|
#[test_case(
|
||||||
|
indoc! {"
|
||||||
|
a = 2.4
|
||||||
|
b = 3.6
|
||||||
|
c = a + b
|
||||||
|
d = a - b
|
||||||
|
e = a * b
|
||||||
|
f = a / b
|
||||||
|
g = a // b
|
||||||
|
h = a % b
|
||||||
|
i = a ** b
|
||||||
|
ii = 3
|
||||||
|
j = a ** b
|
||||||
|
"},
|
||||||
|
[("a", "float"),
|
||||||
|
("b", "float"),
|
||||||
|
("c", "float"),
|
||||||
|
("d", "float"),
|
||||||
|
("e", "float"),
|
||||||
|
("f", "float"),
|
||||||
|
("g", "float"),
|
||||||
|
("h", "float"),
|
||||||
|
("i", "float"),
|
||||||
|
("ii", "int32"),
|
||||||
|
("j", "float")].iter().cloned().collect()
|
||||||
|
; "float"
|
||||||
|
)]
|
||||||
|
#[test_case(
|
||||||
|
indoc! {"
|
||||||
|
a = int64(12312312312)
|
||||||
|
b = int64(24242424424)
|
||||||
|
c = a + b
|
||||||
|
d = a - b
|
||||||
|
e = a * b
|
||||||
|
f = a / b
|
||||||
|
g = a // b
|
||||||
|
h = a % b
|
||||||
|
i = a == b
|
||||||
|
j = a > b
|
||||||
|
k = a < b
|
||||||
|
l = a != b
|
||||||
|
"},
|
||||||
|
[("a", "int64"),
|
||||||
|
("b", "int64"),
|
||||||
|
("c", "int64"),
|
||||||
|
("d", "int64"),
|
||||||
|
("e", "int64"),
|
||||||
|
("f", "float"),
|
||||||
|
("g", "int64"),
|
||||||
|
("h", "int64"),
|
||||||
|
("i", "bool"),
|
||||||
|
("j", "bool"),
|
||||||
|
("k", "bool"),
|
||||||
|
("l", "bool")].iter().cloned().collect()
|
||||||
|
; "int64"
|
||||||
|
)]
|
||||||
|
#[test_case(
|
||||||
|
indoc! {"
|
||||||
|
a = True
|
||||||
|
b = False
|
||||||
|
c = a == b
|
||||||
|
d = not a
|
||||||
|
e = a != b
|
||||||
|
"},
|
||||||
|
[("a", "bool"),
|
||||||
|
("b", "bool"),
|
||||||
|
("c", "bool"),
|
||||||
|
("d", "bool"),
|
||||||
|
("e", "bool")].iter().cloned().collect()
|
||||||
|
; "boolean"
|
||||||
|
)]
|
||||||
|
fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
|
||||||
|
println!("source:\n{}", source);
|
||||||
|
let mut env = TestEnvironment::basic_test_env();
|
||||||
|
let id_to_name = std::mem::take(&mut env.id_to_name);
|
||||||
|
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
||||||
|
defined_identifiers.push("virtual".to_string());
|
||||||
|
let mut inferencer = env.get_inferencer();
|
||||||
|
let statements = parse_program(source).unwrap();
|
||||||
|
let statements = statements
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| inferencer.fold_stmt(v))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
|
||||||
|
|
||||||
|
for (k, v) in inferencer.variable_mapping.iter() {
|
||||||
|
let name = inferencer.unifier.stringify(
|
||||||
|
*v,
|
||||||
|
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||||
|
&mut |v| format!("v{}", v),
|
||||||
|
);
|
||||||
|
println!("{}: {}", k, name);
|
||||||
|
}
|
||||||
|
for (k, v) in mapping.iter() {
|
||||||
|
let ty = inferencer.variable_mapping.get(*k).unwrap();
|
||||||
|
let name = inferencer.unifier.stringify(
|
||||||
|
*ty,
|
||||||
|
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||||
|
&mut |v| format!("v{}", v),
|
||||||
|
);
|
||||||
|
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
|
||||||
|
}
|
||||||
|
}
|
947
nac3core/src/typecheck/typedef/mod.rs
Normal file
947
nac3core/src/typecheck/typedef/mod.rs
Normal file
@ -0,0 +1,947 @@
|
|||||||
|
use itertools::{chain, zip, Itertools};
|
||||||
|
use std::borrow::Cow;
|
||||||
|
use std::cell::RefCell;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::iter::once;
|
||||||
|
use std::rc::Rc;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use super::unification_table::{UnificationKey, UnificationTable};
|
||||||
|
use crate::symbol_resolver::SymbolValue;
|
||||||
|
use crate::top_level::DefinitionId;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test;
|
||||||
|
|
||||||
|
/// Handle for a type, implementated as a key in the unification table.
|
||||||
|
pub type Type = UnificationKey;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub struct CallId(usize);
|
||||||
|
|
||||||
|
pub type Mapping<K, V = Type> = HashMap<K, V>;
|
||||||
|
type VarMap = Mapping<u32>;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Call {
|
||||||
|
pub posargs: Vec<Type>,
|
||||||
|
pub kwargs: HashMap<String, Type>,
|
||||||
|
pub ret: Type,
|
||||||
|
pub fun: RefCell<Option<Type>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FuncArg {
|
||||||
|
pub name: String,
|
||||||
|
pub ty: Type,
|
||||||
|
pub default_value: Option<SymbolValue>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FunSignature {
|
||||||
|
pub args: Vec<FuncArg>,
|
||||||
|
pub ret: Type,
|
||||||
|
pub vars: VarMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum TypeVarMeta {
|
||||||
|
Generic,
|
||||||
|
Sequence(RefCell<Mapping<i32>>),
|
||||||
|
Record(RefCell<Mapping<String>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum TypeEnum {
|
||||||
|
TRigidVar {
|
||||||
|
id: u32,
|
||||||
|
},
|
||||||
|
TVar {
|
||||||
|
id: u32,
|
||||||
|
meta: TypeVarMeta,
|
||||||
|
// empty indicates no restriction
|
||||||
|
range: RefCell<Vec<Type>>,
|
||||||
|
},
|
||||||
|
TTuple {
|
||||||
|
ty: Vec<Type>,
|
||||||
|
},
|
||||||
|
TList {
|
||||||
|
ty: Type,
|
||||||
|
},
|
||||||
|
TObj {
|
||||||
|
obj_id: DefinitionId,
|
||||||
|
fields: RefCell<Mapping<String>>,
|
||||||
|
params: RefCell<VarMap>,
|
||||||
|
},
|
||||||
|
TVirtual {
|
||||||
|
ty: Type,
|
||||||
|
},
|
||||||
|
TCall(RefCell<Vec<CallId>>),
|
||||||
|
TFunc(RefCell<FunSignature>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TypeEnum {
|
||||||
|
pub fn get_type_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
TypeEnum::TRigidVar { .. } => "TRigidVar",
|
||||||
|
TypeEnum::TVar { .. } => "TVar",
|
||||||
|
TypeEnum::TTuple { .. } => "TTuple",
|
||||||
|
TypeEnum::TList { .. } => "TList",
|
||||||
|
TypeEnum::TObj { .. } => "TObj",
|
||||||
|
TypeEnum::TVirtual { .. } => "TVirtual",
|
||||||
|
TypeEnum::TCall { .. } => "TCall",
|
||||||
|
TypeEnum::TFunc { .. } => "TFunc",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
|
||||||
|
|
||||||
|
pub struct Unifier {
|
||||||
|
unification_table: UnificationTable<Rc<TypeEnum>>,
|
||||||
|
calls: Vec<Rc<Call>>,
|
||||||
|
var_id: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Unifier {
|
||||||
|
/// Get an empty unifier
|
||||||
|
pub fn new() -> Unifier {
|
||||||
|
Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determine if the two types are the same
|
||||||
|
pub fn unioned(&mut self, a: Type, b: Type) -> bool {
|
||||||
|
self.unification_table.unioned(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
|
||||||
|
let lock = unifier.lock().unwrap();
|
||||||
|
Unifier {
|
||||||
|
unification_table: UnificationTable::from_send(&lock.0),
|
||||||
|
var_id: lock.1,
|
||||||
|
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_shared_unifier(&self) -> SharedUnifier {
|
||||||
|
Arc::new(Mutex::new((
|
||||||
|
self.unification_table.get_send(),
|
||||||
|
self.var_id,
|
||||||
|
self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a type to the unifier.
|
||||||
|
/// Returns a key in the unification_table.
|
||||||
|
pub fn add_ty(&mut self, a: TypeEnum) -> Type {
|
||||||
|
self.unification_table.new_key(Rc::new(a))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_record(&mut self, fields: Mapping<String>) -> Type {
|
||||||
|
let id = self.var_id + 1;
|
||||||
|
self.var_id += 1;
|
||||||
|
self.add_ty(TypeEnum::TVar {
|
||||||
|
id,
|
||||||
|
range: vec![].into(),
|
||||||
|
meta: TypeVarMeta::Record(fields.into()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_call(&mut self, call: Call) -> CallId {
|
||||||
|
let id = CallId(self.calls.len());
|
||||||
|
self.calls.push(Rc::new(call));
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_representative(&mut self, ty: Type) -> Type {
|
||||||
|
self.unification_table.get_representative(ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_sequence(&mut self, sequence: Mapping<i32>) -> Type {
|
||||||
|
let id = self.var_id + 1;
|
||||||
|
self.var_id += 1;
|
||||||
|
self.add_ty(TypeEnum::TVar {
|
||||||
|
id,
|
||||||
|
range: vec![].into(),
|
||||||
|
meta: TypeVarMeta::Sequence(sequence.into()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the TypeEnum of a type.
|
||||||
|
pub fn get_ty(&mut self, a: Type) -> Rc<TypeEnum> {
|
||||||
|
self.unification_table.probe_value(a).clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_fresh_rigid_var(&mut self) -> (Type, u32) {
|
||||||
|
let id = self.var_id + 1;
|
||||||
|
self.var_id += 1;
|
||||||
|
(self.add_ty(TypeEnum::TRigidVar { id }), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_fresh_var(&mut self) -> (Type, u32) {
|
||||||
|
self.get_fresh_var_with_range(&[])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a fresh type variable.
|
||||||
|
pub fn get_fresh_var_with_range(&mut self, range: &[Type]) -> (Type, u32) {
|
||||||
|
let id = self.var_id + 1;
|
||||||
|
self.var_id += 1;
|
||||||
|
let range = range.to_vec().into();
|
||||||
|
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unification would not unify rigid variables with other types, but we want to do this for
|
||||||
|
/// function instantiations, so we make it explicit.
|
||||||
|
pub fn replace_rigid_var(&mut self, rigid: Type, b: Type) {
|
||||||
|
assert!(matches!(&*self.get_ty(rigid), TypeEnum::TRigidVar { .. }));
|
||||||
|
self.set_a_to_b(rigid, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_instantiations(&mut self, ty: Type) -> Option<Vec<Type>> {
|
||||||
|
match &*self.get_ty(ty) {
|
||||||
|
TypeEnum::TVar { range, .. } => {
|
||||||
|
let range = range.borrow();
|
||||||
|
if range.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
range
|
||||||
|
.iter()
|
||||||
|
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||||
|
.flatten()
|
||||||
|
.collect_vec(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TList { ty } => self
|
||||||
|
.get_instantiations(*ty)
|
||||||
|
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
|
||||||
|
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
|
||||||
|
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
|
||||||
|
}),
|
||||||
|
TypeEnum::TTuple { ty } => {
|
||||||
|
let tuples = ty
|
||||||
|
.iter()
|
||||||
|
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||||
|
.multi_cartesian_product()
|
||||||
|
.collect_vec();
|
||||||
|
if tuples.len() == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { params, .. } => {
|
||||||
|
let params = params.borrow();
|
||||||
|
let (keys, params): (Vec<&u32>, Vec<&Type>) = params.iter().unzip();
|
||||||
|
let params = params
|
||||||
|
.into_iter()
|
||||||
|
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||||
|
.multi_cartesian_product()
|
||||||
|
.collect_vec();
|
||||||
|
if params.len() <= 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
params
|
||||||
|
.into_iter()
|
||||||
|
.map(|params| {
|
||||||
|
self.subst(
|
||||||
|
ty,
|
||||||
|
&zip(keys.iter().cloned().cloned(), params.iter().cloned())
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
.unwrap_or(ty)
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
||||||
|
use TypeEnum::*;
|
||||||
|
match &*self.get_ty(a) {
|
||||||
|
TRigidVar { .. } => true,
|
||||||
|
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||||
|
TCall { .. } => false,
|
||||||
|
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||||
|
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||||
|
TObj { params: vars, .. } => {
|
||||||
|
vars.borrow().values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||||
|
}
|
||||||
|
// functions are instantiated for each call sites, so the function type can contain
|
||||||
|
// type variables.
|
||||||
|
TFunc { .. } => true,
|
||||||
|
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||||
|
if self.unification_table.unioned(a, b) {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
self.unify_impl(a, b, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> {
|
||||||
|
use TypeEnum::*;
|
||||||
|
use TypeVarMeta::*;
|
||||||
|
let (ty_a, ty_b) = {
|
||||||
|
(
|
||||||
|
self.unification_table.probe_value(a).clone(),
|
||||||
|
self.unification_table.probe_value(b).clone(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
match (&*ty_a, &*ty_b) {
|
||||||
|
(TVar { meta: meta1, range: range1, .. }, TVar { meta: meta2, range: range2, .. }) => {
|
||||||
|
self.occur_check(a, b)?;
|
||||||
|
self.occur_check(b, a)?;
|
||||||
|
match (meta1, meta2) {
|
||||||
|
(Generic, _) => {}
|
||||||
|
(_, Generic) => {
|
||||||
|
return self.unify_impl(b, a, true);
|
||||||
|
}
|
||||||
|
(Record(fields1), Record(fields2)) => {
|
||||||
|
let mut fields2 = fields2.borrow_mut();
|
||||||
|
for (key, value) in fields1.borrow().iter() {
|
||||||
|
if let Some(ty) = fields2.get(key) {
|
||||||
|
self.unify(*ty, *value)?;
|
||||||
|
} else {
|
||||||
|
fields2.insert(key.clone(), *value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Sequence(map1), Sequence(map2)) => {
|
||||||
|
let mut map2 = map2.borrow_mut();
|
||||||
|
for (key, value) in map1.borrow().iter() {
|
||||||
|
if let Some(ty) = map2.get(key) {
|
||||||
|
self.unify(*ty, *value)?;
|
||||||
|
} else {
|
||||||
|
map2.insert(*key, *value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err("Incompatible".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let range1 = range1.borrow();
|
||||||
|
// new range is the intersection of them
|
||||||
|
// empty range indicates no constraint
|
||||||
|
if !range1.is_empty() {
|
||||||
|
let old_range2 = range2.take();
|
||||||
|
let mut range2 = range2.borrow_mut();
|
||||||
|
if old_range2.is_empty() {
|
||||||
|
range2.extend_from_slice(&range1);
|
||||||
|
}
|
||||||
|
for v1 in old_range2.iter() {
|
||||||
|
for v2 in range1.iter() {
|
||||||
|
if let Ok(result) = self.get_intersection(*v1, *v2) {
|
||||||
|
range2.push(result.unwrap_or(*v2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if range2.is_empty() {
|
||||||
|
return Err(
|
||||||
|
"cannot unify type variables with incompatible value range".to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
(TVar { meta: Generic, id, range, .. }, _) => {
|
||||||
|
self.occur_check(a, b)?;
|
||||||
|
// We check for the range of the type variable to see if unification is allowed.
|
||||||
|
// Note that although b may be compatible with a, we may have to constrain type
|
||||||
|
// variables in b to make sure that instantiations of b would always be compatible
|
||||||
|
// with a.
|
||||||
|
// The return value x of check_var_compatibility would be a new type that is
|
||||||
|
// guaranteed to be compatible with a under all possible instantiations. So we
|
||||||
|
// unify x with b to recursively apply the constrains, and then set a to x.
|
||||||
|
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||||
|
self.unify(x, b)?;
|
||||||
|
self.set_a_to_b(a, x);
|
||||||
|
}
|
||||||
|
(TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => {
|
||||||
|
self.occur_check(a, b)?;
|
||||||
|
let len = ty.len() as i32;
|
||||||
|
for (k, v) in map.borrow().iter() {
|
||||||
|
// handle negative index
|
||||||
|
let ind = if *k < 0 { len + *k } else { *k };
|
||||||
|
if ind >= len || ind < 0 {
|
||||||
|
return Err(format!(
|
||||||
|
"Tuple index out of range. (Length: {}, Index: {})",
|
||||||
|
len, k
|
||||||
|
));
|
||||||
|
}
|
||||||
|
self.unify(*v, ty[ind as usize])?;
|
||||||
|
}
|
||||||
|
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||||
|
self.unify(x, b)?;
|
||||||
|
self.set_a_to_b(a, x);
|
||||||
|
}
|
||||||
|
(TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => {
|
||||||
|
self.occur_check(a, b)?;
|
||||||
|
for v in map.borrow().values() {
|
||||||
|
self.unify(*v, *ty)?;
|
||||||
|
}
|
||||||
|
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||||
|
self.unify(x, b)?;
|
||||||
|
self.set_a_to_b(a, x);
|
||||||
|
}
|
||||||
|
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
||||||
|
if ty1.len() != ty2.len() {
|
||||||
|
return Err(format!(
|
||||||
|
"Cannot unify tuples with length {} and {}",
|
||||||
|
ty1.len(),
|
||||||
|
ty2.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
for (x, y) in ty1.iter().zip(ty2.iter()) {
|
||||||
|
self.unify(*x, *y)?;
|
||||||
|
}
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||||
|
self.unify(*ty1, *ty2)?;
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
|
||||||
|
self.occur_check(a, b)?;
|
||||||
|
for (k, v) in map.borrow().iter() {
|
||||||
|
let ty = fields
|
||||||
|
.borrow()
|
||||||
|
.get(k)
|
||||||
|
.copied()
|
||||||
|
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||||
|
self.unify(ty, *v)?;
|
||||||
|
}
|
||||||
|
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||||
|
self.unify(x, b)?;
|
||||||
|
self.set_a_to_b(a, x);
|
||||||
|
}
|
||||||
|
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
|
||||||
|
self.occur_check(a, b)?;
|
||||||
|
let ty = self.get_ty(*ty);
|
||||||
|
if let TObj { fields, .. } = ty.as_ref() {
|
||||||
|
for (k, v) in map.borrow().iter() {
|
||||||
|
let ty = fields
|
||||||
|
.borrow()
|
||||||
|
.get(k)
|
||||||
|
.copied()
|
||||||
|
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||||
|
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
|
||||||
|
return Err(format!("Cannot access field {} for virtual type", k));
|
||||||
|
}
|
||||||
|
self.unify(*v, ty)?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// require annotation...
|
||||||
|
return Err("Requires type annotation for virtual".to_string());
|
||||||
|
}
|
||||||
|
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||||
|
self.unify(x, b)?;
|
||||||
|
self.set_a_to_b(a, x);
|
||||||
|
}
|
||||||
|
(
|
||||||
|
TObj { obj_id: id1, params: params1, .. },
|
||||||
|
TObj { obj_id: id2, params: params2, .. },
|
||||||
|
) => {
|
||||||
|
if id1 != id2 {
|
||||||
|
return Err(format!("Cannot unify objects with ID {} and {}", id1.0, id2.0));
|
||||||
|
}
|
||||||
|
for (x, y) in zip(params1.borrow().values(), params2.borrow().values()) {
|
||||||
|
self.unify(*x, *y)?;
|
||||||
|
}
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||||
|
self.unify(*ty1, *ty2)?;
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
(TCall(calls1), TCall(calls2)) => {
|
||||||
|
// we do not unify individual calls, instead we defer until the unification wtih a
|
||||||
|
// function definition.
|
||||||
|
calls2.borrow_mut().extend_from_slice(&calls1.borrow());
|
||||||
|
}
|
||||||
|
(TCall(calls), TFunc(signature)) => {
|
||||||
|
self.occur_check(a, b)?;
|
||||||
|
let required: Vec<String> = signature
|
||||||
|
.borrow()
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.filter(|v| v.default_value.is_none())
|
||||||
|
.map(|v| v.name.clone())
|
||||||
|
.rev()
|
||||||
|
.collect();
|
||||||
|
// we unify every calls to the function signature.
|
||||||
|
for c in calls.borrow().iter() {
|
||||||
|
let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone();
|
||||||
|
let instantiated = self.instantiate_fun(b, &*signature.borrow());
|
||||||
|
let r = self.get_ty(instantiated);
|
||||||
|
let r = r.as_ref();
|
||||||
|
let signature;
|
||||||
|
if let TypeEnum::TFunc(s) = &*r {
|
||||||
|
signature = s;
|
||||||
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
}
|
||||||
|
// we check to make sure that all required arguments (those without default
|
||||||
|
// arguments) are provided, and do not provide the same argument twice.
|
||||||
|
let mut required = required.clone();
|
||||||
|
let mut all_names: Vec<_> = signature
|
||||||
|
.borrow()
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|v| (v.name.clone(), v.ty))
|
||||||
|
.rev()
|
||||||
|
.collect();
|
||||||
|
for (i, t) in posargs.iter().enumerate() {
|
||||||
|
if signature.borrow().args.len() <= i {
|
||||||
|
return Err("Too many arguments.".to_string());
|
||||||
|
}
|
||||||
|
if !required.is_empty() {
|
||||||
|
required.pop();
|
||||||
|
}
|
||||||
|
self.unify(all_names.pop().unwrap().1, *t)?;
|
||||||
|
}
|
||||||
|
for (k, t) in kwargs.iter() {
|
||||||
|
if let Some(i) = required.iter().position(|v| v == k) {
|
||||||
|
required.remove(i);
|
||||||
|
}
|
||||||
|
let i = all_names
|
||||||
|
.iter()
|
||||||
|
.position(|v| &v.0 == k)
|
||||||
|
.ok_or_else(|| format!("Unknown keyword argument {}", k))?;
|
||||||
|
self.unify(all_names.remove(i).1, *t)?;
|
||||||
|
}
|
||||||
|
if !required.is_empty() {
|
||||||
|
return Err("Expected more arguments".to_string());
|
||||||
|
}
|
||||||
|
self.unify(*ret, signature.borrow().ret)?;
|
||||||
|
*fun.borrow_mut() = Some(instantiated);
|
||||||
|
}
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
(TFunc(sign1), TFunc(sign2)) => {
|
||||||
|
let (sign1, sign2) = (&*sign1.borrow(), &*sign2.borrow());
|
||||||
|
if !sign1.vars.is_empty() || !sign2.vars.is_empty() {
|
||||||
|
return Err("Polymorphic function pointer is prohibited.".to_string());
|
||||||
|
}
|
||||||
|
if sign1.args.len() != sign2.args.len() {
|
||||||
|
return Err("Functions differ in number of parameters.".to_string());
|
||||||
|
}
|
||||||
|
for (x, y) in sign1.args.iter().zip(sign2.args.iter()) {
|
||||||
|
if x.name != y.name {
|
||||||
|
return Err("Functions differ in parameter names.".to_string());
|
||||||
|
}
|
||||||
|
if x.default_value != y.default_value {
|
||||||
|
return Err("Functions differ in optional parameters.".to_string());
|
||||||
|
}
|
||||||
|
self.unify(x.ty, y.ty)?;
|
||||||
|
}
|
||||||
|
self.unify(sign1.ret, sign2.ret)?;
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if swapped {
|
||||||
|
return self.incompatible_types(&*ty_a, &*ty_b);
|
||||||
|
} else {
|
||||||
|
self.unify_impl(b, a, true)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get string representation of the type
|
||||||
|
pub fn stringify<F, G>(&mut self, ty: Type, obj_to_name: &mut F, var_to_name: &mut G) -> String
|
||||||
|
where
|
||||||
|
F: FnMut(usize) -> String,
|
||||||
|
G: FnMut(u32) -> String,
|
||||||
|
{
|
||||||
|
use TypeVarMeta::*;
|
||||||
|
let ty = self.unification_table.probe_value(ty).clone();
|
||||||
|
match ty.as_ref() {
|
||||||
|
TypeEnum::TRigidVar { id } => var_to_name(*id),
|
||||||
|
TypeEnum::TVar { id, meta: Generic, .. } => var_to_name(*id),
|
||||||
|
TypeEnum::TVar { meta: Sequence(map), .. } => {
|
||||||
|
let fields = map
|
||||||
|
.borrow()
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
|
||||||
|
.join(", ");
|
||||||
|
format!("seq[{}]", fields)
|
||||||
|
}
|
||||||
|
TypeEnum::TVar { meta: Record(fields), .. } => {
|
||||||
|
let fields = fields
|
||||||
|
.borrow()
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
|
||||||
|
.join(", ");
|
||||||
|
format!("record[{}]", fields)
|
||||||
|
}
|
||||||
|
TypeEnum::TTuple { ty } => {
|
||||||
|
let mut fields = ty.iter().map(|v| self.stringify(*v, obj_to_name, var_to_name));
|
||||||
|
format!("tuple[{}]", fields.join(", "))
|
||||||
|
}
|
||||||
|
TypeEnum::TList { ty } => {
|
||||||
|
format!("list[{}]", self.stringify(*ty, obj_to_name, var_to_name))
|
||||||
|
}
|
||||||
|
TypeEnum::TVirtual { ty } => {
|
||||||
|
format!("virtual[{}]", self.stringify(*ty, obj_to_name, var_to_name))
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { obj_id, params, .. } => {
|
||||||
|
let name = obj_to_name(obj_id.0);
|
||||||
|
let params = params.borrow();
|
||||||
|
if !params.is_empty() {
|
||||||
|
let mut params =
|
||||||
|
params.values().map(|v| self.stringify(*v, obj_to_name, var_to_name));
|
||||||
|
format!("{}[{}]", name, params.join(", "))
|
||||||
|
} else {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TCall { .. } => "call".to_owned(),
|
||||||
|
TypeEnum::TFunc(signature) => {
|
||||||
|
let params = signature
|
||||||
|
.borrow()
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|arg| {
|
||||||
|
format!("{}={}", arg.name, self.stringify(arg.ty, obj_to_name, var_to_name))
|
||||||
|
})
|
||||||
|
.join(", ");
|
||||||
|
let ret = self.stringify(signature.borrow().ret, obj_to_name, var_to_name);
|
||||||
|
format!("fn[[{}], {}]", params, ret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_a_to_b(&mut self, a: Type, b: Type) {
|
||||||
|
// unify a and b together, and set the value to b's value.
|
||||||
|
let table = &mut self.unification_table;
|
||||||
|
let ty_b = table.probe_value(b).clone();
|
||||||
|
table.unify(a, b);
|
||||||
|
table.set_value(a, ty_b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn incompatible_types(&self, a: &TypeEnum, b: &TypeEnum) -> Result<(), String> {
|
||||||
|
Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Instantiate a function if it hasn't been instantiated.
|
||||||
|
/// Returns Some(T) where T is the instantiated type.
|
||||||
|
/// Returns None if the function is already instantiated.
|
||||||
|
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
|
||||||
|
let mut instantiated = false;
|
||||||
|
let mut vars = Vec::new();
|
||||||
|
for (k, v) in fun.vars.iter() {
|
||||||
|
if let TypeEnum::TVar { id, range, .. } =
|
||||||
|
self.unification_table.probe_value(*v).as_ref()
|
||||||
|
{
|
||||||
|
if k != id {
|
||||||
|
instantiated = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// actually, if the first check succeeded, the function should be uninstatiated.
|
||||||
|
// The cloned values must be used and would not be wasted.
|
||||||
|
vars.push((*k, range.clone()));
|
||||||
|
} else {
|
||||||
|
instantiated = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if instantiated {
|
||||||
|
ty
|
||||||
|
} else {
|
||||||
|
let mapping = vars
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, range)| (k, self.get_fresh_var_with_range(range.borrow().as_ref()).0))
|
||||||
|
.collect();
|
||||||
|
self.subst(ty, &mapping).unwrap_or(ty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Substitute type variables within a type into other types.
|
||||||
|
/// If this returns Some(T), T would be the substituted type.
|
||||||
|
/// If this returns None, the result type would be the original type
|
||||||
|
/// (no substitution has to be done).
|
||||||
|
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||||
|
use TypeVarMeta::*;
|
||||||
|
let ty = self.unification_table.probe_value(a).clone();
|
||||||
|
// this function would only be called when we instantiate functions.
|
||||||
|
// function type signature should ONLY contain concrete types and type
|
||||||
|
// variables, i.e. things like TRecord, TCall should not occur, and we
|
||||||
|
// should be safe to not implement the substitution for those variants.
|
||||||
|
match &*ty {
|
||||||
|
TypeEnum::TRigidVar { .. } => None,
|
||||||
|
TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(),
|
||||||
|
TypeEnum::TTuple { ty } => {
|
||||||
|
let mut new_ty = Cow::from(ty);
|
||||||
|
for (i, t) in ty.iter().enumerate() {
|
||||||
|
if let Some(t1) = self.subst(*t, mapping) {
|
||||||
|
new_ty.to_mut()[i] = t1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matches!(new_ty, Cow::Owned(_)) {
|
||||||
|
Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() }))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TList { ty } => {
|
||||||
|
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
||||||
|
}
|
||||||
|
TypeEnum::TVirtual { ty } => {
|
||||||
|
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { obj_id, fields, params } => {
|
||||||
|
// Type variables in field types must be present in the type parameter.
|
||||||
|
// If the mapping does not contain any type variables in the
|
||||||
|
// parameter list, we don't need to substitute the fields.
|
||||||
|
// This is also used to prevent infinite substitution...
|
||||||
|
let params = params.borrow();
|
||||||
|
let need_subst = params.values().any(|v| {
|
||||||
|
let ty = self.unification_table.probe_value(*v);
|
||||||
|
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
|
||||||
|
mapping.contains_key(&id)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if need_subst {
|
||||||
|
let obj_id = *obj_id;
|
||||||
|
let params = self.subst_map(¶ms, mapping).unwrap_or_else(|| params.clone());
|
||||||
|
let fields = self
|
||||||
|
.subst_map(&fields.borrow(), mapping)
|
||||||
|
.unwrap_or_else(|| fields.borrow().clone());
|
||||||
|
Some(self.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id,
|
||||||
|
params: params.into(),
|
||||||
|
fields: fields.into(),
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TFunc(sig) => {
|
||||||
|
let FunSignature { args, ret, vars: params } = &*sig.borrow();
|
||||||
|
let new_params = self.subst_map(params, mapping);
|
||||||
|
let new_ret = self.subst(*ret, mapping);
|
||||||
|
let mut new_args = Cow::from(args);
|
||||||
|
for (i, t) in args.iter().enumerate() {
|
||||||
|
if let Some(t1) = self.subst(t.ty, mapping) {
|
||||||
|
let mut t = t.clone();
|
||||||
|
t.ty = t1;
|
||||||
|
new_args.to_mut()[i] = t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if new_params.is_some() || new_ret.is_some() || matches!(new_args, Cow::Owned(..)) {
|
||||||
|
let params = new_params.unwrap_or_else(|| params.clone());
|
||||||
|
let ret = new_ret.unwrap_or_else(|| *ret);
|
||||||
|
let args = new_args.into_owned();
|
||||||
|
Some(
|
||||||
|
self.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature { args, ret, vars: params }.into(),
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
|
||||||
|
where
|
||||||
|
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||||
|
{
|
||||||
|
let mut map2 = None;
|
||||||
|
for (k, v) in map.iter() {
|
||||||
|
if let Some(v1) = self.subst(*v, mapping) {
|
||||||
|
if map2.is_none() {
|
||||||
|
map2 = Some(map.clone());
|
||||||
|
}
|
||||||
|
*map2.as_mut().unwrap().get_mut(k).unwrap() = v1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
map2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||||
|
use TypeVarMeta::*;
|
||||||
|
if self.unification_table.unioned(a, b) {
|
||||||
|
return Err("Recursive type is prohibited.".to_owned());
|
||||||
|
}
|
||||||
|
let ty = self.unification_table.probe_value(b).clone();
|
||||||
|
|
||||||
|
match ty.as_ref() {
|
||||||
|
TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {}
|
||||||
|
TypeEnum::TVar { meta: Sequence(map), .. } => {
|
||||||
|
for t in map.borrow().values() {
|
||||||
|
self.occur_check(a, *t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TVar { meta: Record(map), .. } => {
|
||||||
|
for t in map.borrow().values() {
|
||||||
|
self.occur_check(a, *t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TCall(calls) => {
|
||||||
|
let call_store = self.calls.clone();
|
||||||
|
for t in calls
|
||||||
|
.borrow()
|
||||||
|
.iter()
|
||||||
|
.map(|call| {
|
||||||
|
let call = call_store[call.0].as_ref();
|
||||||
|
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
|
||||||
|
})
|
||||||
|
.flatten()
|
||||||
|
{
|
||||||
|
self.occur_check(a, *t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TTuple { ty } => {
|
||||||
|
for t in ty.iter() {
|
||||||
|
self.occur_check(a, *t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => {
|
||||||
|
self.occur_check(a, *ty)?;
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { params: map, .. } => {
|
||||||
|
for t in map.borrow().values() {
|
||||||
|
self.occur_check(a, *t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TFunc(sig) => {
|
||||||
|
let FunSignature { args, ret, vars: params } = &*sig.borrow();
|
||||||
|
for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) {
|
||||||
|
self.occur_check(a, *t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
|
||||||
|
use TypeEnum::*;
|
||||||
|
let x = self.get_ty(a);
|
||||||
|
let y = self.get_ty(b);
|
||||||
|
match (x.as_ref(), y.as_ref()) {
|
||||||
|
(TVar { range: range1, .. }, TVar { meta, range: range2, .. }) => {
|
||||||
|
// we should restrict range2
|
||||||
|
let range1 = range1.borrow();
|
||||||
|
// new range is the intersection of them
|
||||||
|
// empty range indicates no constraint
|
||||||
|
if !range1.is_empty() {
|
||||||
|
let range2 = range2.borrow();
|
||||||
|
let mut range = Vec::new();
|
||||||
|
if range2.is_empty() {
|
||||||
|
range.extend_from_slice(&range1);
|
||||||
|
}
|
||||||
|
for v1 in range2.iter() {
|
||||||
|
for v2 in range1.iter() {
|
||||||
|
let result = self.get_intersection(*v1, *v2);
|
||||||
|
if let Ok(result) = result {
|
||||||
|
range.push(result.unwrap_or(*v2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if range.is_empty() {
|
||||||
|
Err(())
|
||||||
|
} else {
|
||||||
|
let id = self.var_id + 1;
|
||||||
|
self.var_id += 1;
|
||||||
|
let ty = TVar { id, meta: meta.clone(), range: range.into() };
|
||||||
|
Ok(Some(self.unification_table.new_key(ty.into())))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(Some(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(_, TVar { range, .. }) => {
|
||||||
|
// range should be restricted to the left hand side
|
||||||
|
let range = range.borrow();
|
||||||
|
if range.is_empty() {
|
||||||
|
Ok(Some(a))
|
||||||
|
} else {
|
||||||
|
for v in range.iter() {
|
||||||
|
let result = self.get_intersection(a, *v);
|
||||||
|
if let Ok(result) = result {
|
||||||
|
return Ok(result.or(Some(a)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(TVar { id, range, .. }, _) => {
|
||||||
|
self.check_var_compatibility(*id, b, &range.borrow()).or(Err(()))
|
||||||
|
}
|
||||||
|
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
||||||
|
if ty1.len() != ty2.len() {
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
let mut need_new = false;
|
||||||
|
let mut ty = ty1.clone();
|
||||||
|
for (a, b) in zip(ty1.iter(), ty2.iter()) {
|
||||||
|
let result = self.get_intersection(*a, *b)?;
|
||||||
|
ty.push(result.unwrap_or(*a));
|
||||||
|
if result.is_some() {
|
||||||
|
need_new = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if need_new {
|
||||||
|
Ok(Some(self.add_ty(TTuple { ty })))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||||
|
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
|
||||||
|
}
|
||||||
|
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||||
|
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
|
||||||
|
}
|
||||||
|
(TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => {
|
||||||
|
if id1 == id2 {
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
Err(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// don't deal with function shape for now
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_var_compatibility(
|
||||||
|
&mut self,
|
||||||
|
id: u32,
|
||||||
|
b: Type,
|
||||||
|
range: &[Type],
|
||||||
|
) -> Result<Option<Type>, String> {
|
||||||
|
if range.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
for t in range.iter() {
|
||||||
|
let result = self.get_intersection(*t, b);
|
||||||
|
if let Ok(result) = result {
|
||||||
|
return Ok(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Err(format!(
|
||||||
|
"Cannot unify type variable {} with {} due to incompatible value range",
|
||||||
|
id,
|
||||||
|
self.get_ty(b).get_type_name()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
534
nac3core/src/typecheck/typedef/test.rs
Normal file
534
nac3core/src/typecheck/typedef/test.rs
Normal file
@ -0,0 +1,534 @@
|
|||||||
|
use super::*;
|
||||||
|
use indoc::indoc;
|
||||||
|
use itertools::Itertools;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use test_case::test_case;
|
||||||
|
|
||||||
|
impl Unifier {
|
||||||
|
/// Check whether two types are equal.
|
||||||
|
fn eq(&mut self, a: Type, b: Type) -> bool {
|
||||||
|
use TypeVarMeta::*;
|
||||||
|
if a == b {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
let (ty_a, ty_b) = {
|
||||||
|
let table = &mut self.unification_table;
|
||||||
|
if table.unioned(a, b) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
(table.probe_value(a).clone(), table.probe_value(b).clone())
|
||||||
|
};
|
||||||
|
|
||||||
|
match (&*ty_a, &*ty_b) {
|
||||||
|
(
|
||||||
|
TypeEnum::TVar { meta: Generic, id: id1, .. },
|
||||||
|
TypeEnum::TVar { meta: Generic, id: id2, .. },
|
||||||
|
) => id1 == id2,
|
||||||
|
(
|
||||||
|
TypeEnum::TVar { meta: Sequence(map1), .. },
|
||||||
|
TypeEnum::TVar { meta: Sequence(map2), .. },
|
||||||
|
) => self.map_eq(&map1.borrow(), &map2.borrow()),
|
||||||
|
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
|
||||||
|
ty1.len() == ty2.len()
|
||||||
|
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
|
||||||
|
}
|
||||||
|
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
|
||||||
|
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
|
||||||
|
self.eq(*ty1, *ty2)
|
||||||
|
}
|
||||||
|
(
|
||||||
|
TypeEnum::TVar { meta: Record(fields1), .. },
|
||||||
|
TypeEnum::TVar { meta: Record(fields2), .. },
|
||||||
|
) => self.map_eq(&fields1.borrow(), &fields2.borrow()),
|
||||||
|
(
|
||||||
|
TypeEnum::TObj { obj_id: id1, params: params1, .. },
|
||||||
|
TypeEnum::TObj { obj_id: id2, params: params2, .. },
|
||||||
|
) => id1 == id2 && self.map_eq(¶ms1.borrow(), ¶ms2.borrow()),
|
||||||
|
// TCall and TFunc are not yet implemented
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
|
||||||
|
where
|
||||||
|
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||||
|
{
|
||||||
|
if map1.len() != map2.len() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (k, v) in map1.iter() {
|
||||||
|
if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestEnvironment {
|
||||||
|
pub unifier: Unifier,
|
||||||
|
type_mapping: HashMap<String, Type>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestEnvironment {
|
||||||
|
fn new() -> TestEnvironment {
|
||||||
|
let mut unifier = Unifier::new();
|
||||||
|
let mut type_mapping = HashMap::new();
|
||||||
|
|
||||||
|
type_mapping.insert(
|
||||||
|
"int".into(),
|
||||||
|
unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(0),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
type_mapping.insert(
|
||||||
|
"float".into(),
|
||||||
|
unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(1),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
type_mapping.insert(
|
||||||
|
"bool".into(),
|
||||||
|
unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(2),
|
||||||
|
fields: HashMap::new().into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
let (v0, id) = unifier.get_fresh_var();
|
||||||
|
type_mapping.insert(
|
||||||
|
"Foo".into(),
|
||||||
|
unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(3),
|
||||||
|
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||||
|
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
TestEnvironment { unifier, type_mapping }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse(&mut self, typ: &str, mapping: &Mapping<String>) -> Type {
|
||||||
|
let result = self.internal_parse(typ, mapping);
|
||||||
|
assert!(result.1.is_empty());
|
||||||
|
result.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn internal_parse<'a, 'b>(
|
||||||
|
&'a mut self,
|
||||||
|
typ: &'b str,
|
||||||
|
mapping: &Mapping<String>,
|
||||||
|
) -> (Type, &'b str) {
|
||||||
|
// for testing only, so we can just panic when the input is malformed
|
||||||
|
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or_else(|| typ.len());
|
||||||
|
match &typ[..end] {
|
||||||
|
"Tuple" => {
|
||||||
|
let mut s = &typ[end..];
|
||||||
|
assert!(&s[0..1] == "[");
|
||||||
|
let mut ty = Vec::new();
|
||||||
|
while &s[0..1] != "]" {
|
||||||
|
let result = self.internal_parse(&s[1..], mapping);
|
||||||
|
ty.push(result.0);
|
||||||
|
s = result.1;
|
||||||
|
}
|
||||||
|
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
|
||||||
|
}
|
||||||
|
"List" => {
|
||||||
|
assert!(&typ[end..end + 1] == "[");
|
||||||
|
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
|
||||||
|
assert!(&s[0..1] == "]");
|
||||||
|
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
|
||||||
|
}
|
||||||
|
"Record" => {
|
||||||
|
let mut s = &typ[end..];
|
||||||
|
assert!(&s[0..1] == "[");
|
||||||
|
let mut fields = HashMap::new();
|
||||||
|
while &s[0..1] != "]" {
|
||||||
|
let eq = s.find('=').unwrap();
|
||||||
|
let key = s[1..eq].to_string();
|
||||||
|
let result = self.internal_parse(&s[eq + 1..], mapping);
|
||||||
|
fields.insert(key, result.0);
|
||||||
|
s = result.1;
|
||||||
|
}
|
||||||
|
(self.unifier.add_record(fields), &s[1..])
|
||||||
|
}
|
||||||
|
x => {
|
||||||
|
let mut s = &typ[end..];
|
||||||
|
let ty = mapping.get(x).cloned().unwrap_or_else(|| {
|
||||||
|
// mapping should be type variables, type_mapping should be concrete types
|
||||||
|
// we should not resolve the type of type variables.
|
||||||
|
let mut ty = *self.type_mapping.get(x).unwrap();
|
||||||
|
let te = self.unifier.get_ty(ty);
|
||||||
|
if let TypeEnum::TObj { params, .. } = &*te.as_ref() {
|
||||||
|
let params = params.borrow();
|
||||||
|
if !params.is_empty() {
|
||||||
|
assert!(&s[0..1] == "[");
|
||||||
|
let mut p = Vec::new();
|
||||||
|
while &s[0..1] != "]" {
|
||||||
|
let result = self.internal_parse(&s[1..], mapping);
|
||||||
|
p.push(result.0);
|
||||||
|
s = result.1;
|
||||||
|
}
|
||||||
|
s = &s[1..];
|
||||||
|
ty = self
|
||||||
|
.unifier
|
||||||
|
.subst(ty, ¶ms.keys().cloned().zip(p.into_iter()).collect())
|
||||||
|
.unwrap_or(ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ty
|
||||||
|
});
|
||||||
|
(ty, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test_case(2,
|
||||||
|
&[("v1", "v2"), ("v2", "float")],
|
||||||
|
&[("v1", "float"), ("v2", "float")]
|
||||||
|
; "simple variable"
|
||||||
|
)]
|
||||||
|
#[test_case(2,
|
||||||
|
&[("v1", "List[v2]"), ("v1", "List[float]")],
|
||||||
|
&[("v1", "List[float]"), ("v2", "float")]
|
||||||
|
; "list element"
|
||||||
|
)]
|
||||||
|
#[test_case(3,
|
||||||
|
&[
|
||||||
|
("v1", "Record[a=v3,b=v3]"),
|
||||||
|
("v2", "Record[b=float,c=v3]"),
|
||||||
|
("v1", "v2")
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
("v1", "Record[a=float,b=float,c=float]"),
|
||||||
|
("v2", "Record[a=float,b=float,c=float]"),
|
||||||
|
("v3", "float")
|
||||||
|
]
|
||||||
|
; "record merge"
|
||||||
|
)]
|
||||||
|
#[test_case(3,
|
||||||
|
&[
|
||||||
|
("v1", "Record[a=float]"),
|
||||||
|
("v2", "Foo[v3]"),
|
||||||
|
("v1", "v2")
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
("v1", "Foo[float]"),
|
||||||
|
("v3", "float")
|
||||||
|
]
|
||||||
|
; "record obj merge"
|
||||||
|
)]
|
||||||
|
/// Test cases for valid unifications.
|
||||||
|
fn test_unify(
|
||||||
|
variable_count: u32,
|
||||||
|
unify_pairs: &[(&'static str, &'static str)],
|
||||||
|
verify_pairs: &[(&'static str, &'static str)],
|
||||||
|
) {
|
||||||
|
let unify_count = unify_pairs.len();
|
||||||
|
// test all permutations...
|
||||||
|
for perm in unify_pairs.iter().permutations(unify_count) {
|
||||||
|
let mut env = TestEnvironment::new();
|
||||||
|
let mut mapping = HashMap::new();
|
||||||
|
for i in 1..=variable_count {
|
||||||
|
let v = env.unifier.get_fresh_var();
|
||||||
|
mapping.insert(format!("v{}", i), v.0);
|
||||||
|
}
|
||||||
|
// unification may have side effect when we do type resolution, so freeze the types
|
||||||
|
// before doing unification.
|
||||||
|
let mut pairs = Vec::new();
|
||||||
|
for (a, b) in perm.iter() {
|
||||||
|
let t1 = env.parse(a, &mapping);
|
||||||
|
let t2 = env.parse(b, &mapping);
|
||||||
|
pairs.push((t1, t2));
|
||||||
|
}
|
||||||
|
for (t1, t2) in pairs {
|
||||||
|
env.unifier.unify(t1, t2).unwrap();
|
||||||
|
}
|
||||||
|
for (a, b) in verify_pairs.iter() {
|
||||||
|
println!("{} = {}", a, b);
|
||||||
|
let t1 = env.parse(a, &mapping);
|
||||||
|
let t2 = env.parse(b, &mapping);
|
||||||
|
assert!(env.unifier.eq(t1, t2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test_case(2,
|
||||||
|
&[
|
||||||
|
("v1", "Tuple[int]"),
|
||||||
|
("v2", "List[int]"),
|
||||||
|
],
|
||||||
|
(("v1", "v2"), "Cannot unify TList with TTuple")
|
||||||
|
; "type mismatch"
|
||||||
|
)]
|
||||||
|
#[test_case(2,
|
||||||
|
&[
|
||||||
|
("v1", "Tuple[int]"),
|
||||||
|
("v2", "Tuple[float]"),
|
||||||
|
],
|
||||||
|
(("v1", "v2"), "Cannot unify objects with ID 0 and 1")
|
||||||
|
; "tuple parameter mismatch"
|
||||||
|
)]
|
||||||
|
#[test_case(2,
|
||||||
|
&[
|
||||||
|
("v1", "Tuple[int,int]"),
|
||||||
|
("v2", "Tuple[int]"),
|
||||||
|
],
|
||||||
|
(("v1", "v2"), "Cannot unify tuples with length 2 and 1")
|
||||||
|
; "tuple length mismatch"
|
||||||
|
)]
|
||||||
|
#[test_case(3,
|
||||||
|
&[
|
||||||
|
("v1", "Record[a=float,b=int]"),
|
||||||
|
("v2", "Foo[v3]"),
|
||||||
|
],
|
||||||
|
(("v1", "v2"), "No such attribute b")
|
||||||
|
; "record obj merge"
|
||||||
|
)]
|
||||||
|
#[test_case(2,
|
||||||
|
&[
|
||||||
|
("v1", "List[v2]"),
|
||||||
|
],
|
||||||
|
(("v1", "v2"), "Recursive type is prohibited.")
|
||||||
|
; "recursive type for lists"
|
||||||
|
)]
|
||||||
|
/// Test cases for invalid unifications.
|
||||||
|
fn test_invalid_unification(
|
||||||
|
variable_count: u32,
|
||||||
|
unify_pairs: &[(&'static str, &'static str)],
|
||||||
|
errornous_pair: ((&'static str, &'static str), &'static str),
|
||||||
|
) {
|
||||||
|
let mut env = TestEnvironment::new();
|
||||||
|
let mut mapping = HashMap::new();
|
||||||
|
for i in 1..=variable_count {
|
||||||
|
let v = env.unifier.get_fresh_var();
|
||||||
|
mapping.insert(format!("v{}", i), v.0);
|
||||||
|
}
|
||||||
|
// unification may have side effect when we do type resolution, so freeze the types
|
||||||
|
// before doing unification.
|
||||||
|
let mut pairs = Vec::new();
|
||||||
|
for (a, b) in unify_pairs.iter() {
|
||||||
|
let t1 = env.parse(a, &mapping);
|
||||||
|
let t2 = env.parse(b, &mapping);
|
||||||
|
pairs.push((t1, t2));
|
||||||
|
}
|
||||||
|
let (t1, t2) =
|
||||||
|
(env.parse(errornous_pair.0 .0, &mapping), env.parse(errornous_pair.0 .1, &mapping));
|
||||||
|
for (a, b) in pairs {
|
||||||
|
env.unifier.unify(a, b).unwrap();
|
||||||
|
}
|
||||||
|
assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_virtual() {
|
||||||
|
let mut env = TestEnvironment::new();
|
||||||
|
let int = env.parse("int", &HashMap::new());
|
||||||
|
let fun = env.unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
FunSignature { args: vec![], ret: int, vars: HashMap::new() }.into(),
|
||||||
|
));
|
||||||
|
let bar = env.unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: DefinitionId(5),
|
||||||
|
fields: [("f".to_string(), fun), ("a".to_string(), int)]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect::<HashMap<_, _>>()
|
||||||
|
.into(),
|
||||||
|
params: HashMap::new().into(),
|
||||||
|
});
|
||||||
|
let v0 = env.unifier.get_fresh_var().0;
|
||||||
|
let v1 = env.unifier.get_fresh_var().0;
|
||||||
|
|
||||||
|
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
|
||||||
|
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
|
||||||
|
let c = env.unifier.add_record([("f".to_string(), v1)].iter().cloned().collect());
|
||||||
|
env.unifier.unify(a, b).unwrap();
|
||||||
|
env.unifier.unify(b, c).unwrap();
|
||||||
|
assert!(env.unifier.eq(v1, fun));
|
||||||
|
|
||||||
|
let d = env.unifier.add_record([("a".to_string(), v1)].iter().cloned().collect());
|
||||||
|
assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string()));
|
||||||
|
|
||||||
|
let d = env.unifier.add_record([("b".to_string(), v1)].iter().cloned().collect());
|
||||||
|
assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_typevar_range() {
|
||||||
|
let mut env = TestEnvironment::new();
|
||||||
|
let int = env.parse("int", &HashMap::new());
|
||||||
|
let boolean = env.parse("bool", &HashMap::new());
|
||||||
|
let float = env.parse("float", &HashMap::new());
|
||||||
|
let int_list = env.parse("List[int]", &HashMap::new());
|
||||||
|
let float_list = env.parse("List[float]", &HashMap::new());
|
||||||
|
|
||||||
|
// unification between v and int
|
||||||
|
// where v in (int, bool)
|
||||||
|
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||||
|
env.unifier.unify(int, v).unwrap();
|
||||||
|
|
||||||
|
// unification between v and List[int]
|
||||||
|
// where v in (int, bool)
|
||||||
|
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||||
|
assert_eq!(
|
||||||
|
env.unifier.unify(int_list, v),
|
||||||
|
Err("Cannot unify type variable 3 with TList due to incompatible value range".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
// unification between v and float
|
||||||
|
// where v in (int, bool)
|
||||||
|
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||||
|
assert_eq!(
|
||||||
|
env.unifier.unify(float, v),
|
||||||
|
Err("Cannot unify type variable 4 with TObj due to incompatible value range".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||||
|
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 });
|
||||||
|
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
|
||||||
|
// unification between v and int
|
||||||
|
// where v in (int, List[v1]), v1 in (int, bool)
|
||||||
|
env.unifier.unify(int, v).unwrap();
|
||||||
|
|
||||||
|
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
|
||||||
|
// unification between v and List[int]
|
||||||
|
// where v in (int, List[v1]), v1 in (int, bool)
|
||||||
|
env.unifier.unify(int_list, v).unwrap();
|
||||||
|
|
||||||
|
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0;
|
||||||
|
// unification between v and List[float]
|
||||||
|
// where v in (int, List[v1]), v1 in (int, bool)
|
||||||
|
assert_eq!(
|
||||||
|
env.unifier.unify(float_list, v),
|
||||||
|
Err("Cannot unify type variable 8 with TList due to incompatible value range".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||||
|
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||||
|
env.unifier.unify(a, b).unwrap();
|
||||||
|
env.unifier.unify(a, float).unwrap();
|
||||||
|
|
||||||
|
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||||
|
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||||
|
env.unifier.unify(a, b).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
env.unifier.unify(a, int),
|
||||||
|
Err("Cannot unify type variable 12 with TObj due to incompatible value range".into())
|
||||||
|
);
|
||||||
|
|
||||||
|
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||||
|
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||||
|
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||||
|
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
|
||||||
|
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||||
|
let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0;
|
||||||
|
env.unifier.unify(a_list, b_list).unwrap();
|
||||||
|
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
|
||||||
|
env.unifier.unify(a_list, float_list).unwrap();
|
||||||
|
// previous unifications should not affect a and b
|
||||||
|
env.unifier.unify(a, int).unwrap();
|
||||||
|
|
||||||
|
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||||
|
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||||
|
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||||
|
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||||
|
env.unifier.unify(a_list, b_list).unwrap();
|
||||||
|
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int });
|
||||||
|
assert_eq!(
|
||||||
|
env.unifier.unify(a_list, int_list),
|
||||||
|
Err("Cannot unify type variable 19 with TObj due to incompatible value range".into())
|
||||||
|
);
|
||||||
|
|
||||||
|
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||||
|
let b = env.unifier.get_fresh_var().0;
|
||||||
|
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||||
|
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
|
||||||
|
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||||
|
env.unifier.unify(a_list, b_list).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
env.unifier.unify(b, boolean),
|
||||||
|
Err("Cannot unify type variable 21 with TObj due to incompatible value range".into())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rigid_var() {
|
||||||
|
let mut env = TestEnvironment::new();
|
||||||
|
let a = env.unifier.get_fresh_rigid_var().0;
|
||||||
|
let b = env.unifier.get_fresh_rigid_var().0;
|
||||||
|
let x = env.unifier.get_fresh_var().0;
|
||||||
|
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||||
|
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
|
||||||
|
let int = env.parse("int", &HashMap::new());
|
||||||
|
let list_int = env.parse("List[int]", &HashMap::new());
|
||||||
|
|
||||||
|
assert_eq!(env.unifier.unify(a, b), Err("Cannot unify TRigidVar with TRigidVar".to_string()));
|
||||||
|
env.unifier.unify(list_a, list_x).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
env.unifier.unify(list_x, list_int),
|
||||||
|
Err("Cannot unify TObj with TRigidVar".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
env.unifier.replace_rigid_var(a, int);
|
||||||
|
env.unifier.unify(list_x, list_int).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_instantiation() {
|
||||||
|
let mut env = TestEnvironment::new();
|
||||||
|
let int = env.parse("int", &HashMap::new());
|
||||||
|
let boolean = env.parse("bool", &HashMap::new());
|
||||||
|
let float = env.parse("float", &HashMap::new());
|
||||||
|
let list_int = env.parse("List[int]", &HashMap::new());
|
||||||
|
|
||||||
|
let obj_map: HashMap<_, _> =
|
||||||
|
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
|
||||||
|
|
||||||
|
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||||
|
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
|
||||||
|
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int]).0;
|
||||||
|
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float]).0;
|
||||||
|
let t = env.unifier.get_fresh_rigid_var().0;
|
||||||
|
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
|
||||||
|
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t]).0;
|
||||||
|
// t = TypeVar('t')
|
||||||
|
// v = TypeVar('v', int, bool)
|
||||||
|
// v1 = TypeVar('v1', 'list[v]', int)
|
||||||
|
// v2 = TypeVar('v2', 'list[int]', float)
|
||||||
|
// v3 = TypeVar('v3', tuple[v, v1, v2], t)
|
||||||
|
// what values can v3 take?
|
||||||
|
|
||||||
|
let types = env.unifier.get_instantiations(v3).unwrap();
|
||||||
|
let expected_types = indoc! {"
|
||||||
|
tuple[bool, int, float]
|
||||||
|
tuple[bool, int, list[int]]
|
||||||
|
tuple[bool, list[bool], float]
|
||||||
|
tuple[bool, list[bool], list[int]]
|
||||||
|
tuple[bool, list[int], float]
|
||||||
|
tuple[bool, list[int], list[int]]
|
||||||
|
tuple[int, int, float]
|
||||||
|
tuple[int, int, list[int]]
|
||||||
|
tuple[int, list[bool], float]
|
||||||
|
tuple[int, list[bool], list[int]]
|
||||||
|
tuple[int, list[int], float]
|
||||||
|
tuple[int, list[int], list[int]]
|
||||||
|
v5"
|
||||||
|
}
|
||||||
|
.split('\n')
|
||||||
|
.collect_vec();
|
||||||
|
let types = types
|
||||||
|
.iter()
|
||||||
|
.map(|ty| {
|
||||||
|
env.unifier.stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| {
|
||||||
|
format!("v{}", i)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.sorted()
|
||||||
|
.collect_vec();
|
||||||
|
assert_eq!(expected_types, types);
|
||||||
|
}
|
87
nac3core/src/typecheck/unification_table.rs
Normal file
87
nac3core/src/typecheck/unification_table.rs
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
|
||||||
|
pub struct UnificationKey(usize);
|
||||||
|
|
||||||
|
pub struct UnificationTable<V> {
|
||||||
|
parents: Vec<usize>,
|
||||||
|
ranks: Vec<u32>,
|
||||||
|
values: Vec<V>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V> UnificationTable<V> {
|
||||||
|
pub fn new() -> UnificationTable<V> {
|
||||||
|
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_key(&mut self, v: V) -> UnificationKey {
|
||||||
|
let index = self.parents.len();
|
||||||
|
self.parents.push(index);
|
||||||
|
self.ranks.push(0);
|
||||||
|
self.values.push(v);
|
||||||
|
UnificationKey(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unify(&mut self, a: UnificationKey, b: UnificationKey) {
|
||||||
|
let mut a = self.find(a);
|
||||||
|
let mut b = self.find(b);
|
||||||
|
if a == b {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if self.ranks[a] < self.ranks[b] {
|
||||||
|
std::mem::swap(&mut a, &mut b);
|
||||||
|
}
|
||||||
|
self.parents[b] = a;
|
||||||
|
if self.ranks[a] == self.ranks[b] {
|
||||||
|
self.ranks[a] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn probe_value(&mut self, a: UnificationKey) -> &V {
|
||||||
|
let index = self.find(a);
|
||||||
|
&self.values[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_value(&mut self, a: UnificationKey, v: V) {
|
||||||
|
let index = self.find(a);
|
||||||
|
self.values[index] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
|
||||||
|
self.find(a) == self.find(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_representative(&mut self, key: UnificationKey) -> UnificationKey {
|
||||||
|
UnificationKey(self.find(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find(&mut self, key: UnificationKey) -> usize {
|
||||||
|
let mut root = key.0;
|
||||||
|
let mut parent = self.parents[root];
|
||||||
|
while root != parent {
|
||||||
|
// a = parent.parent
|
||||||
|
let a = self.parents[parent];
|
||||||
|
// root.parent = parent.parent
|
||||||
|
self.parents[root] = a;
|
||||||
|
root = parent;
|
||||||
|
// parent = root.parent
|
||||||
|
parent = a;
|
||||||
|
}
|
||||||
|
parent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V> UnificationTable<Rc<V>>
|
||||||
|
where
|
||||||
|
V: Clone,
|
||||||
|
{
|
||||||
|
pub fn get_send(&self) -> UnificationTable<V> {
|
||||||
|
let values = self.values.iter().map(|v| v.as_ref().clone()).collect();
|
||||||
|
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
|
||||||
|
let values = table.values.iter().cloned().map(Rc::new).collect();
|
||||||
|
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values }
|
||||||
|
}
|
||||||
|
}
|
@ -1,60 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::rc::Rc;
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
|
||||||
pub struct PrimitiveId(pub(crate) usize);
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
|
||||||
pub struct ClassId(pub(crate) usize);
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
|
||||||
pub struct ParamId(pub(crate) usize);
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
|
||||||
pub struct VariableId(pub(crate) usize);
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Hash, Debug)]
|
|
||||||
pub enum TypeEnum {
|
|
||||||
BotType,
|
|
||||||
SelfType,
|
|
||||||
PrimitiveType(PrimitiveId),
|
|
||||||
ClassType(ClassId),
|
|
||||||
VirtualClassType(ClassId),
|
|
||||||
ParametricType(ParamId, Vec<Rc<TypeEnum>>),
|
|
||||||
TypeVariable(VariableId),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type Type = Rc<TypeEnum>;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct FnDef {
|
|
||||||
// we assume methods first argument to be SelfType,
|
|
||||||
// so the first argument is not contained here
|
|
||||||
pub args: Vec<Type>,
|
|
||||||
pub result: Option<Type>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct TypeDef<'a> {
|
|
||||||
pub name: &'a str,
|
|
||||||
pub fields: HashMap<&'a str, Type>,
|
|
||||||
pub methods: HashMap<&'a str, FnDef>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct ClassDef<'a> {
|
|
||||||
pub base: TypeDef<'a>,
|
|
||||||
pub parents: Vec<ClassId>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct ParametricDef<'a> {
|
|
||||||
pub base: TypeDef<'a>,
|
|
||||||
pub params: Vec<VariableId>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct VarDef<'a> {
|
|
||||||
pub name: &'a str,
|
|
||||||
pub bound: Vec<Type>,
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user