From 60368aa9e268d3d088af226473b2e7c8792b5c5e Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Mon, 8 Sep 2014 18:45:46 +0800 Subject: [PATCH] py2llvm: complete rational arithmetic support --- artiq/py2llvm/fractions.py | 242 +++++++++++++++++++++++-------------- test/py2llvm.py | 62 ++++++++-- 2 files changed, 201 insertions(+), 103 deletions(-) diff --git a/artiq/py2llvm/fractions.py b/artiq/py2llvm/fractions.py index b12382ae0..7bf652c73 100644 --- a/artiq/py2llvm/fractions.py +++ b/artiq/py2llvm/fractions.py @@ -8,33 +8,61 @@ from artiq.py2llvm.base_types import VBool, VInt def _gcd(a, b): + if a < 0: + a = -a while a: c = a a = b % a b = c return b + def init_module(module): funcdef = ast.parse(inspect.getsource(_gcd)).body[0] module.compile_function(funcdef, {"a": VInt(64), "b": VInt(64)}) -def _call_gcd(builder, a, b): + +def _reduce(builder, a, b): gcd_f = builder.basic_block.function.module.get_function_named("_gcd") - return builder.call(gcd_f, [a, b]) - -def _frac_normalize(builder, numerator, denominator): - gcd = _call_gcd(builder, numerator, denominator) - numerator = builder.sdiv(numerator, gcd) - denominator = builder.sdiv(denominator, gcd) - return numerator, denominator + gcd = builder.call(gcd_f, [a, b]) + a = builder.sdiv(a, gcd) + b = builder.sdiv(b, gcd) + return a, b -def _frac_make_ssa(builder, numerator, denominator): +def _signnum(builder, a, b): + function = builder.basic_block.function + orig_block = builder.basic_block + swap_block = function.append_basic_block("sn_swap") + merge_block = function.append_basic_block("sn_merge") + + condition = builder.icmp( + lc.ICMP_SLT, b, lc.Constant.int(lc.Type.int(64), 0)) + builder.cbranch(condition, swap_block, merge_block) + + builder.position_at_end(swap_block) + minusone = lc.Constant.int(lc.Type.int(64), -1) + a_swp = builder.mul(minusone, a) + b_swp = builder.mul(minusone, b) + builder.branch(merge_block) + + builder.position_at_end(merge_block) + a_phi = builder.phi(lc.Type.int(64)) + a_phi.add_incoming(a, orig_block) + a_phi.add_incoming(a_swp, swap_block) + b_phi = builder.phi(lc.Type.int(64)) + b_phi.add_incoming(b, orig_block) + b_phi.add_incoming(b_swp, swap_block) + + return a_phi, b_phi + + +def _make_ssa(builder, n, d): value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) value = builder.insert_element( - value, numerator, lc.Constant.int(lc.Type.int(), 0)) + value, n, lc.Constant.int(lc.Type.int(), 0)) value = builder.insert_element( - value, denominator, lc.Constant.int(lc.Type.int(), 1)) + value, d, lc.Constant.int(lc.Type.int(), 1)) return value @@ -52,29 +80,25 @@ class VFraction(VGeneric): if not isinstance(other, VFraction): raise TypeError - def _nd(self, builder, invert=False): + def _nd(self, builder): ssa_value = self.get_ssa_value(builder) - numerator = builder.extract_element( + a = builder.extract_element( ssa_value, lc.Constant.int(lc.Type.int(), 0)) - denominator = builder.extract_element( + b = builder.extract_element( ssa_value, lc.Constant.int(lc.Type.int(), 1)) - if invert: - return denominator, numerator - else: - return numerator, denominator + return a, b - def set_value_nd(self, builder, numerator, denominator): - numerator = numerator.o_int64(builder).get_ssa_value(builder) - denominator = denominator.o_int64(builder).get_ssa_value(builder) - numerator, denominator = _frac_normalize( - builder, numerator, denominator) - self.set_ssa_value( - builder, _frac_make_ssa(builder, numerator, denominator)) + def set_value_nd(self, builder, a, b): + a = a.o_int64(builder).get_ssa_value(builder) + b = b.o_int64(builder).get_ssa_value(builder) + a, b = _reduce(builder, a, b) + a, b = _signnum(builder, a, b) + self.set_ssa_value(builder, _make_ssa(builder, a, b)) - def set_value(self, builder, n): - if not isinstance(n, VFraction): + def set_value(self, builder, v): + if not isinstance(v, VFraction): raise TypeError - self.set_ssa_value(builder, n.get_ssa_value(builder)) + self.set_ssa_value(builder, v.get_ssa_value(builder)) def o_getattr(self, attr, builder): if attr == "numerator": @@ -86,7 +110,8 @@ class VFraction(VGeneric): r = VInt(64) if builder is not None: elt = builder.extract_element( - self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), idx)) + self.get_ssa_value(builder), + lc.Constant.int(lc.Type.int(), idx)) r.set_ssa_value(builder, elt) return r @@ -94,9 +119,9 @@ class VFraction(VGeneric): r = VBool() if builder is not None: zero = lc.Constant.int(lc.Type.int(64), 0) - numerator = builder.extract_element( + a = builder.extract_element( self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) - r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero)) + r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, a, zero)) return r def o_intx(self, target_bits, builder): @@ -104,8 +129,8 @@ class VFraction(VGeneric): return VInt(target_bits) else: r = VInt(64) - numerator, denominator = self._nd(builder) - r.set_ssa_value(builder, builder.sdiv(numerator, denominator)) + a, b = self._nd(builder) + r.set_ssa_value(builder, builder.sdiv(a, b)) return r.o_intx(target_bits, builder) def o_roundx(self, target_bits, builder): @@ -113,34 +138,36 @@ class VFraction(VGeneric): return VInt(target_bits) else: r = VInt(64) - numerator, denominator = self._nd(builder) - h_denominator = builder.ashr(denominator, - lc.Constant.int(lc.Type.int(), 1)) - r_numerator = builder.add(numerator, h_denominator) - r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) + a, b = self._nd(builder) + h_b = builder.ashr(b, lc.Constant.int(lc.Type.int(), 1)) + a = builder.add(a, h_b) + r.set_ssa_value(builder, builder.sdiv(a, b)) return r.o_intx(target_bits, builder) def _o_eq_inv(self, other, builder, ne): - if isinstance(other, VFraction): - r = VBool() - if builder is not None: - ee = [] - for i in range(2): - es = builder.extract_element( - self.get_ssa_value(builder), - lc.Constant.int(lc.Type.int(), i)) - eo = builder.extract_element( - other.get_ssa_value(builder), - lc.Constant.int(lc.Type.int(), i)) - ee.append(builder.icmp(lc.ICMP_EQ, es, eo)) - ssa_r = builder.and_(ee[0], ee[1]) - if ne: - ssa_r = builder.xor(ssa_r, - lc.Constant.int(lc.Type.int(1), 1)) - r.set_ssa_value(builder, ssa_r) - return r - else: + if not isinstance(other, (VInt, VFraction)): return NotImplemented + r = VBool() + if builder is not None: + if isinstance(other, VInt): + other = other.o_int64(builder) + a, b = self._nd(builder) + ssa_r = builder.and_( + builder.icmp(lc.ICMP_EQ, a, + other.get_ssa_value()), + builder.icmp(lc.ICMP_EQ, b, + lc.Constant.int(lc.Type.int(64), 1))) + else: + a, b = self._nd(builder) + c, d = other._nd(builder) + ssa_r = builder.and_( + builder.icmp(lc.ICMP_EQ, a, c), + builder.icmp(lc.ICMP_EQ, b, d)) + if ne: + ssa_r = builder.xor(ssa_r, + lc.Constant.int(lc.Type.int(1), 1)) + r.set_ssa_value(builder, ssa_r) + return r def o_eq(self, other, builder): return self._o_eq_inv(other, builder, False) @@ -148,44 +175,71 @@ class VFraction(VGeneric): def o_ne(self, other, builder): return self._o_eq_inv(other, builder, True) - def _o_muldiv(self, other, builder, div, invert=False): - r = VFraction() - if isinstance(other, VInt): - if builder is None: - return r - else: - numerator, denominator = self._nd(builder, invert) - i = other.get_ssa_value(builder) - if div: - gcd = _call_gcd(builder, i, numerator) - i = builder.sdiv(i, gcd) - numerator = builder.sdiv(numerator, gcd) - denominator = builder.mul(denominator, i) - else: - gcd = _call_gcd(builder, i, denominator) - i = builder.sdiv(i, gcd) - denominator = builder.sdiv(denominator, gcd) - numerator = builder.mul(numerator, i) - self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, - denominator)) - elif isinstance(other, VFraction): - if builder is None: - return r - else: - numerator, denominator = self._nd(builder, invert) - onumerator, odenominator = other._nd(builder) - if div: - numerator = builder.mul(numerator, odenominator) - denominator = builder.mul(denominator, onumerator) - else: - numerator = builder.mul(numerator, onumerator) - denominator = builder.mul(denominator, odenominator) - numerator, denominator = _frac_normalize(builder, numerator, - denominator) - self.set_ssa_value( - builder, _frac_make_ssa(builder, numerator, denominator)) - else: + def _o_addsub(self, other, builder, sub, invert=False): + if not isinstance(other, (VInt, VFraction)): return NotImplemented + r = VFraction() + if builder is not None: + if isinstance(other, VInt): + i = other.o_int64(builder).get_ssa_value() + x, rd = self._nd(builder) + y = builder.mul(rd, i) + else: + a, b = self._nd(builder) + c, d = other._nd(builder) + rd = builder.mul(b, d) + x = builder.mul(a, d) + y = builder.mul(c, b) + if sub: + if invert: + rn = builder.sub(y, x) + else: + rn = builder.sub(x, y) + else: + rn = builder.add(x, y) + rn, rd = _reduce(builder, rn, rd) # rd is already > 0 + r.set_ssa_value(builder, _make_ssa(builder, rn, rd)) + return r + + def o_add(self, other, builder): + return self._o_addsub(other, builder, False) + + def o_sub(self, other, builder): + return self._o_addsub(other, builder, True) + + def or_add(self, other, builder): + return self._o_addsub(other, builder, False) + + def or_sub(self, other, builder): + return self._o_addsub(other, builder, False, True) + + def _o_muldiv(self, other, builder, div, invert=False): + if not isinstance(other, (VFraction, VInt)): + return NotImplemented + r = VFraction() + if builder is not None: + a, b = self._nd(builder) + if invert: + a, b = b, a + if isinstance(other, VInt): + i = other.o_int64(builder).get_ssa_value(builder) + if div: + b = builder.mul(b, i) + else: + a = builder.mul(a, i) + else: + c, d = other._nd(builder) + if div: + a = builder.mul(a, d) + b = builder.mul(b, c) + else: + a = builder.mul(a, c) + b = builder.mul(b, d) + if div or invert: + a, b = _signnum(builder, a, b) + a, b = _reduce(builder, a, b) + r.set_ssa_value(builder, _make_ssa(builder, a, b)) + return r def o_mul(self, other, builder): return self._o_muldiv(other, builder, False) diff --git a/test/py2llvm.py b/test/py2llvm.py index 03c742777..c68ef7c24 100644 --- a/test/py2llvm.py +++ b/test/py2llvm.py @@ -26,6 +26,7 @@ def test_types(choice): else: return x + c + class FunctionTypesCase(unittest.TestCase): def setUp(self): self.ns = infer_function_types( @@ -39,7 +40,7 @@ class FunctionTypesCase(unittest.TestCase): self.assertEqual(self.ns["d"].nbits, 32) self.assertIsInstance(self.ns["x"], base_types.VInt) self.assertEqual(self.ns["x"].nbits, 64) - + def test_promotion(self): for v in "abc": self.assertIsInstance(self.ns[v], base_types.VInt) @@ -80,19 +81,62 @@ def is_prime(x): d += 1 return True -def simplify_encode(n, d): - f = Fraction(n, d) + +def simplify_encode(a, b): + f = Fraction(a, b) return f.numerator*1000 + f.denominator + +def arith_encode(op, a, b, c, d): + if op == 1: + f = Fraction(a, b) - Fraction(c, d) + elif op == 2: + f = Fraction(a, b) + Fraction(c, d) + elif op == 3: + f = Fraction(a, b) * Fraction(c, d) + else: + f = Fraction(a, b) / Fraction(c, d) + return f.numerator*1000 + f.denominator + + +is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()}) +simplify_encode_c = CompiledFunction( + simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()}) +arith_encode_c = CompiledFunction( + arith_encode, { + "op": base_types.VInt(), + "a": base_types.VInt(), "b": base_types.VInt(), + "c": base_types.VInt(), "d": base_types.VInt()}) + + class CodeGenCase(unittest.TestCase): def test_is_prime(self): - is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()}) for i in range(200): self.assertEqual(is_prime_c(i), is_prime(i)) def test_frac_simplify(self): - simplify_encode_c = CompiledFunction( - simplify_encode, {"n": base_types.VInt(), "d": base_types.VInt()}) - for n in range(5, 20): - for d in range(5, 20): - self.assertEqual(simplify_encode_c(n, d), simplify_encode(n, d)) + for a in range(5, 20): + for b in range(5, 20): + self.assertEqual( + simplify_encode_c(a, b), simplify_encode(a, b)) + + def _test_frac_arith(self, op): + for a in range(5, 10): + for b in range(5, 10): + for c in range(5, 10): + for d in range(5, 10): + self.assertEqual( + arith_encode_c(op, a, b, c, d), + arith_encode(op, a, b, c, d)) + + def test_frac_add(self): + self._test_frac_arith(0) + + def test_frac_sub(self): + self._test_frac_arith(1) + + def test_frac_mul(self): + self._test_frac_arith(2) + + def test_frac_div(self): + self._test_frac_arith(3)