diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 1518b2dc..052e483b 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -298,8 +298,40 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { (Operator::BitOr, _) => self.builder.build_or(lhs, rhs, "or").into(), (Operator::BitXor, _) => self.builder.build_xor(lhs, rhs, "xor").into(), (Operator::BitAnd, _) => self.builder.build_and(lhs, rhs, "and").into(), - (Operator::LShift, _) => self.builder.build_left_shift(lhs, rhs, "lshift").into(), - (Operator::RShift, _) => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(), + + // Sign-ness of bitshift operators are always determined by the left operand + (Operator::LShift, signed) | (Operator::RShift, signed) => { + // RHS operand is always 32 bits + assert_eq!(rhs.get_type().get_bit_width(), 32); + + let common_type = lhs.get_type(); + let rhs = if common_type.get_bit_width() > 32 { + if signed { + self.builder.build_int_s_extend(rhs, common_type, "") + } else { + self.builder.build_int_z_extend(rhs, common_type, "") + } + } else { + rhs + }; + + let rhs_gez = self.builder.build_int_compare(IntPredicate::SGE, rhs, common_type.const_zero(), ""); + self.make_assert( + generator, + rhs_gez, + "ValueError", + "negative shift count", + [None, None, None], + self.current_loc + ); + + match *op { + Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(), + Operator::RShift => self.builder.build_right_shift(lhs, rhs, signed, "rshift").into(), + _ => unreachable!() + } + } + (Operator::FloorDiv, true) => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(), (Operator::FloorDiv, false) => self.builder.build_int_unsigned_div(lhs, rhs, "floordiv").into(), (Operator::Pow, s) => integer_power(generator, self, lhs, rhs, s).into(), @@ -1085,6 +1117,9 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, true).into())) } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, false).into())) + } else if [Operator::LShift, Operator::RShift].contains(op) { + let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1); + Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, signed).into())) } else if ty1 == ty2 && ctx.primitives.float == ty1 { Ok(Some(ctx.gen_float_ops(op, left_val, right_val).into())) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 2c4b28b3..7ed07358 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -2,7 +2,7 @@ use crate::typecheck::typedef::TypeEnum; use super::type_inferencer::Inferencer; use super::typedef::Type; -use nac3parser::ast::{self, Expr, ExprKind, Stmt, StmtKind, StrRef}; +use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef}; use std::{collections::HashSet, iter::once}; impl<'a> Inferencer<'a> { @@ -107,11 +107,28 @@ impl<'a> Inferencer<'a> { self.check_expr(value, defined_identifiers)?; self.should_have_value(value)?; } - ExprKind::BinOp { left, right, .. } => { + ExprKind::BinOp { left, op, right } => { self.check_expr(left, defined_identifiers)?; self.check_expr(right, defined_identifiers)?; self.should_have_value(left)?; self.should_have_value(right)?; + + // Check whether a bitwise shift has a negative RHS constant value + if *op == LShift || *op == RShift { + if let ExprKind::Constant { value, .. } = &right.node { + let rhs_val = match value { + Constant::Int(v) => v, + _ => unreachable!(), + }; + + if *rhs_val < 0 { + return Err(format!( + "shift count is negative at {}", + right.location + )); + } + } + } } ExprKind::UnaryOp { operand, .. } => { self.check_expr(operand, defined_identifiers)?; diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 0a023b19..f8e1fcb6 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -96,11 +96,13 @@ pub fn impl_binop( let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); (ty, Some(var_id)) }; + let function_vars = if let Some(var_id) = other_var_id { vec![(var_id, other_ty)].into_iter().collect::>() } else { HashMap::new() }; + for op in ops { fields.insert(binop_name(op).into(), { ( @@ -224,7 +226,7 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty /// LShift, RShift pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift]) + impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[ast::Operator::LShift, ast::Operator::RShift]); } /// Div @@ -295,6 +297,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie uint64: uint64_t, .. } = *store; + /* int ======== */ for t in [int32_t, int64_t, uint32_t, uint64_t] { impl_basic_arithmetic(unifier, store, t, &[t], t); diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 702fe944..393de140 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -116,6 +116,7 @@ impl RecordField { } } +/// Category of variable and value types. #[derive(Clone)] pub enum TypeEnum { TRigidVar { @@ -123,6 +124,8 @@ pub enum TypeEnum { name: Option, loc: Option, }, + + /// A type variable. TVar { id: u32, // empty indicates this is not a struct/tuple/list @@ -132,21 +135,41 @@ pub enum TypeEnum { name: Option, loc: Option, }, + + /// A tuple type. TTuple { + /// The types of elements present in this tuple. ty: Vec, }, + + /// A list type. TList { + /// The type of elements present in this list. ty: Type, }, + + /// An object type. TObj { + /// The [DefintionId] of this object type. obj_id: DefinitionId, + + /// The fields present in this object type. + /// + /// The key of the [Mapping] is the identifier of the field, while the value is a tuple + /// containing the [Type] of the field, and a `bool` indicating whether the field is a + /// variable (as opposed to a function). fields: Mapping, + + /// Mapping between the ID of type variables and the [Type] representing the type variables + /// of this object type. params: VarMap, }, TVirtual { ty: Type, }, TCall(Vec), + + /// A function type. TFunc(FunSignature), } @@ -294,11 +317,16 @@ impl Unifier { self.get_fresh_var_with_range(&[], None, None) } + /// Returns a fresh [type variable][TypeEnum::TVar] with no associated range. + /// + /// This type variable can be instantiated by any type. pub fn get_fresh_var(&mut self, name: Option, loc: Option) -> (Type, u32) { self.get_fresh_var_with_range(&[], name, loc) } - /// Get a fresh type variable. + /// Returns a fresh [type variable][TypeEnum::TVar] with the range specified by `range`. + /// + /// This type variable can be instantiated by any type present in `range`. pub fn get_fresh_var_with_range( &mut self, range: &[Type], diff --git a/nac3standalone/demo/src/numeric_primitives.py b/nac3standalone/demo/src/numeric_primitives.py index e19d552d..77a641f8 100644 --- a/nac3standalone/demo/src/numeric_primitives.py +++ b/nac3standalone/demo/src/numeric_primitives.py @@ -41,10 +41,10 @@ def u64_max() -> uint64: return ~uint64(0) def i64_min() -> int64: - return int64(1) << int64(63) + return int64(1) << 63 def i64_max() -> int64: - return ~(int64(1) << int64(63)) + return ~(int64(1) << 63) def test_u32_bnot(): output_uint32(~uint32(0)) diff --git a/nac3standalone/demo/src/operators.py b/nac3standalone/demo/src/operators.py index 0470b969..5556bcd2 100644 --- a/nac3standalone/demo/src/operators.py +++ b/nac3standalone/demo/src/operators.py @@ -37,7 +37,9 @@ def test_int32(): output_int32(a ^ b) output_int32(a & b) output_int32(a << b) + output_int32(a << uint32(b)) output_int32(a >> b) + output_int32(a >> uint32(b)) output_float64(a / b) a += b output_int32(a) @@ -74,7 +76,9 @@ def test_uint32(): output_uint32(a ^ b) output_uint32(a & b) output_uint32(a << b) + output_uint32(a << int32(b)) output_uint32(a >> b) + output_uint32(a >> int32(b)) output_float64(a / b) a += b output_uint32(a) @@ -108,8 +112,10 @@ def test_int64(): output_int64(a | b) output_int64(a ^ b) output_int64(a & b) - output_int64(a << b) - output_int64(a >> b) + output_int64(a << int32(b)) + output_int64(a << uint32(b)) + output_int64(a >> int32(b)) + output_int64(a >> uint32(b)) output_float64(a / b) a += b output_int64(a) @@ -127,9 +133,9 @@ def test_int64(): output_int64(a) a &= b output_int64(a) - a <<= b + a <<= int32(b) output_int64(a) - a >>= b + a >>= int32(b) output_int64(a) def test_uint64(): @@ -143,8 +149,8 @@ def test_uint64(): output_uint64(a | b) output_uint64(a ^ b) output_uint64(a & b) - output_uint64(a << b) - output_uint64(a >> b) + output_uint64(a << uint32(b)) + output_uint64(a >> uint32(b)) output_float64(a / b) a += b output_uint64(a) @@ -162,9 +168,9 @@ def test_uint64(): output_uint64(a) a &= b output_uint64(a) - a <<= b + a <<= uint32(b) output_uint64(a) - a >>= b + a >>= uint32(b) output_uint64(a) class A: