From 3231eb0d78444ce299cf5662b8ad9391e997acbb Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 2 Nov 2023 16:04:55 +0800 Subject: [PATCH] core: Add compile-time error and runtime assertion for negative shifts --- nac3core/src/codegen/expr.rs | 39 ++++++++++++++++++++++-- nac3core/src/typecheck/function_check.rs | 21 +++++++++++-- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 1518b2d..052e483 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -298,8 +298,40 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { (Operator::BitOr, _) => self.builder.build_or(lhs, rhs, "or").into(), (Operator::BitXor, _) => self.builder.build_xor(lhs, rhs, "xor").into(), (Operator::BitAnd, _) => self.builder.build_and(lhs, rhs, "and").into(), - (Operator::LShift, _) => self.builder.build_left_shift(lhs, rhs, "lshift").into(), - (Operator::RShift, _) => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(), + + // Sign-ness of bitshift operators are always determined by the left operand + (Operator::LShift, signed) | (Operator::RShift, signed) => { + // RHS operand is always 32 bits + assert_eq!(rhs.get_type().get_bit_width(), 32); + + let common_type = lhs.get_type(); + let rhs = if common_type.get_bit_width() > 32 { + if signed { + self.builder.build_int_s_extend(rhs, common_type, "") + } else { + self.builder.build_int_z_extend(rhs, common_type, "") + } + } else { + rhs + }; + + let rhs_gez = self.builder.build_int_compare(IntPredicate::SGE, rhs, common_type.const_zero(), ""); + self.make_assert( + generator, + rhs_gez, + "ValueError", + "negative shift count", + [None, None, None], + self.current_loc + ); + + match *op { + Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(), + Operator::RShift => self.builder.build_right_shift(lhs, rhs, signed, "rshift").into(), + _ => unreachable!() + } + } + (Operator::FloorDiv, true) => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(), (Operator::FloorDiv, false) => self.builder.build_int_unsigned_div(lhs, rhs, "floordiv").into(), (Operator::Pow, s) => integer_power(generator, self, lhs, rhs, s).into(), @@ -1085,6 +1117,9 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, true).into())) } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, false).into())) + } else if [Operator::LShift, Operator::RShift].contains(op) { + let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1); + Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, signed).into())) } else if ty1 == ty2 && ctx.primitives.float == ty1 { Ok(Some(ctx.gen_float_ops(op, left_val, right_val).into())) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 2c4b28b..7ed0735 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -2,7 +2,7 @@ use crate::typecheck::typedef::TypeEnum; use super::type_inferencer::Inferencer; use super::typedef::Type; -use nac3parser::ast::{self, Expr, ExprKind, Stmt, StmtKind, StrRef}; +use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef}; use std::{collections::HashSet, iter::once}; impl<'a> Inferencer<'a> { @@ -107,11 +107,28 @@ impl<'a> Inferencer<'a> { self.check_expr(value, defined_identifiers)?; self.should_have_value(value)?; } - ExprKind::BinOp { left, right, .. } => { + ExprKind::BinOp { left, op, right } => { self.check_expr(left, defined_identifiers)?; self.check_expr(right, defined_identifiers)?; self.should_have_value(left)?; self.should_have_value(right)?; + + // Check whether a bitwise shift has a negative RHS constant value + if *op == LShift || *op == RShift { + if let ExprKind::Constant { value, .. } = &right.node { + let rhs_val = match value { + Constant::Int(v) => v, + _ => unreachable!(), + }; + + if *rhs_val < 0 { + return Err(format!( + "shift count is negative at {}", + right.location + )); + } + } + } } ExprKind::UnaryOp { operand, .. } => { self.check_expr(operand, defined_identifiers)?;