forked from M-Labs/artiq
py2llvm: complete rational arithmetic support
This commit is contained in:
parent
1133308dd5
commit
60368aa9e2
|
@ -8,33 +8,61 @@ from artiq.py2llvm.base_types import VBool, VInt
|
||||||
|
|
||||||
|
|
||||||
def _gcd(a, b):
|
def _gcd(a, b):
|
||||||
|
if a < 0:
|
||||||
|
a = -a
|
||||||
while a:
|
while a:
|
||||||
c = a
|
c = a
|
||||||
a = b % a
|
a = b % a
|
||||||
b = c
|
b = c
|
||||||
return b
|
return b
|
||||||
|
|
||||||
|
|
||||||
def init_module(module):
|
def init_module(module):
|
||||||
funcdef = ast.parse(inspect.getsource(_gcd)).body[0]
|
funcdef = ast.parse(inspect.getsource(_gcd)).body[0]
|
||||||
module.compile_function(funcdef, {"a": VInt(64), "b": VInt(64)})
|
module.compile_function(funcdef, {"a": VInt(64), "b": VInt(64)})
|
||||||
|
|
||||||
def _call_gcd(builder, a, b):
|
|
||||||
|
def _reduce(builder, a, b):
|
||||||
gcd_f = builder.basic_block.function.module.get_function_named("_gcd")
|
gcd_f = builder.basic_block.function.module.get_function_named("_gcd")
|
||||||
return builder.call(gcd_f, [a, b])
|
gcd = builder.call(gcd_f, [a, b])
|
||||||
|
a = builder.sdiv(a, gcd)
|
||||||
def _frac_normalize(builder, numerator, denominator):
|
b = builder.sdiv(b, gcd)
|
||||||
gcd = _call_gcd(builder, numerator, denominator)
|
return a, b
|
||||||
numerator = builder.sdiv(numerator, gcd)
|
|
||||||
denominator = builder.sdiv(denominator, gcd)
|
|
||||||
return numerator, denominator
|
|
||||||
|
|
||||||
|
|
||||||
def _frac_make_ssa(builder, numerator, denominator):
|
def _signnum(builder, a, b):
|
||||||
|
function = builder.basic_block.function
|
||||||
|
orig_block = builder.basic_block
|
||||||
|
swap_block = function.append_basic_block("sn_swap")
|
||||||
|
merge_block = function.append_basic_block("sn_merge")
|
||||||
|
|
||||||
|
condition = builder.icmp(
|
||||||
|
lc.ICMP_SLT, b, lc.Constant.int(lc.Type.int(64), 0))
|
||||||
|
builder.cbranch(condition, swap_block, merge_block)
|
||||||
|
|
||||||
|
builder.position_at_end(swap_block)
|
||||||
|
minusone = lc.Constant.int(lc.Type.int(64), -1)
|
||||||
|
a_swp = builder.mul(minusone, a)
|
||||||
|
b_swp = builder.mul(minusone, b)
|
||||||
|
builder.branch(merge_block)
|
||||||
|
|
||||||
|
builder.position_at_end(merge_block)
|
||||||
|
a_phi = builder.phi(lc.Type.int(64))
|
||||||
|
a_phi.add_incoming(a, orig_block)
|
||||||
|
a_phi.add_incoming(a_swp, swap_block)
|
||||||
|
b_phi = builder.phi(lc.Type.int(64))
|
||||||
|
b_phi.add_incoming(b, orig_block)
|
||||||
|
b_phi.add_incoming(b_swp, swap_block)
|
||||||
|
|
||||||
|
return a_phi, b_phi
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ssa(builder, n, d):
|
||||||
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
|
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
|
||||||
value = builder.insert_element(
|
value = builder.insert_element(
|
||||||
value, numerator, lc.Constant.int(lc.Type.int(), 0))
|
value, n, lc.Constant.int(lc.Type.int(), 0))
|
||||||
value = builder.insert_element(
|
value = builder.insert_element(
|
||||||
value, denominator, lc.Constant.int(lc.Type.int(), 1))
|
value, d, lc.Constant.int(lc.Type.int(), 1))
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,29 +80,25 @@ class VFraction(VGeneric):
|
||||||
if not isinstance(other, VFraction):
|
if not isinstance(other, VFraction):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def _nd(self, builder, invert=False):
|
def _nd(self, builder):
|
||||||
ssa_value = self.get_ssa_value(builder)
|
ssa_value = self.get_ssa_value(builder)
|
||||||
numerator = builder.extract_element(
|
a = builder.extract_element(
|
||||||
ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
||||||
denominator = builder.extract_element(
|
b = builder.extract_element(
|
||||||
ssa_value, lc.Constant.int(lc.Type.int(), 1))
|
ssa_value, lc.Constant.int(lc.Type.int(), 1))
|
||||||
if invert:
|
return a, b
|
||||||
return denominator, numerator
|
|
||||||
else:
|
|
||||||
return numerator, denominator
|
|
||||||
|
|
||||||
def set_value_nd(self, builder, numerator, denominator):
|
def set_value_nd(self, builder, a, b):
|
||||||
numerator = numerator.o_int64(builder).get_ssa_value(builder)
|
a = a.o_int64(builder).get_ssa_value(builder)
|
||||||
denominator = denominator.o_int64(builder).get_ssa_value(builder)
|
b = b.o_int64(builder).get_ssa_value(builder)
|
||||||
numerator, denominator = _frac_normalize(
|
a, b = _reduce(builder, a, b)
|
||||||
builder, numerator, denominator)
|
a, b = _signnum(builder, a, b)
|
||||||
self.set_ssa_value(
|
self.set_ssa_value(builder, _make_ssa(builder, a, b))
|
||||||
builder, _frac_make_ssa(builder, numerator, denominator))
|
|
||||||
|
|
||||||
def set_value(self, builder, n):
|
def set_value(self, builder, v):
|
||||||
if not isinstance(n, VFraction):
|
if not isinstance(v, VFraction):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
self.set_ssa_value(builder, n.get_ssa_value(builder))
|
self.set_ssa_value(builder, v.get_ssa_value(builder))
|
||||||
|
|
||||||
def o_getattr(self, attr, builder):
|
def o_getattr(self, attr, builder):
|
||||||
if attr == "numerator":
|
if attr == "numerator":
|
||||||
|
@ -86,7 +110,8 @@ class VFraction(VGeneric):
|
||||||
r = VInt(64)
|
r = VInt(64)
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
elt = builder.extract_element(
|
elt = builder.extract_element(
|
||||||
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), idx))
|
self.get_ssa_value(builder),
|
||||||
|
lc.Constant.int(lc.Type.int(), idx))
|
||||||
r.set_ssa_value(builder, elt)
|
r.set_ssa_value(builder, elt)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
@ -94,9 +119,9 @@ class VFraction(VGeneric):
|
||||||
r = VBool()
|
r = VBool()
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
zero = lc.Constant.int(lc.Type.int(64), 0)
|
zero = lc.Constant.int(lc.Type.int(64), 0)
|
||||||
numerator = builder.extract_element(
|
a = builder.extract_element(
|
||||||
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
||||||
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
|
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, a, zero))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def o_intx(self, target_bits, builder):
|
def o_intx(self, target_bits, builder):
|
||||||
|
@ -104,8 +129,8 @@ class VFraction(VGeneric):
|
||||||
return VInt(target_bits)
|
return VInt(target_bits)
|
||||||
else:
|
else:
|
||||||
r = VInt(64)
|
r = VInt(64)
|
||||||
numerator, denominator = self._nd(builder)
|
a, b = self._nd(builder)
|
||||||
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
|
r.set_ssa_value(builder, builder.sdiv(a, b))
|
||||||
return r.o_intx(target_bits, builder)
|
return r.o_intx(target_bits, builder)
|
||||||
|
|
||||||
def o_roundx(self, target_bits, builder):
|
def o_roundx(self, target_bits, builder):
|
||||||
|
@ -113,34 +138,36 @@ class VFraction(VGeneric):
|
||||||
return VInt(target_bits)
|
return VInt(target_bits)
|
||||||
else:
|
else:
|
||||||
r = VInt(64)
|
r = VInt(64)
|
||||||
numerator, denominator = self._nd(builder)
|
a, b = self._nd(builder)
|
||||||
h_denominator = builder.ashr(denominator,
|
h_b = builder.ashr(b, lc.Constant.int(lc.Type.int(), 1))
|
||||||
lc.Constant.int(lc.Type.int(), 1))
|
a = builder.add(a, h_b)
|
||||||
r_numerator = builder.add(numerator, h_denominator)
|
r.set_ssa_value(builder, builder.sdiv(a, b))
|
||||||
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
|
|
||||||
return r.o_intx(target_bits, builder)
|
return r.o_intx(target_bits, builder)
|
||||||
|
|
||||||
def _o_eq_inv(self, other, builder, ne):
|
def _o_eq_inv(self, other, builder, ne):
|
||||||
if isinstance(other, VFraction):
|
if not isinstance(other, (VInt, VFraction)):
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
ee = []
|
|
||||||
for i in range(2):
|
|
||||||
es = builder.extract_element(
|
|
||||||
self.get_ssa_value(builder),
|
|
||||||
lc.Constant.int(lc.Type.int(), i))
|
|
||||||
eo = builder.extract_element(
|
|
||||||
other.get_ssa_value(builder),
|
|
||||||
lc.Constant.int(lc.Type.int(), i))
|
|
||||||
ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
|
|
||||||
ssa_r = builder.and_(ee[0], ee[1])
|
|
||||||
if ne:
|
|
||||||
ssa_r = builder.xor(ssa_r,
|
|
||||||
lc.Constant.int(lc.Type.int(1), 1))
|
|
||||||
r.set_ssa_value(builder, ssa_r)
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
r = VBool()
|
||||||
|
if builder is not None:
|
||||||
|
if isinstance(other, VInt):
|
||||||
|
other = other.o_int64(builder)
|
||||||
|
a, b = self._nd(builder)
|
||||||
|
ssa_r = builder.and_(
|
||||||
|
builder.icmp(lc.ICMP_EQ, a,
|
||||||
|
other.get_ssa_value()),
|
||||||
|
builder.icmp(lc.ICMP_EQ, b,
|
||||||
|
lc.Constant.int(lc.Type.int(64), 1)))
|
||||||
|
else:
|
||||||
|
a, b = self._nd(builder)
|
||||||
|
c, d = other._nd(builder)
|
||||||
|
ssa_r = builder.and_(
|
||||||
|
builder.icmp(lc.ICMP_EQ, a, c),
|
||||||
|
builder.icmp(lc.ICMP_EQ, b, d))
|
||||||
|
if ne:
|
||||||
|
ssa_r = builder.xor(ssa_r,
|
||||||
|
lc.Constant.int(lc.Type.int(1), 1))
|
||||||
|
r.set_ssa_value(builder, ssa_r)
|
||||||
|
return r
|
||||||
|
|
||||||
def o_eq(self, other, builder):
|
def o_eq(self, other, builder):
|
||||||
return self._o_eq_inv(other, builder, False)
|
return self._o_eq_inv(other, builder, False)
|
||||||
|
@ -148,44 +175,71 @@ class VFraction(VGeneric):
|
||||||
def o_ne(self, other, builder):
|
def o_ne(self, other, builder):
|
||||||
return self._o_eq_inv(other, builder, True)
|
return self._o_eq_inv(other, builder, True)
|
||||||
|
|
||||||
def _o_muldiv(self, other, builder, div, invert=False):
|
def _o_addsub(self, other, builder, sub, invert=False):
|
||||||
r = VFraction()
|
if not isinstance(other, (VInt, VFraction)):
|
||||||
if isinstance(other, VInt):
|
|
||||||
if builder is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
numerator, denominator = self._nd(builder, invert)
|
|
||||||
i = other.get_ssa_value(builder)
|
|
||||||
if div:
|
|
||||||
gcd = _call_gcd(builder, i, numerator)
|
|
||||||
i = builder.sdiv(i, gcd)
|
|
||||||
numerator = builder.sdiv(numerator, gcd)
|
|
||||||
denominator = builder.mul(denominator, i)
|
|
||||||
else:
|
|
||||||
gcd = _call_gcd(builder, i, denominator)
|
|
||||||
i = builder.sdiv(i, gcd)
|
|
||||||
denominator = builder.sdiv(denominator, gcd)
|
|
||||||
numerator = builder.mul(numerator, i)
|
|
||||||
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
|
|
||||||
denominator))
|
|
||||||
elif isinstance(other, VFraction):
|
|
||||||
if builder is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
numerator, denominator = self._nd(builder, invert)
|
|
||||||
onumerator, odenominator = other._nd(builder)
|
|
||||||
if div:
|
|
||||||
numerator = builder.mul(numerator, odenominator)
|
|
||||||
denominator = builder.mul(denominator, onumerator)
|
|
||||||
else:
|
|
||||||
numerator = builder.mul(numerator, onumerator)
|
|
||||||
denominator = builder.mul(denominator, odenominator)
|
|
||||||
numerator, denominator = _frac_normalize(builder, numerator,
|
|
||||||
denominator)
|
|
||||||
self.set_ssa_value(
|
|
||||||
builder, _frac_make_ssa(builder, numerator, denominator))
|
|
||||||
else:
|
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
r = VFraction()
|
||||||
|
if builder is not None:
|
||||||
|
if isinstance(other, VInt):
|
||||||
|
i = other.o_int64(builder).get_ssa_value()
|
||||||
|
x, rd = self._nd(builder)
|
||||||
|
y = builder.mul(rd, i)
|
||||||
|
else:
|
||||||
|
a, b = self._nd(builder)
|
||||||
|
c, d = other._nd(builder)
|
||||||
|
rd = builder.mul(b, d)
|
||||||
|
x = builder.mul(a, d)
|
||||||
|
y = builder.mul(c, b)
|
||||||
|
if sub:
|
||||||
|
if invert:
|
||||||
|
rn = builder.sub(y, x)
|
||||||
|
else:
|
||||||
|
rn = builder.sub(x, y)
|
||||||
|
else:
|
||||||
|
rn = builder.add(x, y)
|
||||||
|
rn, rd = _reduce(builder, rn, rd) # rd is already > 0
|
||||||
|
r.set_ssa_value(builder, _make_ssa(builder, rn, rd))
|
||||||
|
return r
|
||||||
|
|
||||||
|
def o_add(self, other, builder):
|
||||||
|
return self._o_addsub(other, builder, False)
|
||||||
|
|
||||||
|
def o_sub(self, other, builder):
|
||||||
|
return self._o_addsub(other, builder, True)
|
||||||
|
|
||||||
|
def or_add(self, other, builder):
|
||||||
|
return self._o_addsub(other, builder, False)
|
||||||
|
|
||||||
|
def or_sub(self, other, builder):
|
||||||
|
return self._o_addsub(other, builder, False, True)
|
||||||
|
|
||||||
|
def _o_muldiv(self, other, builder, div, invert=False):
|
||||||
|
if not isinstance(other, (VFraction, VInt)):
|
||||||
|
return NotImplemented
|
||||||
|
r = VFraction()
|
||||||
|
if builder is not None:
|
||||||
|
a, b = self._nd(builder)
|
||||||
|
if invert:
|
||||||
|
a, b = b, a
|
||||||
|
if isinstance(other, VInt):
|
||||||
|
i = other.o_int64(builder).get_ssa_value(builder)
|
||||||
|
if div:
|
||||||
|
b = builder.mul(b, i)
|
||||||
|
else:
|
||||||
|
a = builder.mul(a, i)
|
||||||
|
else:
|
||||||
|
c, d = other._nd(builder)
|
||||||
|
if div:
|
||||||
|
a = builder.mul(a, d)
|
||||||
|
b = builder.mul(b, c)
|
||||||
|
else:
|
||||||
|
a = builder.mul(a, c)
|
||||||
|
b = builder.mul(b, d)
|
||||||
|
if div or invert:
|
||||||
|
a, b = _signnum(builder, a, b)
|
||||||
|
a, b = _reduce(builder, a, b)
|
||||||
|
r.set_ssa_value(builder, _make_ssa(builder, a, b))
|
||||||
|
return r
|
||||||
|
|
||||||
def o_mul(self, other, builder):
|
def o_mul(self, other, builder):
|
||||||
return self._o_muldiv(other, builder, False)
|
return self._o_muldiv(other, builder, False)
|
||||||
|
|
|
@ -26,6 +26,7 @@ def test_types(choice):
|
||||||
else:
|
else:
|
||||||
return x + c
|
return x + c
|
||||||
|
|
||||||
|
|
||||||
class FunctionTypesCase(unittest.TestCase):
|
class FunctionTypesCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.ns = infer_function_types(
|
self.ns = infer_function_types(
|
||||||
|
@ -80,19 +81,62 @@ def is_prime(x):
|
||||||
d += 1
|
d += 1
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def simplify_encode(n, d):
|
|
||||||
f = Fraction(n, d)
|
def simplify_encode(a, b):
|
||||||
|
f = Fraction(a, b)
|
||||||
return f.numerator*1000 + f.denominator
|
return f.numerator*1000 + f.denominator
|
||||||
|
|
||||||
|
|
||||||
|
def arith_encode(op, a, b, c, d):
|
||||||
|
if op == 1:
|
||||||
|
f = Fraction(a, b) - Fraction(c, d)
|
||||||
|
elif op == 2:
|
||||||
|
f = Fraction(a, b) + Fraction(c, d)
|
||||||
|
elif op == 3:
|
||||||
|
f = Fraction(a, b) * Fraction(c, d)
|
||||||
|
else:
|
||||||
|
f = Fraction(a, b) / Fraction(c, d)
|
||||||
|
return f.numerator*1000 + f.denominator
|
||||||
|
|
||||||
|
|
||||||
|
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
|
||||||
|
simplify_encode_c = CompiledFunction(
|
||||||
|
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
|
||||||
|
arith_encode_c = CompiledFunction(
|
||||||
|
arith_encode, {
|
||||||
|
"op": base_types.VInt(),
|
||||||
|
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||||
|
"c": base_types.VInt(), "d": base_types.VInt()})
|
||||||
|
|
||||||
|
|
||||||
class CodeGenCase(unittest.TestCase):
|
class CodeGenCase(unittest.TestCase):
|
||||||
def test_is_prime(self):
|
def test_is_prime(self):
|
||||||
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
|
|
||||||
for i in range(200):
|
for i in range(200):
|
||||||
self.assertEqual(is_prime_c(i), is_prime(i))
|
self.assertEqual(is_prime_c(i), is_prime(i))
|
||||||
|
|
||||||
def test_frac_simplify(self):
|
def test_frac_simplify(self):
|
||||||
simplify_encode_c = CompiledFunction(
|
for a in range(5, 20):
|
||||||
simplify_encode, {"n": base_types.VInt(), "d": base_types.VInt()})
|
for b in range(5, 20):
|
||||||
for n in range(5, 20):
|
self.assertEqual(
|
||||||
for d in range(5, 20):
|
simplify_encode_c(a, b), simplify_encode(a, b))
|
||||||
self.assertEqual(simplify_encode_c(n, d), simplify_encode(n, d))
|
|
||||||
|
def _test_frac_arith(self, op):
|
||||||
|
for a in range(5, 10):
|
||||||
|
for b in range(5, 10):
|
||||||
|
for c in range(5, 10):
|
||||||
|
for d in range(5, 10):
|
||||||
|
self.assertEqual(
|
||||||
|
arith_encode_c(op, a, b, c, d),
|
||||||
|
arith_encode(op, a, b, c, d))
|
||||||
|
|
||||||
|
def test_frac_add(self):
|
||||||
|
self._test_frac_arith(0)
|
||||||
|
|
||||||
|
def test_frac_sub(self):
|
||||||
|
self._test_frac_arith(1)
|
||||||
|
|
||||||
|
def test_frac_mul(self):
|
||||||
|
self._test_frac_arith(2)
|
||||||
|
|
||||||
|
def test_frac_div(self):
|
||||||
|
self._test_frac_arith(3)
|
||||||
|
|
Loading…
Reference in New Issue