test/py2llvm: more fraction tests

This commit is contained in:
Sebastien Bourdeauducq 2014-09-16 16:43:19 +08:00
parent dbca62c1d7
commit b923690d6f

View File

@ -123,6 +123,30 @@ def arith_encode(op, a, b, c, d):
return f.numerator*1000 + f.denominator
def arith_encode_int(op, a, b, x):
if op == 1:
f = Fraction(a, b) - x
elif op == 2:
f = Fraction(a, b) + x
elif op == 3:
f = Fraction(a, b) * x
else:
f = Fraction(a, b) / x
return f.numerator*1000 + f.denominator
def arith_encode_int_rev(op, a, b, x):
if op == 1:
f = x - Fraction(a, b)
elif op == 2:
f = x + Fraction(a, b)
elif op == 3:
f = x * Fraction(a, b)
else:
f = x / Fraction(a, b)
return f.numerator*1000 + f.denominator
def array_test():
a = array(array(2, 5), 5)
a[3][2] = 11
@ -136,6 +160,12 @@ def array_test():
return acc
def _frac_test_range():
for i in range(5, 10):
yield i
yield -i
class CodeGenCase(unittest.TestCase):
def test_is_prime(self):
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
@ -145,8 +175,8 @@ class CodeGenCase(unittest.TestCase):
def test_frac_simplify(self):
simplify_encode_c = CompiledFunction(
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
for a in range(5, 20):
for b in range(5, 20):
for a in _frac_test_range():
for b in _frac_test_range():
self.assertEqual(
simplify_encode_c(a, b), simplify_encode(a, b))
@ -156,10 +186,10 @@ class CodeGenCase(unittest.TestCase):
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"c": base_types.VInt(), "d": base_types.VInt()})
for a in range(5, 10):
for b in range(5, 10):
for c in range(5, 10):
for d in range(5, 10):
for a in _frac_test_range():
for b in _frac_test_range():
for c in _frac_test_range():
for d in _frac_test_range():
self.assertEqual(
arith_encode_c(op, a, b, c, d),
arith_encode(op, a, b, c, d))
@ -176,6 +206,35 @@ class CodeGenCase(unittest.TestCase):
def test_frac_div(self):
self._test_frac_arith(3)
def _test_frac_arith_int(self, op, rev):
f = arith_encode_int_rev if rev else arith_encode_int
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VInt()})
for a in _frac_test_range():
for b in _frac_test_range():
for x in _frac_test_range():
self.assertEqual(
f_c(op, a, b, x),
f(op, a, b, x))
def test_frac_add_int(self):
self._test_frac_arith_int(0, False)
self._test_frac_arith_int(0, True)
def test_frac_sub_int(self):
self._test_frac_arith_int(1, False)
self._test_frac_arith_int(1, True)
def test_frac_mul_int(self):
self._test_frac_arith_int(2, False)
self._test_frac_arith_int(2, True)
def test_frac_div_int(self):
self._test_frac_arith_int(3, False)
self._test_frac_arith_int(3, True)
def test_array(self):
array_test_c = CompiledFunction(array_test, dict())
self.assertEqual(array_test_c(), array_test())