artiq/test/py2llvm.py

370 lines
10 KiB
Python
Raw Normal View History

import unittest
import ast
import inspect
from fractions import Fraction
2014-12-05 17:05:43 +08:00
from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double
2014-12-05 17:05:43 +08:00
import llvmlite.binding as llvm
2014-09-09 17:13:48 +08:00
from artiq.language.core import int64, array
from artiq.py2llvm.infer_types import infer_function_types
2014-09-09 17:13:48 +08:00
from artiq.py2llvm import base_types, arrays
from artiq.py2llvm.module import Module
2014-12-05 17:05:43 +08:00
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
def _base_types(choice):
a = 2 # promoted later to int64
b = a + 1 # initially int32, becomes int64 after a is promoted
c = b//2 # initially int32, becomes int64 after b is promoted
2014-10-13 23:54:56 +08:00
d = 4 and 5 # stays int32
x = int64(7)
a += x # promotes a to int64
2014-10-13 23:54:56 +08:00
foo = True | True or False
bar = None
2014-09-16 23:11:30 +08:00
myf = 4.5
myf2 = myf + x
if choice and foo and not bar:
return d
2014-09-17 16:25:14 +08:00
elif myf2:
return x + c
2014-09-16 23:11:30 +08:00
else:
return int64(8)
2014-09-09 17:13:48 +08:00
def _build_function_types(f):
return infer_function_types(
None, ast.parse(inspect.getsource(f)),
dict())
class FunctionBaseTypesCase(unittest.TestCase):
def setUp(self):
self.ns = _build_function_types(_base_types)
2014-09-09 17:13:48 +08:00
def test_simple_types(self):
self.assertIsInstance(self.ns["foo"], base_types.VBool)
self.assertIsInstance(self.ns["bar"], base_types.VNone)
self.assertIsInstance(self.ns["d"], base_types.VInt)
self.assertEqual(self.ns["d"].nbits, 32)
self.assertIsInstance(self.ns["x"], base_types.VInt)
self.assertEqual(self.ns["x"].nbits, 64)
2014-09-16 23:11:30 +08:00
self.assertIsInstance(self.ns["myf"], base_types.VFloat)
self.assertIsInstance(self.ns["myf2"], base_types.VFloat)
def test_promotion(self):
for v in "abc":
self.assertIsInstance(self.ns[v], base_types.VInt)
self.assertEqual(self.ns[v].nbits, 64)
def test_return(self):
self.assertIsInstance(self.ns["return"], base_types.VInt)
self.assertEqual(self.ns["return"].nbits, 64)
2014-09-09 17:13:48 +08:00
def test_array_types():
a = array(0, 5)
2014-09-10 16:06:27 +08:00
for i in range(2):
a[i] = int64(8)
2014-09-09 17:13:48 +08:00
return a
class FunctionArrayTypesCase(unittest.TestCase):
def setUp(self):
self.ns = _build_function_types(test_array_types)
def test_array_types(self):
self.assertIsInstance(self.ns["a"], arrays.VArray)
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt)
self.assertEqual(self.ns["a"].el_init.nbits, 64)
self.assertEqual(self.ns["a"].count, 5)
2014-09-10 16:06:27 +08:00
self.assertIsInstance(self.ns["i"], base_types.VInt)
self.assertEqual(self.ns["i"].nbits, 32)
2014-09-09 17:13:48 +08:00
2014-12-05 17:05:43 +08:00
def _value_to_ctype(v):
if isinstance(v, base_types.VBool):
return c_int
elif isinstance(v, base_types.VInt):
if v.nbits == 32:
return c_int32
elif v.nbits == 64:
return c_int64
else:
raise NotImplementedError(str(v))
elif isinstance(v, base_types.VFloat):
return c_double
else:
raise NotImplementedError(str(v))
class CompiledFunction:
def __init__(self, function, param_types):
module = Module()
2014-12-05 17:05:43 +08:00
2014-09-13 19:32:21 +08:00
func_def = ast.parse(inspect.getsource(function)).body[0]
2014-12-05 17:05:43 +08:00
function, retval = module.compile_function(func_def, param_types)
argvals = [param_types[arg.arg] for arg in func_def.args.args]
ee = module.get_ee()
cfptr = ee.get_pointer_to_global(
module.llvm_module.get_function(function.name))
retval_ctype = _value_to_ctype(retval)
argval_ctypes = [_value_to_ctype(argval) for argval in argvals]
self.cfunc = CFUNCTYPE(retval_ctype, *argval_ctypes)(cfptr)
# HACK: prevent garbage collection of self.cfunc internals
self.ee = ee
def __call__(self, *args):
2014-12-05 17:05:43 +08:00
return self.cfunc(*args)
2014-09-16 23:11:30 +08:00
def arith(op, a, b):
if op == 0:
2014-09-16 23:11:30 +08:00
return a + b
elif op == 1:
2014-09-16 23:11:30 +08:00
return a - b
elif op == 2:
2014-09-16 23:11:30 +08:00
return a * b
else:
return a / b
def is_prime(x):
d = 2
while d*d <= x:
if not x % d:
return False
d += 1
return True
def simplify_encode(a, b):
f = Fraction(a, b)
return f.numerator*1000 + f.denominator
2014-09-16 23:11:30 +08:00
def frac_arith_encode(op, a, b, c, d):
if op == 0:
f = Fraction(a, b) - Fraction(c, d)
elif op == 1:
f = Fraction(a, b) + Fraction(c, d)
elif op == 2:
f = Fraction(a, b) * Fraction(c, d)
else:
f = Fraction(a, b) / Fraction(c, d)
return f.numerator*1000 + f.denominator
2014-09-16 23:11:30 +08:00
def frac_arith_encode_int(op, a, b, x):
if op == 0:
2014-09-16 16:43:19 +08:00
f = Fraction(a, b) - x
elif op == 1:
2014-09-16 16:43:19 +08:00
f = Fraction(a, b) + x
elif op == 2:
2014-09-16 16:43:19 +08:00
f = Fraction(a, b) * x
else:
f = Fraction(a, b) / x
return f.numerator*1000 + f.denominator
2014-09-16 23:11:30 +08:00
def frac_arith_encode_int_rev(op, a, b, x):
if op == 0:
2014-09-16 16:43:19 +08:00
f = x - Fraction(a, b)
elif op == 1:
2014-09-16 16:43:19 +08:00
f = x + Fraction(a, b)
elif op == 2:
2014-09-16 16:43:19 +08:00
f = x * Fraction(a, b)
else:
f = x / Fraction(a, b)
return f.numerator*1000 + f.denominator
def frac_arith_float(op, a, b, x):
if op == 0:
return Fraction(a, b) - x
elif op == 1:
return Fraction(a, b) + x
elif op == 2:
return Fraction(a, b) * x
else:
return Fraction(a, b) / x
def frac_arith_float_rev(op, a, b, x):
if op == 0:
return x - Fraction(a, b)
elif op == 1:
return x + Fraction(a, b)
elif op == 2:
return x * Fraction(a, b)
else:
return x / Fraction(a, b)
2014-09-09 17:13:48 +08:00
def array_test():
a = array(array(2, 5), 5)
a[3][2] = 11
a[4][1] = 42
a[0][0] += 6
acc = 0
2014-09-10 16:06:27 +08:00
for i in range(5):
for j in range(5):
2014-10-13 23:54:56 +08:00
if i + j == 2 or i + j == 1:
2014-09-27 00:27:30 +08:00
continue
2014-10-13 23:54:56 +08:00
if i and j and a[i][j]:
acc += 1
2014-09-09 17:13:48 +08:00
acc += a[i][j]
return acc
2014-09-17 16:25:14 +08:00
def corner_cases():
2014-10-13 19:49:29 +08:00
two = True + True - (not True)
2014-09-17 16:25:14 +08:00
three = two + True//True - False*True
two_float = three - True/True
one_float = two_float - (1.0 == bool(0.1))
zero = int(one_float) + round(-0.6)
eleven_float = zero + 5.5//0.5
ten_float = eleven_float + round(Fraction(2, -3))
return ten_float
2014-09-16 23:11:30 +08:00
def _test_range():
2014-09-16 16:43:19 +08:00
for i in range(5, 10):
yield i
yield -i
class CodeGenCase(unittest.TestCase):
2014-09-16 23:11:30 +08:00
def _test_float_arith(self, op):
arith_c = CompiledFunction(arith, {
2014-09-26 17:21:51 +08:00
"op": base_types.VInt(),
"a": base_types.VFloat(), "b": base_types.VFloat()})
2014-09-16 23:11:30 +08:00
for a in _test_range():
for b in _test_range():
self.assertEqual(arith_c(op, a/2, b/2), arith(op, a/2, b/2))
def test_float_add(self):
self._test_float_arith(0)
def test_float_sub(self):
self._test_float_arith(1)
def test_float_mul(self):
self._test_float_arith(2)
def test_float_div(self):
self._test_float_arith(3)
def test_is_prime(self):
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
for i in range(200):
self.assertEqual(is_prime_c(i), is_prime(i))
def test_frac_simplify(self):
simplify_encode_c = CompiledFunction(
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
2014-09-16 23:11:30 +08:00
for a in _test_range():
for b in _test_range():
self.assertEqual(
simplify_encode_c(a, b), simplify_encode(a, b))
def _test_frac_arith(self, op):
2014-09-16 23:11:30 +08:00
frac_arith_encode_c = CompiledFunction(
frac_arith_encode, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"c": base_types.VInt(), "d": base_types.VInt()})
2014-09-16 23:11:30 +08:00
for a in _test_range():
for b in _test_range():
for c in _test_range():
for d in _test_range():
self.assertEqual(
2014-09-16 23:11:30 +08:00
frac_arith_encode_c(op, a, b, c, d),
frac_arith_encode(op, a, b, c, d))
def test_frac_add(self):
self._test_frac_arith(0)
def test_frac_sub(self):
self._test_frac_arith(1)
def test_frac_mul(self):
self._test_frac_arith(2)
def test_frac_div(self):
self._test_frac_arith(3)
2014-09-09 17:13:48 +08:00
def _test_frac_arith_int(self, op, rev):
2014-09-16 23:11:30 +08:00
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
2014-09-16 16:43:19 +08:00
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VInt()})
2014-09-16 23:11:30 +08:00
for a in _test_range():
for b in _test_range():
for x in _test_range():
2014-09-16 16:43:19 +08:00
self.assertEqual(
f_c(op, a, b, x),
f(op, a, b, x))
def test_frac_add_int(self):
self._test_frac_arith_int(0, False)
self._test_frac_arith_int(0, True)
2014-09-16 16:43:19 +08:00
def test_frac_sub_int(self):
self._test_frac_arith_int(1, False)
self._test_frac_arith_int(1, True)
2014-09-16 16:43:19 +08:00
def test_frac_mul_int(self):
self._test_frac_arith_int(2, False)
self._test_frac_arith_int(2, True)
2014-09-16 16:43:19 +08:00
def test_frac_div_int(self):
self._test_frac_arith_int(3, False)
self._test_frac_arith_int(3, True)
def _test_frac_arith_float(self, op, rev):
f = frac_arith_float_rev if rev else frac_arith_float
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VFloat()})
for a in _test_range():
for b in _test_range():
for x in _test_range():
self.assertAlmostEqual(
f_c(op, a, b, x/2),
f(op, a, b, x/2))
def test_frac_add_float(self):
self._test_frac_arith_float(0, False)
self._test_frac_arith_float(0, True)
def test_frac_sub_float(self):
self._test_frac_arith_float(1, False)
self._test_frac_arith_float(1, True)
def test_frac_mul_float(self):
self._test_frac_arith_float(2, False)
self._test_frac_arith_float(2, True)
def test_frac_div_float(self):
self._test_frac_arith_float(3, False)
self._test_frac_arith_float(3, True)
2014-09-16 16:43:19 +08:00
2014-09-09 17:13:48 +08:00
def test_array(self):
array_test_c = CompiledFunction(array_test, dict())
self.assertEqual(array_test_c(), array_test())
2014-09-17 16:25:14 +08:00
def test_corner_cases(self):
corner_cases_c = CompiledFunction(corner_cases, dict())
self.assertEqual(corner_cases_c(), corner_cases())