forked from M-Labs/artiq
py2llvm: reorganize, split 'values' module, factor LLVM module/pass management
This commit is contained in:
parent
58465e49fa
commit
3c8b541939
|
@ -1,7 +1,7 @@
|
||||||
from llvm import core as lc
|
from llvm import core as lc
|
||||||
from llvm import target as lt
|
from llvm import target as lt
|
||||||
|
|
||||||
from artiq.py2llvm import values
|
from artiq.py2llvm import base_types
|
||||||
|
|
||||||
|
|
||||||
lt.initialize_all()
|
lt.initialize_all()
|
||||||
|
@ -21,9 +21,9 @@ _chr_to_type = {
|
||||||
}
|
}
|
||||||
|
|
||||||
_chr_to_value = {
|
_chr_to_value = {
|
||||||
"n": lambda: values.VNone(),
|
"n": lambda: base_types.VNone(),
|
||||||
"i": lambda: values.VInt(),
|
"i": lambda: base_types.VInt(),
|
||||||
"I": lambda: values.VInt(64)
|
"I": lambda: base_types.VInt(64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,6 @@
|
||||||
from llvm import core as lc
|
from artiq.py2llvm.module import Module
|
||||||
from llvm import passes as lp
|
|
||||||
|
|
||||||
from artiq.py2llvm import values
|
|
||||||
from artiq.py2llvm.functions import compile_function
|
|
||||||
from artiq.py2llvm.tools import add_common_passes
|
|
||||||
|
|
||||||
|
|
||||||
def get_runtime_binary(env, funcdef):
|
def get_runtime_binary(env, funcdef):
|
||||||
module = lc.Module.new("main")
|
module = Module(env)
|
||||||
env.init_module(module)
|
module.compile_function(funcdef, dict())
|
||||||
values.init_module(module)
|
return module.emit_object()
|
||||||
|
|
||||||
compile_function(module, env, funcdef, dict())
|
|
||||||
|
|
||||||
pass_manager = lp.PassManager.new()
|
|
||||||
add_common_passes(pass_manager)
|
|
||||||
pass_manager.run(module)
|
|
||||||
|
|
||||||
return env.emit_object()
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
from artiq.py2llvm import values
|
from artiq.py2llvm import values, base_types, fractions
|
||||||
from artiq.py2llvm.tools import is_terminated
|
from artiq.py2llvm.tools import is_terminated
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,9 +30,9 @@ class Visitor:
|
||||||
def _visit_expr_NameConstant(self, node):
|
def _visit_expr_NameConstant(self, node):
|
||||||
v = node.value
|
v = node.value
|
||||||
if v is None:
|
if v is None:
|
||||||
r = values.VNone()
|
r = base_types.VNone()
|
||||||
elif isinstance(v, bool):
|
elif isinstance(v, bool):
|
||||||
r = values.VBool()
|
r = base_types.VBool()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if self.builder is not None:
|
if self.builder is not None:
|
||||||
|
@ -43,9 +43,9 @@ class Visitor:
|
||||||
n = node.n
|
n = node.n
|
||||||
if isinstance(n, int):
|
if isinstance(n, int):
|
||||||
if abs(n) < 2**31:
|
if abs(n) < 2**31:
|
||||||
r = values.VInt()
|
r = base_types.VInt()
|
||||||
else:
|
else:
|
||||||
r = values.VInt(64)
|
r = base_types.VInt(64)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if self.builder is not None:
|
if self.builder is not None:
|
||||||
|
@ -116,7 +116,7 @@ class Visitor:
|
||||||
return ast_unfuns[fn](self.visit_expression(node.args[0]),
|
return ast_unfuns[fn](self.visit_expression(node.args[0]),
|
||||||
self.builder)
|
self.builder)
|
||||||
elif fn == "Fraction":
|
elif fn == "Fraction":
|
||||||
r = values.VFraction()
|
r = fractions.VFraction()
|
||||||
if self.builder is not None:
|
if self.builder is not None:
|
||||||
numerator = self.visit_expression(node.args[0])
|
numerator = self.visit_expression(node.args[0])
|
||||||
denominator = self.visit_expression(node.args[1])
|
denominator = self.visit_expression(node.args[1])
|
||||||
|
@ -213,10 +213,10 @@ class Visitor:
|
||||||
|
|
||||||
def _visit_stmt_Return(self, node):
|
def _visit_stmt_Return(self, node):
|
||||||
if node.value is None:
|
if node.value is None:
|
||||||
val = values.VNone()
|
val = base_types.VNone()
|
||||||
else:
|
else:
|
||||||
val = self.visit_expression(node.value)
|
val = self.visit_expression(node.value)
|
||||||
if isinstance(val, values.VNone):
|
if isinstance(val, base_types.VNone):
|
||||||
self.builder.ret_void()
|
self.builder.ret_void()
|
||||||
else:
|
else:
|
||||||
self.builder.ret(val.get_ssa_value(self.builder))
|
self.builder.ret(val.get_ssa_value(self.builder))
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
from llvm import core as lc
|
||||||
|
|
||||||
|
from artiq.py2llvm.values import VGeneric
|
||||||
|
|
||||||
|
|
||||||
|
class VNone(VGeneric):
|
||||||
|
def __repr__(self):
|
||||||
|
return "<VNone>"
|
||||||
|
|
||||||
|
def get_llvm_type(self):
|
||||||
|
return lc.Type.void()
|
||||||
|
|
||||||
|
def same_type(self, other):
|
||||||
|
return isinstance(other, VNone)
|
||||||
|
|
||||||
|
def merge(self, other):
|
||||||
|
if not isinstance(other, VNone):
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def alloca(self, builder, name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def o_bool(self, builder):
|
||||||
|
r = VBool()
|
||||||
|
if builder is not None:
|
||||||
|
r.set_const_value(builder, False)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
class VInt(VGeneric):
|
||||||
|
def __init__(self, nbits=32):
|
||||||
|
VGeneric.__init__(self)
|
||||||
|
self.nbits = nbits
|
||||||
|
|
||||||
|
def get_llvm_type(self):
|
||||||
|
return lc.Type.int(self.nbits)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<VInt:{}>".format(self.nbits)
|
||||||
|
|
||||||
|
def same_type(self, other):
|
||||||
|
return isinstance(other, VInt) and other.nbits == self.nbits
|
||||||
|
|
||||||
|
def merge(self, other):
|
||||||
|
if isinstance(other, VInt) and not isinstance(other, VBool):
|
||||||
|
if other.nbits > self.nbits:
|
||||||
|
self.nbits = other.nbits
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def set_value(self, builder, n):
|
||||||
|
self.set_ssa_value(
|
||||||
|
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
|
||||||
|
|
||||||
|
def set_const_value(self, builder, n):
|
||||||
|
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
|
||||||
|
|
||||||
|
def o_bool(self, builder, inv=False):
|
||||||
|
r = VBool()
|
||||||
|
if builder is not None:
|
||||||
|
r.set_ssa_value(
|
||||||
|
builder, builder.icmp(
|
||||||
|
lc.ICMP_EQ if inv else lc.ICMP_NE,
|
||||||
|
self.get_ssa_value(builder),
|
||||||
|
lc.Constant.int(self.get_llvm_type(), 0)))
|
||||||
|
return r
|
||||||
|
|
||||||
|
def o_not(self, builder):
|
||||||
|
return self.o_bool(builder, True)
|
||||||
|
|
||||||
|
def o_intx(self, target_bits, builder):
|
||||||
|
r = VInt(target_bits)
|
||||||
|
if builder is not None:
|
||||||
|
if self.nbits == target_bits:
|
||||||
|
r.set_ssa_value(
|
||||||
|
builder, self.get_ssa_value(builder))
|
||||||
|
if self.nbits > target_bits:
|
||||||
|
r.set_ssa_value(
|
||||||
|
builder, builder.trunc(self.get_ssa_value(builder),
|
||||||
|
r.get_llvm_type()))
|
||||||
|
if self.nbits < target_bits:
|
||||||
|
r.set_ssa_value(
|
||||||
|
builder, builder.sext(self.get_ssa_value(builder),
|
||||||
|
r.get_llvm_type()))
|
||||||
|
return r
|
||||||
|
o_roundx = o_intx
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vint_binop_method(builder_name):
|
||||||
|
def binop_method(self, other, builder):
|
||||||
|
if isinstance(other, VInt):
|
||||||
|
target_bits = max(self.nbits, other.nbits)
|
||||||
|
r = VInt(target_bits)
|
||||||
|
if builder is not None:
|
||||||
|
left = self.o_intx(target_bits, builder)
|
||||||
|
right = other.o_intx(target_bits, builder)
|
||||||
|
bf = getattr(builder, builder_name)
|
||||||
|
r.set_ssa_value(
|
||||||
|
builder, bf(left.get_ssa_value(builder),
|
||||||
|
right.get_ssa_value(builder)))
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return NotImplemented
|
||||||
|
return binop_method
|
||||||
|
|
||||||
|
for _method_name, _builder_name in (("o_add", "add"),
|
||||||
|
("o_sub", "sub"),
|
||||||
|
("o_mul", "mul"),
|
||||||
|
("o_floordiv", "sdiv"),
|
||||||
|
("o_mod", "srem"),
|
||||||
|
("o_and", "and_"),
|
||||||
|
("o_xor", "xor"),
|
||||||
|
("o_or", "or_")):
|
||||||
|
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vint_cmp_method(icmp_val):
|
||||||
|
def cmp_method(self, other, builder):
|
||||||
|
if isinstance(other, VInt):
|
||||||
|
r = VBool()
|
||||||
|
if builder is not None:
|
||||||
|
target_bits = max(self.nbits, other.nbits)
|
||||||
|
left = self.o_intx(target_bits, builder)
|
||||||
|
right = other.o_intx(target_bits, builder)
|
||||||
|
r.set_ssa_value(
|
||||||
|
builder,
|
||||||
|
builder.icmp(
|
||||||
|
icmp_val, left.get_ssa_value(builder),
|
||||||
|
right.get_ssa_value(builder)))
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return NotImplemented
|
||||||
|
return cmp_method
|
||||||
|
|
||||||
|
for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
|
||||||
|
("o_ne", lc.ICMP_NE),
|
||||||
|
("o_lt", lc.ICMP_SLT),
|
||||||
|
("o_le", lc.ICMP_SLE),
|
||||||
|
("o_gt", lc.ICMP_SGT),
|
||||||
|
("o_ge", lc.ICMP_SGE)):
|
||||||
|
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
|
||||||
|
|
||||||
|
|
||||||
|
class VBool(VInt):
|
||||||
|
def __init__(self):
|
||||||
|
VInt.__init__(self, 1)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<VBool>"
|
||||||
|
|
||||||
|
def same_type(self, other):
|
||||||
|
return isinstance(other, VBool)
|
||||||
|
|
||||||
|
def merge(self, other):
|
||||||
|
if not isinstance(other, VBool):
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def set_const_value(self, builder, b):
|
||||||
|
VInt.set_const_value(self, builder, int(b))
|
||||||
|
|
||||||
|
def o_bool(self, builder):
|
||||||
|
r = VBool()
|
||||||
|
if builder is not None:
|
||||||
|
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
||||||
|
return r
|
|
@ -0,0 +1,192 @@
|
||||||
|
from llvm import core as lc
|
||||||
|
|
||||||
|
from artiq.py2llvm.values import VGeneric
|
||||||
|
from artiq.py2llvm.base_types import VBool, VInt
|
||||||
|
|
||||||
|
|
||||||
|
def _gcd64(builder, a, b):
|
||||||
|
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
|
||||||
|
return builder.call(gcd_f, [a, b])
|
||||||
|
|
||||||
|
def init_module(module):
|
||||||
|
func_type = lc.Type.function(
|
||||||
|
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
|
||||||
|
module.add_function(func_type, "__gcd64")
|
||||||
|
|
||||||
|
|
||||||
|
def _frac_normalize(builder, numerator, denominator):
|
||||||
|
gcd = _gcd64(builder, numerator, denominator)
|
||||||
|
numerator = builder.sdiv(numerator, gcd)
|
||||||
|
denominator = builder.sdiv(denominator, gcd)
|
||||||
|
return numerator, denominator
|
||||||
|
|
||||||
|
|
||||||
|
def _frac_make_ssa(builder, numerator, denominator):
|
||||||
|
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
|
||||||
|
value = builder.insert_element(
|
||||||
|
value, numerator, lc.Constant.int(lc.Type.int(), 0))
|
||||||
|
value = builder.insert_element(
|
||||||
|
value, denominator, lc.Constant.int(lc.Type.int(), 1))
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class VFraction(VGeneric):
|
||||||
|
def get_llvm_type(self):
|
||||||
|
return lc.Type.vector(lc.Type.int(64), 2)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<VFraction>"
|
||||||
|
|
||||||
|
def same_type(self, other):
|
||||||
|
return isinstance(other, VFraction)
|
||||||
|
|
||||||
|
def merge(self, other):
|
||||||
|
if not isinstance(other, VFraction):
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def _nd(self, builder, invert=False):
|
||||||
|
ssa_value = self.get_ssa_value(builder)
|
||||||
|
numerator = builder.extract_element(
|
||||||
|
ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
||||||
|
denominator = builder.extract_element(
|
||||||
|
ssa_value, lc.Constant.int(lc.Type.int(), 1))
|
||||||
|
if invert:
|
||||||
|
return denominator, numerator
|
||||||
|
else:
|
||||||
|
return numerator, denominator
|
||||||
|
|
||||||
|
def set_value_nd(self, builder, numerator, denominator):
|
||||||
|
numerator = numerator.o_int64(builder).get_ssa_value(builder)
|
||||||
|
denominator = denominator.o_int64(builder).get_ssa_value(builder)
|
||||||
|
numerator, denominator = _frac_normalize(
|
||||||
|
builder, numerator, denominator)
|
||||||
|
self.set_ssa_value(
|
||||||
|
builder, _frac_make_ssa(builder, numerator, denominator))
|
||||||
|
|
||||||
|
def set_value(self, builder, n):
|
||||||
|
if not isinstance(n, VFraction):
|
||||||
|
raise TypeError
|
||||||
|
self.set_ssa_value(builder, n.get_ssa_value(builder))
|
||||||
|
|
||||||
|
def o_bool(self, builder):
|
||||||
|
r = VBool()
|
||||||
|
if builder is not None:
|
||||||
|
zero = lc.Constant.int(lc.Type.int(64), 0)
|
||||||
|
numerator = builder.extract_element(
|
||||||
|
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
||||||
|
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
|
||||||
|
return r
|
||||||
|
|
||||||
|
def o_intx(self, target_bits, builder):
|
||||||
|
if builder is None:
|
||||||
|
return VInt(target_bits)
|
||||||
|
else:
|
||||||
|
r = VInt(64)
|
||||||
|
numerator, denominator = self._nd(builder)
|
||||||
|
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
|
||||||
|
return r.o_intx(target_bits, builder)
|
||||||
|
|
||||||
|
def o_roundx(self, target_bits, builder):
|
||||||
|
if builder is None:
|
||||||
|
return VInt(target_bits)
|
||||||
|
else:
|
||||||
|
r = VInt(64)
|
||||||
|
numerator, denominator = self._nd(builder)
|
||||||
|
h_denominator = builder.ashr(denominator,
|
||||||
|
lc.Constant.int(lc.Type.int(), 1))
|
||||||
|
r_numerator = builder.add(numerator, h_denominator)
|
||||||
|
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
|
||||||
|
return r.o_intx(target_bits, builder)
|
||||||
|
|
||||||
|
def _o_eq_inv(self, other, builder, ne):
|
||||||
|
if isinstance(other, VFraction):
|
||||||
|
r = VBool()
|
||||||
|
if builder is not None:
|
||||||
|
ee = []
|
||||||
|
for i in range(2):
|
||||||
|
es = builder.extract_element(
|
||||||
|
self.get_ssa_value(builder),
|
||||||
|
lc.Constant.int(lc.Type.int(), i))
|
||||||
|
eo = builder.extract_element(
|
||||||
|
other.get_ssa_value(builder),
|
||||||
|
lc.Constant.int(lc.Type.int(), i))
|
||||||
|
ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
|
||||||
|
ssa_r = builder.and_(ee[0], ee[1])
|
||||||
|
if ne:
|
||||||
|
ssa_r = builder.xor(ssa_r,
|
||||||
|
lc.Constant.int(lc.Type.int(1), 1))
|
||||||
|
r.set_ssa_value(builder, ssa_r)
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def o_eq(self, other, builder):
|
||||||
|
return self._o_eq_inv(other, builder, False)
|
||||||
|
|
||||||
|
def o_ne(self, other, builder):
|
||||||
|
return self._o_eq_inv(other, builder, True)
|
||||||
|
|
||||||
|
def _o_muldiv(self, other, builder, div, invert=False):
|
||||||
|
r = VFraction()
|
||||||
|
if isinstance(other, VInt):
|
||||||
|
if builder is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
numerator, denominator = self._nd(builder, invert)
|
||||||
|
i = other.get_ssa_value(builder)
|
||||||
|
if div:
|
||||||
|
gcd = _gcd64(i, numerator)
|
||||||
|
i = builder.sdiv(i, gcd)
|
||||||
|
numerator = builder.sdiv(numerator, gcd)
|
||||||
|
denominator = builder.mul(denominator, i)
|
||||||
|
else:
|
||||||
|
gcd = _gcd64(i, denominator)
|
||||||
|
i = builder.sdiv(i, gcd)
|
||||||
|
denominator = builder.sdiv(denominator, gcd)
|
||||||
|
numerator = builder.mul(numerator, i)
|
||||||
|
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
|
||||||
|
denominator))
|
||||||
|
elif isinstance(other, VFraction):
|
||||||
|
if builder is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
numerator, denominator = self._nd(builder, invert)
|
||||||
|
onumerator, odenominator = other._nd(builder)
|
||||||
|
if div:
|
||||||
|
numerator = builder.mul(numerator, odenominator)
|
||||||
|
denominator = builder.mul(denominator, onumerator)
|
||||||
|
else:
|
||||||
|
numerator = builder.mul(numerator, onumerator)
|
||||||
|
denominator = builder.mul(denominator, odenominator)
|
||||||
|
numerator, denominator = _frac_normalize(builder, numerator,
|
||||||
|
denominator)
|
||||||
|
self.set_ssa_value(
|
||||||
|
builder, _frac_make_ssa(builder, numerator, denominator))
|
||||||
|
else:
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def o_mul(self, other, builder):
|
||||||
|
return self._o_muldiv(other, builder, False)
|
||||||
|
|
||||||
|
def o_truediv(self, other, builder):
|
||||||
|
return self._o_muldiv(other, builder, True)
|
||||||
|
|
||||||
|
def or_mul(self, other, builder):
|
||||||
|
return self._o_muldiv(other, builder, False)
|
||||||
|
|
||||||
|
def or_truediv(self, other, builder):
|
||||||
|
return self._o_muldiv(other, builder, False, True)
|
||||||
|
|
||||||
|
def o_floordiv(self, other, builder):
|
||||||
|
r = self.o_truediv(other, builder)
|
||||||
|
if r is NotImplemented:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return r.o_int(builder)
|
||||||
|
|
||||||
|
def or_floordiv(self, other, builder):
|
||||||
|
r = self.or_truediv(other, builder)
|
||||||
|
if r is NotImplemented:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return r.o_int(builder)
|
|
@ -1,31 +0,0 @@
|
||||||
from llvm import core as lc
|
|
||||||
|
|
||||||
from artiq.py2llvm import infer_types, ast_body, values, tools
|
|
||||||
|
|
||||||
def compile_function(module, env, funcdef, param_types):
|
|
||||||
ns = infer_types.infer_function_types(env, funcdef, param_types)
|
|
||||||
retval = ns["return"]
|
|
||||||
|
|
||||||
function_type = lc.Type.function(retval.get_llvm_type(),
|
|
||||||
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
|
|
||||||
function = module.add_function(function_type, funcdef.name)
|
|
||||||
bb = function.append_basic_block("entry")
|
|
||||||
builder = lc.Builder.new(bb)
|
|
||||||
|
|
||||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
|
||||||
arg_llvm.name = arg_ast.arg
|
|
||||||
for k, v in ns.items():
|
|
||||||
v.alloca(builder, k)
|
|
||||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
|
||||||
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
|
|
||||||
|
|
||||||
visitor = ast_body.Visitor(env, ns, builder)
|
|
||||||
visitor.visit_statements(funcdef.body)
|
|
||||||
|
|
||||||
if not tools.is_terminated(builder.basic_block):
|
|
||||||
if isinstance(retval, values.VNone):
|
|
||||||
builder.ret_void()
|
|
||||||
else:
|
|
||||||
builder.ret(retval.get_ssa_value(builder))
|
|
||||||
|
|
||||||
return function, retval
|
|
|
@ -2,7 +2,7 @@ import ast
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from artiq.py2llvm.ast_body import Visitor
|
from artiq.py2llvm.ast_body import Visitor
|
||||||
from artiq.py2llvm import values
|
from artiq.py2llvm import base_types
|
||||||
|
|
||||||
|
|
||||||
class _TypeScanner(ast.NodeVisitor):
|
class _TypeScanner(ast.NodeVisitor):
|
||||||
|
@ -31,7 +31,7 @@ class _TypeScanner(ast.NodeVisitor):
|
||||||
|
|
||||||
def visit_Return(self, node):
|
def visit_Return(self, node):
|
||||||
if node.value is None:
|
if node.value is None:
|
||||||
val = values.VNone()
|
val = base_types.VNone()
|
||||||
else:
|
else:
|
||||||
val = self.exprv.visit_expression(node.value)
|
val = self.exprv.visit_expression(node.value)
|
||||||
ns = self.exprv.ns
|
ns = self.exprv.ns
|
||||||
|
@ -51,5 +51,5 @@ def infer_function_types(env, node, param_types):
|
||||||
if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
||||||
# no more promotions - completed
|
# no more promotions - completed
|
||||||
if "return" not in ns:
|
if "return" not in ns:
|
||||||
ns["return"] = values.VNone()
|
ns["return"] = base_types.VNone()
|
||||||
return ns
|
return ns
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
from llvm import core as lc
|
||||||
|
from llvm import passes as lp
|
||||||
|
from llvm import ee as le
|
||||||
|
|
||||||
|
from artiq.py2llvm import infer_types, ast_body, base_types, fractions, tools
|
||||||
|
|
||||||
|
|
||||||
|
class Module:
|
||||||
|
def __init__(self, env=None):
|
||||||
|
self.llvm_module = lc.Module.new("main")
|
||||||
|
self.env = env
|
||||||
|
|
||||||
|
if self.env is not None:
|
||||||
|
self.env.init_module(self.llvm_module)
|
||||||
|
fractions.init_module(self.llvm_module)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
pass_manager = lp.PassManager.new()
|
||||||
|
pass_manager.add(lp.PASS_MEM2REG)
|
||||||
|
pass_manager.add(lp.PASS_INSTCOMBINE)
|
||||||
|
pass_manager.add(lp.PASS_REASSOCIATE)
|
||||||
|
pass_manager.add(lp.PASS_GVN)
|
||||||
|
pass_manager.add(lp.PASS_SIMPLIFYCFG)
|
||||||
|
pass_manager.run(self.llvm_module)
|
||||||
|
|
||||||
|
def get_ee(self):
|
||||||
|
return le.ExecutionEngine.new(self.llvm_module)
|
||||||
|
|
||||||
|
def emit_object(self):
|
||||||
|
self.finalize()
|
||||||
|
return self.env.emit_object()
|
||||||
|
|
||||||
|
def compile_function(self, funcdef, param_types):
|
||||||
|
ns = infer_types.infer_function_types(self.env, funcdef, param_types)
|
||||||
|
retval = ns["return"]
|
||||||
|
|
||||||
|
function_type = lc.Type.function(retval.get_llvm_type(),
|
||||||
|
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
|
||||||
|
function = self.llvm_module.add_function(function_type, funcdef.name)
|
||||||
|
bb = function.append_basic_block("entry")
|
||||||
|
builder = lc.Builder.new(bb)
|
||||||
|
|
||||||
|
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||||
|
arg_llvm.name = arg_ast.arg
|
||||||
|
for k, v in ns.items():
|
||||||
|
v.alloca(builder, k)
|
||||||
|
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||||
|
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
|
||||||
|
|
||||||
|
visitor = ast_body.Visitor(self.env, ns, builder)
|
||||||
|
visitor.visit_statements(funcdef.body)
|
||||||
|
|
||||||
|
if not tools.is_terminated(builder.basic_block):
|
||||||
|
if isinstance(retval, base_types.VNone):
|
||||||
|
builder.ret_void()
|
||||||
|
else:
|
||||||
|
builder.ret(retval.get_ssa_value(builder))
|
||||||
|
|
||||||
|
return function, retval
|
|
@ -1,11 +1,2 @@
|
||||||
from llvm import passes as lp
|
|
||||||
|
|
||||||
def is_terminated(basic_block):
|
def is_terminated(basic_block):
|
||||||
return basic_block.instructions and basic_block.instructions[-1].is_terminator
|
return basic_block.instructions and basic_block.instructions[-1].is_terminator
|
||||||
|
|
||||||
def add_common_passes(pass_manager):
|
|
||||||
pass_manager.add(lp.PASS_MEM2REG)
|
|
||||||
pass_manager.add(lp.PASS_INSTCOMBINE)
|
|
||||||
pass_manager.add(lp.PASS_REASSOCIATE)
|
|
||||||
pass_manager.add(lp.PASS_GVN)
|
|
||||||
pass_manager.add(lp.PASS_SIMPLIFYCFG)
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from types import SimpleNamespace
|
||||||
from llvm import core as lc
|
from llvm import core as lc
|
||||||
|
|
||||||
|
|
||||||
class _Value:
|
class VGeneric:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._llvm_value = None
|
self._llvm_value = None
|
||||||
|
|
||||||
|
@ -40,361 +40,6 @@ class _Value:
|
||||||
return self.o_roundx(64, builder)
|
return self.o_roundx(64, builder)
|
||||||
|
|
||||||
|
|
||||||
# None type
|
|
||||||
|
|
||||||
class VNone(_Value):
|
|
||||||
def __repr__(self):
|
|
||||||
return "<VNone>"
|
|
||||||
|
|
||||||
def get_llvm_type(self):
|
|
||||||
return lc.Type.void()
|
|
||||||
|
|
||||||
def same_type(self, other):
|
|
||||||
return isinstance(other, VNone)
|
|
||||||
|
|
||||||
def merge(self, other):
|
|
||||||
if not isinstance(other, VNone):
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def alloca(self, builder, name):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def o_bool(self, builder):
|
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
r.set_const_value(builder, False)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
# Integer type
|
|
||||||
|
|
||||||
class VInt(_Value):
|
|
||||||
def __init__(self, nbits=32):
|
|
||||||
_Value.__init__(self)
|
|
||||||
self.nbits = nbits
|
|
||||||
|
|
||||||
def get_llvm_type(self):
|
|
||||||
return lc.Type.int(self.nbits)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<VInt:{}>".format(self.nbits)
|
|
||||||
|
|
||||||
def same_type(self, other):
|
|
||||||
return isinstance(other, VInt) and other.nbits == self.nbits
|
|
||||||
|
|
||||||
def merge(self, other):
|
|
||||||
if isinstance(other, VInt) and not isinstance(other, VBool):
|
|
||||||
if other.nbits > self.nbits:
|
|
||||||
self.nbits = other.nbits
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def set_value(self, builder, n):
|
|
||||||
self.set_ssa_value(
|
|
||||||
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
|
|
||||||
|
|
||||||
def set_const_value(self, builder, n):
|
|
||||||
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
|
|
||||||
|
|
||||||
def o_bool(self, builder, inv=False):
|
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
r.set_ssa_value(
|
|
||||||
builder, builder.icmp(
|
|
||||||
lc.ICMP_EQ if inv else lc.ICMP_NE,
|
|
||||||
self.get_ssa_value(builder),
|
|
||||||
lc.Constant.int(self.get_llvm_type(), 0)))
|
|
||||||
return r
|
|
||||||
|
|
||||||
def o_not(self, builder):
|
|
||||||
return self.o_bool(builder, True)
|
|
||||||
|
|
||||||
def o_intx(self, target_bits, builder):
|
|
||||||
r = VInt(target_bits)
|
|
||||||
if builder is not None:
|
|
||||||
if self.nbits == target_bits:
|
|
||||||
r.set_ssa_value(
|
|
||||||
builder, self.get_ssa_value(builder))
|
|
||||||
if self.nbits > target_bits:
|
|
||||||
r.set_ssa_value(
|
|
||||||
builder, builder.trunc(self.get_ssa_value(builder),
|
|
||||||
r.get_llvm_type()))
|
|
||||||
if self.nbits < target_bits:
|
|
||||||
r.set_ssa_value(
|
|
||||||
builder, builder.sext(self.get_ssa_value(builder),
|
|
||||||
r.get_llvm_type()))
|
|
||||||
return r
|
|
||||||
o_roundx = o_intx
|
|
||||||
|
|
||||||
|
|
||||||
def _make_vint_binop_method(builder_name):
|
|
||||||
def binop_method(self, other, builder):
|
|
||||||
if isinstance(other, VInt):
|
|
||||||
target_bits = max(self.nbits, other.nbits)
|
|
||||||
r = VInt(target_bits)
|
|
||||||
if builder is not None:
|
|
||||||
left = self.o_intx(target_bits, builder)
|
|
||||||
right = other.o_intx(target_bits, builder)
|
|
||||||
bf = getattr(builder, builder_name)
|
|
||||||
r.set_ssa_value(
|
|
||||||
builder, bf(left.get_ssa_value(builder),
|
|
||||||
right.get_ssa_value(builder)))
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return NotImplemented
|
|
||||||
return binop_method
|
|
||||||
|
|
||||||
for _method_name, _builder_name in (("o_add", "add"),
|
|
||||||
("o_sub", "sub"),
|
|
||||||
("o_mul", "mul"),
|
|
||||||
("o_floordiv", "sdiv"),
|
|
||||||
("o_mod", "srem"),
|
|
||||||
("o_and", "and_"),
|
|
||||||
("o_xor", "xor"),
|
|
||||||
("o_or", "or_")):
|
|
||||||
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
|
|
||||||
|
|
||||||
|
|
||||||
def _make_vint_cmp_method(icmp_val):
|
|
||||||
def cmp_method(self, other, builder):
|
|
||||||
if isinstance(other, VInt):
|
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
target_bits = max(self.nbits, other.nbits)
|
|
||||||
left = self.o_intx(target_bits, builder)
|
|
||||||
right = other.o_intx(target_bits, builder)
|
|
||||||
r.set_ssa_value(
|
|
||||||
builder,
|
|
||||||
builder.icmp(
|
|
||||||
icmp_val, left.get_ssa_value(builder),
|
|
||||||
right.get_ssa_value(builder)))
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return NotImplemented
|
|
||||||
return cmp_method
|
|
||||||
|
|
||||||
for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
|
|
||||||
("o_ne", lc.ICMP_NE),
|
|
||||||
("o_lt", lc.ICMP_SLT),
|
|
||||||
("o_le", lc.ICMP_SLE),
|
|
||||||
("o_gt", lc.ICMP_SGT),
|
|
||||||
("o_ge", lc.ICMP_SGE)):
|
|
||||||
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
|
|
||||||
|
|
||||||
|
|
||||||
# Boolean type
|
|
||||||
|
|
||||||
class VBool(VInt):
|
|
||||||
def __init__(self):
|
|
||||||
VInt.__init__(self, 1)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<VBool>"
|
|
||||||
|
|
||||||
def same_type(self, other):
|
|
||||||
return isinstance(other, VBool)
|
|
||||||
|
|
||||||
def merge(self, other):
|
|
||||||
if not isinstance(other, VBool):
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def set_const_value(self, builder, b):
|
|
||||||
VInt.set_const_value(self, builder, int(b))
|
|
||||||
|
|
||||||
def o_bool(self, builder):
|
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
# Fraction type
|
|
||||||
|
|
||||||
def _gcd64(builder, a, b):
|
|
||||||
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
|
|
||||||
return builder.call(gcd_f, [a, b])
|
|
||||||
|
|
||||||
|
|
||||||
def _frac_normalize(builder, numerator, denominator):
|
|
||||||
gcd = _gcd64(builder, numerator, denominator)
|
|
||||||
numerator = builder.sdiv(numerator, gcd)
|
|
||||||
denominator = builder.sdiv(denominator, gcd)
|
|
||||||
return numerator, denominator
|
|
||||||
|
|
||||||
|
|
||||||
def _frac_make_ssa(builder, numerator, denominator):
|
|
||||||
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
|
|
||||||
value = builder.insert_element(
|
|
||||||
value, numerator, lc.Constant.int(lc.Type.int(), 0))
|
|
||||||
value = builder.insert_element(
|
|
||||||
value, denominator, lc.Constant.int(lc.Type.int(), 1))
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class VFraction(_Value):
|
|
||||||
def get_llvm_type(self):
|
|
||||||
return lc.Type.vector(lc.Type.int(64), 2)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<VFraction>"
|
|
||||||
|
|
||||||
def same_type(self, other):
|
|
||||||
return isinstance(other, VFraction)
|
|
||||||
|
|
||||||
def merge(self, other):
|
|
||||||
if not isinstance(other, VFraction):
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
def _nd(self, builder, invert=False):
|
|
||||||
ssa_value = self.get_ssa_value(builder)
|
|
||||||
numerator = builder.extract_element(
|
|
||||||
ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
|
||||||
denominator = builder.extract_element(
|
|
||||||
ssa_value, lc.Constant.int(lc.Type.int(), 1))
|
|
||||||
if invert:
|
|
||||||
return denominator, numerator
|
|
||||||
else:
|
|
||||||
return numerator, denominator
|
|
||||||
|
|
||||||
def set_value_nd(self, builder, numerator, denominator):
|
|
||||||
numerator = numerator.o_int64(builder).get_ssa_value(builder)
|
|
||||||
denominator = denominator.o_int64(builder).get_ssa_value(builder)
|
|
||||||
numerator, denominator = _frac_normalize(
|
|
||||||
builder, numerator, denominator)
|
|
||||||
self.set_ssa_value(
|
|
||||||
builder, _frac_make_ssa(builder, numerator, denominator))
|
|
||||||
|
|
||||||
def set_value(self, builder, n):
|
|
||||||
if not isinstance(n, VFraction):
|
|
||||||
raise TypeError
|
|
||||||
self.set_ssa_value(builder, n.get_ssa_value(builder))
|
|
||||||
|
|
||||||
def o_bool(self, builder):
|
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
zero = lc.Constant.int(lc.Type.int(64), 0)
|
|
||||||
numerator = builder.extract_element(
|
|
||||||
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
|
||||||
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
|
|
||||||
return r
|
|
||||||
|
|
||||||
def o_intx(self, target_bits, builder):
|
|
||||||
if builder is None:
|
|
||||||
return VInt(target_bits)
|
|
||||||
else:
|
|
||||||
r = VInt(64)
|
|
||||||
numerator, denominator = self._nd(builder)
|
|
||||||
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
|
|
||||||
return r.o_intx(target_bits, builder)
|
|
||||||
|
|
||||||
def o_roundx(self, target_bits, builder):
|
|
||||||
if builder is None:
|
|
||||||
return VInt(target_bits)
|
|
||||||
else:
|
|
||||||
r = VInt(64)
|
|
||||||
numerator, denominator = self._nd(builder)
|
|
||||||
h_denominator = builder.ashr(denominator,
|
|
||||||
lc.Constant.int(lc.Type.int(), 1))
|
|
||||||
r_numerator = builder.add(numerator, h_denominator)
|
|
||||||
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
|
|
||||||
return r.o_intx(target_bits, builder)
|
|
||||||
|
|
||||||
def _o_eq_inv(self, other, builder, ne):
|
|
||||||
if isinstance(other, VFraction):
|
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
ee = []
|
|
||||||
for i in range(2):
|
|
||||||
es = builder.extract_element(
|
|
||||||
self.get_ssa_value(builder),
|
|
||||||
lc.Constant.int(lc.Type.int(), i))
|
|
||||||
eo = builder.extract_element(
|
|
||||||
other.get_ssa_value(builder),
|
|
||||||
lc.Constant.int(lc.Type.int(), i))
|
|
||||||
ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
|
|
||||||
ssa_r = builder.and_(ee[0], ee[1])
|
|
||||||
if ne:
|
|
||||||
ssa_r = builder.xor(ssa_r,
|
|
||||||
lc.Constant.int(lc.Type.int(1), 1))
|
|
||||||
r.set_ssa_value(builder, ssa_r)
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def o_eq(self, other, builder):
|
|
||||||
return self._o_eq_inv(other, builder, False)
|
|
||||||
|
|
||||||
def o_ne(self, other, builder):
|
|
||||||
return self._o_eq_inv(other, builder, True)
|
|
||||||
|
|
||||||
def _o_muldiv(self, other, builder, div, invert=False):
|
|
||||||
r = VFraction()
|
|
||||||
if isinstance(other, VInt):
|
|
||||||
if builder is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
numerator, denominator = self._nd(builder, invert)
|
|
||||||
i = other.get_ssa_value(builder)
|
|
||||||
if div:
|
|
||||||
gcd = _gcd64(i, numerator)
|
|
||||||
i = builder.sdiv(i, gcd)
|
|
||||||
numerator = builder.sdiv(numerator, gcd)
|
|
||||||
denominator = builder.mul(denominator, i)
|
|
||||||
else:
|
|
||||||
gcd = _gcd64(i, denominator)
|
|
||||||
i = builder.sdiv(i, gcd)
|
|
||||||
denominator = builder.sdiv(denominator, gcd)
|
|
||||||
numerator = builder.mul(numerator, i)
|
|
||||||
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
|
|
||||||
denominator))
|
|
||||||
elif isinstance(other, VFraction):
|
|
||||||
if builder is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
numerator, denominator = self._nd(builder, invert)
|
|
||||||
onumerator, odenominator = other._nd(builder)
|
|
||||||
if div:
|
|
||||||
numerator = builder.mul(numerator, odenominator)
|
|
||||||
denominator = builder.mul(denominator, onumerator)
|
|
||||||
else:
|
|
||||||
numerator = builder.mul(numerator, onumerator)
|
|
||||||
denominator = builder.mul(denominator, odenominator)
|
|
||||||
numerator, denominator = _frac_normalize(builder, numerator,
|
|
||||||
denominator)
|
|
||||||
self.set_ssa_value(
|
|
||||||
builder, _frac_make_ssa(builder, numerator, denominator))
|
|
||||||
else:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def o_mul(self, other, builder):
|
|
||||||
return self._o_muldiv(other, builder, False)
|
|
||||||
|
|
||||||
def o_truediv(self, other, builder):
|
|
||||||
return self._o_muldiv(other, builder, True)
|
|
||||||
|
|
||||||
def or_mul(self, other, builder):
|
|
||||||
return self._o_muldiv(other, builder, False)
|
|
||||||
|
|
||||||
def or_truediv(self, other, builder):
|
|
||||||
return self._o_muldiv(other, builder, False, True)
|
|
||||||
|
|
||||||
def o_floordiv(self, other, builder):
|
|
||||||
r = self.o_truediv(other, builder)
|
|
||||||
if r is NotImplemented:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r.o_int(builder)
|
|
||||||
|
|
||||||
def or_floordiv(self, other, builder):
|
|
||||||
r = self.or_truediv(other, builder)
|
|
||||||
if r is NotImplemented:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r.o_int(builder)
|
|
||||||
|
|
||||||
|
|
||||||
# Operators
|
|
||||||
|
|
||||||
def _make_unary_operator(op_name):
|
def _make_unary_operator(op_name):
|
||||||
def op(x, builder):
|
def op(x, builder):
|
||||||
try:
|
try:
|
||||||
|
@ -446,9 +91,3 @@ def _make_operators():
|
||||||
return SimpleNamespace(**d)
|
return SimpleNamespace(**d)
|
||||||
|
|
||||||
operators = _make_operators()
|
operators = _make_operators()
|
||||||
|
|
||||||
|
|
||||||
def init_module(module):
|
|
||||||
func_type = lc.Type.function(
|
|
||||||
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
|
|
||||||
module.add_function(func_type, "__gcd64")
|
|
||||||
|
|
|
@ -2,15 +2,12 @@ import unittest
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from llvm import core as lc
|
|
||||||
from llvm import passes as lp
|
|
||||||
from llvm import ee as le
|
from llvm import ee as le
|
||||||
|
|
||||||
from artiq.language.core import int64
|
from artiq.language.core import int64
|
||||||
from artiq.py2llvm.infer_types import infer_function_types
|
from artiq.py2llvm.infer_types import infer_function_types
|
||||||
from artiq.py2llvm import values
|
from artiq.py2llvm import base_types
|
||||||
from artiq.py2llvm import compile_function
|
from artiq.py2llvm.module import Module
|
||||||
from artiq.py2llvm.tools import add_common_passes
|
|
||||||
|
|
||||||
|
|
||||||
def test_types(choice):
|
def test_types(choice):
|
||||||
|
@ -35,46 +32,40 @@ class FunctionTypesCase(unittest.TestCase):
|
||||||
dict())
|
dict())
|
||||||
|
|
||||||
def test_base_types(self):
|
def test_base_types(self):
|
||||||
self.assertIsInstance(self.ns["foo"], values.VBool)
|
self.assertIsInstance(self.ns["foo"], base_types.VBool)
|
||||||
self.assertIsInstance(self.ns["bar"], values.VNone)
|
self.assertIsInstance(self.ns["bar"], base_types.VNone)
|
||||||
self.assertIsInstance(self.ns["d"], values.VInt)
|
self.assertIsInstance(self.ns["d"], base_types.VInt)
|
||||||
self.assertEqual(self.ns["d"].nbits, 32)
|
self.assertEqual(self.ns["d"].nbits, 32)
|
||||||
self.assertIsInstance(self.ns["x"], values.VInt)
|
self.assertIsInstance(self.ns["x"], base_types.VInt)
|
||||||
self.assertEqual(self.ns["x"].nbits, 64)
|
self.assertEqual(self.ns["x"].nbits, 64)
|
||||||
|
|
||||||
def test_promotion(self):
|
def test_promotion(self):
|
||||||
for v in "abc":
|
for v in "abc":
|
||||||
self.assertIsInstance(self.ns[v], values.VInt)
|
self.assertIsInstance(self.ns[v], base_types.VInt)
|
||||||
self.assertEqual(self.ns[v].nbits, 64)
|
self.assertEqual(self.ns[v].nbits, 64)
|
||||||
|
|
||||||
def test_return(self):
|
def test_return(self):
|
||||||
self.assertIsInstance(self.ns["return"], values.VInt)
|
self.assertIsInstance(self.ns["return"], base_types.VInt)
|
||||||
self.assertEqual(self.ns["return"].nbits, 64)
|
self.assertEqual(self.ns["return"].nbits, 64)
|
||||||
|
|
||||||
|
|
||||||
class CompiledFunction:
|
class CompiledFunction:
|
||||||
def __init__(self, function, param_types):
|
def __init__(self, function, param_types):
|
||||||
module = lc.Module.new("main")
|
module = Module()
|
||||||
values.init_module(module)
|
|
||||||
|
|
||||||
funcdef = ast.parse(inspect.getsource(function)).body[0]
|
funcdef = ast.parse(inspect.getsource(function)).body[0]
|
||||||
self.function, self.retval = compile_function(
|
self.function, self.retval = module.compile_function(
|
||||||
module, None, funcdef, param_types)
|
funcdef, param_types)
|
||||||
self.argval = [param_types[arg.arg] for arg in funcdef.args.args]
|
self.argval = [param_types[arg.arg] for arg in funcdef.args.args]
|
||||||
|
self.ee = module.get_ee()
|
||||||
self.executor = le.ExecutionEngine.new(module)
|
|
||||||
pass_manager = lp.PassManager.new()
|
|
||||||
add_common_passes(pass_manager)
|
|
||||||
pass_manager.run(module)
|
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
args_llvm = [
|
args_llvm = [
|
||||||
le.GenericValue.int(av.get_llvm_type(), a)
|
le.GenericValue.int(av.get_llvm_type(), a)
|
||||||
for av, a in zip(self.argval, args)]
|
for av, a in zip(self.argval, args)]
|
||||||
result = self.executor.run_function(self.function, args_llvm)
|
result = self.ee.run_function(self.function, args_llvm)
|
||||||
if isinstance(self.retval, values.VBool):
|
if isinstance(self.retval, base_types.VBool):
|
||||||
return bool(result.as_int())
|
return bool(result.as_int())
|
||||||
elif isinstance(self.retval, values.VInt):
|
elif isinstance(self.retval, base_types.VInt):
|
||||||
return result.as_int_signed()
|
return result.as_int_signed()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -90,6 +81,6 @@ def is_prime(x):
|
||||||
|
|
||||||
class CodeGenCase(unittest.TestCase):
|
class CodeGenCase(unittest.TestCase):
|
||||||
def test_is_prime(self):
|
def test_is_prime(self):
|
||||||
is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)})
|
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt(32)})
|
||||||
for i in range(200):
|
for i in range(200):
|
||||||
self.assertEqual(is_prime_c(i), is_prime(i))
|
self.assertEqual(is_prime_c(i), is_prime(i))
|
||||||
|
|
Loading…
Reference in New Issue