forked from M-Labs/artiq
1
0
Fork 0

py2llvm: add floating point support

This commit is contained in:
Sebastien Bourdeauducq 2014-09-16 23:11:30 +08:00
parent b923690d6f
commit f061b15994
3 changed files with 177 additions and 45 deletions

View File

@ -79,6 +79,8 @@ class Visitor:
r = base_types.VInt() r = base_types.VInt()
else: else:
r = base_types.VInt(64) r = base_types.VInt(64)
elif isinstance(n, float):
r = base_types.VFloat()
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None: if self.builder is not None:
@ -110,7 +112,7 @@ class Visitor:
def _visit_expr_Call(self, node): def _visit_expr_Call(self, node):
fn = node.func.id fn = node.func.id
if fn in {"bool", "int", "int64", "round", "round64"}: if fn in {"bool", "int", "int64", "round", "round64", "float"}:
value = self.visit_expression(node.args[0]) value = self.visit_expression(node.args[0])
return getattr(value, "o_"+fn)(self.builder) return getattr(value, "o_"+fn)(self.builder)
elif fn == "Fraction": elif fn == "Fraction":

View File

@ -62,6 +62,13 @@ class VInt(VGeneric):
lc.Constant.int(self.get_llvm_type(), 0))) lc.Constant.int(self.get_llvm_type(), 0)))
return r return r
def o_float(self, builder):
r = VFloat()
if builder is not None:
r.auto_store(builder, builder.sitofp(self.auto_load(builder),
r.get_llvm_type()))
return r
def o_not(self, builder): def o_not(self, builder):
return self.o_bool(builder, True) return self.o_bool(builder, True)
@ -91,22 +98,30 @@ class VInt(VGeneric):
return r return r
o_roundx = o_intx o_roundx = o_intx
def o_truediv(self, other, builder):
if isinstance(other, VInt):
left = self.o_float(builder)
right = other.o_float(builder)
return left.o_truediv(right, builder)
else:
return NotImplemented
def _make_vint_binop_method(builder_name): def _make_vint_binop_method(builder_name):
def binop_method(self, other, builder): def binop_method(self, other, builder):
if isinstance(other, VInt): if isinstance(other, VInt):
target_bits = max(self.nbits, other.nbits) target_bits = max(self.nbits, other.nbits)
r = VInt(target_bits) r = VInt(target_bits)
if builder is not None: if builder is not None:
left = self.o_intx(target_bits, builder) left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder) right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name) bf = getattr(builder, builder_name)
r.auto_store( r.auto_store(
builder, bf(left.auto_load(builder), builder, bf(left.auto_load(builder),
right.auto_load(builder))) right.auto_load(builder)))
return r return r
else: else:
return NotImplemented return NotImplemented
return binop_method return binop_method
for _method_name, _builder_name in (("o_add", "add"), for _method_name, _builder_name in (("o_add", "add"),
@ -163,3 +178,73 @@ class VBool(VInt):
if builder is not None: if builder is not None:
r.auto_store(builder, self.auto_load(builder)) r.auto_store(builder, self.auto_load(builder))
return r return r
class VFloat(VGeneric):
def get_llvm_type(self):
return lc.Type.double()
def set_value(self, builder, v):
if not isinstance(v, VFloat):
raise TypeError
self.auto_store(builder, v.auto_load(builder))
def set_const_value(self, builder, n):
self.auto_store(builder, lc.Constant.real(self.get_llvm_type(), n))
def o_float(self, builder):
r = VFloat()
if builder is not None:
r.auto_store(builder, self.auto_load(builder))
return r
def _make_vfloat_binop_method(builder_name, reverse):
def binop_method(self, other, builder):
if not hasattr(other, "o_float"):
return NotImplemented
r = VFloat()
if builder is not None:
left = self.o_float(builder)
right = other.o_float(builder)
if reverse:
left, right = right, left
bf = getattr(builder, builder_name)
r.auto_store(
builder, bf(left.auto_load(builder),
right.auto_load(builder)))
return r
return binop_method
for _method_name, _builder_name in (("add", "fadd"),
("sub", "fsub"),
("mul", "fmul"),
("truediv", "fdiv")):
setattr(VFloat, "o_" + _method_name,
_make_vfloat_binop_method(_builder_name, False))
setattr(VFloat, "or_" + _method_name,
_make_vfloat_binop_method(_builder_name, True))
def _make_vfloat_cmp_method(fcmp_val):
def cmp_method(self, other, builder):
if not hasattr(other, "o_float"):
return NotImplemented
r = VBool()
if builder is not None:
left = self.o_float(builder)
right = other.o_float(builder)
r.auto_store(
builder,
builder.fcmp(
fcmp_val, left.auto_load(builder),
right.auto_load(builder)))
return r
return cmp_method
for _method_name, _fcmp_val in (("o_eq", lc.FCMP_OEQ),
("o_ne", lc.FCMP_ONE),
("o_lt", lc.FCMP_OLT),
("o_le", lc.FCMP_OLE),
("o_gt", lc.FCMP_OGT),
("o_ge", lc.FCMP_OGE)):
setattr(VFloat, _method_name, _make_vfloat_cmp_method(_fcmp_val))

View File

@ -20,11 +20,15 @@ def test_base_types(choice):
a += x # promotes a to int64 a += x # promotes a to int64
foo = True foo = True
bar = None bar = None
myf = 4.5
myf2 = myf + x
if choice and foo and not bar: if choice and foo and not bar:
return d return d
else: elif myf:
return x + c return x + c
else:
return int64(8)
def _build_function_types(f): def _build_function_types(f):
@ -44,6 +48,8 @@ class FunctionBaseTypesCase(unittest.TestCase):
self.assertEqual(self.ns["d"].nbits, 32) self.assertEqual(self.ns["d"].nbits, 32)
self.assertIsInstance(self.ns["x"], base_types.VInt) self.assertIsInstance(self.ns["x"], base_types.VInt)
self.assertEqual(self.ns["x"].nbits, 64) self.assertEqual(self.ns["x"].nbits, 64)
self.assertIsInstance(self.ns["myf"], base_types.VFloat)
self.assertIsInstance(self.ns["myf2"], base_types.VFloat)
def test_promotion(self): def test_promotion(self):
for v in "abc": for v in "abc":
@ -85,18 +91,37 @@ class CompiledFunction:
self.ee = module.get_ee() self.ee = module.get_ee()
def __call__(self, *args): def __call__(self, *args):
args_llvm = [ args_llvm = []
le.GenericValue.int(av.get_llvm_type(), a) for av, a in zip(self.argval, args):
for av, a in zip(self.argval, args)] if isinstance(av, base_types.VInt):
al = le.GenericValue.int(av.get_llvm_type(), a)
elif isinstance(av, base_types.VFloat):
al = le.GenericValue.real(av.get_llvm_type(), a)
else:
raise NotImplementedError
args_llvm.append(al)
result = self.ee.run_function(self.function, args_llvm) result = self.ee.run_function(self.function, args_llvm)
if isinstance(self.retval, base_types.VBool): if isinstance(self.retval, base_types.VBool):
return bool(result.as_int()) return bool(result.as_int())
elif isinstance(self.retval, base_types.VInt): elif isinstance(self.retval, base_types.VInt):
return result.as_int_signed() return result.as_int_signed()
elif isinstance(self.retval, base_types.VFloat):
return result.as_real(self.retval.get_llvm_type())
else: else:
raise NotImplementedError raise NotImplementedError
def arith(op, a, b):
if op == 1:
return a + b
elif op == 2:
return a - b
elif op == 3:
return a * b
else:
return a / b
def is_prime(x): def is_prime(x):
d = 2 d = 2
while d*d <= x: while d*d <= x:
@ -111,7 +136,7 @@ def simplify_encode(a, b):
return f.numerator*1000 + f.denominator return f.numerator*1000 + f.denominator
def arith_encode(op, a, b, c, d): def frac_arith_encode(op, a, b, c, d):
if op == 1: if op == 1:
f = Fraction(a, b) - Fraction(c, d) f = Fraction(a, b) - Fraction(c, d)
elif op == 2: elif op == 2:
@ -123,7 +148,7 @@ def arith_encode(op, a, b, c, d):
return f.numerator*1000 + f.denominator return f.numerator*1000 + f.denominator
def arith_encode_int(op, a, b, x): def frac_arith_encode_int(op, a, b, x):
if op == 1: if op == 1:
f = Fraction(a, b) - x f = Fraction(a, b) - x
elif op == 2: elif op == 2:
@ -135,7 +160,7 @@ def arith_encode_int(op, a, b, x):
return f.numerator*1000 + f.denominator return f.numerator*1000 + f.denominator
def arith_encode_int_rev(op, a, b, x): def frac_arith_encode_int_rev(op, a, b, x):
if op == 1: if op == 1:
f = x - Fraction(a, b) f = x - Fraction(a, b)
elif op == 2: elif op == 2:
@ -160,13 +185,33 @@ def array_test():
return acc return acc
def _frac_test_range(): def _test_range():
for i in range(5, 10): for i in range(5, 10):
yield i yield i
yield -i yield -i
class CodeGenCase(unittest.TestCase): class CodeGenCase(unittest.TestCase):
def _test_float_arith(self, op):
arith_c = CompiledFunction(arith, {
"op": base_types.VInt(),
"a": base_types.VFloat(), "b": base_types.VFloat()})
for a in _test_range():
for b in _test_range():
self.assertEqual(arith_c(op, a/2, b/2), arith(op, a/2, b/2))
def test_float_add(self):
self._test_float_arith(0)
def test_float_sub(self):
self._test_float_arith(1)
def test_float_mul(self):
self._test_float_arith(2)
def test_float_div(self):
self._test_float_arith(3)
def test_is_prime(self): def test_is_prime(self):
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()}) is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
for i in range(200): for i in range(200):
@ -175,24 +220,24 @@ class CodeGenCase(unittest.TestCase):
def test_frac_simplify(self): def test_frac_simplify(self):
simplify_encode_c = CompiledFunction( simplify_encode_c = CompiledFunction(
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()}) simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
for a in _frac_test_range(): for a in _test_range():
for b in _frac_test_range(): for b in _test_range():
self.assertEqual( self.assertEqual(
simplify_encode_c(a, b), simplify_encode(a, b)) simplify_encode_c(a, b), simplify_encode(a, b))
def _test_frac_arith(self, op): def _test_frac_arith(self, op):
arith_encode_c = CompiledFunction( frac_arith_encode_c = CompiledFunction(
arith_encode, { frac_arith_encode, {
"op": base_types.VInt(), "op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(), "a": base_types.VInt(), "b": base_types.VInt(),
"c": base_types.VInt(), "d": base_types.VInt()}) "c": base_types.VInt(), "d": base_types.VInt()})
for a in _frac_test_range(): for a in _test_range():
for b in _frac_test_range(): for b in _test_range():
for c in _frac_test_range(): for c in _test_range():
for d in _frac_test_range(): for d in _test_range():
self.assertEqual( self.assertEqual(
arith_encode_c(op, a, b, c, d), frac_arith_encode_c(op, a, b, c, d),
arith_encode(op, a, b, c, d)) frac_arith_encode(op, a, b, c, d))
def test_frac_add(self): def test_frac_add(self):
self._test_frac_arith(0) self._test_frac_arith(0)
@ -206,34 +251,34 @@ class CodeGenCase(unittest.TestCase):
def test_frac_div(self): def test_frac_div(self):
self._test_frac_arith(3) self._test_frac_arith(3)
def _test_frac_arith_int(self, op, rev): def _test_frac_frac_arith_int(self, op, rev):
f = arith_encode_int_rev if rev else arith_encode_int f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
f_c = CompiledFunction(f, { f_c = CompiledFunction(f, {
"op": base_types.VInt(), "op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(), "a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VInt()}) "x": base_types.VInt()})
for a in _frac_test_range(): for a in _test_range():
for b in _frac_test_range(): for b in _test_range():
for x in _frac_test_range(): for x in _test_range():
self.assertEqual( self.assertEqual(
f_c(op, a, b, x), f_c(op, a, b, x),
f(op, a, b, x)) f(op, a, b, x))
def test_frac_add_int(self): def test_frac_add_int(self):
self._test_frac_arith_int(0, False) self._test_frac_frac_arith_int(0, False)
self._test_frac_arith_int(0, True) self._test_frac_frac_arith_int(0, True)
def test_frac_sub_int(self): def test_frac_sub_int(self):
self._test_frac_arith_int(1, False) self._test_frac_frac_arith_int(1, False)
self._test_frac_arith_int(1, True) self._test_frac_frac_arith_int(1, True)
def test_frac_mul_int(self): def test_frac_mul_int(self):
self._test_frac_arith_int(2, False) self._test_frac_frac_arith_int(2, False)
self._test_frac_arith_int(2, True) self._test_frac_frac_arith_int(2, True)
def test_frac_div_int(self): def test_frac_div_int(self):
self._test_frac_arith_int(3, False) self._test_frac_frac_arith_int(3, False)
self._test_frac_arith_int(3, True) self._test_frac_frac_arith_int(3, True)
def test_array(self): def test_array(self):
array_test_c = CompiledFunction(array_test, dict()) array_test_c = CompiledFunction(array_test, dict())