forked from M-Labs/nac3
1
0
Fork 0

core/magic_methods: Allow unknown return types

These types can be later inferred by the type inferencer.
This commit is contained in:
David Mak 2024-03-25 16:44:06 +08:00
parent 8f1497df83
commit a77fd213e0
1 changed files with 83 additions and 44 deletions

View File

@ -90,7 +90,7 @@ pub fn impl_binop(
_store: &PrimitiveStore, _store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
ops: &[Operator], ops: &[Operator],
) { ) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
@ -107,6 +107,8 @@ pub fn impl_binop(
VarMap::new() VarMap::new()
}; };
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops { for op in ops {
fields.insert(binop_name(op).into(), { fields.insert(binop_name(op).into(), {
( (
@ -141,8 +143,10 @@ pub fn impl_binop(
}); });
} }
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryop]) { pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops { for op in ops {
fields.insert( fields.insert(
unaryop_name(op).into(), unaryop_name(op).into(),
@ -161,19 +165,35 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryo
pub fn impl_cmpop( pub fn impl_cmpop(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, _store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: Type, other_ty: &[Type],
ops: &[Cmpop], ops: &[Cmpop],
ret_ty: Option<Type>,
) { ) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(ty, Some(var_id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else {
VarMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops { for op in ops {
fields.insert( fields.insert(
comparison_name(op).unwrap().into(), comparison_name(op).unwrap().into(),
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.bool, ret: ret_ty,
vars: VarMap::new(), vars: function_vars.clone(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
@ -193,7 +213,7 @@ pub fn impl_basic_arithmetic(
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop( impl_binop(
unifier, unifier,
@ -211,7 +231,7 @@ pub fn impl_pow(
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]); impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]);
} }
@ -223,19 +243,25 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
store, store,
ty, ty,
&[ty], &[ty],
ty, Some(ty),
&[Operator::BitAnd, Operator::BitOr, Operator::BitXor], &[Operator::BitAnd, Operator::BitOr, Operator::BitXor],
); );
} }
/// `LShift`, `RShift` /// `LShift`, `RShift`
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[Operator::LShift, Operator::RShift]); impl_binop(unifier, store, ty, &[store.int32, store.uint32], Some(ty), &[Operator::LShift, Operator::RShift]);
} }
/// `Div` /// `Div`
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) { pub fn impl_div(
impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]); unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]);
} }
/// `FloorDiv` /// `FloorDiv`
@ -244,7 +270,7 @@ pub fn impl_floordiv(
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]); impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]);
} }
@ -255,40 +281,53 @@ pub fn impl_mod(
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]); impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
} }
/// `UAdd`, `USub` /// `UAdd`, `USub`
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ty, &[Unaryop::UAdd, Unaryop::USub]); impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
} }
/// `Invert` /// `Invert`
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ty, &[Unaryop::Invert]); impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]);
} }
/// `Not` /// `Not`
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_not(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, store.bool, &[Unaryop::Not]); impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
} }
/// `Lt`, `LtE`, `Gt`, `GtE` /// `Lt`, `LtE`, `Gt`, `GtE`
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { pub fn impl_comparison(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop( impl_cmpop(
unifier, unifier,
store, store,
ty, ty,
other_ty, other_ty,
&[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE], &[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE],
ret_ty,
); );
} }
/// `Eq`, `NotEq` /// `Eq`, `NotEq`
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_eq(
impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]); unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
} }
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
@ -304,34 +343,34 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* int ======== */ /* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] { for t in [int32_t, int64_t, uint32_t, uint64_t] {
impl_basic_arithmetic(unifier, store, t, &[t], t); impl_basic_arithmetic(unifier, store, t, &[t], Some(t));
impl_pow(unifier, store, t, &[t], t); impl_pow(unifier, store, t, &[t], Some(t));
impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t); impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t]); impl_div(unifier, store, t, &[t], Some(float_t));
impl_floordiv(unifier, store, t, &[t], t); impl_floordiv(unifier, store, t, &[t], Some(t));
impl_mod(unifier, store, t, &[t], t); impl_mod(unifier, store, t, &[t], Some(t));
impl_invert(unifier, store, t); impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t); impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, t); impl_comparison(unifier, store, t, &[t], Some(bool_t));
impl_eq(unifier, store, t); impl_eq(unifier, store, t, &[t], Some(bool_t));
} }
for t in [int32_t, int64_t] { for t in [int32_t, int64_t] {
impl_sign(unifier, store, t); impl_sign(unifier, store, t, Some(t));
} }
/* float ======== */ /* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); impl_basic_arithmetic(unifier, store, float_t, &[float_t], Some(float_t));
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t); impl_pow(unifier, store, float_t, &[int32_t, float_t], Some(float_t));
impl_div(unifier, store, float_t, &[float_t]); impl_div(unifier, store, float_t, &[float_t], Some(float_t));
impl_floordiv(unifier, store, float_t, &[float_t], float_t); impl_floordiv(unifier, store, float_t, &[float_t], Some(float_t));
impl_mod(unifier, store, float_t, &[float_t], float_t); impl_mod(unifier, store, float_t, &[float_t], Some(float_t));
impl_sign(unifier, store, float_t); impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t); impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, float_t); impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t));
impl_eq(unifier, store, float_t); impl_eq(unifier, store, float_t, &[float_t], Some(bool_t));
/* bool ======== */ /* bool ======== */
impl_not(unifier, store, bool_t); impl_not(unifier, store, bool_t, Some(bool_t));
impl_eq(unifier, store, bool_t); impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t));
} }