forked from M-Labs/artiq
py2llvm: add floating point support
This commit is contained in:
parent
b923690d6f
commit
f061b15994
|
@ -79,6 +79,8 @@ class Visitor:
|
||||||
r = base_types.VInt()
|
r = base_types.VInt()
|
||||||
else:
|
else:
|
||||||
r = base_types.VInt(64)
|
r = base_types.VInt(64)
|
||||||
|
elif isinstance(n, float):
|
||||||
|
r = base_types.VFloat()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if self.builder is not None:
|
if self.builder is not None:
|
||||||
|
@ -110,7 +112,7 @@ class Visitor:
|
||||||
|
|
||||||
def _visit_expr_Call(self, node):
|
def _visit_expr_Call(self, node):
|
||||||
fn = node.func.id
|
fn = node.func.id
|
||||||
if fn in {"bool", "int", "int64", "round", "round64"}:
|
if fn in {"bool", "int", "int64", "round", "round64", "float"}:
|
||||||
value = self.visit_expression(node.args[0])
|
value = self.visit_expression(node.args[0])
|
||||||
return getattr(value, "o_"+fn)(self.builder)
|
return getattr(value, "o_"+fn)(self.builder)
|
||||||
elif fn == "Fraction":
|
elif fn == "Fraction":
|
||||||
|
|
|
@ -62,6 +62,13 @@ class VInt(VGeneric):
|
||||||
lc.Constant.int(self.get_llvm_type(), 0)))
|
lc.Constant.int(self.get_llvm_type(), 0)))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
def o_float(self, builder):
|
||||||
|
r = VFloat()
|
||||||
|
if builder is not None:
|
||||||
|
r.auto_store(builder, builder.sitofp(self.auto_load(builder),
|
||||||
|
r.get_llvm_type()))
|
||||||
|
return r
|
||||||
|
|
||||||
def o_not(self, builder):
|
def o_not(self, builder):
|
||||||
return self.o_bool(builder, True)
|
return self.o_bool(builder, True)
|
||||||
|
|
||||||
|
@ -91,6 +98,14 @@ class VInt(VGeneric):
|
||||||
return r
|
return r
|
||||||
o_roundx = o_intx
|
o_roundx = o_intx
|
||||||
|
|
||||||
|
def o_truediv(self, other, builder):
|
||||||
|
if isinstance(other, VInt):
|
||||||
|
left = self.o_float(builder)
|
||||||
|
right = other.o_float(builder)
|
||||||
|
return left.o_truediv(right, builder)
|
||||||
|
else:
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
def _make_vint_binop_method(builder_name):
|
def _make_vint_binop_method(builder_name):
|
||||||
def binop_method(self, other, builder):
|
def binop_method(self, other, builder):
|
||||||
|
@ -163,3 +178,73 @@ class VBool(VInt):
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
r.auto_store(builder, self.auto_load(builder))
|
r.auto_store(builder, self.auto_load(builder))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
class VFloat(VGeneric):
|
||||||
|
def get_llvm_type(self):
|
||||||
|
return lc.Type.double()
|
||||||
|
|
||||||
|
def set_value(self, builder, v):
|
||||||
|
if not isinstance(v, VFloat):
|
||||||
|
raise TypeError
|
||||||
|
self.auto_store(builder, v.auto_load(builder))
|
||||||
|
|
||||||
|
def set_const_value(self, builder, n):
|
||||||
|
self.auto_store(builder, lc.Constant.real(self.get_llvm_type(), n))
|
||||||
|
|
||||||
|
def o_float(self, builder):
|
||||||
|
r = VFloat()
|
||||||
|
if builder is not None:
|
||||||
|
r.auto_store(builder, self.auto_load(builder))
|
||||||
|
return r
|
||||||
|
|
||||||
|
def _make_vfloat_binop_method(builder_name, reverse):
|
||||||
|
def binop_method(self, other, builder):
|
||||||
|
if not hasattr(other, "o_float"):
|
||||||
|
return NotImplemented
|
||||||
|
r = VFloat()
|
||||||
|
if builder is not None:
|
||||||
|
left = self.o_float(builder)
|
||||||
|
right = other.o_float(builder)
|
||||||
|
if reverse:
|
||||||
|
left, right = right, left
|
||||||
|
bf = getattr(builder, builder_name)
|
||||||
|
r.auto_store(
|
||||||
|
builder, bf(left.auto_load(builder),
|
||||||
|
right.auto_load(builder)))
|
||||||
|
return r
|
||||||
|
return binop_method
|
||||||
|
|
||||||
|
for _method_name, _builder_name in (("add", "fadd"),
|
||||||
|
("sub", "fsub"),
|
||||||
|
("mul", "fmul"),
|
||||||
|
("truediv", "fdiv")):
|
||||||
|
setattr(VFloat, "o_" + _method_name,
|
||||||
|
_make_vfloat_binop_method(_builder_name, False))
|
||||||
|
setattr(VFloat, "or_" + _method_name,
|
||||||
|
_make_vfloat_binop_method(_builder_name, True))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vfloat_cmp_method(fcmp_val):
|
||||||
|
def cmp_method(self, other, builder):
|
||||||
|
if not hasattr(other, "o_float"):
|
||||||
|
return NotImplemented
|
||||||
|
r = VBool()
|
||||||
|
if builder is not None:
|
||||||
|
left = self.o_float(builder)
|
||||||
|
right = other.o_float(builder)
|
||||||
|
r.auto_store(
|
||||||
|
builder,
|
||||||
|
builder.fcmp(
|
||||||
|
fcmp_val, left.auto_load(builder),
|
||||||
|
right.auto_load(builder)))
|
||||||
|
return r
|
||||||
|
return cmp_method
|
||||||
|
|
||||||
|
for _method_name, _fcmp_val in (("o_eq", lc.FCMP_OEQ),
|
||||||
|
("o_ne", lc.FCMP_ONE),
|
||||||
|
("o_lt", lc.FCMP_OLT),
|
||||||
|
("o_le", lc.FCMP_OLE),
|
||||||
|
("o_gt", lc.FCMP_OGT),
|
||||||
|
("o_ge", lc.FCMP_OGE)):
|
||||||
|
setattr(VFloat, _method_name, _make_vfloat_cmp_method(_fcmp_val))
|
||||||
|
|
107
test/py2llvm.py
107
test/py2llvm.py
|
@ -20,11 +20,15 @@ def test_base_types(choice):
|
||||||
a += x # promotes a to int64
|
a += x # promotes a to int64
|
||||||
foo = True
|
foo = True
|
||||||
bar = None
|
bar = None
|
||||||
|
myf = 4.5
|
||||||
|
myf2 = myf + x
|
||||||
|
|
||||||
if choice and foo and not bar:
|
if choice and foo and not bar:
|
||||||
return d
|
return d
|
||||||
else:
|
elif myf:
|
||||||
return x + c
|
return x + c
|
||||||
|
else:
|
||||||
|
return int64(8)
|
||||||
|
|
||||||
|
|
||||||
def _build_function_types(f):
|
def _build_function_types(f):
|
||||||
|
@ -44,6 +48,8 @@ class FunctionBaseTypesCase(unittest.TestCase):
|
||||||
self.assertEqual(self.ns["d"].nbits, 32)
|
self.assertEqual(self.ns["d"].nbits, 32)
|
||||||
self.assertIsInstance(self.ns["x"], base_types.VInt)
|
self.assertIsInstance(self.ns["x"], base_types.VInt)
|
||||||
self.assertEqual(self.ns["x"].nbits, 64)
|
self.assertEqual(self.ns["x"].nbits, 64)
|
||||||
|
self.assertIsInstance(self.ns["myf"], base_types.VFloat)
|
||||||
|
self.assertIsInstance(self.ns["myf2"], base_types.VFloat)
|
||||||
|
|
||||||
def test_promotion(self):
|
def test_promotion(self):
|
||||||
for v in "abc":
|
for v in "abc":
|
||||||
|
@ -85,18 +91,37 @@ class CompiledFunction:
|
||||||
self.ee = module.get_ee()
|
self.ee = module.get_ee()
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
args_llvm = [
|
args_llvm = []
|
||||||
le.GenericValue.int(av.get_llvm_type(), a)
|
for av, a in zip(self.argval, args):
|
||||||
for av, a in zip(self.argval, args)]
|
if isinstance(av, base_types.VInt):
|
||||||
|
al = le.GenericValue.int(av.get_llvm_type(), a)
|
||||||
|
elif isinstance(av, base_types.VFloat):
|
||||||
|
al = le.GenericValue.real(av.get_llvm_type(), a)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
args_llvm.append(al)
|
||||||
result = self.ee.run_function(self.function, args_llvm)
|
result = self.ee.run_function(self.function, args_llvm)
|
||||||
if isinstance(self.retval, base_types.VBool):
|
if isinstance(self.retval, base_types.VBool):
|
||||||
return bool(result.as_int())
|
return bool(result.as_int())
|
||||||
elif isinstance(self.retval, base_types.VInt):
|
elif isinstance(self.retval, base_types.VInt):
|
||||||
return result.as_int_signed()
|
return result.as_int_signed()
|
||||||
|
elif isinstance(self.retval, base_types.VFloat):
|
||||||
|
return result.as_real(self.retval.get_llvm_type())
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def arith(op, a, b):
|
||||||
|
if op == 1:
|
||||||
|
return a + b
|
||||||
|
elif op == 2:
|
||||||
|
return a - b
|
||||||
|
elif op == 3:
|
||||||
|
return a * b
|
||||||
|
else:
|
||||||
|
return a / b
|
||||||
|
|
||||||
|
|
||||||
def is_prime(x):
|
def is_prime(x):
|
||||||
d = 2
|
d = 2
|
||||||
while d*d <= x:
|
while d*d <= x:
|
||||||
|
@ -111,7 +136,7 @@ def simplify_encode(a, b):
|
||||||
return f.numerator*1000 + f.denominator
|
return f.numerator*1000 + f.denominator
|
||||||
|
|
||||||
|
|
||||||
def arith_encode(op, a, b, c, d):
|
def frac_arith_encode(op, a, b, c, d):
|
||||||
if op == 1:
|
if op == 1:
|
||||||
f = Fraction(a, b) - Fraction(c, d)
|
f = Fraction(a, b) - Fraction(c, d)
|
||||||
elif op == 2:
|
elif op == 2:
|
||||||
|
@ -123,7 +148,7 @@ def arith_encode(op, a, b, c, d):
|
||||||
return f.numerator*1000 + f.denominator
|
return f.numerator*1000 + f.denominator
|
||||||
|
|
||||||
|
|
||||||
def arith_encode_int(op, a, b, x):
|
def frac_arith_encode_int(op, a, b, x):
|
||||||
if op == 1:
|
if op == 1:
|
||||||
f = Fraction(a, b) - x
|
f = Fraction(a, b) - x
|
||||||
elif op == 2:
|
elif op == 2:
|
||||||
|
@ -135,7 +160,7 @@ def arith_encode_int(op, a, b, x):
|
||||||
return f.numerator*1000 + f.denominator
|
return f.numerator*1000 + f.denominator
|
||||||
|
|
||||||
|
|
||||||
def arith_encode_int_rev(op, a, b, x):
|
def frac_arith_encode_int_rev(op, a, b, x):
|
||||||
if op == 1:
|
if op == 1:
|
||||||
f = x - Fraction(a, b)
|
f = x - Fraction(a, b)
|
||||||
elif op == 2:
|
elif op == 2:
|
||||||
|
@ -160,13 +185,33 @@ def array_test():
|
||||||
return acc
|
return acc
|
||||||
|
|
||||||
|
|
||||||
def _frac_test_range():
|
def _test_range():
|
||||||
for i in range(5, 10):
|
for i in range(5, 10):
|
||||||
yield i
|
yield i
|
||||||
yield -i
|
yield -i
|
||||||
|
|
||||||
|
|
||||||
class CodeGenCase(unittest.TestCase):
|
class CodeGenCase(unittest.TestCase):
|
||||||
|
def _test_float_arith(self, op):
|
||||||
|
arith_c = CompiledFunction(arith, {
|
||||||
|
"op": base_types.VInt(),
|
||||||
|
"a": base_types.VFloat(), "b": base_types.VFloat()})
|
||||||
|
for a in _test_range():
|
||||||
|
for b in _test_range():
|
||||||
|
self.assertEqual(arith_c(op, a/2, b/2), arith(op, a/2, b/2))
|
||||||
|
|
||||||
|
def test_float_add(self):
|
||||||
|
self._test_float_arith(0)
|
||||||
|
|
||||||
|
def test_float_sub(self):
|
||||||
|
self._test_float_arith(1)
|
||||||
|
|
||||||
|
def test_float_mul(self):
|
||||||
|
self._test_float_arith(2)
|
||||||
|
|
||||||
|
def test_float_div(self):
|
||||||
|
self._test_float_arith(3)
|
||||||
|
|
||||||
def test_is_prime(self):
|
def test_is_prime(self):
|
||||||
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
|
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
|
||||||
for i in range(200):
|
for i in range(200):
|
||||||
|
@ -175,24 +220,24 @@ class CodeGenCase(unittest.TestCase):
|
||||||
def test_frac_simplify(self):
|
def test_frac_simplify(self):
|
||||||
simplify_encode_c = CompiledFunction(
|
simplify_encode_c = CompiledFunction(
|
||||||
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
|
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
|
||||||
for a in _frac_test_range():
|
for a in _test_range():
|
||||||
for b in _frac_test_range():
|
for b in _test_range():
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
simplify_encode_c(a, b), simplify_encode(a, b))
|
simplify_encode_c(a, b), simplify_encode(a, b))
|
||||||
|
|
||||||
def _test_frac_arith(self, op):
|
def _test_frac_arith(self, op):
|
||||||
arith_encode_c = CompiledFunction(
|
frac_arith_encode_c = CompiledFunction(
|
||||||
arith_encode, {
|
frac_arith_encode, {
|
||||||
"op": base_types.VInt(),
|
"op": base_types.VInt(),
|
||||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||||
"c": base_types.VInt(), "d": base_types.VInt()})
|
"c": base_types.VInt(), "d": base_types.VInt()})
|
||||||
for a in _frac_test_range():
|
for a in _test_range():
|
||||||
for b in _frac_test_range():
|
for b in _test_range():
|
||||||
for c in _frac_test_range():
|
for c in _test_range():
|
||||||
for d in _frac_test_range():
|
for d in _test_range():
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
arith_encode_c(op, a, b, c, d),
|
frac_arith_encode_c(op, a, b, c, d),
|
||||||
arith_encode(op, a, b, c, d))
|
frac_arith_encode(op, a, b, c, d))
|
||||||
|
|
||||||
def test_frac_add(self):
|
def test_frac_add(self):
|
||||||
self._test_frac_arith(0)
|
self._test_frac_arith(0)
|
||||||
|
@ -206,34 +251,34 @@ class CodeGenCase(unittest.TestCase):
|
||||||
def test_frac_div(self):
|
def test_frac_div(self):
|
||||||
self._test_frac_arith(3)
|
self._test_frac_arith(3)
|
||||||
|
|
||||||
def _test_frac_arith_int(self, op, rev):
|
def _test_frac_frac_arith_int(self, op, rev):
|
||||||
f = arith_encode_int_rev if rev else arith_encode_int
|
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
|
||||||
f_c = CompiledFunction(f, {
|
f_c = CompiledFunction(f, {
|
||||||
"op": base_types.VInt(),
|
"op": base_types.VInt(),
|
||||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||||
"x": base_types.VInt()})
|
"x": base_types.VInt()})
|
||||||
for a in _frac_test_range():
|
for a in _test_range():
|
||||||
for b in _frac_test_range():
|
for b in _test_range():
|
||||||
for x in _frac_test_range():
|
for x in _test_range():
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
f_c(op, a, b, x),
|
f_c(op, a, b, x),
|
||||||
f(op, a, b, x))
|
f(op, a, b, x))
|
||||||
|
|
||||||
def test_frac_add_int(self):
|
def test_frac_add_int(self):
|
||||||
self._test_frac_arith_int(0, False)
|
self._test_frac_frac_arith_int(0, False)
|
||||||
self._test_frac_arith_int(0, True)
|
self._test_frac_frac_arith_int(0, True)
|
||||||
|
|
||||||
def test_frac_sub_int(self):
|
def test_frac_sub_int(self):
|
||||||
self._test_frac_arith_int(1, False)
|
self._test_frac_frac_arith_int(1, False)
|
||||||
self._test_frac_arith_int(1, True)
|
self._test_frac_frac_arith_int(1, True)
|
||||||
|
|
||||||
def test_frac_mul_int(self):
|
def test_frac_mul_int(self):
|
||||||
self._test_frac_arith_int(2, False)
|
self._test_frac_frac_arith_int(2, False)
|
||||||
self._test_frac_arith_int(2, True)
|
self._test_frac_frac_arith_int(2, True)
|
||||||
|
|
||||||
def test_frac_div_int(self):
|
def test_frac_div_int(self):
|
||||||
self._test_frac_arith_int(3, False)
|
self._test_frac_frac_arith_int(3, False)
|
||||||
self._test_frac_arith_int(3, True)
|
self._test_frac_frac_arith_int(3, True)
|
||||||
|
|
||||||
def test_array(self):
|
def test_array(self):
|
||||||
array_test_c = CompiledFunction(array_test, dict())
|
array_test_c = CompiledFunction(array_test, dict())
|
||||||
|
|
Loading…
Reference in New Issue