mirror of https://github.com/m-labs/artiq.git
py2llvm: reorganize, split 'values' module, factor LLVM module/pass management
This commit is contained in:
parent
58465e49fa
commit
3c8b541939
|
@ -1,7 +1,7 @@
|
|||
from llvm import core as lc
|
||||
from llvm import target as lt
|
||||
|
||||
from artiq.py2llvm import values
|
||||
from artiq.py2llvm import base_types
|
||||
|
||||
|
||||
lt.initialize_all()
|
||||
|
@ -21,9 +21,9 @@ _chr_to_type = {
|
|||
}
|
||||
|
||||
_chr_to_value = {
|
||||
"n": lambda: values.VNone(),
|
||||
"i": lambda: values.VInt(),
|
||||
"I": lambda: values.VInt(64)
|
||||
"n": lambda: base_types.VNone(),
|
||||
"i": lambda: base_types.VInt(),
|
||||
"I": lambda: base_types.VInt(64)
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,20 +1,6 @@
|
|||
from llvm import core as lc
|
||||
from llvm import passes as lp
|
||||
|
||||
from artiq.py2llvm import values
|
||||
from artiq.py2llvm.functions import compile_function
|
||||
from artiq.py2llvm.tools import add_common_passes
|
||||
|
||||
from artiq.py2llvm.module import Module
|
||||
|
||||
def get_runtime_binary(env, funcdef):
|
||||
module = lc.Module.new("main")
|
||||
env.init_module(module)
|
||||
values.init_module(module)
|
||||
|
||||
compile_function(module, env, funcdef, dict())
|
||||
|
||||
pass_manager = lp.PassManager.new()
|
||||
add_common_passes(pass_manager)
|
||||
pass_manager.run(module)
|
||||
|
||||
return env.emit_object()
|
||||
module = Module(env)
|
||||
module.compile_function(funcdef, dict())
|
||||
return module.emit_object()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import ast
|
||||
|
||||
from artiq.py2llvm import values
|
||||
from artiq.py2llvm import values, base_types, fractions
|
||||
from artiq.py2llvm.tools import is_terminated
|
||||
|
||||
|
||||
|
@ -30,9 +30,9 @@ class Visitor:
|
|||
def _visit_expr_NameConstant(self, node):
|
||||
v = node.value
|
||||
if v is None:
|
||||
r = values.VNone()
|
||||
r = base_types.VNone()
|
||||
elif isinstance(v, bool):
|
||||
r = values.VBool()
|
||||
r = base_types.VBool()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
|
@ -43,9 +43,9 @@ class Visitor:
|
|||
n = node.n
|
||||
if isinstance(n, int):
|
||||
if abs(n) < 2**31:
|
||||
r = values.VInt()
|
||||
r = base_types.VInt()
|
||||
else:
|
||||
r = values.VInt(64)
|
||||
r = base_types.VInt(64)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
|
@ -116,7 +116,7 @@ class Visitor:
|
|||
return ast_unfuns[fn](self.visit_expression(node.args[0]),
|
||||
self.builder)
|
||||
elif fn == "Fraction":
|
||||
r = values.VFraction()
|
||||
r = fractions.VFraction()
|
||||
if self.builder is not None:
|
||||
numerator = self.visit_expression(node.args[0])
|
||||
denominator = self.visit_expression(node.args[1])
|
||||
|
@ -213,10 +213,10 @@ class Visitor:
|
|||
|
||||
def _visit_stmt_Return(self, node):
|
||||
if node.value is None:
|
||||
val = values.VNone()
|
||||
val = base_types.VNone()
|
||||
else:
|
||||
val = self.visit_expression(node.value)
|
||||
if isinstance(val, values.VNone):
|
||||
if isinstance(val, base_types.VNone):
|
||||
self.builder.ret_void()
|
||||
else:
|
||||
self.builder.ret(val.get_ssa_value(self.builder))
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
from llvm import core as lc
|
||||
|
||||
from artiq.py2llvm.values import VGeneric
|
||||
|
||||
|
||||
class VNone(VGeneric):
|
||||
def __repr__(self):
|
||||
return "<VNone>"
|
||||
|
||||
def get_llvm_type(self):
|
||||
return lc.Type.void()
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VNone)
|
||||
|
||||
def merge(self, other):
|
||||
if not isinstance(other, VNone):
|
||||
raise TypeError
|
||||
|
||||
def alloca(self, builder, name):
|
||||
pass
|
||||
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_const_value(builder, False)
|
||||
return r
|
||||
|
||||
|
||||
class VInt(VGeneric):
|
||||
def __init__(self, nbits=32):
|
||||
VGeneric.__init__(self)
|
||||
self.nbits = nbits
|
||||
|
||||
def get_llvm_type(self):
|
||||
return lc.Type.int(self.nbits)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VInt:{}>".format(self.nbits)
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VInt) and other.nbits == self.nbits
|
||||
|
||||
def merge(self, other):
|
||||
if isinstance(other, VInt) and not isinstance(other, VBool):
|
||||
if other.nbits > self.nbits:
|
||||
self.nbits = other.nbits
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def set_value(self, builder, n):
|
||||
self.set_ssa_value(
|
||||
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
|
||||
|
||||
def set_const_value(self, builder, n):
|
||||
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
|
||||
|
||||
def o_bool(self, builder, inv=False):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_ssa_value(
|
||||
builder, builder.icmp(
|
||||
lc.ICMP_EQ if inv else lc.ICMP_NE,
|
||||
self.get_ssa_value(builder),
|
||||
lc.Constant.int(self.get_llvm_type(), 0)))
|
||||
return r
|
||||
|
||||
def o_not(self, builder):
|
||||
return self.o_bool(builder, True)
|
||||
|
||||
def o_intx(self, target_bits, builder):
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
if self.nbits == target_bits:
|
||||
r.set_ssa_value(
|
||||
builder, self.get_ssa_value(builder))
|
||||
if self.nbits > target_bits:
|
||||
r.set_ssa_value(
|
||||
builder, builder.trunc(self.get_ssa_value(builder),
|
||||
r.get_llvm_type()))
|
||||
if self.nbits < target_bits:
|
||||
r.set_ssa_value(
|
||||
builder, builder.sext(self.get_ssa_value(builder),
|
||||
r.get_llvm_type()))
|
||||
return r
|
||||
o_roundx = o_intx
|
||||
|
||||
|
||||
def _make_vint_binop_method(builder_name):
|
||||
def binop_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
target_bits = max(self.nbits, other.nbits)
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
bf = getattr(builder, builder_name)
|
||||
r.set_ssa_value(
|
||||
builder, bf(left.get_ssa_value(builder),
|
||||
right.get_ssa_value(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
return binop_method
|
||||
|
||||
for _method_name, _builder_name in (("o_add", "add"),
|
||||
("o_sub", "sub"),
|
||||
("o_mul", "mul"),
|
||||
("o_floordiv", "sdiv"),
|
||||
("o_mod", "srem"),
|
||||
("o_and", "and_"),
|
||||
("o_xor", "xor"),
|
||||
("o_or", "or_")):
|
||||
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
|
||||
|
||||
|
||||
def _make_vint_cmp_method(icmp_val):
|
||||
def cmp_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
target_bits = max(self.nbits, other.nbits)
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
r.set_ssa_value(
|
||||
builder,
|
||||
builder.icmp(
|
||||
icmp_val, left.get_ssa_value(builder),
|
||||
right.get_ssa_value(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
return cmp_method
|
||||
|
||||
for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
|
||||
("o_ne", lc.ICMP_NE),
|
||||
("o_lt", lc.ICMP_SLT),
|
||||
("o_le", lc.ICMP_SLE),
|
||||
("o_gt", lc.ICMP_SGT),
|
||||
("o_ge", lc.ICMP_SGE)):
|
||||
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
|
||||
|
||||
|
||||
class VBool(VInt):
|
||||
def __init__(self):
|
||||
VInt.__init__(self, 1)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VBool>"
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VBool)
|
||||
|
||||
def merge(self, other):
|
||||
if not isinstance(other, VBool):
|
||||
raise TypeError
|
||||
|
||||
def set_const_value(self, builder, b):
|
||||
VInt.set_const_value(self, builder, int(b))
|
||||
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
||||
return r
|
|
@ -0,0 +1,192 @@
|
|||
from llvm import core as lc
|
||||
|
||||
from artiq.py2llvm.values import VGeneric
|
||||
from artiq.py2llvm.base_types import VBool, VInt
|
||||
|
||||
|
||||
def _gcd64(builder, a, b):
|
||||
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
|
||||
return builder.call(gcd_f, [a, b])
|
||||
|
||||
def init_module(module):
|
||||
func_type = lc.Type.function(
|
||||
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
|
||||
module.add_function(func_type, "__gcd64")
|
||||
|
||||
|
||||
def _frac_normalize(builder, numerator, denominator):
|
||||
gcd = _gcd64(builder, numerator, denominator)
|
||||
numerator = builder.sdiv(numerator, gcd)
|
||||
denominator = builder.sdiv(denominator, gcd)
|
||||
return numerator, denominator
|
||||
|
||||
|
||||
def _frac_make_ssa(builder, numerator, denominator):
|
||||
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
|
||||
value = builder.insert_element(
|
||||
value, numerator, lc.Constant.int(lc.Type.int(), 0))
|
||||
value = builder.insert_element(
|
||||
value, denominator, lc.Constant.int(lc.Type.int(), 1))
|
||||
return value
|
||||
|
||||
|
||||
class VFraction(VGeneric):
|
||||
def get_llvm_type(self):
|
||||
return lc.Type.vector(lc.Type.int(64), 2)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VFraction>"
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VFraction)
|
||||
|
||||
def merge(self, other):
|
||||
if not isinstance(other, VFraction):
|
||||
raise TypeError
|
||||
|
||||
def _nd(self, builder, invert=False):
|
||||
ssa_value = self.get_ssa_value(builder)
|
||||
numerator = builder.extract_element(
|
||||
ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
||||
denominator = builder.extract_element(
|
||||
ssa_value, lc.Constant.int(lc.Type.int(), 1))
|
||||
if invert:
|
||||
return denominator, numerator
|
||||
else:
|
||||
return numerator, denominator
|
||||
|
||||
def set_value_nd(self, builder, numerator, denominator):
|
||||
numerator = numerator.o_int64(builder).get_ssa_value(builder)
|
||||
denominator = denominator.o_int64(builder).get_ssa_value(builder)
|
||||
numerator, denominator = _frac_normalize(
|
||||
builder, numerator, denominator)
|
||||
self.set_ssa_value(
|
||||
builder, _frac_make_ssa(builder, numerator, denominator))
|
||||
|
||||
def set_value(self, builder, n):
|
||||
if not isinstance(n, VFraction):
|
||||
raise TypeError
|
||||
self.set_ssa_value(builder, n.get_ssa_value(builder))
|
||||
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
zero = lc.Constant.int(lc.Type.int(64), 0)
|
||||
numerator = builder.extract_element(
|
||||
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
||||
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
|
||||
return r
|
||||
|
||||
def o_intx(self, target_bits, builder):
|
||||
if builder is None:
|
||||
return VInt(target_bits)
|
||||
else:
|
||||
r = VInt(64)
|
||||
numerator, denominator = self._nd(builder)
|
||||
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
|
||||
return r.o_intx(target_bits, builder)
|
||||
|
||||
def o_roundx(self, target_bits, builder):
|
||||
if builder is None:
|
||||
return VInt(target_bits)
|
||||
else:
|
||||
r = VInt(64)
|
||||
numerator, denominator = self._nd(builder)
|
||||
h_denominator = builder.ashr(denominator,
|
||||
lc.Constant.int(lc.Type.int(), 1))
|
||||
r_numerator = builder.add(numerator, h_denominator)
|
||||
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
|
||||
return r.o_intx(target_bits, builder)
|
||||
|
||||
def _o_eq_inv(self, other, builder, ne):
|
||||
if isinstance(other, VFraction):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
ee = []
|
||||
for i in range(2):
|
||||
es = builder.extract_element(
|
||||
self.get_ssa_value(builder),
|
||||
lc.Constant.int(lc.Type.int(), i))
|
||||
eo = builder.extract_element(
|
||||
other.get_ssa_value(builder),
|
||||
lc.Constant.int(lc.Type.int(), i))
|
||||
ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
|
||||
ssa_r = builder.and_(ee[0], ee[1])
|
||||
if ne:
|
||||
ssa_r = builder.xor(ssa_r,
|
||||
lc.Constant.int(lc.Type.int(1), 1))
|
||||
r.set_ssa_value(builder, ssa_r)
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def o_eq(self, other, builder):
|
||||
return self._o_eq_inv(other, builder, False)
|
||||
|
||||
def o_ne(self, other, builder):
|
||||
return self._o_eq_inv(other, builder, True)
|
||||
|
||||
def _o_muldiv(self, other, builder, div, invert=False):
|
||||
r = VFraction()
|
||||
if isinstance(other, VInt):
|
||||
if builder is None:
|
||||
return r
|
||||
else:
|
||||
numerator, denominator = self._nd(builder, invert)
|
||||
i = other.get_ssa_value(builder)
|
||||
if div:
|
||||
gcd = _gcd64(i, numerator)
|
||||
i = builder.sdiv(i, gcd)
|
||||
numerator = builder.sdiv(numerator, gcd)
|
||||
denominator = builder.mul(denominator, i)
|
||||
else:
|
||||
gcd = _gcd64(i, denominator)
|
||||
i = builder.sdiv(i, gcd)
|
||||
denominator = builder.sdiv(denominator, gcd)
|
||||
numerator = builder.mul(numerator, i)
|
||||
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
|
||||
denominator))
|
||||
elif isinstance(other, VFraction):
|
||||
if builder is None:
|
||||
return r
|
||||
else:
|
||||
numerator, denominator = self._nd(builder, invert)
|
||||
onumerator, odenominator = other._nd(builder)
|
||||
if div:
|
||||
numerator = builder.mul(numerator, odenominator)
|
||||
denominator = builder.mul(denominator, onumerator)
|
||||
else:
|
||||
numerator = builder.mul(numerator, onumerator)
|
||||
denominator = builder.mul(denominator, odenominator)
|
||||
numerator, denominator = _frac_normalize(builder, numerator,
|
||||
denominator)
|
||||
self.set_ssa_value(
|
||||
builder, _frac_make_ssa(builder, numerator, denominator))
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def o_mul(self, other, builder):
|
||||
return self._o_muldiv(other, builder, False)
|
||||
|
||||
def o_truediv(self, other, builder):
|
||||
return self._o_muldiv(other, builder, True)
|
||||
|
||||
def or_mul(self, other, builder):
|
||||
return self._o_muldiv(other, builder, False)
|
||||
|
||||
def or_truediv(self, other, builder):
|
||||
return self._o_muldiv(other, builder, False, True)
|
||||
|
||||
def o_floordiv(self, other, builder):
|
||||
r = self.o_truediv(other, builder)
|
||||
if r is NotImplemented:
|
||||
return r
|
||||
else:
|
||||
return r.o_int(builder)
|
||||
|
||||
def or_floordiv(self, other, builder):
|
||||
r = self.or_truediv(other, builder)
|
||||
if r is NotImplemented:
|
||||
return r
|
||||
else:
|
||||
return r.o_int(builder)
|
|
@ -1,31 +0,0 @@
|
|||
from llvm import core as lc
|
||||
|
||||
from artiq.py2llvm import infer_types, ast_body, values, tools
|
||||
|
||||
def compile_function(module, env, funcdef, param_types):
|
||||
ns = infer_types.infer_function_types(env, funcdef, param_types)
|
||||
retval = ns["return"]
|
||||
|
||||
function_type = lc.Type.function(retval.get_llvm_type(),
|
||||
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
|
||||
function = module.add_function(function_type, funcdef.name)
|
||||
bb = function.append_basic_block("entry")
|
||||
builder = lc.Builder.new(bb)
|
||||
|
||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||
arg_llvm.name = arg_ast.arg
|
||||
for k, v in ns.items():
|
||||
v.alloca(builder, k)
|
||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
|
||||
|
||||
visitor = ast_body.Visitor(env, ns, builder)
|
||||
visitor.visit_statements(funcdef.body)
|
||||
|
||||
if not tools.is_terminated(builder.basic_block):
|
||||
if isinstance(retval, values.VNone):
|
||||
builder.ret_void()
|
||||
else:
|
||||
builder.ret(retval.get_ssa_value(builder))
|
||||
|
||||
return function, retval
|
|
@ -2,7 +2,7 @@ import ast
|
|||
from copy import deepcopy
|
||||
|
||||
from artiq.py2llvm.ast_body import Visitor
|
||||
from artiq.py2llvm import values
|
||||
from artiq.py2llvm import base_types
|
||||
|
||||
|
||||
class _TypeScanner(ast.NodeVisitor):
|
||||
|
@ -31,7 +31,7 @@ class _TypeScanner(ast.NodeVisitor):
|
|||
|
||||
def visit_Return(self, node):
|
||||
if node.value is None:
|
||||
val = values.VNone()
|
||||
val = base_types.VNone()
|
||||
else:
|
||||
val = self.exprv.visit_expression(node.value)
|
||||
ns = self.exprv.ns
|
||||
|
@ -51,5 +51,5 @@ def infer_function_types(env, node, param_types):
|
|||
if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
||||
# no more promotions - completed
|
||||
if "return" not in ns:
|
||||
ns["return"] = values.VNone()
|
||||
ns["return"] = base_types.VNone()
|
||||
return ns
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
from llvm import core as lc
|
||||
from llvm import passes as lp
|
||||
from llvm import ee as le
|
||||
|
||||
from artiq.py2llvm import infer_types, ast_body, base_types, fractions, tools
|
||||
|
||||
|
||||
class Module:
|
||||
def __init__(self, env=None):
|
||||
self.llvm_module = lc.Module.new("main")
|
||||
self.env = env
|
||||
|
||||
if self.env is not None:
|
||||
self.env.init_module(self.llvm_module)
|
||||
fractions.init_module(self.llvm_module)
|
||||
|
||||
def finalize(self):
|
||||
pass_manager = lp.PassManager.new()
|
||||
pass_manager.add(lp.PASS_MEM2REG)
|
||||
pass_manager.add(lp.PASS_INSTCOMBINE)
|
||||
pass_manager.add(lp.PASS_REASSOCIATE)
|
||||
pass_manager.add(lp.PASS_GVN)
|
||||
pass_manager.add(lp.PASS_SIMPLIFYCFG)
|
||||
pass_manager.run(self.llvm_module)
|
||||
|
||||
def get_ee(self):
|
||||
return le.ExecutionEngine.new(self.llvm_module)
|
||||
|
||||
def emit_object(self):
|
||||
self.finalize()
|
||||
return self.env.emit_object()
|
||||
|
||||
def compile_function(self, funcdef, param_types):
|
||||
ns = infer_types.infer_function_types(self.env, funcdef, param_types)
|
||||
retval = ns["return"]
|
||||
|
||||
function_type = lc.Type.function(retval.get_llvm_type(),
|
||||
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
|
||||
function = self.llvm_module.add_function(function_type, funcdef.name)
|
||||
bb = function.append_basic_block("entry")
|
||||
builder = lc.Builder.new(bb)
|
||||
|
||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||
arg_llvm.name = arg_ast.arg
|
||||
for k, v in ns.items():
|
||||
v.alloca(builder, k)
|
||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
|
||||
|
||||
visitor = ast_body.Visitor(self.env, ns, builder)
|
||||
visitor.visit_statements(funcdef.body)
|
||||
|
||||
if not tools.is_terminated(builder.basic_block):
|
||||
if isinstance(retval, base_types.VNone):
|
||||
builder.ret_void()
|
||||
else:
|
||||
builder.ret(retval.get_ssa_value(builder))
|
||||
|
||||
return function, retval
|
|
@ -1,11 +1,2 @@
|
|||
from llvm import passes as lp
|
||||
|
||||
def is_terminated(basic_block):
|
||||
return basic_block.instructions and basic_block.instructions[-1].is_terminator
|
||||
|
||||
def add_common_passes(pass_manager):
|
||||
pass_manager.add(lp.PASS_MEM2REG)
|
||||
pass_manager.add(lp.PASS_INSTCOMBINE)
|
||||
pass_manager.add(lp.PASS_REASSOCIATE)
|
||||
pass_manager.add(lp.PASS_GVN)
|
||||
pass_manager.add(lp.PASS_SIMPLIFYCFG)
|
||||
|
|
|
@ -3,7 +3,7 @@ from types import SimpleNamespace
|
|||
from llvm import core as lc
|
||||
|
||||
|
||||
class _Value:
|
||||
class VGeneric:
|
||||
def __init__(self):
|
||||
self._llvm_value = None
|
||||
|
||||
|
@ -40,361 +40,6 @@ class _Value:
|
|||
return self.o_roundx(64, builder)
|
||||
|
||||
|
||||
# None type
|
||||
|
||||
class VNone(_Value):
|
||||
def __repr__(self):
|
||||
return "<VNone>"
|
||||
|
||||
def get_llvm_type(self):
|
||||
return lc.Type.void()
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VNone)
|
||||
|
||||
def merge(self, other):
|
||||
if not isinstance(other, VNone):
|
||||
raise TypeError
|
||||
|
||||
def alloca(self, builder, name):
|
||||
pass
|
||||
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_const_value(builder, False)
|
||||
return r
|
||||
|
||||
|
||||
# Integer type
|
||||
|
||||
class VInt(_Value):
|
||||
def __init__(self, nbits=32):
|
||||
_Value.__init__(self)
|
||||
self.nbits = nbits
|
||||
|
||||
def get_llvm_type(self):
|
||||
return lc.Type.int(self.nbits)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VInt:{}>".format(self.nbits)
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VInt) and other.nbits == self.nbits
|
||||
|
||||
def merge(self, other):
|
||||
if isinstance(other, VInt) and not isinstance(other, VBool):
|
||||
if other.nbits > self.nbits:
|
||||
self.nbits = other.nbits
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def set_value(self, builder, n):
|
||||
self.set_ssa_value(
|
||||
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
|
||||
|
||||
def set_const_value(self, builder, n):
|
||||
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
|
||||
|
||||
def o_bool(self, builder, inv=False):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_ssa_value(
|
||||
builder, builder.icmp(
|
||||
lc.ICMP_EQ if inv else lc.ICMP_NE,
|
||||
self.get_ssa_value(builder),
|
||||
lc.Constant.int(self.get_llvm_type(), 0)))
|
||||
return r
|
||||
|
||||
def o_not(self, builder):
|
||||
return self.o_bool(builder, True)
|
||||
|
||||
def o_intx(self, target_bits, builder):
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
if self.nbits == target_bits:
|
||||
r.set_ssa_value(
|
||||
builder, self.get_ssa_value(builder))
|
||||
if self.nbits > target_bits:
|
||||
r.set_ssa_value(
|
||||
builder, builder.trunc(self.get_ssa_value(builder),
|
||||
r.get_llvm_type()))
|
||||
if self.nbits < target_bits:
|
||||
r.set_ssa_value(
|
||||
builder, builder.sext(self.get_ssa_value(builder),
|
||||
r.get_llvm_type()))
|
||||
return r
|
||||
o_roundx = o_intx
|
||||
|
||||
|
||||
def _make_vint_binop_method(builder_name):
|
||||
def binop_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
target_bits = max(self.nbits, other.nbits)
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
bf = getattr(builder, builder_name)
|
||||
r.set_ssa_value(
|
||||
builder, bf(left.get_ssa_value(builder),
|
||||
right.get_ssa_value(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
return binop_method
|
||||
|
||||
for _method_name, _builder_name in (("o_add", "add"),
|
||||
("o_sub", "sub"),
|
||||
("o_mul", "mul"),
|
||||
("o_floordiv", "sdiv"),
|
||||
("o_mod", "srem"),
|
||||
("o_and", "and_"),
|
||||
("o_xor", "xor"),
|
||||
("o_or", "or_")):
|
||||
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
|
||||
|
||||
|
||||
def _make_vint_cmp_method(icmp_val):
|
||||
def cmp_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
target_bits = max(self.nbits, other.nbits)
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
r.set_ssa_value(
|
||||
builder,
|
||||
builder.icmp(
|
||||
icmp_val, left.get_ssa_value(builder),
|
||||
right.get_ssa_value(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
return cmp_method
|
||||
|
||||
for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
|
||||
("o_ne", lc.ICMP_NE),
|
||||
("o_lt", lc.ICMP_SLT),
|
||||
("o_le", lc.ICMP_SLE),
|
||||
("o_gt", lc.ICMP_SGT),
|
||||
("o_ge", lc.ICMP_SGE)):
|
||||
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
|
||||
|
||||
|
||||
# Boolean type
|
||||
|
||||
class VBool(VInt):
|
||||
def __init__(self):
|
||||
VInt.__init__(self, 1)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VBool>"
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VBool)
|
||||
|
||||
def merge(self, other):
|
||||
if not isinstance(other, VBool):
|
||||
raise TypeError
|
||||
|
||||
def set_const_value(self, builder, b):
|
||||
VInt.set_const_value(self, builder, int(b))
|
||||
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
||||
return r
|
||||
|
||||
|
||||
# Fraction type
|
||||
|
||||
def _gcd64(builder, a, b):
|
||||
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
|
||||
return builder.call(gcd_f, [a, b])
|
||||
|
||||
|
||||
def _frac_normalize(builder, numerator, denominator):
|
||||
gcd = _gcd64(builder, numerator, denominator)
|
||||
numerator = builder.sdiv(numerator, gcd)
|
||||
denominator = builder.sdiv(denominator, gcd)
|
||||
return numerator, denominator
|
||||
|
||||
|
||||
def _frac_make_ssa(builder, numerator, denominator):
|
||||
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
|
||||
value = builder.insert_element(
|
||||
value, numerator, lc.Constant.int(lc.Type.int(), 0))
|
||||
value = builder.insert_element(
|
||||
value, denominator, lc.Constant.int(lc.Type.int(), 1))
|
||||
return value
|
||||
|
||||
|
||||
class VFraction(_Value):
|
||||
def get_llvm_type(self):
|
||||
return lc.Type.vector(lc.Type.int(64), 2)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VFraction>"
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VFraction)
|
||||
|
||||
def merge(self, other):
|
||||
if not isinstance(other, VFraction):
|
||||
raise TypeError
|
||||
|
||||
def _nd(self, builder, invert=False):
|
||||
ssa_value = self.get_ssa_value(builder)
|
||||
numerator = builder.extract_element(
|
||||
ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
||||
denominator = builder.extract_element(
|
||||
ssa_value, lc.Constant.int(lc.Type.int(), 1))
|
||||
if invert:
|
||||
return denominator, numerator
|
||||
else:
|
||||
return numerator, denominator
|
||||
|
||||
def set_value_nd(self, builder, numerator, denominator):
|
||||
numerator = numerator.o_int64(builder).get_ssa_value(builder)
|
||||
denominator = denominator.o_int64(builder).get_ssa_value(builder)
|
||||
numerator, denominator = _frac_normalize(
|
||||
builder, numerator, denominator)
|
||||
self.set_ssa_value(
|
||||
builder, _frac_make_ssa(builder, numerator, denominator))
|
||||
|
||||
def set_value(self, builder, n):
|
||||
if not isinstance(n, VFraction):
|
||||
raise TypeError
|
||||
self.set_ssa_value(builder, n.get_ssa_value(builder))
|
||||
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
zero = lc.Constant.int(lc.Type.int(64), 0)
|
||||
numerator = builder.extract_element(
|
||||
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
||||
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
|
||||
return r
|
||||
|
||||
def o_intx(self, target_bits, builder):
|
||||
if builder is None:
|
||||
return VInt(target_bits)
|
||||
else:
|
||||
r = VInt(64)
|
||||
numerator, denominator = self._nd(builder)
|
||||
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
|
||||
return r.o_intx(target_bits, builder)
|
||||
|
||||
def o_roundx(self, target_bits, builder):
|
||||
if builder is None:
|
||||
return VInt(target_bits)
|
||||
else:
|
||||
r = VInt(64)
|
||||
numerator, denominator = self._nd(builder)
|
||||
h_denominator = builder.ashr(denominator,
|
||||
lc.Constant.int(lc.Type.int(), 1))
|
||||
r_numerator = builder.add(numerator, h_denominator)
|
||||
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
|
||||
return r.o_intx(target_bits, builder)
|
||||
|
||||
def _o_eq_inv(self, other, builder, ne):
|
||||
if isinstance(other, VFraction):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
ee = []
|
||||
for i in range(2):
|
||||
es = builder.extract_element(
|
||||
self.get_ssa_value(builder),
|
||||
lc.Constant.int(lc.Type.int(), i))
|
||||
eo = builder.extract_element(
|
||||
other.get_ssa_value(builder),
|
||||
lc.Constant.int(lc.Type.int(), i))
|
||||
ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
|
||||
ssa_r = builder.and_(ee[0], ee[1])
|
||||
if ne:
|
||||
ssa_r = builder.xor(ssa_r,
|
||||
lc.Constant.int(lc.Type.int(1), 1))
|
||||
r.set_ssa_value(builder, ssa_r)
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def o_eq(self, other, builder):
|
||||
return self._o_eq_inv(other, builder, False)
|
||||
|
||||
def o_ne(self, other, builder):
|
||||
return self._o_eq_inv(other, builder, True)
|
||||
|
||||
def _o_muldiv(self, other, builder, div, invert=False):
|
||||
r = VFraction()
|
||||
if isinstance(other, VInt):
|
||||
if builder is None:
|
||||
return r
|
||||
else:
|
||||
numerator, denominator = self._nd(builder, invert)
|
||||
i = other.get_ssa_value(builder)
|
||||
if div:
|
||||
gcd = _gcd64(i, numerator)
|
||||
i = builder.sdiv(i, gcd)
|
||||
numerator = builder.sdiv(numerator, gcd)
|
||||
denominator = builder.mul(denominator, i)
|
||||
else:
|
||||
gcd = _gcd64(i, denominator)
|
||||
i = builder.sdiv(i, gcd)
|
||||
denominator = builder.sdiv(denominator, gcd)
|
||||
numerator = builder.mul(numerator, i)
|
||||
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
|
||||
denominator))
|
||||
elif isinstance(other, VFraction):
|
||||
if builder is None:
|
||||
return r
|
||||
else:
|
||||
numerator, denominator = self._nd(builder, invert)
|
||||
onumerator, odenominator = other._nd(builder)
|
||||
if div:
|
||||
numerator = builder.mul(numerator, odenominator)
|
||||
denominator = builder.mul(denominator, onumerator)
|
||||
else:
|
||||
numerator = builder.mul(numerator, onumerator)
|
||||
denominator = builder.mul(denominator, odenominator)
|
||||
numerator, denominator = _frac_normalize(builder, numerator,
|
||||
denominator)
|
||||
self.set_ssa_value(
|
||||
builder, _frac_make_ssa(builder, numerator, denominator))
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def o_mul(self, other, builder):
|
||||
return self._o_muldiv(other, builder, False)
|
||||
|
||||
def o_truediv(self, other, builder):
|
||||
return self._o_muldiv(other, builder, True)
|
||||
|
||||
def or_mul(self, other, builder):
|
||||
return self._o_muldiv(other, builder, False)
|
||||
|
||||
def or_truediv(self, other, builder):
|
||||
return self._o_muldiv(other, builder, False, True)
|
||||
|
||||
def o_floordiv(self, other, builder):
|
||||
r = self.o_truediv(other, builder)
|
||||
if r is NotImplemented:
|
||||
return r
|
||||
else:
|
||||
return r.o_int(builder)
|
||||
|
||||
def or_floordiv(self, other, builder):
|
||||
r = self.or_truediv(other, builder)
|
||||
if r is NotImplemented:
|
||||
return r
|
||||
else:
|
||||
return r.o_int(builder)
|
||||
|
||||
|
||||
# Operators
|
||||
|
||||
def _make_unary_operator(op_name):
|
||||
def op(x, builder):
|
||||
try:
|
||||
|
@ -446,9 +91,3 @@ def _make_operators():
|
|||
return SimpleNamespace(**d)
|
||||
|
||||
operators = _make_operators()
|
||||
|
||||
|
||||
def init_module(module):
|
||||
func_type = lc.Type.function(
|
||||
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
|
||||
module.add_function(func_type, "__gcd64")
|
||||
|
|
|
@ -2,15 +2,12 @@ import unittest
|
|||
import ast
|
||||
import inspect
|
||||
|
||||
from llvm import core as lc
|
||||
from llvm import passes as lp
|
||||
from llvm import ee as le
|
||||
|
||||
from artiq.language.core import int64
|
||||
from artiq.py2llvm.infer_types import infer_function_types
|
||||
from artiq.py2llvm import values
|
||||
from artiq.py2llvm import compile_function
|
||||
from artiq.py2llvm.tools import add_common_passes
|
||||
from artiq.py2llvm import base_types
|
||||
from artiq.py2llvm.module import Module
|
||||
|
||||
|
||||
def test_types(choice):
|
||||
|
@ -35,46 +32,40 @@ class FunctionTypesCase(unittest.TestCase):
|
|||
dict())
|
||||
|
||||
def test_base_types(self):
|
||||
self.assertIsInstance(self.ns["foo"], values.VBool)
|
||||
self.assertIsInstance(self.ns["bar"], values.VNone)
|
||||
self.assertIsInstance(self.ns["d"], values.VInt)
|
||||
self.assertIsInstance(self.ns["foo"], base_types.VBool)
|
||||
self.assertIsInstance(self.ns["bar"], base_types.VNone)
|
||||
self.assertIsInstance(self.ns["d"], base_types.VInt)
|
||||
self.assertEqual(self.ns["d"].nbits, 32)
|
||||
self.assertIsInstance(self.ns["x"], values.VInt)
|
||||
self.assertIsInstance(self.ns["x"], base_types.VInt)
|
||||
self.assertEqual(self.ns["x"].nbits, 64)
|
||||
|
||||
def test_promotion(self):
|
||||
for v in "abc":
|
||||
self.assertIsInstance(self.ns[v], values.VInt)
|
||||
self.assertIsInstance(self.ns[v], base_types.VInt)
|
||||
self.assertEqual(self.ns[v].nbits, 64)
|
||||
|
||||
def test_return(self):
|
||||
self.assertIsInstance(self.ns["return"], values.VInt)
|
||||
self.assertIsInstance(self.ns["return"], base_types.VInt)
|
||||
self.assertEqual(self.ns["return"].nbits, 64)
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
def __init__(self, function, param_types):
|
||||
module = lc.Module.new("main")
|
||||
values.init_module(module)
|
||||
|
||||
module = Module()
|
||||
funcdef = ast.parse(inspect.getsource(function)).body[0]
|
||||
self.function, self.retval = compile_function(
|
||||
module, None, funcdef, param_types)
|
||||
self.function, self.retval = module.compile_function(
|
||||
funcdef, param_types)
|
||||
self.argval = [param_types[arg.arg] for arg in funcdef.args.args]
|
||||
|
||||
self.executor = le.ExecutionEngine.new(module)
|
||||
pass_manager = lp.PassManager.new()
|
||||
add_common_passes(pass_manager)
|
||||
pass_manager.run(module)
|
||||
self.ee = module.get_ee()
|
||||
|
||||
def __call__(self, *args):
|
||||
args_llvm = [
|
||||
le.GenericValue.int(av.get_llvm_type(), a)
|
||||
for av, a in zip(self.argval, args)]
|
||||
result = self.executor.run_function(self.function, args_llvm)
|
||||
if isinstance(self.retval, values.VBool):
|
||||
result = self.ee.run_function(self.function, args_llvm)
|
||||
if isinstance(self.retval, base_types.VBool):
|
||||
return bool(result.as_int())
|
||||
elif isinstance(self.retval, values.VInt):
|
||||
elif isinstance(self.retval, base_types.VInt):
|
||||
return result.as_int_signed()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -90,6 +81,6 @@ def is_prime(x):
|
|||
|
||||
class CodeGenCase(unittest.TestCase):
|
||||
def test_is_prime(self):
|
||||
is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)})
|
||||
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt(32)})
|
||||
for i in range(200):
|
||||
self.assertEqual(is_prime_c(i), is_prime(i))
|
||||
|
|
Loading…
Reference in New Issue