2
0
mirror of https://github.com/m-labs/artiq.git synced 2024-12-24 19:04:02 +08:00

py2llvm: reorganize, split 'values' module, factor LLVM module/pass management

This commit is contained in:
Sebastien Bourdeauducq 2014-09-07 14:09:03 +08:00
parent 58465e49fa
commit 3c8b541939
11 changed files with 452 additions and 460 deletions

View File

@ -1,7 +1,7 @@
from llvm import core as lc
from llvm import target as lt
from artiq.py2llvm import values
from artiq.py2llvm import base_types
lt.initialize_all()
@ -21,9 +21,9 @@ _chr_to_type = {
}
_chr_to_value = {
"n": lambda: values.VNone(),
"i": lambda: values.VInt(),
"I": lambda: values.VInt(64)
"n": lambda: base_types.VNone(),
"i": lambda: base_types.VInt(),
"I": lambda: base_types.VInt(64)
}

View File

@ -1,20 +1,6 @@
from llvm import core as lc
from llvm import passes as lp
from artiq.py2llvm import values
from artiq.py2llvm.functions import compile_function
from artiq.py2llvm.tools import add_common_passes
from artiq.py2llvm.module import Module
def get_runtime_binary(env, funcdef):
module = lc.Module.new("main")
env.init_module(module)
values.init_module(module)
compile_function(module, env, funcdef, dict())
pass_manager = lp.PassManager.new()
add_common_passes(pass_manager)
pass_manager.run(module)
return env.emit_object()
module = Module(env)
module.compile_function(funcdef, dict())
return module.emit_object()

View File

@ -1,6 +1,6 @@
import ast
from artiq.py2llvm import values
from artiq.py2llvm import values, base_types, fractions
from artiq.py2llvm.tools import is_terminated
@ -30,9 +30,9 @@ class Visitor:
def _visit_expr_NameConstant(self, node):
v = node.value
if v is None:
r = values.VNone()
r = base_types.VNone()
elif isinstance(v, bool):
r = values.VBool()
r = base_types.VBool()
else:
raise NotImplementedError
if self.builder is not None:
@ -43,9 +43,9 @@ class Visitor:
n = node.n
if isinstance(n, int):
if abs(n) < 2**31:
r = values.VInt()
r = base_types.VInt()
else:
r = values.VInt(64)
r = base_types.VInt(64)
else:
raise NotImplementedError
if self.builder is not None:
@ -116,7 +116,7 @@ class Visitor:
return ast_unfuns[fn](self.visit_expression(node.args[0]),
self.builder)
elif fn == "Fraction":
r = values.VFraction()
r = fractions.VFraction()
if self.builder is not None:
numerator = self.visit_expression(node.args[0])
denominator = self.visit_expression(node.args[1])
@ -213,10 +213,10 @@ class Visitor:
def _visit_stmt_Return(self, node):
if node.value is None:
val = values.VNone()
val = base_types.VNone()
else:
val = self.visit_expression(node.value)
if isinstance(val, values.VNone):
if isinstance(val, base_types.VNone):
self.builder.ret_void()
else:
self.builder.ret(val.get_ssa_value(self.builder))

165
artiq/py2llvm/base_types.py Normal file
View File

