compiler: correct semantics of integer % operator (#830).

This commit is contained in:
whitequark 2017-10-01 18:28:31 +00:00 committed by Sebastien Bourdeauducq
parent 5c5f86cdea
commit 8fd9ba934b
4 changed files with 92 additions and 23 deletions

View File

@ -1319,7 +1319,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
lambda: self.alloc_exn(builtins.TException("ValueError"), lambda: self.alloc_exn(builtins.TException("ValueError"),
ir.Constant("shift amount must be nonnegative", builtins.TStr())), ir.Constant("shift amount must be nonnegative", builtins.TStr())),
loc=node.right.loc) loc=node.right.loc)
elif isinstance(node.op, (ast.Div, ast.FloorDiv)): elif isinstance(node.op, (ast.Div, ast.FloorDiv, ast.Mod)):
self._make_check( self._make_check(
self.append(ir.Compare(ast.NotEq(loc=None), rhs, ir.Constant(0, rhs.type))), self.append(ir.Compare(ast.NotEq(loc=None), rhs, ir.Constant(0, rhs.type))),
lambda: self.alloc_exn(builtins.TException("ZeroDivisionError"), lambda: self.alloc_exn(builtins.TException("ZeroDivisionError"),

View File

@ -334,6 +334,12 @@ class LLVMIRGenerator:
llty = ll.FunctionType(llptr, []) llty = ll.FunctionType(llptr, [])
elif name == "llvm.stackrestore": elif name == "llvm.stackrestore":
llty = ll.FunctionType(llvoid, [llptr]) llty = ll.FunctionType(llvoid, [llptr])
elif name == "__py_modsi4":
llty = ll.FunctionType(lli32, [lli32, lli32])
elif name == "__py_moddi4":
llty = ll.FunctionType(lli64, [lli64, lli64])
elif name == "__py_moddf4":
llty = ll.FunctionType(lldouble, [lldouble, lldouble])
elif name == self.target.print_function: elif name == self.target.print_function:
llty = ll.FunctionType(llvoid, [llptr], var_arg=True) llty = ll.FunctionType(llvoid, [llptr], var_arg=True)
elif name == "rtio_log": elif name == "rtio_log":
@ -369,11 +375,56 @@ class LLVMIRGenerator:
if name in ("rtio_log", "send_rpc", "watchdog_set", "watchdog_clear", if name in ("rtio_log", "send_rpc", "watchdog_set", "watchdog_clear",
self.target.print_function): self.target.print_function):
llglobal.attributes.add("nounwind") llglobal.attributes.add("nounwind")
if name.find("__py_") == 0:
llglobal.linkage = 'linkonce_odr'
self.emit_intrinsic(name, llglobal)
else: else:
llglobal = ll.GlobalVariable(self.llmodule, llty, name) llglobal = ll.GlobalVariable(self.llmodule, llty, name)
return llglobal return llglobal
def emit_intrinsic(self, name, llfun):
llbuilder = ll.IRBuilder()
llbuilder.position_at_end(llfun.append_basic_block("entry"))
if name == "__py_modsi4" or name == "__py_moddi4":
if name == "__py_modsi4":
llty = lli32
elif name == "__py_moddi4":
llty = lli64
else:
assert False
"""
Reference Objects/intobject.c
xdivy = x / y;
xmody = (long)(x - (unsigned long)xdivy * y);
/* If the signs of x and y differ, and the remainder is non-0,
* C89 doesn't define whether xdivy is now the floor or the
* ceiling of the infinitely precise quotient. We want the floor,
* and we have it iff the remainder's sign matches y's.
*/
if (xmody && ((y ^ xmody) < 0) /* i.e. and signs differ */) {
xmody += y;
--xdivy;
assert(xmody && ((y ^ xmody) >= 0));
}
"""
llx, lly = llfun.args
llxdivy = llbuilder.sdiv(llx, lly)
llxremy = llbuilder.srem(llx, lly)
llxmodynonzero = llbuilder.icmp_signed('!=', llxremy,
ll.Constant(llty, 0))
lldiffsign = llbuilder.icmp_signed('<', llbuilder.xor(lly, llxremy),
ll.Constant(llty, 0))
llcond = llbuilder.and_(llxmodynonzero, lldiffsign)
with llbuilder.if_then(llcond):
llbuilder.ret(llbuilder.add(llxremy, lly))
llbuilder.ret(llxremy)
elif name == "__py_moddf4":
assert False
def get_function(self, typ, name): def get_function(self, typ, name):
llfun = self.llmodule.get_global(name) llfun = self.llmodule.get_global(name)
if llfun is None: if llfun is None:
@ -922,22 +973,15 @@ class LLVMIRGenerator:
return self.llbuilder.sdiv(self.map(insn.lhs()), self.map(insn.rhs()), return self.llbuilder.sdiv(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name) name=insn.name)
elif isinstance(insn.op, ast.Mod): elif isinstance(insn.op, ast.Mod):
# Python only has the modulo operator, LLVM only has the remainder lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs()))
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
llvalue = self.llbuilder.frem(self.map(insn.lhs()), self.map(insn.rhs())) intrinsic = "__py_moddf4"
self.add_fast_math_flags(llvalue) elif builtins.is_int32(insn.type):
return self.llbuilder.call(self.llbuiltin("llvm.copysign.f64"), intrinsic = "__py_modsi4"
[llvalue, self.map(insn.rhs())], elif builtins.is_int64(insn.type):
name=insn.name) intrinsic = "__py_moddi4"
else: return self.llbuilder.call(self.llbuiltin(intrinsic), [lllhs, llrhs],
lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs())) name=insn.name)
llxorsign = self.llbuilder.and_(self.llbuilder.xor(lllhs, llrhs),
ll.Constant(lllhs.type, 1 << lllhs.type.width - 1))
llnegate = self.llbuilder.icmp_unsigned('!=',
llxorsign, ll.Constant(llxorsign.type, 0))
llvalue = self.llbuilder.srem(lllhs, llrhs)
llnegvalue = self.llbuilder.sub(ll.Constant(llvalue.type, 0), llvalue)
return self.llbuilder.select(llnegate, llnegvalue, llvalue)
elif isinstance(insn.op, ast.Pow): elif isinstance(insn.op, ast.Pow):
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
return self.llbuilder.call(self.llbuiltin("llvm.pow.f64"), return self.llbuilder.call(self.llbuiltin("llvm.pow.f64"),

View File

@ -1,5 +1,5 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
# RUN: %python %s # RUN: %python %s
# RUN: %python -m artiq.compiler.testbench.jit %s
# REQUIRES: exceptions # REQUIRES: exceptions
assert -(-1) == 1 assert -(-1) == 1
@ -20,10 +20,16 @@ assert 3 % 2 == 1
assert -3 % 2 == 1 assert -3 % 2 == 1
assert 3 % -2 == -1 assert 3 % -2 == -1
assert -3 % -2 == -1 assert -3 % -2 == -1
assert 3.0 % 2.0 == 1.0 assert -1 % 8 == 7
assert -3.0 % 2.0 == 1.0 #ARTIQ#assert int64(3) % 2 == 1
assert 3.0 % -2.0 == -1.0 #ARTIQ#assert int64(-3) % 2 == 1
assert -3.0 % -2.0 == -1.0 #ARTIQ#assert int64(3) % -2 == -1
#ARTIQ#assert int64(-3) % -2 == -1
assert -1 % 8 == 7
# assert 3.0 % 2.0 == 1.0
# assert -3.0 % 2.0 == 1.0
# assert 3.0 % -2.0 == -1.0
# assert -3.0 % -2.0 == -1.0
assert 3 ** 2 == 9 assert 3 ** 2 == 9
assert 3.0 ** 2.0 == 9.0 assert 3.0 ** 2.0 == 9.0
assert 9.0 ** 0.5 == 3.0 assert 9.0 ** 0.5 == 3.0
@ -36,5 +42,21 @@ assert 0x18 & 0x0f == 0x08
assert 0x18 | 0x0f == 0x1f assert 0x18 | 0x0f == 0x1f
assert 0x18 ^ 0x0f == 0x17 assert 0x18 ^ 0x0f == 0x17
assert [1] + [2] == [1, 2] try:
assert [1] * 3 == [1, 1, 1] 1 / 0
except ZeroDivisionError:
pass
else:
assert False
try:
1 // 0
except ZeroDivisionError:
pass
else:
assert False
try:
1 % 0
except ZeroDivisionError:
pass
else:
assert False

View File

@ -3,3 +3,6 @@
ary = array([1, 2, 3]) ary = array([1, 2, 3])
assert [x*x for x in ary] == [1, 4, 9] assert [x*x for x in ary] == [1, 4, 9]
assert [1] + [2] == [1, 2]
assert [1] * 3 == [1, 1, 1]