py2llvm: support operations between fractions and floats

This commit is contained in:
Sebastien Bourdeauducq 2014-11-27 18:52:45 +08:00
parent f12389cdd4
commit 6e219469fe
2 changed files with 163 additions and 65 deletions

View File

@ -3,8 +3,8 @@ import ast
from llvm import core as lc from llvm import core as lc
from artiq.py2llvm.values import VGeneric from artiq.py2llvm.values import VGeneric, operators
from artiq.py2llvm.base_types import VBool, VInt from artiq.py2llvm.base_types import VBool, VInt, VFloat
def _gcd(a, b): def _gcd(a, b):
@ -214,30 +214,62 @@ class VFraction(VGeneric):
return self._o_cmp(other, lc.ICMP_SGE, builder) return self._o_cmp(other, lc.ICMP_SGE, builder)
def _o_addsub(self, other, builder, sub, invert=False): def _o_addsub(self, other, builder, sub, invert=False):
if not isinstance(other, (VInt, VFraction)): if isinstance(other, VFloat):
return NotImplemented a = self.o_getattr("numerator", builder)
r = VFraction() b = self.o_getattr("denominator", builder)
if builder is not None:
if isinstance(other, VInt):
i = other.o_int64(builder).auto_load(builder)
x, rd = self._nd(builder)
y = builder.mul(rd, i)
else:
a, b = self._nd(builder)
c, d = other._nd(builder)
rd = builder.mul(b, d)
x = builder.mul(a, d)
y = builder.mul(c, b)
if sub: if sub:
if invert: if invert:
rn = builder.sub(y, x) return operators.truediv(
operators.sub(operators.mul(other,
b,
builder),
a,
builder),
b,
builder)
else: else:
rn = builder.sub(x, y) return operators.truediv(
operators.sub(a,
operators.mul(other,
b,
builder),
builder),
b,
builder)
else: else:
rn = builder.add(x, y) return operators.truediv(
rn, rd = _reduce(builder, rn, rd) # rd is already > 0 operators.add(operators.mul(other,
r.auto_store(builder, _make_ssa(builder, rn, rd)) b,
return r builder),
a,
builder),
b,
builder)
else:
if not isinstance(other, (VFraction, VInt)):
return NotImplemented
r = VFraction()
if builder is not None:
if isinstance(other, VInt):
i = other.o_int64(builder).auto_load(builder)
x, rd = self._nd(builder)
y = builder.mul(rd, i)
else:
a, b = self._nd(builder)
c, d = other._nd(builder)
rd = builder.mul(b, d)
x = builder.mul(a, d)
y = builder.mul(c, b)
if sub:
if invert:
rn = builder.sub(y, x)
else:
rn = builder.sub(x, y)
else:
rn = builder.add(x, y)
rn, rd = _reduce(builder, rn, rd) # rd is already > 0
r.auto_store(builder, _make_ssa(builder, rn, rd))
return r
def o_add(self, other, builder): def o_add(self, other, builder):
return self._o_addsub(other, builder, False) return self._o_addsub(other, builder, False)
@ -252,32 +284,46 @@ class VFraction(VGeneric):
return self._o_addsub(other, builder, True, True) return self._o_addsub(other, builder, True, True)
def _o_muldiv(self, other, builder, div, invert=False): def _o_muldiv(self, other, builder, div, invert=False):
if not isinstance(other, (VFraction, VInt)): if isinstance(other, VFloat):
return NotImplemented a = self.o_getattr("numerator", builder)
r = VFraction() b = self.o_getattr("denominator", builder)
if builder is not None:
a, b = self._nd(builder)
if invert: if invert:
a, b = b, a a, b = b, a
if isinstance(other, VInt): if div:
i = other.o_int64(builder).auto_load(builder) return operators.truediv(a,
if div: operators.mul(b, other, builder),
b = builder.mul(b, i) builder)
else:
a = builder.mul(a, i)
else: else:
c, d = other._nd(builder) return operators.truediv(operators.mul(a, other, builder),
if div: b,
a = builder.mul(a, d) builder)
b = builder.mul(b, c) else:
if not isinstance(other, (VFraction, VInt)):
return NotImplemented
r = VFraction()
if builder is not None:
a, b = self._nd(builder)
if invert:
a, b = b, a
if isinstance(other, VInt):
i = other.o_int64(builder).auto_load(builder)
if div:
b = builder.mul(b, i)
else:
a = builder.mul(a, i)
else: else:
a = builder.mul(a, c) c, d = other._nd(builder)
b = builder.mul(b, d) if div:
if div or invert: a = builder.mul(a, d)
a, b = _signnum(builder, a, b) b = builder.mul(b, c)
a, b = _reduce(builder, a, b) else:
r.auto_store(builder, _make_ssa(builder, a, b)) a = builder.mul(a, c)
return r b = builder.mul(b, d)
if div or invert:
a, b = _signnum(builder, a, b)
a, b = _reduce(builder, a, b)
r.auto_store(builder, _make_ssa(builder, a, b))
return r
def o_mul(self, other, builder): def o_mul(self, other, builder):
return self._o_muldiv(other, builder, False) return self._o_muldiv(other, builder, False)
@ -289,6 +335,7 @@ class VFraction(VGeneric):
return self._o_muldiv(other, builder, False) return self._o_muldiv(other, builder, False)
def or_truediv(self, other, builder): def or_truediv(self, other, builder):
# multiply by the inverse
return self._o_muldiv(other, builder, False, True) return self._o_muldiv(other, builder, False, True)
def o_floordiv(self, other, builder): def o_floordiv(self, other, builder):

