From 2f73c96e980e7b12be7583f426e4a84f89dd8aac Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 25 Mar 2024 16:44:06 +0800 Subject: [PATCH] core/magic_methods: Allow unknown return types These types can be later inferred by the type inferencer. --- nac3core/src/typecheck/magic_methods.rs | 34 +++++++++++++------------ 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index d25fff5..6dcf6f9 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -90,7 +90,7 @@ pub fn impl_binop( _store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ops: &[Operator], ) { with_fields(unifier, ty, |unifier, fields| { @@ -107,6 +107,8 @@ pub fn impl_binop( VarMap::new() }; + let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); + for op in ops { fields.insert(binop_name(op).into(), { ( @@ -193,7 +195,7 @@ pub fn impl_basic_arithmetic( store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ) { impl_binop( unifier, @@ -211,7 +213,7 @@ pub fn impl_pow( store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ) { impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]); } @@ -223,19 +225,19 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty store, ty, &[ty], - ty, + Some(ty), &[Operator::BitAnd, Operator::BitOr, Operator::BitXor], ); } /// `LShift`, `RShift` pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[Operator::LShift, Operator::RShift]); + impl_binop(unifier, store, ty, &[store.int32, store.uint32], Some(ty), &[Operator::LShift, Operator::RShift]); } /// `Div` pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) { - impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]); + impl_binop(unifier, store, ty, other_ty, Some(store.float), &[Operator::Div]); } /// `FloorDiv` @@ -244,7 +246,7 @@ pub fn impl_floordiv( store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ) { impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]); } @@ -255,7 +257,7 @@ pub fn impl_mod( store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ) { impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]); } @@ -304,13 +306,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* int ======== */ for t in [int32_t, int64_t, uint32_t, uint64_t] { - impl_basic_arithmetic(unifier, store, t, &[t], t); - impl_pow(unifier, store, t, &[t], t); + impl_basic_arithmetic(unifier, store, t, &[t], Some(t)); + impl_pow(unifier, store, t, &[t], Some(t)); impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_shift(unifier, store, t); impl_div(unifier, store, t, &[t]); - impl_floordiv(unifier, store, t, &[t], t); - impl_mod(unifier, store, t, &[t], t); + impl_floordiv(unifier, store, t, &[t], Some(t)); + impl_mod(unifier, store, t, &[t], Some(t)); impl_invert(unifier, store, t); impl_not(unifier, store, t); impl_comparison(unifier, store, t, t); @@ -321,11 +323,11 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie } /* float ======== */ - impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); - impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t); + impl_basic_arithmetic(unifier, store, float_t, &[float_t], Some(float_t)); + impl_pow(unifier, store, float_t, &[int32_t, float_t], Some(float_t)); impl_div(unifier, store, float_t, &[float_t]); - impl_floordiv(unifier, store, float_t, &[float_t], float_t); - impl_mod(unifier, store, float_t, &[float_t], float_t); + impl_floordiv(unifier, store, float_t, &[float_t], Some(float_t)); + impl_mod(unifier, store, float_t, &[float_t], Some(float_t)); impl_sign(unifier, store, float_t); impl_not(unifier, store, float_t); impl_comparison(unifier, store, float_t, float_t);