From f061b159940bfe719de80a6388fe834e322defb6 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Tue, 16 Sep 2014 23:11:30 +0800 Subject: [PATCH] py2llvm: add floating point support --- artiq/py2llvm/ast_body.py | 4 +- artiq/py2llvm/base_types.py | 111 +++++++++++++++++++++++++++++++----- test/py2llvm.py | 107 ++++++++++++++++++++++++---------- 3 files changed, 177 insertions(+), 45 deletions(-) diff --git a/artiq/py2llvm/ast_body.py b/artiq/py2llvm/ast_body.py index 60981098a..575508b5a 100644 --- a/artiq/py2llvm/ast_body.py +++ b/artiq/py2llvm/ast_body.py @@ -79,6 +79,8 @@ class Visitor: r = base_types.VInt() else: r = base_types.VInt(64) + elif isinstance(n, float): + r = base_types.VFloat() else: raise NotImplementedError if self.builder is not None: @@ -110,7 +112,7 @@ class Visitor: def _visit_expr_Call(self, node): fn = node.func.id - if fn in {"bool", "int", "int64", "round", "round64"}: + if fn in {"bool", "int", "int64", "round", "round64", "float"}: value = self.visit_expression(node.args[0]) return getattr(value, "o_"+fn)(self.builder) elif fn == "Fraction": diff --git a/artiq/py2llvm/base_types.py b/artiq/py2llvm/base_types.py index 9e9a0aec5..dee76fa15 100644 --- a/artiq/py2llvm/base_types.py +++ b/artiq/py2llvm/base_types.py @@ -62,6 +62,13 @@ class VInt(VGeneric): lc.Constant.int(self.get_llvm_type(), 0))) return r + def o_float(self, builder): + r = VFloat() + if builder is not None: + r.auto_store(builder, builder.sitofp(self.auto_load(builder), + r.get_llvm_type())) + return r + def o_not(self, builder): return self.o_bool(builder, True) @@ -91,22 +98,30 @@ class VInt(VGeneric): return r o_roundx = o_intx + def o_truediv(self, other, builder): + if isinstance(other, VInt): + left = self.o_float(builder) + right = other.o_float(builder) + return left.o_truediv(right, builder) + else: + return NotImplemented + def _make_vint_binop_method(builder_name): def binop_method(self, other, builder): - if isinstance(other, VInt): - target_bits = max(self.nbits, other.nbits) - r = VInt(target_bits) - if builder is not None: - left = self.o_intx(target_bits, builder) - right = other.o_intx(target_bits, builder) - bf = getattr(builder, builder_name) - r.auto_store( - builder, bf(left.auto_load(builder), - right.auto_load(builder))) - return r - else: - return NotImplemented + if isinstance(other, VInt): + target_bits = max(self.nbits, other.nbits) + r = VInt(target_bits) + if builder is not None: + left = self.o_intx(target_bits, builder) + right = other.o_intx(target_bits, builder) + bf = getattr(builder, builder_name) + r.auto_store( + builder, bf(left.auto_load(builder), + right.auto_load(builder))) + return r + else: + return NotImplemented return binop_method for _method_name, _builder_name in (("o_add", "add"), @@ -163,3 +178,73 @@ class VBool(VInt): if builder is not None: r.auto_store(builder, self.auto_load(builder)) return r + + +class VFloat(VGeneric): + def get_llvm_type(self): + return lc.Type.double() + + def set_value(self, builder, v): + if not isinstance(v, VFloat): + raise TypeError + self.auto_store(builder, v.auto_load(builder)) + + def set_const_value(self, builder, n): + self.auto_store(builder, lc.Constant.real(self.get_llvm_type(), n)) + + def o_float(self, builder): + r = VFloat() + if builder is not None: + r.auto_store(builder, self.auto_load(builder)) + return r + +def _make_vfloat_binop_method(builder_name, reverse): + def binop_method(self, other, builder): + if not hasattr(other, "o_float"): + return NotImplemented + r = VFloat() + if builder is not None: + left = self.o_float(builder) + right = other.o_float(builder) + if reverse: + left, right = right, left + bf = getattr(builder, builder_name) + r.auto_store( + builder, bf(left.auto_load(builder), + right.auto_load(builder))) + return r + return binop_method + +for _method_name, _builder_name in (("add", "fadd"), + ("sub", "fsub"), + ("mul", "fmul"), + ("truediv", "fdiv")): + setattr(VFloat, "o_" + _method_name, + _make_vfloat_binop_method(_builder_name, False)) + setattr(VFloat, "or_" + _method_name, + _make_vfloat_binop_method(_builder_name, True)) + + +def _make_vfloat_cmp_method(fcmp_val): + def cmp_method(self, other, builder): + if not hasattr(other, "o_float"): + return NotImplemented + r = VBool() + if builder is not None: + left = self.o_float(builder) + right = other.o_float(builder) + r.auto_store( + builder, + builder.fcmp( + fcmp_val, left.auto_load(builder), + right.auto_load(builder))) + return r + return cmp_method + +for _method_name, _fcmp_val in (("o_eq", lc.FCMP_OEQ), + ("o_ne", lc.FCMP_ONE), + ("o_lt", lc.FCMP_OLT), + ("o_le", lc.FCMP_OLE), + ("o_gt", lc.FCMP_OGT), + ("o_ge", lc.FCMP_OGE)): + setattr(VFloat, _method_name, _make_vfloat_cmp_method(_fcmp_val)) diff --git a/test/py2llvm.py b/test/py2llvm.py index 13703281b..6523f47dd 100644 --- a/test/py2llvm.py +++ b/test/py2llvm.py @@ -20,11 +20,15 @@ def test_base_types(choice): a += x # promotes a to int64 foo = True bar = None + myf = 4.5 + myf2 = myf + x if choice and foo and not bar: return d - else: + elif myf: return x + c + else: + return int64(8) def _build_function_types(f): @@ -44,6 +48,8 @@ class FunctionBaseTypesCase(unittest.TestCase): self.assertEqual(self.ns["d"].nbits, 32) self.assertIsInstance(self.ns["x"], base_types.VInt) self.assertEqual(self.ns["x"].nbits, 64) + self.assertIsInstance(self.ns["myf"], base_types.VFloat) + self.assertIsInstance(self.ns["myf2"], base_types.VFloat) def test_promotion(self): for v in "abc": @@ -85,18 +91,37 @@ class CompiledFunction: self.ee = module.get_ee() def __call__(self, *args): - args_llvm = [ - le.GenericValue.int(av.get_llvm_type(), a) - for av, a in zip(self.argval, args)] + args_llvm = [] + for av, a in zip(self.argval, args): + if isinstance(av, base_types.VInt): + al = le.GenericValue.int(av.get_llvm_type(), a) + elif isinstance(av, base_types.VFloat): + al = le.GenericValue.real(av.get_llvm_type(), a) + else: + raise NotImplementedError + args_llvm.append(al) result = self.ee.run_function(self.function, args_llvm) if isinstance(self.retval, base_types.VBool): return bool(result.as_int()) elif isinstance(self.retval, base_types.VInt): return result.as_int_signed() + elif isinstance(self.retval, base_types.VFloat): + return result.as_real(self.retval.get_llvm_type()) else: raise NotImplementedError +def arith(op, a, b): + if op == 1: + return a + b + elif op == 2: + return a - b + elif op == 3: + return a * b + else: + return a / b + + def is_prime(x): d = 2 while d*d <= x: @@ -111,7 +136,7 @@ def simplify_encode(a, b): return f.numerator*1000 + f.denominator -def arith_encode(op, a, b, c, d): +def frac_arith_encode(op, a, b, c, d): if op == 1: f = Fraction(a, b) - Fraction(c, d) elif op == 2: @@ -123,7 +148,7 @@ def arith_encode(op, a, b, c, d): return f.numerator*1000 + f.denominator -def arith_encode_int(op, a, b, x): +def frac_arith_encode_int(op, a, b, x): if op == 1: f = Fraction(a, b) - x elif op == 2: @@ -135,7 +160,7 @@ def arith_encode_int(op, a, b, x): return f.numerator*1000 + f.denominator -def arith_encode_int_rev(op, a, b, x): +def frac_arith_encode_int_rev(op, a, b, x): if op == 1: f = x - Fraction(a, b) elif op == 2: @@ -160,13 +185,33 @@ def array_test(): return acc -def _frac_test_range(): +def _test_range(): for i in range(5, 10): yield i yield -i class CodeGenCase(unittest.TestCase): + def _test_float_arith(self, op): + arith_c = CompiledFunction(arith, { + "op": base_types.VInt(), + "a": base_types.VFloat(), "b": base_types.VFloat()}) + for a in _test_range(): + for b in _test_range(): + self.assertEqual(arith_c(op, a/2, b/2), arith(op, a/2, b/2)) + + def test_float_add(self): + self._test_float_arith(0) + + def test_float_sub(self): + self._test_float_arith(1) + + def test_float_mul(self): + self._test_float_arith(2) + + def test_float_div(self): + self._test_float_arith(3) + def test_is_prime(self): is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()}) for i in range(200): @@ -175,24 +220,24 @@ class CodeGenCase(unittest.TestCase): def test_frac_simplify(self): simplify_encode_c = CompiledFunction( simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()}) - for a in _frac_test_range(): - for b in _frac_test_range(): + for a in _test_range(): + for b in _test_range(): self.assertEqual( simplify_encode_c(a, b), simplify_encode(a, b)) def _test_frac_arith(self, op): - arith_encode_c = CompiledFunction( - arith_encode, { + frac_arith_encode_c = CompiledFunction( + frac_arith_encode, { "op": base_types.VInt(), "a": base_types.VInt(), "b": base_types.VInt(), "c": base_types.VInt(), "d": base_types.VInt()}) - for a in _frac_test_range(): - for b in _frac_test_range(): - for c in _frac_test_range(): - for d in _frac_test_range(): + for a in _test_range(): + for b in _test_range(): + for c in _test_range(): + for d in _test_range(): self.assertEqual( - arith_encode_c(op, a, b, c, d), - arith_encode(op, a, b, c, d)) + frac_arith_encode_c(op, a, b, c, d), + frac_arith_encode(op, a, b, c, d)) def test_frac_add(self): self._test_frac_arith(0) @@ -206,34 +251,34 @@ class CodeGenCase(unittest.TestCase): def test_frac_div(self): self._test_frac_arith(3) - def _test_frac_arith_int(self, op, rev): - f = arith_encode_int_rev if rev else arith_encode_int + def _test_frac_frac_arith_int(self, op, rev): + f = frac_arith_encode_int_rev if rev else frac_arith_encode_int f_c = CompiledFunction(f, { "op": base_types.VInt(), "a": base_types.VInt(), "b": base_types.VInt(), "x": base_types.VInt()}) - for a in _frac_test_range(): - for b in _frac_test_range(): - for x in _frac_test_range(): + for a in _test_range(): + for b in _test_range(): + for x in _test_range(): self.assertEqual( f_c(op, a, b, x), f(op, a, b, x)) def test_frac_add_int(self): - self._test_frac_arith_int(0, False) - self._test_frac_arith_int(0, True) + self._test_frac_frac_arith_int(0, False) + self._test_frac_frac_arith_int(0, True) def test_frac_sub_int(self): - self._test_frac_arith_int(1, False) - self._test_frac_arith_int(1, True) + self._test_frac_frac_arith_int(1, False) + self._test_frac_frac_arith_int(1, True) def test_frac_mul_int(self): - self._test_frac_arith_int(2, False) - self._test_frac_arith_int(2, True) + self._test_frac_frac_arith_int(2, False) + self._test_frac_frac_arith_int(2, True) def test_frac_div_int(self): - self._test_frac_arith_int(3, False) - self._test_frac_arith_int(3, True) + self._test_frac_frac_arith_int(3, False) + self._test_frac_frac_arith_int(3, True) def test_array(self): array_test_c = CompiledFunction(array_test, dict())