hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
2 changed files with 66 additions and 38 deletions
Showing only changes of commit c0227210df - Show all commits

View File

@ -61,33 +61,47 @@ pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
} }
} }
pub fn impl_binop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type, ops: &[ast::Operator]) { pub fn impl_binop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type, ops: &[ast::Operator]) {
if let TypeEnum::TObj {fields, ..} = unifier.get_ty(ty).borrow() { if let TypeEnum::TObj {fields, ..} = unifier.get_ty(ty).borrow() {
for op in ops { for op in ops {
fields.borrow_mut().insert( fields.borrow_mut().insert(
binop_name(op).into(), binop_name(op).into(),
{
let other = if other_ty.len() == 1 {
other_ty[0]
} else {
unifier.get_fresh_var_with_range(other_ty).0
};
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
vars: HashMap::new(), vars: HashMap::new(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other_ty, ty: other,
is_optional: false, is_optional: false,
name: "other".into() name: "other".into()
}] }]
})) }))
}
); );
fields.borrow_mut().insert( fields.borrow_mut().insert(
binop_assign_name(op).into(), binop_assign_name(op).into(),
{
let other = if other_ty.len() == 1 {
other_ty[0]
} else {
unifier.get_fresh_var_with_range(other_ty).0
};
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.none, ret: ret_ty,
vars: HashMap::new(), vars: HashMap::new(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other_ty, ty: other,
is_optional: false, is_optional: false,
name: "other".into() name: "other".into()
}] }]
})) }))
}
); );
} }
} else { unreachable!("") } } else { unreachable!("") }
@ -128,18 +142,23 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other
} }
/// Add, Sub, Mult, Pow /// Add, Sub, Mult, Pow
pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) { pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ impl_binop(unifier, store, ty, other_ty, ret_ty, &[
ast::Operator::Add, ast::Operator::Add,
ast::Operator::Sub, ast::Operator::Sub,
ast::Operator::Mult, ast::Operator::Mult,
])
}
pub fn impl_pow(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[
ast::Operator::Pow, ast::Operator::Pow,
]) ])
} }
/// BitOr, BitXor, BitAnd /// BitOr, BitXor, BitAnd
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, ty, ty, &[ impl_binop(unifier, store, ty, &[ty], ty, &[
ast::Operator::BitAnd, ast::Operator::BitAnd,
ast::Operator::BitOr, ast::Operator::BitOr,
ast::Operator::BitXor, ast::Operator::BitXor,
@ -148,28 +167,28 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
/// LShift, RShift /// LShift, RShift
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, store.int32, ty, &[ impl_binop(unifier, store, ty, &[ty], ty, &[
ast::Operator::LShift, ast::Operator::LShift,
ast::Operator::RShift, ast::Operator::RShift,
]) ])
} }
/// Div /// Div
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
impl_binop(unifier, store, ty, other_ty, store.float, &[ impl_binop(unifier, store, ty, other_ty, store.float, &[
ast::Operator::Div, ast::Operator::Div,
]) ])
} }
/// FloorDiv /// FloorDiv
pub fn impl_floordiv(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) { pub fn impl_floordiv(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ impl_binop(unifier, store, ty, other_ty, ret_ty, &[
ast::Operator::FloorDiv, ast::Operator::FloorDiv,
]) ])
} }
/// Mod /// Mod
pub fn impl_mod(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) { pub fn impl_mod(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ impl_binop(unifier, store, ty, other_ty, ret_ty, &[
ast::Operator::Mod, ast::Operator::Mod,
]) ])
@ -224,12 +243,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
.. ..
} = *store; } = *store;
/* int32 ======== */ /* int32 ======== */
impl_basic_arithmetic(unifier, store, int32_t, int32_t, int32_t); impl_basic_arithmetic(unifier, store, int32_t, &[int32_t], int32_t);
impl_pow(unifier, store, int32_t, &[int32_t], int32_t);
impl_bitwise_arithmetic(unifier, store, int32_t); impl_bitwise_arithmetic(unifier, store, int32_t);
impl_bitwise_shift(unifier, store, int32_t); impl_bitwise_shift(unifier, store, int32_t);
impl_div(unifier, store, int32_t, int32_t); impl_div(unifier, store, int32_t, &[int32_t]);
impl_floordiv(unifier, store, int32_t, int32_t, int32_t); impl_floordiv(unifier, store, int32_t, &[int32_t], int32_t);
impl_mod(unifier, store, int32_t, int32_t, int32_t); impl_mod(unifier, store, int32_t, &[int32_t], int32_t);
impl_sign(unifier, store, int32_t); impl_sign(unifier, store, int32_t);
impl_invert(unifier, store, int32_t); impl_invert(unifier, store, int32_t);
impl_not(unifier, store, int32_t); impl_not(unifier, store, int32_t);
@ -237,12 +257,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_eq(unifier, store, int32_t); impl_eq(unifier, store, int32_t);
/* int64 ======== */ /* int64 ======== */
impl_basic_arithmetic(unifier, store, int64_t, int64_t, int64_t); impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t);
impl_pow(unifier, store, int64_t, &[int64_t], int64_t);
impl_bitwise_arithmetic(unifier, store, int64_t); impl_bitwise_arithmetic(unifier, store, int64_t);
impl_bitwise_shift(unifier, store, int64_t); impl_bitwise_shift(unifier, store, int64_t);
impl_div(unifier, store, int64_t, int64_t); impl_div(unifier, store, int64_t, &[int64_t]);
impl_floordiv(unifier, store, int64_t, int64_t, int64_t); impl_floordiv(unifier, store, int64_t, &[int64_t], int64_t);
impl_mod(unifier, store, int64_t, int64_t, int64_t); impl_mod(unifier, store, int64_t, &[int64_t], int64_t);
impl_sign(unifier, store, int64_t); impl_sign(unifier, store, int64_t);
impl_invert(unifier, store, int64_t); impl_invert(unifier, store, int64_t);
impl_not(unifier, store, int64_t); impl_not(unifier, store, int64_t);
@ -250,10 +271,11 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_eq(unifier, store, int64_t); impl_eq(unifier, store, int64_t);
/* float ======== */ /* float ======== */
impl_basic_arithmetic(unifier, store, float_t, float_t, float_t); impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
impl_div(unifier, store, float_t, float_t); impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
impl_floordiv(unifier, store, float_t, float_t, float_t); impl_div(unifier, store, float_t, &[float_t]);
impl_mod(unifier, store, float_t, float_t, float_t); impl_floordiv(unifier, store, float_t, &[float_t], float_t);
impl_mod(unifier, store, float_t, &[float_t], float_t);
impl_sign(unifier, store, float_t); impl_sign(unifier, store, float_t);
impl_not(unifier, store, float_t); impl_not(unifier, store, float_t);
impl_comparison(unifier, store, float_t, float_t); impl_comparison(unifier, store, float_t, float_t);

View File

@ -395,6 +395,9 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
f = a / b f = a / b
g = a // b g = a // b
h = a % b h = a % b
i = a ** b
ii = 3
j = a ** b
"}, "},
[("a", "float"), [("a", "float"),
("b", "float"), ("b", "float"),
@ -403,7 +406,10 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
("e", "float"), ("e", "float"),
("f", "float"), ("f", "float"),
("g", "float"), ("g", "float"),
("h", "float")].iter().cloned().collect() ("h", "float"),
("i", "float"),
("ii", "int32"),
("j", "float")].iter().cloned().collect()
; "float" ; "float"
)] )]
#[test_case( #[test_case(