diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 0a023b1..f8e1fcb 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -96,11 +96,13 @@ pub fn impl_binop( let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); (ty, Some(var_id)) }; + let function_vars = if let Some(var_id) = other_var_id { vec![(var_id, other_ty)].into_iter().collect::>() } else { HashMap::new() }; + for op in ops { fields.insert(binop_name(op).into(), { ( @@ -224,7 +226,7 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty /// LShift, RShift pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift]) + impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[ast::Operator::LShift, ast::Operator::RShift]); } /// Div @@ -295,6 +297,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie uint64: uint64_t, .. } = *store; + /* int ======== */ for t in [int32_t, int64_t, uint32_t, uint64_t] { impl_basic_arithmetic(unifier, store, t, &[t], t); diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 702fe94..393de14 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -116,6 +116,7 @@ impl RecordField { } } +/// Category of variable and value types. #[derive(Clone)] pub enum TypeEnum { TRigidVar { @@ -123,6 +124,8 @@ pub enum TypeEnum { name: Option, loc: Option, }, + + /// A type variable. TVar { id: u32, // empty indicates this is not a struct/tuple/list @@ -132,21 +135,41 @@ pub enum TypeEnum { name: Option, loc: Option, }, + + /// A tuple type. TTuple { + /// The types of elements present in this tuple. ty: Vec, }, + + /// A list type. TList { + /// The type of elements present in this list. ty: Type, }, + + /// An object type. TObj { + /// The [DefintionId] of this object type. obj_id: DefinitionId, + + /// The fields present in this object type. + /// + /// The key of the [Mapping] is the identifier of the field, while the value is a tuple + /// containing the [Type] of the field, and a `bool` indicating whether the field is a + /// variable (as opposed to a function). fields: Mapping, + + /// Mapping between the ID of type variables and the [Type] representing the type variables + /// of this object type. params: VarMap, }, TVirtual { ty: Type, }, TCall(Vec), + + /// A function type. TFunc(FunSignature), } @@ -294,11 +317,16 @@ impl Unifier { self.get_fresh_var_with_range(&[], None, None) } + /// Returns a fresh [type variable][TypeEnum::TVar] with no associated range. + /// + /// This type variable can be instantiated by any type. pub fn get_fresh_var(&mut self, name: Option, loc: Option) -> (Type, u32) { self.get_fresh_var_with_range(&[], name, loc) } - /// Get a fresh type variable. + /// Returns a fresh [type variable][TypeEnum::TVar] with the range specified by `range`. + /// + /// This type variable can be instantiated by any type present in `range`. pub fn get_fresh_var_with_range( &mut self, range: &[Type], diff --git a/nac3standalone/demo/src/numeric_primitives.py b/nac3standalone/demo/src/numeric_primitives.py index e19d552..77a641f 100644 --- a/nac3standalone/demo/src/numeric_primitives.py +++ b/nac3standalone/demo/src/numeric_primitives.py @@ -41,10 +41,10 @@ def u64_max() -> uint64: return ~uint64(0) def i64_min() -> int64: - return int64(1) << int64(63) + return int64(1) << 63 def i64_max() -> int64: - return ~(int64(1) << int64(63)) + return ~(int64(1) << 63) def test_u32_bnot(): output_uint32(~uint32(0)) diff --git a/nac3standalone/demo/src/operators.py b/nac3standalone/demo/src/operators.py index 0470b96..5556bcd 100644 --- a/nac3standalone/demo/src/operators.py +++ b/nac3standalone/demo/src/operators.py @@ -37,7 +37,9 @@ def test_int32(): output_int32(a ^ b) output_int32(a & b) output_int32(a << b) + output_int32(a << uint32(b)) output_int32(a >> b) + output_int32(a >> uint32(b)) output_float64(a / b) a += b output_int32(a) @@ -74,7 +76,9 @@ def test_uint32(): output_uint32(a ^ b) output_uint32(a & b) output_uint32(a << b) + output_uint32(a << int32(b)) output_uint32(a >> b) + output_uint32(a >> int32(b)) output_float64(a / b) a += b output_uint32(a) @@ -108,8 +112,10 @@ def test_int64(): output_int64(a | b) output_int64(a ^ b) output_int64(a & b) - output_int64(a << b) - output_int64(a >> b) + output_int64(a << int32(b)) + output_int64(a << uint32(b)) + output_int64(a >> int32(b)) + output_int64(a >> uint32(b)) output_float64(a / b) a += b output_int64(a) @@ -127,9 +133,9 @@ def test_int64(): output_int64(a) a &= b output_int64(a) - a <<= b + a <<= int32(b) output_int64(a) - a >>= b + a >>= int32(b) output_int64(a) def test_uint64(): @@ -143,8 +149,8 @@ def test_uint64(): output_uint64(a | b) output_uint64(a ^ b) output_uint64(a & b) - output_uint64(a << b) - output_uint64(a >> b) + output_uint64(a << uint32(b)) + output_uint64(a >> uint32(b)) output_float64(a / b) a += b output_uint64(a) @@ -162,9 +168,9 @@ def test_uint64(): output_uint64(a) a &= b output_uint64(a) - a <<= b + a <<= uint32(b) output_uint64(a) - a >>= b + a >>= uint32(b) output_uint64(a) class A: