hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
4 changed files with 208 additions and 83 deletions
Showing only changes of commit a7e3eeea0d - Show all commits

View File

@ -1,5 +1,5 @@
use std::borrow::Borrow;
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc;
use rustpython_parser::ast::{Cmpop, Operator, Unaryop}; use rustpython_parser::ast::{Cmpop, Operator, Unaryop};
pub fn binop_name(op: &Operator) -> &'static str { pub fn binop_name(op: &Operator) -> &'static str {
@ -64,14 +64,14 @@ use rustpython_parser::ast;
/// Add, Sub, Mult, Pow /// Add, Sub, Mult, Pow
pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) { pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) {
if let Some(TypeEnum::TObj {fields, .. }) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
for op in &[ for op in &[
ast::Operator::Add, ast::Operator::Add,
ast::Operator::Sub, ast::Operator::Sub,
ast::Operator::Mult, ast::Operator::Mult,
ast::Operator::Pow, ast::Operator::Pow,
] { ] {
fields.insert( fields.borrow_mut().insert(
binop_name(op).into(), binop_name(op).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -84,7 +84,7 @@ pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty:
})) }))
); );
fields.insert( fields.borrow_mut().insert(
binop_assign_name(op).into(), binop_assign_name(op).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.none, ret: store.none,
@ -102,7 +102,7 @@ pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty:
/// LShift, RShift, BitOr, BitXor, BitAnd /// LShift, RShift, BitOr, BitXor, BitAnd
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
for op in &[ for op in &[
ast::Operator::LShift, ast::Operator::LShift,
ast::Operator::RShift, ast::Operator::RShift,
@ -110,7 +110,7 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
ast::Operator::BitXor, ast::Operator::BitXor,
ast::Operator::BitAnd, ast::Operator::BitAnd,
] { ] {
fields.insert( fields.borrow_mut().insert(
binop_name(op).into(), binop_name(op).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ty, ret: ty,
@ -123,7 +123,7 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
})) }))
); );
fields.insert( fields.borrow_mut().insert(
binop_assign_name(op).into(), binop_assign_name(op).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.none, ret: store.none,
@ -141,8 +141,8 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
/// Div /// Div
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
fields.insert( fields.borrow_mut().insert(
binop_name(&ast::Operator::Div).into(), binop_name(&ast::Operator::Div).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature{ unifier.add_ty(TypeEnum::TFunc(FunSignature{
ret: store.float, ret: store.float,
@ -155,7 +155,7 @@ pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_t
})) }))
); );
fields.insert( fields.borrow_mut().insert(
binop_assign_name(&ast::Operator::Div).into(), binop_assign_name(&ast::Operator::Div).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature{ unifier.add_ty(TypeEnum::TFunc(FunSignature{
ret: store.none, ret: store.none,
@ -172,8 +172,8 @@ pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_t
/// FloorDiv /// FloorDiv
pub fn impl_floordiv(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) { pub fn impl_floordiv(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) {
if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
fields.insert( fields.borrow_mut().insert(
binop_name(&ast::Operator::FloorDiv).into(), binop_name(&ast::Operator::FloorDiv).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature{ unifier.add_ty(TypeEnum::TFunc(FunSignature{
ret: ret_ty, ret: ret_ty,
@ -186,7 +186,7 @@ pub fn impl_floordiv(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, ot
})) }))
); );
fields.insert( fields.borrow_mut().insert(
binop_assign_name(&ast::Operator::FloorDiv).into(), binop_assign_name(&ast::Operator::FloorDiv).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature{ unifier.add_ty(TypeEnum::TFunc(FunSignature{
ret: store.none, ret: store.none,
@ -203,8 +203,8 @@ pub fn impl_floordiv(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, ot
/// Mod /// Mod
pub fn impl_mod(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) { pub fn impl_mod(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) {
if let Some(TypeEnum::TObj {fields, .. }) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
fields.insert( fields.borrow_mut().insert(
binop_name(&ast::Operator::Mod).into(), binop_name(&ast::Operator::Mod).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -217,7 +217,7 @@ pub fn impl_mod(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_t
})) }))
); );
fields.insert( fields.borrow_mut().insert(
binop_assign_name(&ast::Operator::Mod).into(), binop_assign_name(&ast::Operator::Mod).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.none, ret: store.none,
@ -234,12 +234,12 @@ pub fn impl_mod(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_t
/// UAdd, USub /// UAdd, USub
pub fn impl_unary_op(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_unary_op(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
for op in &[ for op in &[
ast::Unaryop::UAdd, ast::Unaryop::UAdd,
ast::Unaryop::USub ast::Unaryop::USub
] { ] {
fields.insert( fields.borrow_mut().insert(
unaryop_name(op).into(), unaryop_name(op).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ty, ret: ty,
@ -253,8 +253,8 @@ pub fn impl_unary_op(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
/// Invert /// Invert
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
if let Some(TypeEnum::TObj {fields, .. }) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
fields.insert( fields.borrow_mut().insert(
unaryop_name(&ast::Unaryop::Invert).into(), unaryop_name(&ast::Unaryop::Invert).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ty, ret: ty,
@ -267,8 +267,8 @@ pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
/// Not /// Not
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
fields.insert( fields.borrow_mut().insert(
unaryop_name(&ast::Unaryop::Not).into(), unaryop_name(&ast::Unaryop::Not).into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.bool, ret: store.bool,
@ -281,14 +281,14 @@ pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
/// Lt, LtE, Gt, GtE /// Lt, LtE, Gt, GtE
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
for op in &[ for op in &[
ast::Cmpop::Lt, ast::Cmpop::Lt,
ast::Cmpop::LtE, ast::Cmpop::LtE,
ast::Cmpop::Gt, ast::Cmpop::Gt,
ast::Cmpop::GtE, ast::Cmpop::GtE,
] { ] {
fields.insert( fields.borrow_mut().insert(
comparison_name(op).unwrap().into(), comparison_name(op).unwrap().into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.bool, ret: store.bool,
@ -306,12 +306,12 @@ pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type,
/// Eq, NotEq /// Eq, NotEq
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() {
for op in &[ for op in &[
ast::Cmpop::Eq, ast::Cmpop::Eq,
ast::Cmpop::NotEq, ast::Cmpop::NotEq,
] { ] {
fields.insert( fields.borrow_mut().insert(
comparison_name(op).unwrap().into(), comparison_name(op).unwrap().into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.bool, ret: store.bool,
@ -335,68 +335,71 @@ pub fn set_primirives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
bool: bool_t, bool: bool_t,
none: _none_t none: _none_t
} = *store; } = *store;
// int32 -------- /* int32 ======== */
impl_basic_arithmetic(unifier, store, int32_t, int32_t, int32_t); impl_basic_arithmetic(unifier, store, int32_t, int32_t, int32_t);
impl_basic_arithmetic(unifier, store, int32_t, int64_t, int64_t); // impl_basic_arithmetic(unifier, store, int32_t, int64_t, int64_t);
impl_basic_arithmetic(unifier, store, int32_t, float_t, float_t); // impl_basic_arithmetic(unifier, store, int32_t, float_t, float_t);
impl_bitwise_arithmetic(unifier, store, int32_t); impl_bitwise_arithmetic(unifier, store, int32_t);
// impl_div(unifier, store, int32_t, int32_t);
// impl_div(unifier, store, int32_t, int64_t);
impl_div(unifier, store, int32_t, int32_t); impl_div(unifier, store, int32_t, int32_t);
impl_div(unifier, store, int32_t, int64_t);
impl_div(unifier, store, int32_t, float_t);
impl_floordiv(unifier, store, int32_t, int32_t, int32_t); impl_floordiv(unifier, store, int32_t, int32_t, int32_t);
impl_floordiv(unifier, store, int32_t, int64_t, int32_t); // impl_floordiv(unifier, store, int32_t, int64_t, int32_t);
impl_floordiv(unifier, store, int32_t, float_t, float_t); // impl_floordiv(unifier, store, int32_t, float_t, float_t);
impl_mod(unifier, store, int32_t, int32_t, int32_t); impl_mod(unifier, store, int32_t, int32_t, int32_t);
impl_mod(unifier, store, int32_t, int64_t, int32_t); // impl_mod(unifier, store, int32_t, int64_t, int32_t);
impl_mod(unifier, store, int32_t, float_t, float_t); // impl_mod(unifier, store, int32_t, float_t, float_t);
impl_unary_op(unifier, store, int32_t); impl_unary_op(unifier, store, int32_t);
impl_invert(unifier, store, int32_t); impl_invert(unifier, store, int32_t);
impl_not(unifier, store, int32_t); impl_not(unifier, store, int32_t);
impl_comparison(unifier, store, int32_t, int32_t); impl_comparison(unifier, store, int32_t, int32_t);
impl_comparison(unifier, store, int32_t, int64_t); // impl_comparison(unifier, store, int32_t, int64_t);
impl_comparison(unifier, store, int32_t, float_t); // impl_comparison(unifier, store, int32_t, float_t);
impl_eq(unifier, store, int32_t); impl_eq(unifier, store, int32_t);
// int64 --------
impl_basic_arithmetic(unifier, store, int64_t, int32_t, int64_t); /* int64 ======== */
// impl_basic_arithmetic(unifier, store, int64_t, int32_t, int64_t);
impl_basic_arithmetic(unifier, store, int64_t, int64_t, int64_t); impl_basic_arithmetic(unifier, store, int64_t, int64_t, int64_t);
impl_basic_arithmetic(unifier, store, int64_t, float_t, float_t); // impl_basic_arithmetic(unifier, store, int64_t, float_t, float_t);
impl_bitwise_arithmetic(unifier, store, int64_t); impl_bitwise_arithmetic(unifier, store, int64_t);
impl_div(unifier, store, int64_t, int32_t); // impl_div(unifier, store, int64_t, int32_t);
impl_div(unifier, store, int64_t, int64_t); impl_div(unifier, store, int64_t, int64_t);
impl_div(unifier, store, int64_t, float_t); // impl_div(unifier, store, int64_t, float_t);
impl_floordiv(unifier, store, int64_t, int32_t, int64_t); // impl_floordiv(unifier, store, int64_t, int32_t, int64_t);
impl_floordiv(unifier, store, int64_t, int64_t, int64_t); impl_floordiv(unifier, store, int64_t, int64_t, int64_t);
impl_floordiv(unifier, store, int64_t, float_t, float_t); // impl_floordiv(unifier, store, int64_t, float_t, float_t);
impl_mod(unifier, store, int64_t, int32_t, int64_t); // impl_mod(unifier, store, int64_t, int32_t, int64_t);
impl_mod(unifier, store, int64_t, int64_t, int64_t); impl_mod(unifier, store, int64_t, int64_t, int64_t);
impl_mod(unifier, store, int64_t, float_t, float_t); // impl_mod(unifier, store, int64_t, float_t, float_t);
impl_unary_op(unifier, store, int64_t); impl_unary_op(unifier, store, int64_t);
impl_invert(unifier, store, int64_t); impl_invert(unifier, store, int64_t);
impl_not(unifier, store, int64_t); impl_not(unifier, store, int64_t);
impl_comparison(unifier, store, int64_t, int32_t); // impl_comparison(unifier, store, int64_t, int32_t);
impl_comparison(unifier, store, int64_t, int64_t); impl_comparison(unifier, store, int64_t, int64_t);
impl_comparison(unifier, store, int64_t, float_t); // impl_comparison(unifier, store, int64_t, float_t);
impl_eq(unifier, store, int64_t); impl_eq(unifier, store, int64_t);
// float --------
impl_basic_arithmetic(unifier, store, float_t, int32_t, float_t); /* float ======== */
impl_basic_arithmetic(unifier, store, float_t, int64_t, float_t); // impl_basic_arithmetic(unifier, store, float_t, int32_t, float_t);
// impl_basic_arithmetic(unifier, store, float_t, int64_t, float_t);
impl_basic_arithmetic(unifier, store, float_t, float_t, float_t); impl_basic_arithmetic(unifier, store, float_t, float_t, float_t);
impl_div(unifier, store, float_t, int32_t); // impl_div(unifier, store, float_t, int32_t);
impl_div(unifier, store, float_t, int64_t); // impl_div(unifier, store, float_t, int64_t);
impl_div(unifier, store, float_t, float_t); impl_div(unifier, store, float_t, float_t);
impl_floordiv(unifier, store, float_t, int32_t, float_t); // impl_floordiv(unifier, store, float_t, int32_t, float_t);
impl_floordiv(unifier, store, float_t, int64_t, float_t); // impl_floordiv(unifier, store, float_t, int64_t, float_t);
impl_floordiv(unifier, store, float_t, float_t, float_t); impl_floordiv(unifier, store, float_t, float_t, float_t);
impl_mod(unifier, store, float_t, int32_t, float_t); // impl_mod(unifier, store, float_t, int32_t, float_t);
impl_mod(unifier, store, float_t, int64_t, float_t); // impl_mod(unifier, store, float_t, int64_t, float_t);
impl_mod(unifier, store, float_t, float_t, float_t); impl_mod(unifier, store, float_t, float_t, float_t);
impl_unary_op(unifier, store, float_t); impl_unary_op(unifier, store, float_t);
impl_not(unifier, store, float_t); impl_not(unifier, store, float_t);
impl_comparison(unifier, store, float_t, int32_t); // impl_comparison(unifier, store, float_t, int32_t);
impl_comparison(unifier, store, float_t, int64_t); // impl_comparison(unifier, store, float_t, int64_t);
impl_comparison(unifier, store, float_t, float_t); impl_comparison(unifier, store, float_t, float_t);
impl_eq(unifier, store, float_t); impl_eq(unifier, store, float_t);
// bool ---------
/* bool ======== */
impl_not(unifier, store, bool_t); impl_not(unifier, store, bool_t);
impl_eq(unifier, store, bool_t); impl_eq(unifier, store, bool_t);
} }

View File

@ -51,31 +51,32 @@ impl TestEnvironment {
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: 0, obj_id: 0,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: 1, obj_id: 1,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: 2, obj_id: 2,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: 3, obj_id: 3,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: 4, obj_id: 4,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
// identifier_mapping.insert("None".into(), none); // identifier_mapping.insert("None".into(), none);
let primitives = PrimitiveStore { int32, int64, float, bool, none }; let primitives = PrimitiveStore { int32, int64, float, bool, none };
set_primirives_magic_methods(&primitives, &mut unifier); set_primirives_magic_methods(&primitives, &mut unifier);
let id_to_name = [ let id_to_name = [
@ -119,27 +120,27 @@ impl TestEnvironment {
let mut identifier_mapping = HashMap::new(); let mut identifier_mapping = HashMap::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: 0, obj_id: 0,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: 1, obj_id: 1,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: 2, obj_id: 2,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: 3, obj_id: 3,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: 4, obj_id: 4,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
@ -150,7 +151,7 @@ impl TestEnvironment {
let foo_ty = unifier.add_ty(TypeEnum::TObj { let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: 5, obj_id: 5,
fields: [("a".into(), v0)].iter().cloned().collect(), fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: [(id, v0)].iter().cloned().collect(), params: [(id, v0)].iter().cloned().collect(),
}); });
@ -170,7 +171,7 @@ impl TestEnvironment {
})); }));
let bar = unifier.add_ty(TypeEnum::TObj { let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: 6, obj_id: 6,
fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect(), fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: Default::default(), params: Default::default(),
}); });
identifier_mapping.insert( identifier_mapping.insert(
@ -184,7 +185,7 @@ impl TestEnvironment {
let bar2 = unifier.add_ty(TypeEnum::TObj { let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: 7, obj_id: 7,
fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect(), fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: Default::default(), params: Default::default(),
}); });
identifier_mapping.insert( identifier_mapping.insert(
@ -350,3 +351,122 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
assert_eq!(&b, y); assert_eq!(&b, y);
} }
} }
#[test_case(indoc! {"
a = 2
b = 2
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
"},
[("a", "int32"),
("b", "int32"),
("c", "int32"),
("d", "int32"),
("e", "int32"),
("f", "float"),
("g", "int32"),
("h", "int32")].iter().cloned().collect()
; "int32")]
#[test_case(
indoc! {"
a = 2.4
b = 3.6
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
"},
[("a", "float"),
("b", "float"),
("c", "float"),
("d", "float"),
("e", "float"),
("f", "float"),
("g", "float"),
("h", "float")].iter().cloned().collect()
; "float"
)]
#[test_case(
indoc! {"
a = int64(12312312312)
b = int64(24242424424)
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
i = a == b
j = a > b
k = a < b
l = a != b
"},
[("a", "int64"),
("b", "int64"),
("c", "int64"),
("d", "int64"),
("e", "int64"),
("f", "float"),
("g", "int64"),
("h", "int64"),
("i", "bool"),
("j", "bool"),
("k", "bool"),
("l", "bool")].iter().cloned().collect()
; "int64"
)]
#[test_case(
indoc! {"
a = True
b = False
c = a == b
d = not a
e = a != b
"},
[("a", "bool"),
("b", "bool"),
("c", "bool"),
("d", "bool"),
("e", "bool")].iter().cloned().collect()
; "boolean"
)]
fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
println!("source:\n{}", source);
let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer();
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.stringify(
*v,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
println!("{}: {}", k, name);
}
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(*k).unwrap();
let name = inferencer.unifier.stringify(
*ty,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
}

View File

@ -64,7 +64,7 @@ pub enum TypeEnum {
}, },
TObj { TObj {
obj_id: usize, obj_id: usize,
fields: Mapping<String>, fields: RefCell<Mapping<String>>,
params: VarMap, params: VarMap,
}, },
TVirtual { TVirtual {
@ -373,7 +373,8 @@ impl Unifier {
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => { (TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
self.occur_check(a, b)?; self.occur_check(a, b)?;
for (k, v) in map.borrow().iter() { for (k, v) in map.borrow().iter() {
let ty = fields.get(k).ok_or_else(|| format!("No such attribute {}", k))?; let temp = fields.borrow();
let ty = temp.get(k).ok_or_else(|| format!("No such attribute {}", k))?;
self.unify(*ty, *v)?; self.unify(*ty, *v)?;
} }
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
@ -385,7 +386,8 @@ impl Unifier {
let ty = self.get_ty(*ty); let ty = self.get_ty(*ty);
if let TObj { fields, .. } = ty.as_ref() { if let TObj { fields, .. } = ty.as_ref() {
for (k, v) in map.borrow().iter() { for (k, v) in map.borrow().iter() {
let ty = fields.get(k).ok_or_else(|| format!("No such attribute {}", k))?; let temp = fields.borrow();
let ty = temp.get(k).ok_or_else(|| format!("No such attribute {}", k))?;
if !matches!(self.get_ty(*ty).as_ref(), TFunc { .. }) { if !matches!(self.get_ty(*ty).as_ref(), TFunc { .. }) {
return Err(format!("Cannot access field {} for virtual type", k)); return Err(format!("Cannot access field {} for virtual type", k));
} }
@ -659,8 +661,8 @@ impl Unifier {
if need_subst { if need_subst {
let obj_id = *obj_id; let obj_id = *obj_id;
let params = self.subst_map(&params, mapping).unwrap_or_else(|| params.clone()); let params = self.subst_map(&params, mapping).unwrap_or_else(|| params.clone());
let fields = self.subst_map(&fields, mapping).unwrap_or_else(|| fields.clone()); let fields = self.subst_map(&fields.borrow(), mapping).unwrap_or_else(|| fields.borrow().clone());
Some(self.add_ty(TypeEnum::TObj { obj_id, params, fields })) Some(self.add_ty(TypeEnum::TObj { obj_id, params, fields: fields.into() }))
} else { } else {
None None
} }

View File

@ -79,7 +79,7 @@ impl TestEnvironment {
"int".into(), "int".into(),
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: 0, obj_id: 0,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}), }),
); );
@ -87,7 +87,7 @@ impl TestEnvironment {
"float".into(), "float".into(),
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: 1, obj_id: 1,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}), }),
); );
@ -95,7 +95,7 @@ impl TestEnvironment {
"bool".into(), "bool".into(),
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: 2, obj_id: 2,
fields: HashMap::new(), fields: HashMap::new().into(),
params: HashMap::new(), params: HashMap::new(),
}), }),
); );
@ -104,7 +104,7 @@ impl TestEnvironment {
"Foo".into(), "Foo".into(),
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: 3, obj_id: 3,
fields: [("a".into(), v0)].iter().cloned().collect(), fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: [(id, v0)].iter().cloned().collect(), params: [(id, v0)].iter().cloned().collect(),
}), }),
); );
@ -335,7 +335,7 @@ fn test_virtual() {
})); }));
let bar = env.unifier.add_ty(TypeEnum::TObj { let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: 5, obj_id: 5,
fields: [("f".to_string(), fun), ("a".to_string(), int)].iter().cloned().collect(), fields: [("f".to_string(), fun), ("a".to_string(), int)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: HashMap::new(), params: HashMap::new(),
}); });
let v0 = env.unifier.get_fresh_var().0; let v0 = env.unifier.get_fresh_var().0;