@ -0,0 +1,165 @@
from llvm import core as lc
from artiq.py2llvm.values import VGeneric
class VNone(VGeneric):
def __repr__(self):
return "<VNone>"
def get_llvm_type(self):
return lc.Type.void()
def same_type(self, other):
return isinstance(other, VNone)
def merge(self, other):
if not isinstance(other, VNone):
raise TypeError
def alloca(self, builder, name):
pass
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_const_value(builder, False)
return r
class VInt(VGeneric):
def __init__(self, nbits=32):
VGeneric.__init__(self)
self.nbits = nbits
def get_llvm_type(self):
return lc.Type.int(self.nbits)
def __repr__(self):
return "<VInt:{}>".format(self.nbits)
def same_type(self, other):
return isinstance(other, VInt) and other.nbits == self.nbits
def merge(self, other):
if isinstance(other, VInt) and not isinstance(other, VBool):
if other.nbits > self.nbits:
self.nbits = other.nbits
else:
raise TypeError
def set_value(self, builder, n):
self.set_ssa_value(
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
def set_const_value(self, builder, n):
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
def o_bool(self, builder, inv=False):
r = VBool()
if builder is not None:
r.set_ssa_value(
builder, builder.icmp(
lc.ICMP_EQ if inv else lc.ICMP_NE,
self.get_ssa_value(builder),
lc.Constant.int(self.get_llvm_type(), 0)))
return r
def o_not(self, builder):
return self.o_bool(builder, True)
def o_intx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
if self.nbits == target_bits:
r.set_ssa_value(
builder, self.get_ssa_value(builder))
if self.nbits > target_bits:
r.set_ssa_value(
builder, builder.trunc(self.get_ssa_value(builder),
r.get_llvm_type()))
if self.nbits < target_bits:
r.set_ssa_value(
builder, builder.sext(self.get_ssa_value(builder),
r.get_llvm_type()))
return r
o_roundx = o_intx
def _make_vint_binop_method(builder_name):
def binop_method(self, other, builder):
if isinstance(other, VInt):
target_bits = max(self.nbits, other.nbits)
r = VInt(target_bits)
if builder is not None:
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name)
r.set_ssa_value(
builder, bf(left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r
else:
return NotImplemented
return binop_method
for _method_name, _builder_name in (("o_add", "add"),
("o_sub", "sub"),
("o_mul", "mul"),
("o_floordiv", "sdiv"),
("o_mod", "srem"),
("o_and", "and_"),
("o_xor", "xor"),
("o_or", "or_")):
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
def _make_vint_cmp_method(icmp_val):
def cmp_method(self, other, builder):
if isinstance(other, VInt):
r = VBool()
if builder is not None:
target_bits = max(self.nbits, other.nbits)
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
r.set_ssa_value(
builder,
builder.icmp(
icmp_val, left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r
else:
return NotImplemented
return cmp_method
for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
("o_ne", lc.ICMP_NE),
("o_lt", lc.ICMP_SLT),
("o_le", lc.ICMP_SLE),
("o_gt", lc.ICMP_SGT),
("o_ge", lc.ICMP_SGE)):
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
class VBool(VInt):
def __init__(self):
VInt.__init__(self, 1)
def __repr__(self):
return "<VBool>"
def same_type(self, other):
return isinstance(other, VBool)
def merge(self, other):
if not isinstance(other, VBool):
raise TypeError
def set_const_value(self, builder, b):
VInt.set_const_value(self, builder, int(b))
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_ssa_value(builder, self.get_ssa_value(builder))
return r

192
artiq/py2llvm/fractions.py Normal file
View File

@ -0,0 +1,192 @@
from llvm import core as lc
from artiq.py2llvm.values import VGeneric
from artiq.py2llvm.base_types import VBool, VInt
def _gcd64(builder, a, b):
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
return builder.call(gcd_f, [a, b])
def init_module(module):
func_type = lc.Type.function(
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
module.add_function(func_type, "__gcd64")
def _frac_normalize(builder, numerator, denominator):
gcd = _gcd64(builder, numerator, denominator)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd)
return numerator, denominator
def _frac_make_ssa(builder, numerator, denominator):
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
value = builder.insert_element(
value, numerator, lc.Constant.int(lc.Type.int(), 0))
value = builder.insert_element(
value, denominator, lc.Constant.int(lc.Type.int(), 1))
return value
class VFraction(VGeneric):
def get_llvm_type(self):
return lc.Type.vector(lc.Type.int(64), 2)
def __repr__(self):
return "<VFraction>"
def same_type(self, other):
return isinstance(other, VFraction)
def merge(self, other):
if not isinstance(other, VFraction):
raise TypeError
def _nd(self, builder, invert=False):
ssa_value = self.get_ssa_value(builder)
numerator = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 0))
denominator = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 1))
if invert:
return denominator, numerator
else:
return numerator, denominator
def set_value_nd(self, builder, numerator, denominator):
numerator = numerator.o_int64(builder).get_ssa_value(builder)
denominator = denominator.o_int64(builder).get_ssa_value(builder)
numerator, denominator = _frac_normalize(
builder, numerator, denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
def set_value(self, builder, n):
if not isinstance(n, VFraction):
raise TypeError
self.set_ssa_value(builder, n.get_ssa_value(builder))
def o_bool(self, builder):
r = VBool()
if builder is not None:
zero = lc.Constant.int(lc.Type.int(64), 0)
numerator = builder.extract_element(
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
return r
def o_intx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
numerator, denominator = self._nd(builder)
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
return r.o_intx(target_bits, builder)
def o_roundx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
numerator, denominator = self._nd(builder)
h_denominator = builder.ashr(denominator,
lc.Constant.int(lc.Type.int(), 1))
r_numerator = builder.add(numerator, h_denominator)
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
return r.o_intx(target_bits, builder)
def _o_eq_inv(self, other, builder, ne):
if isinstance(other, VFraction):
r = VBool()
if builder is not None:
ee = []
for i in range(2):
es = builder.extract_element(
self.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
eo = builder.extract_element(
other.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
ssa_r = builder.and_(ee[0], ee[1])
if ne:
ssa_r = builder.xor(ssa_r,
lc.Constant.int(lc.Type.int(1), 1))
r.set_ssa_value(builder, ssa_r)
return r
else:
return NotImplemented
def o_eq(self, other, builder):
return self._o_eq_inv(other, builder, False)
def o_ne(self, other, builder):
return self._o_eq_inv(other, builder, True)
def _o_muldiv(self, other, builder, div, invert=False):
r = VFraction()
if isinstance(other, VInt):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
i = other.get_ssa_value(builder)
if div:
gcd = _gcd64(i, numerator)
i = builder.sdiv(i, gcd)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.mul(denominator, i)
else:
gcd = _gcd64(i, denominator)
i = builder.sdiv(i, gcd)
denominator = builder.sdiv(denominator, gcd)
numerator = builder.mul(numerator, i)
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
denominator))
elif isinstance(other, VFraction):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
onumerator, odenominator = other._nd(builder)
if div:
numerator = builder.mul(numerator, odenominator)
denominator = builder.mul(denominator, onumerator)
else:
numerator = builder.mul(numerator, onumerator)
denominator = builder.mul(denominator, odenominator)
numerator, denominator = _frac_normalize(builder, numerator,
denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
else:
return NotImplemented
def o_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def o_truediv(self, other, builder):
return self._o_muldiv(other, builder, True)
def or_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def or_truediv(self, other, builder):
return self._o_muldiv(other, builder, False, True)
def o_floordiv(self, other, builder):
r = self.o_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
def or_floordiv(self, other, builder):
r = self.or_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)

View File

@ -1,31 +0,0 @@
from llvm import core as lc
from artiq.py2llvm import infer_types, ast_body, values, tools
def compile_function(module, env, funcdef, param_types):
ns = infer_types.infer_function_types(env, funcdef, param_types)
retval = ns["return"]
function_type = lc.Type.function(retval.get_llvm_type(),
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
function = module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
arg_llvm.name = arg_ast.arg
for k, v in ns.items():
v.alloca(builder, k)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
visitor = ast_body.Visitor(env, ns, builder)
visitor.visit_statements(funcdef.body)
if not tools.is_terminated(builder.basic_block):
if isinstance(retval, values.VNone):
builder.ret_void()
else:
builder.ret(retval.get_ssa_value(builder))
return function, retval

View File

@ -2,7 +2,7 @@ import ast
from copy import deepcopy
from artiq.py2llvm.ast_body import Visitor
from artiq.py2llvm import values
from artiq.py2llvm import base_types
class _TypeScanner(ast.NodeVisitor):
@ -31,7 +31,7 @@ class _TypeScanner(ast.NodeVisitor):
def visit_Return(self, node):
if node.value is None:
val = values.VNone()
val = base_types.VNone()
else:
val = self.exprv.visit_expression(node.value)
ns = self.exprv.ns
@ -51,5 +51,5 @@ def infer_function_types(env, node, param_types):
if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
# no more promotions - completed
if "return" not in ns:
ns["return"] = values.VNone()
ns["return"] = base_types.VNone()
return ns

59
artiq/py2llvm/module.py Normal file
View File

@ -0,0 +1,59 @@
from llvm import core as lc
from llvm import passes as lp
from llvm import ee as le
from artiq.py2llvm import infer_types, ast_body, base_types, fractions, tools
class Module:
def __init__(self, env=None):
self.llvm_module = lc.Module.new("main")
self.env = env
if self.env is not None:
self.env.init_module(self.llvm_module)
fractions.init_module(self.llvm_module)
def finalize(self):
pass_manager = lp.PassManager.new()
pass_manager.add(lp.PASS_MEM2REG)
pass_manager.add(lp.PASS_INSTCOMBINE)
pass_manager.add(lp.PASS_REASSOCIATE)
pass_manager.add(lp.PASS_GVN)
pass_manager.add(lp.PASS_SIMPLIFYCFG)
pass_manager.run(self.llvm_module)
def get_ee(self):
return le.ExecutionEngine.new(self.llvm_module)
def emit_object(self):
self.finalize()
return self.env.emit_object()
def compile_function(self, funcdef, param_types):
ns = infer_types.infer_function_types(self.env, funcdef, param_types)
retval = ns["return"]
function_type = lc.Type.function(retval.get_llvm_type(),
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
function = self.llvm_module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
arg_llvm.name = arg_ast.arg
for k, v in ns.items():
v.alloca(builder, k)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
visitor = ast_body.Visitor(self.env, ns, builder)
visitor.visit_statements(funcdef.body)
if not tools.is_terminated(builder.basic_block):
if isinstance(retval, base_types.VNone):
builder.ret_void()
else:
builder.ret(retval.get_ssa_value(builder))
return function, retval

View File

@ -1,11 +1,2 @@
from llvm import passes as lp
def is_terminated(basic_block):
return basic_block.instructions and basic_block.instructions[-1].is_terminator
def add_common_passes(pass_manager):
pass_manager.add(lp.PASS_MEM2REG)
pass_manager.add(lp.PASS_INSTCOMBINE)
pass_manager.add(lp.PASS_REASSOCIATE)
pass_manager.add(lp.PASS_GVN)
pass_manager.add(lp.PASS_SIMPLIFYCFG)

View File

@ -3,7 +3,7 @@ from types import SimpleNamespace
from llvm import core as lc
class _Value:
class VGeneric:
def __init__(self):
self._llvm_value = None
@ -40,361 +40,6 @@ class _Value:
return self.o_roundx(64, builder)
# None type
class VNone(_Value):
def __repr__(self):
return "<VNone>"
def get_llvm_type(self):
return lc.Type.void()
def same_type(self, other):
return isinstance(other, VNone)
def merge(self, other):
if not isinstance(other, VNone):
raise TypeError
def alloca(self, builder, name):
pass
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_const_value(builder, False)
return r
# Integer type
class VInt(_Value):
def __init__(self, nbits=32):
_Value.__init__(self)
self.nbits = nbits
def get_llvm_type(self):
return lc.Type.int(self.nbits)
def __repr__(self):
return "<VInt:{}>".format(self.nbits)
def same_type(self, other):
return isinstance(other, VInt) and other.nbits == self.nbits
def merge(self, other):
if isinstance(other, VInt) and not isinstance(other, VBool):
if other.nbits > self.nbits:
self.nbits = other.nbits
else:
raise TypeError
def set_value(self, builder, n):
self.set_ssa_value(
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
def set_const_value(self, builder, n):
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
def o_bool(self, builder, inv=False):
r = VBool()
if builder is not None:
r.set_ssa_value(
builder, builder.icmp(
lc.ICMP_EQ if inv else lc.ICMP_NE,
self.get_ssa_value(builder),
lc.Constant.int(self.get_llvm_type(), 0)))
return r
def o_not(self, builder):
return self.o_bool(builder, True)
def o_intx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
if self.nbits == target_bits:
r.set_ssa_value(
builder, self.get_ssa_value(builder))
if self.nbits > target_bits:
r.set_ssa_value(
builder, builder.trunc(self.get_ssa_value(builder),
r.get_llvm_type()))
if self.nbits < target_bits:
r.set_ssa_value(
builder, builder.sext(self.get_ssa_value(builder),
r.get_llvm_type()))
return r
o_roundx = o_intx
def _make_vint_binop_method(builder_name):
def binop_method(self, other, builder):
if isinstance(other, VInt):
target_bits = max(self.nbits, other.nbits)
r = VInt(target_bits)
if builder is not None:
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name)
r.set_ssa_value(
builder, bf(left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r
else:
return NotImplemented
return binop_method
for _method_name, _builder_name in (("o_add", "add"),
("o_sub", "sub"),
("o_mul", "mul"),
("o_floordiv", "sdiv"),
("o_mod", "srem"),
("o_and", "and_"),
("o_xor", "xor"),
("o_or", "or_")):
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
def _make_vint_cmp_method(icmp_val):
def cmp_method(self, other, builder):
if isinstance(other, VInt):
r = VBool()
if builder is not None:
target_bits = max(self.nbits, other.nbits)
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
r.set_ssa_value(
builder,
builder.icmp(
icmp_val, left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r
else:
return NotImplemented
return cmp_method
for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
("o_ne", lc.ICMP_NE),
("o_lt", lc.ICMP_SLT),
("o_le", lc.ICMP_SLE),
("o_gt", lc.ICMP_SGT),
("o_ge", lc.ICMP_SGE)):
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
# Boolean type
class VBool(VInt):
def __init__(self):
VInt.__init__(self, 1)
def __repr__(self):
return "<VBool>"
def same_type(self, other):
return isinstance(other, VBool)
def merge(self, other):
if not isinstance(other, VBool):
raise TypeError
def set_const_value(self, builder, b):
VInt.set_const_value(self, builder, int(b))
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_ssa_value(builder, self.get_ssa_value(builder))
return r
# Fraction type
def _gcd64(builder, a, b):
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
return builder.call(gcd_f, [a, b])
def _frac_normalize(builder, numerator, denominator):
gcd = _gcd64(builder, numerator, denominator)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd)
return numerator, denominator
def _frac_make_ssa(builder, numerator, denominator):
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
value = builder.insert_element(
value, numerator, lc.Constant.int(lc.Type.int(), 0))
value = builder.insert_element(
value, denominator, lc.Constant.int(lc.Type.int(), 1))
return value
class VFraction(_Value):
def get_llvm_type(self):
return lc.Type.vector(lc.Type.int(64), 2)
def __repr__(self):
return "<VFraction>"
def same_type(self, other):
return isinstance(other, VFraction)
def merge(self, other):
if not isinstance(other, VFraction):
raise TypeError
def _nd(self, builder, invert=False):
ssa_value = self.get_ssa_value(builder)
numerator = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 0))
denominator = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 1))
if invert:
return denominator, numerator
else:
return numerator, denominator
def set_value_nd(self, builder, numerator, denominator):
numerator = numerator.o_int64(builder).get_ssa_value(builder)
denominator = denominator.o_int64(builder).get_ssa_value(builder)
numerator, denominator = _frac_normalize(
builder, numerator, denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
def set_value(self, builder, n):
if not isinstance(n, VFraction):
raise TypeError
self.set_ssa_value(builder, n.get_ssa_value(builder))
def o_bool(self, builder):
r = VBool()
if builder is not None:
zero = lc.Constant.int(lc.Type.int(64), 0)
numerator = builder.extract_element(
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
return r
def o_intx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
numerator, denominator = self._nd(builder)
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
return r.o_intx(target_bits, builder)
def o_roundx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
numerator, denominator = self._nd(builder)
h_denominator = builder.ashr(denominator,
lc.Constant.int(lc.Type.int(), 1))
r_numerator = builder.add(numerator, h_denominator)
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
return r.o_intx(target_bits, builder)
def _o_eq_inv(self, other, builder, ne):
if isinstance(other, VFraction):
r = VBool()
if builder is not None:
ee = []
for i in range(2):
es = builder.extract_element(
self.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
eo = builder.extract_element(
other.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
ssa_r = builder.and_(ee[0], ee[1])
if ne:
ssa_r = builder.xor(ssa_r,
lc.Constant.int(lc.Type.int(1), 1))
r.set_ssa_value(builder, ssa_r)
return r
else:
return NotImplemented
def o_eq(self, other, builder):
return self._o_eq_inv(other, builder, False)
def o_ne(self, other, builder):
return self._o_eq_inv(other, builder, True)
def _o_muldiv(self, other, builder, div, invert=False):
r = VFraction()
if isinstance(other, VInt):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
i = other.get_ssa_value(builder)
if div:
gcd = _gcd64(i, numerator)
i = builder.sdiv(i, gcd)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.mul(denominator, i)
else:
gcd = _gcd64(i, denominator)
i = builder.sdiv(i, gcd)
denominator = builder.sdiv(denominator, gcd)
numerator = builder.mul(numerator, i)
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
denominator))
elif isinstance(other, VFraction):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
onumerator, odenominator = other._nd(builder)
if div:
numerator = builder.mul(numerator, odenominator)
denominator = builder.mul(denominator, onumerator)
else:
numerator = builder.mul(numerator, onumerator)
denominator = builder.mul(denominator, odenominator)
numerator, denominator = _frac_normalize(builder, numerator,
denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
else:
return NotImplemented
def o_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def o_truediv(self, other, builder):
return self._o_muldiv(other, builder, True)
def or_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def or_truediv(self, other, builder):
return self._o_muldiv(other, builder, False, True)
def o_floordiv(self, other, builder):
r = self.o_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
def or_floordiv(self, other, builder):
r = self.or_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
# Operators
def _make_unary_operator(op_name):
def op(x, builder):
try:
@ -446,9 +91,3 @@ def _make_operators():
return SimpleNamespace(**d)
operators = _make_operators()
def init_module(module):
func_type = lc.Type.function(
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
module.add_function(func_type, "__gcd64")

View File

@ -2,15 +2,12 @@ import unittest
import ast
import inspect
from llvm import core as lc
from llvm import passes as lp
from llvm import ee as le
from artiq.language.core import int64
from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import values
from artiq.py2llvm import compile_function
from artiq.py2llvm.tools import add_common_passes
from artiq.py2llvm import base_types
from artiq.py2llvm.module import Module
def test_types(choice):
@ -35,46 +32,40 @@ class FunctionTypesCase(unittest.TestCase):
dict())
def test_base_types(self):
self.assertIsInstance(self.ns["foo"], values.VBool)
self.assertIsInstance(self.ns["bar"], values.VNone)
self.assertIsInstance(self.ns["d"], values.VInt)
self.assertIsInstance(self.ns["foo"], base_types.VBool)
self.assertIsInstance(self.ns["bar"], base_types.VNone)
self.assertIsInstance(self.ns["d"], base_types.VInt)
self.assertEqual(self.ns["d"].nbits, 32)
self.assertIsInstance(self.ns["x"], values.VInt)
self.assertIsInstance(self.ns["x"], base_types.VInt)
self.assertEqual(self.ns["x"].nbits, 64)
def test_promotion(self):
for v in "abc":
self.assertIsInstance(self.ns[v], values.VInt)
self.assertIsInstance(self.ns[v], base_types.VInt)
self.assertEqual(self.ns[v].nbits, 64)
def test_return(self):
self.assertIsInstance(self.ns["return"], values.VInt)
self.assertIsInstance(self.ns["return"], base_types.VInt)
self.assertEqual(self.ns["return"].nbits, 64)
class CompiledFunction:
def __init__(self, function, param_types):
module = lc.Module.new("main")
values.init_module(module)
module = Module()
funcdef = ast.parse(inspect.getsource(function)).body[0]
self.function, self.retval = compile_function(
module, None, funcdef, param_types)
self.function, self.retval = module.compile_function(
funcdef, param_types)
self.argval = [param_types[arg.arg] for arg in funcdef.args.args]
self.executor = le.ExecutionEngine.new(module)
pass_manager = lp.PassManager.new()
add_common_passes(pass_manager)
pass_manager.run(module)
self.ee = module.get_ee()
def __call__(self, *args):
args_llvm = [
le.GenericValue.int(av.get_llvm_type(), a)
for av, a in zip(self.argval, args)]
result = self.executor.run_function(self.function, args_llvm)
if isinstance(self.retval, values.VBool):
result = self.ee.run_function(self.function, args_llvm)
if isinstance(self.retval, base_types.VBool):
return bool(result.as_int())
elif isinstance(self.retval, values.VInt):
elif isinstance(self.retval, base_types.VInt):
return result.as_int_signed()
else:
raise NotImplementedError
@ -90,6 +81,6 @@ def is_prime(x):
class CodeGenCase(unittest.TestCase):
def test_is_prime(self):
is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)})
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt(32)})
for i in range(200):
self.assertEqual(is_prime_c(i), is_prime(i))