View File

@ -112,11 +112,11 @@ class CompiledFunction:
def arith(op, a, b): def arith(op, a, b):
if op == 1: if op == 0:
return a + b return a + b
elif op == 2: elif op == 1:
return a - b return a - b
elif op == 3: elif op == 2:
return a * b return a * b
else: else:
return a / b return a / b
@ -137,11 +137,11 @@ def simplify_encode(a, b):
def frac_arith_encode(op, a, b, c, d): def frac_arith_encode(op, a, b, c, d):
if op == 1: if op == 0:
f = Fraction(a, b) - Fraction(c, d) f = Fraction(a, b) - Fraction(c, d)
elif op == 2: elif op == 1:
f = Fraction(a, b) + Fraction(c, d) f = Fraction(a, b) + Fraction(c, d)
elif op == 3: elif op == 2:
f = Fraction(a, b) * Fraction(c, d) f = Fraction(a, b) * Fraction(c, d)
else: else:
f = Fraction(a, b) / Fraction(c, d) f = Fraction(a, b) / Fraction(c, d)
@ -149,11 +149,11 @@ def frac_arith_encode(op, a, b, c, d):
def frac_arith_encode_int(op, a, b, x): def frac_arith_encode_int(op, a, b, x):
if op == 1: if op == 0:
f = Fraction(a, b) - x f = Fraction(a, b) - x
elif op == 2: elif op == 1:
f = Fraction(a, b) + x f = Fraction(a, b) + x
elif op == 3: elif op == 2:
f = Fraction(a, b) * x f = Fraction(a, b) * x
else: else:
f = Fraction(a, b) / x f = Fraction(a, b) / x
@ -161,17 +161,39 @@ def frac_arith_encode_int(op, a, b, x):
def frac_arith_encode_int_rev(op, a, b, x): def frac_arith_encode_int_rev(op, a, b, x):
if op == 1: if op == 0:
f = x - Fraction(a, b) f = x - Fraction(a, b)
elif op == 2: elif op == 1:
f = x + Fraction(a, b) f = x + Fraction(a, b)
elif op == 3: elif op == 2:
f = x * Fraction(a, b) f = x * Fraction(a, b)
else: else:
f = x / Fraction(a, b) f = x / Fraction(a, b)
return f.numerator*1000 + f.denominator return f.numerator*1000 + f.denominator
def frac_arith_float(op, a, b, x):
if op == 0:
return Fraction(a, b) - x
elif op == 1:
return Fraction(a, b) + x
elif op == 2:
return Fraction(a, b) * x
else:
return Fraction(a, b) / x
def frac_arith_float_rev(op, a, b, x):
if op == 0:
return x - Fraction(a, b)
elif op == 1:
return x + Fraction(a, b)
elif op == 2:
return x * Fraction(a, b)
else:
return x / Fraction(a, b)
def array_test(): def array_test():
a = array(array(2, 5), 5) a = array(array(2, 5), 5)
a[3][2] = 11 a[3][2] = 11
@ -266,7 +288,7 @@ class CodeGenCase(unittest.TestCase):
def test_frac_div(self): def test_frac_div(self):
self._test_frac_arith(3) self._test_frac_arith(3)
def _test_frac_frac_arith_int(self, op, rev): def _test_frac_arith_int(self, op, rev):
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
f_c = CompiledFunction(f, { f_c = CompiledFunction(f, {
"op": base_types.VInt(), "op": base_types.VInt(),
@ -280,20 +302,49 @@ class CodeGenCase(unittest.TestCase):
f(op, a, b, x)) f(op, a, b, x))
def test_frac_add_int(self): def test_frac_add_int(self):
self._test_frac_frac_arith_int(0, False) self._test_frac_arith_int(0, False)
self._test_frac_frac_arith_int(0, True) self._test_frac_arith_int(0, True)
def test_frac_sub_int(self): def test_frac_sub_int(self):
self._test_frac_frac_arith_int(1, False) self._test_frac_arith_int(1, False)
self._test_frac_frac_arith_int(1, True) self._test_frac_arith_int(1, True)
def test_frac_mul_int(self): def test_frac_mul_int(self):
self._test_frac_frac_arith_int(2, False) self._test_frac_arith_int(2, False)
self._test_frac_frac_arith_int(2, True) self._test_frac_arith_int(2, True)
def test_frac_div_int(self): def test_frac_div_int(self):
self._test_frac_frac_arith_int(3, False) self._test_frac_arith_int(3, False)
self._test_frac_frac_arith_int(3, True) self._test_frac_arith_int(3, True)
def _test_frac_arith_float(self, op, rev):
f = frac_arith_float_rev if rev else frac_arith_float
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VFloat()})
for a in _test_range():
for b in _test_range():
for x in _test_range():
self.assertAlmostEqual(
f_c(op, a, b, x/2),
f(op, a, b, x/2))
def test_frac_add_float(self):
self._test_frac_arith_float(0, False)
self._test_frac_arith_float(0, True)
def test_frac_sub_float(self):
self._test_frac_arith_float(1, False)
self._test_frac_arith_float(1, True)
def test_frac_mul_float(self):
self._test_frac_arith_float(2, False)
self._test_frac_arith_float(2, True)
def test_frac_div_float(self):
self._test_frac_arith_float(3, False)
self._test_frac_arith_float(3, True)
def test_array(self): def test_array(self):
array_test_c = CompiledFunction(array_test, dict()) array_test_c = CompiledFunction(array_test, dict())