diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index d25fff5..a11705f 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -90,7 +90,7 @@ pub fn impl_binop( _store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ops: &[Operator], ) { with_fields(unifier, ty, |unifier, fields| { @@ -107,6 +107,8 @@ pub fn impl_binop( VarMap::new() }; + let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); + for op in ops { fields.insert(binop_name(op).into(), { ( @@ -141,8 +143,10 @@ pub fn impl_binop( }); } -pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryop]) { +pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option, ops: &[Unaryop]) { with_fields(unifier, ty, |unifier, fields| { + let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); + for op in ops { fields.insert( unaryop_name(op).into(), @@ -161,19 +165,35 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryo pub fn impl_cmpop( unifier: &mut Unifier, - store: &PrimitiveStore, + _store: &PrimitiveStore, ty: Type, - other_ty: Type, + other_ty: &[Type], ops: &[Cmpop], + ret_ty: Option, ) { with_fields(unifier, ty, |unifier, fields| { + let (other_ty, other_var_id) = if other_ty.len() == 1 { + (other_ty[0], None) + } else { + let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + (ty, Some(var_id)) + }; + + let function_vars = if let Some(var_id) = other_var_id { + vec![(var_id, other_ty)].into_iter().collect::() + } else { + VarMap::new() + }; + + let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); + for op in ops { fields.insert( comparison_name(op).unwrap().into(), ( unifier.add_ty(TypeEnum::TFunc(FunSignature { - ret: store.bool, - vars: VarMap::new(), + ret: ret_ty, + vars: function_vars.clone(), args: vec![FuncArg { ty: other_ty, default_value: None, @@ -193,7 +213,7 @@ pub fn impl_basic_arithmetic( store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ) { impl_binop( unifier, @@ -211,7 +231,7 @@ pub fn impl_pow( store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ) { impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]); } @@ -223,19 +243,25 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty store, ty, &[ty], - ty, + Some(ty), &[Operator::BitAnd, Operator::BitOr, Operator::BitXor], ); } /// `LShift`, `RShift` pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[Operator::LShift, Operator::RShift]); + impl_binop(unifier, store, ty, &[store.int32, store.uint32], Some(ty), &[Operator::LShift, Operator::RShift]); } /// `Div` -pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) { - impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]); +pub fn impl_div( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Option, +) { + impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]); } /// `FloorDiv` @@ -244,7 +270,7 @@ pub fn impl_floordiv( store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ) { impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]); } @@ -255,40 +281,53 @@ pub fn impl_mod( store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Type, + ret_ty: Option, ) { impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]); } /// `UAdd`, `USub` -pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { - impl_unaryop(unifier, ty, ty, &[Unaryop::UAdd, Unaryop::USub]); +pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option) { + impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]); } /// `Invert` -pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { - impl_unaryop(unifier, ty, ty, &[Unaryop::Invert]); +pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option) { + impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]); } /// `Not` -pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_unaryop(unifier, ty, store.bool, &[Unaryop::Not]); +pub fn impl_not(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option) { + impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]); } /// `Lt`, `LtE`, `Gt`, `GtE` -pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { +pub fn impl_comparison( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Option, +) { impl_cmpop( unifier, store, ty, other_ty, &[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE], + ret_ty, ); } /// `Eq`, `NotEq` -pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]); +pub fn impl_eq( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Option, +) { + impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty); } pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { @@ -304,34 +343,34 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* int ======== */ for t in [int32_t, int64_t, uint32_t, uint64_t] { - impl_basic_arithmetic(unifier, store, t, &[t], t); - impl_pow(unifier, store, t, &[t], t); + impl_basic_arithmetic(unifier, store, t, &[t], Some(t)); + impl_pow(unifier, store, t, &[t], Some(t)); impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_shift(unifier, store, t); - impl_div(unifier, store, t, &[t]); - impl_floordiv(unifier, store, t, &[t], t); - impl_mod(unifier, store, t, &[t], t); - impl_invert(unifier, store, t); - impl_not(unifier, store, t); - impl_comparison(unifier, store, t, t); - impl_eq(unifier, store, t); + impl_div(unifier, store, t, &[t], Some(float_t)); + impl_floordiv(unifier, store, t, &[t], Some(t)); + impl_mod(unifier, store, t, &[t], Some(t)); + impl_invert(unifier, store, t, Some(t)); + impl_not(unifier, store, t, Some(bool_t)); + impl_comparison(unifier, store, t, &[t], Some(bool_t)); + impl_eq(unifier, store, t, &[t], Some(bool_t)); } for t in [int32_t, int64_t] { - impl_sign(unifier, store, t); + impl_sign(unifier, store, t, Some(t)); } /* float ======== */ - impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); - impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t); - impl_div(unifier, store, float_t, &[float_t]); - impl_floordiv(unifier, store, float_t, &[float_t], float_t); - impl_mod(unifier, store, float_t, &[float_t], float_t); - impl_sign(unifier, store, float_t); - impl_not(unifier, store, float_t); - impl_comparison(unifier, store, float_t, float_t); - impl_eq(unifier, store, float_t); + impl_basic_arithmetic(unifier, store, float_t, &[float_t], Some(float_t)); + impl_pow(unifier, store, float_t, &[int32_t, float_t], Some(float_t)); + impl_div(unifier, store, float_t, &[float_t], Some(float_t)); + impl_floordiv(unifier, store, float_t, &[float_t], Some(float_t)); + impl_mod(unifier, store, float_t, &[float_t], Some(float_t)); + impl_sign(unifier, store, float_t, Some(float_t)); + impl_not(unifier, store, float_t, Some(bool_t)); + impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t)); + impl_eq(unifier, store, float_t, &[float_t], Some(bool_t)); /* bool ======== */ - impl_not(unifier, store, bool_t); - impl_eq(unifier, store, bool_t); + impl_not(unifier, store, bool_t, Some(bool_t)); + impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t)); }