diff --git a/test/py2llvm.py b/test/py2llvm.py index 63c98fa18..13703281b 100644 --- a/test/py2llvm.py +++ b/test/py2llvm.py @@ -123,6 +123,30 @@ def arith_encode(op, a, b, c, d): return f.numerator*1000 + f.denominator +def arith_encode_int(op, a, b, x): + if op == 1: + f = Fraction(a, b) - x + elif op == 2: + f = Fraction(a, b) + x + elif op == 3: + f = Fraction(a, b) * x + else: + f = Fraction(a, b) / x + return f.numerator*1000 + f.denominator + + +def arith_encode_int_rev(op, a, b, x): + if op == 1: + f = x - Fraction(a, b) + elif op == 2: + f = x + Fraction(a, b) + elif op == 3: + f = x * Fraction(a, b) + else: + f = x / Fraction(a, b) + return f.numerator*1000 + f.denominator + + def array_test(): a = array(array(2, 5), 5) a[3][2] = 11 @@ -136,6 +160,12 @@ def array_test(): return acc +def _frac_test_range(): + for i in range(5, 10): + yield i + yield -i + + class CodeGenCase(unittest.TestCase): def test_is_prime(self): is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()}) @@ -145,8 +175,8 @@ class CodeGenCase(unittest.TestCase): def test_frac_simplify(self): simplify_encode_c = CompiledFunction( simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()}) - for a in range(5, 20): - for b in range(5, 20): + for a in _frac_test_range(): + for b in _frac_test_range(): self.assertEqual( simplify_encode_c(a, b), simplify_encode(a, b)) @@ -156,10 +186,10 @@ class CodeGenCase(unittest.TestCase): "op": base_types.VInt(), "a": base_types.VInt(), "b": base_types.VInt(), "c": base_types.VInt(), "d": base_types.VInt()}) - for a in range(5, 10): - for b in range(5, 10): - for c in range(5, 10): - for d in range(5, 10): + for a in _frac_test_range(): + for b in _frac_test_range(): + for c in _frac_test_range(): + for d in _frac_test_range(): self.assertEqual( arith_encode_c(op, a, b, c, d), arith_encode(op, a, b, c, d)) @@ -176,6 +206,35 @@ class CodeGenCase(unittest.TestCase): def test_frac_div(self): self._test_frac_arith(3) + def _test_frac_arith_int(self, op, rev): + f = arith_encode_int_rev if rev else arith_encode_int + f_c = CompiledFunction(f, { + "op": base_types.VInt(), + "a": base_types.VInt(), "b": base_types.VInt(), + "x": base_types.VInt()}) + for a in _frac_test_range(): + for b in _frac_test_range(): + for x in _frac_test_range(): + self.assertEqual( + f_c(op, a, b, x), + f(op, a, b, x)) + + def test_frac_add_int(self): + self._test_frac_arith_int(0, False) + self._test_frac_arith_int(0, True) + + def test_frac_sub_int(self): + self._test_frac_arith_int(1, False) + self._test_frac_arith_int(1, True) + + def test_frac_mul_int(self): + self._test_frac_arith_int(2, False) + self._test_frac_arith_int(2, True) + + def test_frac_div_int(self): + self._test_frac_arith_int(3, False) + self._test_frac_arith_int(3, True) + def test_array(self): array_test_c = CompiledFunction(array_test, dict()) self.assertEqual(array_test_c(), array_test())