diff --git a/artiq/py2llvm/fractions.py b/artiq/py2llvm/fractions.py index 297ce23c4..52b73b602 100644 --- a/artiq/py2llvm/fractions.py +++ b/artiq/py2llvm/fractions.py @@ -3,8 +3,8 @@ import ast from llvm import core as lc -from artiq.py2llvm.values import VGeneric -from artiq.py2llvm.base_types import VBool, VInt +from artiq.py2llvm.values import VGeneric, operators +from artiq.py2llvm.base_types import VBool, VInt, VFloat def _gcd(a, b): @@ -214,30 +214,62 @@ class VFraction(VGeneric): return self._o_cmp(other, lc.ICMP_SGE, builder) def _o_addsub(self, other, builder, sub, invert=False): - if not isinstance(other, (VInt, VFraction)): - return NotImplemented - r = VFraction() - if builder is not None: - if isinstance(other, VInt): - i = other.o_int64(builder).auto_load(builder) - x, rd = self._nd(builder) - y = builder.mul(rd, i) - else: - a, b = self._nd(builder) - c, d = other._nd(builder) - rd = builder.mul(b, d) - x = builder.mul(a, d) - y = builder.mul(c, b) + if isinstance(other, VFloat): + a = self.o_getattr("numerator", builder) + b = self.o_getattr("denominator", builder) if sub: if invert: - rn = builder.sub(y, x) + return operators.truediv( + operators.sub(operators.mul(other, + b, + builder), + a, + builder), + b, + builder) else: - rn = builder.sub(x, y) + return operators.truediv( + operators.sub(a, + operators.mul(other, + b, + builder), + builder), + b, + builder) else: - rn = builder.add(x, y) - rn, rd = _reduce(builder, rn, rd) # rd is already > 0 - r.auto_store(builder, _make_ssa(builder, rn, rd)) - return r + return operators.truediv( + operators.add(operators.mul(other, + b, + builder), + a, + builder), + b, + builder) + else: + if not isinstance(other, (VFraction, VInt)): + return NotImplemented + r = VFraction() + if builder is not None: + if isinstance(other, VInt): + i = other.o_int64(builder).auto_load(builder) + x, rd = self._nd(builder) + y = builder.mul(rd, i) + else: + a, b = self._nd(builder) + c, d = other._nd(builder) + rd = builder.mul(b, d) + x = builder.mul(a, d) + y = builder.mul(c, b) + if sub: + if invert: + rn = builder.sub(y, x) + else: + rn = builder.sub(x, y) + else: + rn = builder.add(x, y) + rn, rd = _reduce(builder, rn, rd) # rd is already > 0 + r.auto_store(builder, _make_ssa(builder, rn, rd)) + return r def o_add(self, other, builder): return self._o_addsub(other, builder, False) @@ -252,32 +284,46 @@ class VFraction(VGeneric): return self._o_addsub(other, builder, True, True) def _o_muldiv(self, other, builder, div, invert=False): - if not isinstance(other, (VFraction, VInt)): - return NotImplemented - r = VFraction() - if builder is not None: - a, b = self._nd(builder) + if isinstance(other, VFloat): + a = self.o_getattr("numerator", builder) + b = self.o_getattr("denominator", builder) if invert: a, b = b, a - if isinstance(other, VInt): - i = other.o_int64(builder).auto_load(builder) - if div: - b = builder.mul(b, i) - else: - a = builder.mul(a, i) + if div: + return operators.truediv(a, + operators.mul(b, other, builder), + builder) else: - c, d = other._nd(builder) - if div: - a = builder.mul(a, d) - b = builder.mul(b, c) + return operators.truediv(operators.mul(a, other, builder), + b, + builder) + else: + if not isinstance(other, (VFraction, VInt)): + return NotImplemented + r = VFraction() + if builder is not None: + a, b = self._nd(builder) + if invert: + a, b = b, a + if isinstance(other, VInt): + i = other.o_int64(builder).auto_load(builder) + if div: + b = builder.mul(b, i) + else: + a = builder.mul(a, i) else: - a = builder.mul(a, c) - b = builder.mul(b, d) - if div or invert: - a, b = _signnum(builder, a, b) - a, b = _reduce(builder, a, b) - r.auto_store(builder, _make_ssa(builder, a, b)) - return r + c, d = other._nd(builder) + if div: + a = builder.mul(a, d) + b = builder.mul(b, c) + else: + a = builder.mul(a, c) + b = builder.mul(b, d) + if div or invert: + a, b = _signnum(builder, a, b) + a, b = _reduce(builder, a, b) + r.auto_store(builder, _make_ssa(builder, a, b)) + return r def o_mul(self, other, builder): return self._o_muldiv(other, builder, False) @@ -289,6 +335,7 @@ class VFraction(VGeneric): return self._o_muldiv(other, builder, False) def or_truediv(self, other, builder): + # multiply by the inverse return self._o_muldiv(other, builder, False, True) def o_floordiv(self, other, builder): diff --git a/test/py2llvm.py b/test/py2llvm.py index 882a6d8b9..29c86ed89 100644 --- a/test/py2llvm.py +++ b/test/py2llvm.py @@ -112,11 +112,11 @@ class CompiledFunction: def arith(op, a, b): - if op == 1: + if op == 0: return a + b - elif op == 2: + elif op == 1: return a - b - elif op == 3: + elif op == 2: return a * b else: return a / b @@ -137,11 +137,11 @@ def simplify_encode(a, b): def frac_arith_encode(op, a, b, c, d): - if op == 1: + if op == 0: f = Fraction(a, b) - Fraction(c, d) - elif op == 2: + elif op == 1: f = Fraction(a, b) + Fraction(c, d) - elif op == 3: + elif op == 2: f = Fraction(a, b) * Fraction(c, d) else: f = Fraction(a, b) / Fraction(c, d) @@ -149,11 +149,11 @@ def frac_arith_encode(op, a, b, c, d): def frac_arith_encode_int(op, a, b, x): - if op == 1: + if op == 0: f = Fraction(a, b) - x - elif op == 2: + elif op == 1: f = Fraction(a, b) + x - elif op == 3: + elif op == 2: f = Fraction(a, b) * x else: f = Fraction(a, b) / x @@ -161,17 +161,39 @@ def frac_arith_encode_int(op, a, b, x): def frac_arith_encode_int_rev(op, a, b, x): - if op == 1: + if op == 0: f = x - Fraction(a, b) - elif op == 2: + elif op == 1: f = x + Fraction(a, b) - elif op == 3: + elif op == 2: f = x * Fraction(a, b) else: f = x / Fraction(a, b) return f.numerator*1000 + f.denominator +def frac_arith_float(op, a, b, x): + if op == 0: + return Fraction(a, b) - x + elif op == 1: + return Fraction(a, b) + x + elif op == 2: + return Fraction(a, b) * x + else: + return Fraction(a, b) / x + + +def frac_arith_float_rev(op, a, b, x): + if op == 0: + return x - Fraction(a, b) + elif op == 1: + return x + Fraction(a, b) + elif op == 2: + return x * Fraction(a, b) + else: + return x / Fraction(a, b) + + def array_test(): a = array(array(2, 5), 5) a[3][2] = 11 @@ -266,7 +288,7 @@ class CodeGenCase(unittest.TestCase): def test_frac_div(self): self._test_frac_arith(3) - def _test_frac_frac_arith_int(self, op, rev): + def _test_frac_arith_int(self, op, rev): f = frac_arith_encode_int_rev if rev else frac_arith_encode_int f_c = CompiledFunction(f, { "op": base_types.VInt(), @@ -280,20 +302,49 @@ class CodeGenCase(unittest.TestCase): f(op, a, b, x)) def test_frac_add_int(self): - self._test_frac_frac_arith_int(0, False) - self._test_frac_frac_arith_int(0, True) + self._test_frac_arith_int(0, False) + self._test_frac_arith_int(0, True) def test_frac_sub_int(self): - self._test_frac_frac_arith_int(1, False) - self._test_frac_frac_arith_int(1, True) + self._test_frac_arith_int(1, False) + self._test_frac_arith_int(1, True) def test_frac_mul_int(self): - self._test_frac_frac_arith_int(2, False) - self._test_frac_frac_arith_int(2, True) + self._test_frac_arith_int(2, False) + self._test_frac_arith_int(2, True) def test_frac_div_int(self): - self._test_frac_frac_arith_int(3, False) - self._test_frac_frac_arith_int(3, True) + self._test_frac_arith_int(3, False) + self._test_frac_arith_int(3, True) + + def _test_frac_arith_float(self, op, rev): + f = frac_arith_float_rev if rev else frac_arith_float + f_c = CompiledFunction(f, { + "op": base_types.VInt(), + "a": base_types.VInt(), "b": base_types.VInt(), + "x": base_types.VFloat()}) + for a in _test_range(): + for b in _test_range(): + for x in _test_range(): + self.assertAlmostEqual( + f_c(op, a, b, x/2), + f(op, a, b, x/2)) + + def test_frac_add_float(self): + self._test_frac_arith_float(0, False) + self._test_frac_arith_float(0, True) + + def test_frac_sub_float(self): + self._test_frac_arith_float(1, False) + self._test_frac_arith_float(1, True) + + def test_frac_mul_float(self): + self._test_frac_arith_float(2, False) + self._test_frac_arith_float(2, True) + + def test_frac_div_float(self): + self._test_frac_arith_float(3, False) + self._test_frac_arith_float(3, True) def test_array(self): array_test_c = CompiledFunction(array_test, dict())