diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 0d17417d9..502ad651a 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -900,8 +900,14 @@ class ARTIQIRGenerator(algorithm.Visitor): def visit_BinOpT(self, node): if builtins.is_numeric(node.type): # TODO: check for division by zero - # TODO: check for shift by too many bits - return self.append(ir.Arith(node.op, self.visit(node.left), self.visit(node.right))) + rhs = self.visit(node.right) + if isinstance(node.op, (ast.LShift, ast.RShift)): + # Check for negative shift amount. + self._make_check(self.append(ir.Compare(ast.GtE(loc=None), rhs, + ir.Constant(0, rhs.type))), + lambda: self.append(ir.Alloc([], builtins.TValueError()))) + + return self.append(ir.Arith(node.op, self.visit(node.left), rhs)) elif isinstance(node.op, ast.Add): # list + list, tuple + tuple lhs, rhs = self.visit(node.left), self.visit(node.right) if types.is_tuple(node.left.type) and types.is_tuple(node.right.type): diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 5c9c78fe4..5d4075f8b 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -368,11 +368,21 @@ class LLVMIRGenerator: return self.llbuilder.fptosi(llvalue, self.llty_of_type(insn.type), name=insn.name) elif isinstance(insn.op, ast.LShift): - return self.llbuilder.shl(self.map(insn.lhs()), self.map(insn.rhs()), - name=insn.name) + lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs())) + llrhs_max = ll.Constant(llrhs.type, builtins.get_int_width(insn.lhs().type)) + llrhs_overflow = self.llbuilder.icmp_signed('>=', llrhs, llrhs_max) + llvalue_zero = ll.Constant(lllhs.type, 0) + llvalue = self.llbuilder.shl(lllhs, llrhs) + return self.llbuilder.select(llrhs_overflow, llvalue_zero, llvalue, + name=insn.name) elif isinstance(insn.op, ast.RShift): - return self.llbuilder.ashr(self.map(insn.lhs()), self.map(insn.rhs()), - name=insn.name) + lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs())) + llrhs_max = ll.Constant(llrhs.type, builtins.get_int_width(insn.lhs().type) - 1) + llrhs_overflow = self.llbuilder.icmp_signed('>', llrhs, llrhs_max) + llvalue = self.llbuilder.ashr(lllhs, llrhs) + llvalue_max = self.llbuilder.ashr(lllhs, llrhs_max) # preserve sign bit + return self.llbuilder.select(llrhs_overflow, llvalue_max, llvalue, + name=insn.name) elif isinstance(insn.op, ast.BitAnd): return self.llbuilder.and_(self.map(insn.lhs()), self.map(insn.rhs()), name=insn.name) diff --git a/lit-test/compiler/integration/arithmetics.py b/lit-test/compiler/integration/arithmetics.py index a93dd5945..8a278aff6 100644 --- a/lit-test/compiler/integration/arithmetics.py +++ b/lit-test/compiler/integration/arithmetics.py @@ -28,6 +28,8 @@ assert 9.0 ** 0.5 == 3.0 assert 1 << 1 == 2 assert 2 >> 1 == 1 assert -2 >> 1 == -1 +assert 1 << 32 == 0 +assert -1 >> 32 == -1 assert 0x18 & 0x0f == 0x08 assert 0x18 | 0x0f == 0x1f assert 0x18 ^ 0x0f == 0x17