py2llvm: reorganize, split 'values' module, factor LLVM module/pass management

This commit is contained in:
Sebastien Bourdeauducq 2014-09-07 14:09:03 +08:00
parent 58465e49fa
commit 3c8b541939
11 changed files with 452 additions and 460 deletions

View File

@ -1,7 +1,7 @@
from llvm import core as lc from llvm import core as lc
from llvm import target as lt from llvm import target as lt
from artiq.py2llvm import values from artiq.py2llvm import base_types
lt.initialize_all() lt.initialize_all()
@ -21,9 +21,9 @@ _chr_to_type = {
} }
_chr_to_value = { _chr_to_value = {
"n": lambda: values.VNone(), "n": lambda: base_types.VNone(),
"i": lambda: values.VInt(), "i": lambda: base_types.VInt(),
"I": lambda: values.VInt(64) "I": lambda: base_types.VInt(64)
} }

View File

@ -1,20 +1,6 @@
from llvm import core as lc from artiq.py2llvm.module import Module
from llvm import passes as lp
from artiq.py2llvm import values
from artiq.py2llvm.functions import compile_function
from artiq.py2llvm.tools import add_common_passes
def get_runtime_binary(env, funcdef): def get_runtime_binary(env, funcdef):
module = lc.Module.new("main") module = Module(env)
env.init_module(module) module.compile_function(funcdef, dict())
values.init_module(module) return module.emit_object()
compile_function(module, env, funcdef, dict())
pass_manager = lp.PassManager.new()
add_common_passes(pass_manager)
pass_manager.run(module)
return env.emit_object()

View File

@ -1,6 +1,6 @@
import ast import ast
from artiq.py2llvm import values from artiq.py2llvm import values, base_types, fractions
from artiq.py2llvm.tools import is_terminated from artiq.py2llvm.tools import is_terminated
@ -30,9 +30,9 @@ class Visitor:
def _visit_expr_NameConstant(self, node): def _visit_expr_NameConstant(self, node):
v = node.value v = node.value
if v is None: if v is None:
r = values.VNone() r = base_types.VNone()
elif isinstance(v, bool): elif isinstance(v, bool):
r = values.VBool() r = base_types.VBool()
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None: if self.builder is not None:
@ -43,9 +43,9 @@ class Visitor:
n = node.n n = node.n
if isinstance(n, int): if isinstance(n, int):
if abs(n) < 2**31: if abs(n) < 2**31:
r = values.VInt() r = base_types.VInt()
else: else:
r = values.VInt(64) r = base_types.VInt(64)
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None: if self.builder is not None:
@ -116,7 +116,7 @@ class Visitor:
return ast_unfuns[fn](self.visit_expression(node.args[0]), return ast_unfuns[fn](self.visit_expression(node.args[0]),
self.builder) self.builder)
elif fn == "Fraction": elif fn == "Fraction":
r = values.VFraction() r = fractions.VFraction()
if self.builder is not None: if self.builder is not None:
numerator = self.visit_expression(node.args[0]) numerator = self.visit_expression(node.args[0])
denominator = self.visit_expression(node.args[1]) denominator = self.visit_expression(node.args[1])
@ -213,10 +213,10 @@ class Visitor:
def _visit_stmt_Return(self, node): def _visit_stmt_Return(self, node):
if node.value is None: if node.value is None:
val = values.VNone() val = base_types.VNone()
else: else:
val = self.visit_expression(node.value) val = self.visit_expression(node.value)
if isinstance(val, values.VNone): if isinstance(val, base_types.VNone):
self.builder.ret_void() self.builder.ret_void()
else: else:
self.builder.ret(val.get_ssa_value(self.builder)) self.builder.ret(val.get_ssa_value(self.builder))

165
artiq/py2llvm/base_types.py Normal file
View File

@ -0,0 +1,165 @@
from llvm import core as lc
from artiq.py2llvm.values import VGeneric
class VNone(VGeneric):
def __repr__(self):
return "<VNone>"
def get_llvm_type(self):
return lc.Type.void()
def same_type(self, other):
return isinstance(other, VNone)
def merge(self, other):
if not isinstance(other, VNone):
raise TypeError
def alloca(self, builder, name):
pass
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_const_value(builder, False)
return r
class VInt(VGeneric):
def __init__(self, nbits=32):
VGeneric.__init__(self)
self.nbits = nbits
def get_llvm_type(self):
return lc.Type.int(self.nbits)
def __repr__(self):
return "<VInt:{}>".format(self.nbits)
def same_type(self, other):
return isinstance(other, VInt) and other.nbits == self.nbits
def merge(self, other):
if isinstance(other, VInt) and not isinstance(other, VBool):
if other.nbits > self.nbits:
self.nbits = other.nbits
else:
raise TypeError
def set_value(self, builder, n):
self.set_ssa_value(
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
def set_const_value(self, builder, n):
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
def o_bool(self, builder, inv=False):
r = VBool()
if builder is not None:
r.set_ssa_value(
builder, builder.icmp(
lc.ICMP_EQ if inv else lc.ICMP_NE,
self.get_ssa_value(builder),
lc.Constant.int(self.get_llvm_type(), 0)))
return r
def o_not(self, builder):
return self.o_bool(builder, True)
def o_intx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
if self.nbits == target_bits:
r.set_ssa_value(
builder, self.get_ssa_value(builder))
if self.nbits > target_bits:
r.set_ssa_value(
builder, builder.trunc(self.get_ssa_value(builder),
r.get_llvm_type()))
if self.nbits < target_bits:
r.set_ssa_value(
builder, builder.sext(self.get_ssa_value(builder),
r.get_llvm_type()))
return r
o_roundx = o_intx
def _make_vint_binop_method(builder_name):
def binop_method(self, other, builder):
if isinstance(other, VInt):
target_bits = max(self.nbits, other.nbits)
r = VInt(target_bits)
if builder is not None:
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name)
r.set_ssa_value(
builder, bf(left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r
else:
return NotImplemented
return binop_method
for _method_name, _builder_name in (("o_add", "add"),
("o_sub", "sub"),
("o_mul", "mul"),
("o_floordiv", "sdiv"),
("o_mod", "srem"),
("o_and", "and_"),
("o_xor", "xor"),
("o_or", "or_")):
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
def _make_vint_cmp_method(icmp_val):
def cmp_method(self, other, builder):
if isinstance(other, VInt):
r = VBool()
if builder is not None:
target_bits = max(self.nbits, other.nbits)
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
r.set_ssa_value(
builder,
builder.icmp(
icmp_val, left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r
else:
return NotImplemented
return cmp_method
for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
("o_ne", lc.ICMP_NE),
("o_lt", lc.ICMP_SLT),
("o_le", lc.ICMP_SLE),
("o_gt", lc.ICMP_SGT),
("o_ge", lc.ICMP_SGE)):
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
class VBool(VInt):
def __init__(self):
VInt.__init__(self, 1)
def __repr__(self):
return "<VBool>"
def same_type(self, other):
return isinstance(other, VBool)
def merge(self, other):
if not isinstance(other, VBool):
raise TypeError
def set_const_value(self, builder, b):
VInt.set_const_value(self, builder, int(b))
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_ssa_value(builder, self.get_ssa_value(builder))
return r

192
artiq/py2llvm/fractions.py Normal file
View File

@ -0,0 +1,192 @@
from llvm import core as lc
from artiq.py2llvm.values import VGeneric
from artiq.py2llvm.base_types import VBool, VInt
def _gcd64(builder, a, b):
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
return builder.call(gcd_f, [a, b])
def init_module(module):
func_type = lc.Type.function(
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
module.add_function(func_type, "__gcd64")
def _frac_normalize(builder, numerator, denominator):
gcd = _gcd64(builder, numerator, denominator)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd)
return numerator, denominator
def _frac_make_ssa(builder, numerator, denominator):
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
value = builder.insert_element(
value, numerator, lc.Constant.int(lc.Type.int(), 0))
value = builder.insert_element(
value, denominator, lc.Constant.int(lc.Type.int(), 1))
return value
class VFraction(VGeneric):
def get_llvm_type(self):
return lc.Type.vector(lc.Type.int(64), 2)
def __repr__(self):
return "<VFraction>"
def same_type(self, other):
return isinstance(other, VFraction)
def merge(self, other):
if not isinstance(other, VFraction):
raise TypeError
def _nd(self, builder, invert=False):
ssa_value = self.get_ssa_value(builder)
numerator = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 0))
denominator = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 1))
if invert:
return denominator, numerator
else:
return numerator, denominator
def set_value_nd(self, builder, numerator, denominator):
numerator = numerator.o_int64(builder).get_ssa_value(builder)
denominator = denominator.o_int64(builder).get_ssa_value(builder)
numerator, denominator = _frac_normalize(
builder, numerator, denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
def set_value(self, builder, n):
if not isinstance(n, VFraction):
raise TypeError
self.set_ssa_value(builder, n.get_ssa_value(builder))
def o_bool(self, builder):
r = VBool()
if builder is not None:
zero = lc.Constant.int(lc.Type.int(64), 0)
numerator = builder.extract_element(
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
return r
def o_intx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
numerator, denominator = self._nd(builder)
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
return r.o_intx(target_bits, builder)
def o_roundx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
numerator, denominator = self._nd(builder)
h_denominator = builder.ashr(denominator,
lc.Constant.int(lc.Type.int(), 1))
r_numerator = builder.add(numerator, h_denominator)
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
return r.o_intx(target_bits, builder)
def _o_eq_inv(self, other, builder, ne):
if isinstance(other, VFraction):
r = VBool()
if builder is not None:
ee = []
for i in range(2):
es = builder.extract_element(
self.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
eo = builder.extract_element(
other.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
ssa_r = builder.and_(ee[0], ee[1])
if ne:
ssa_r = builder.xor(ssa_r,
lc.Constant.int(lc.Type.int(1), 1))
r.set_ssa_value(builder, ssa_r)
return r
else:
return NotImplemented
def o_eq(self, other, builder):
return self._o_eq_inv(other, builder, False)
def o_ne(self, other, builder):
return self._o_eq_inv(other, builder, True)
def _o_muldiv(self, other, builder, div, invert=False):
r = VFraction()
if isinstance(other, VInt):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
i = other.get_ssa_value(builder)
if div:
gcd = _gcd64(i, numerator)
i = builder.sdiv(i, gcd)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.mul(denominator, i)
else:
gcd = _gcd64(i, denominator)
i = builder.sdiv(i, gcd)
denominator = builder.sdiv(denominator, gcd)
numerator = builder.mul(numerator, i)
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
denominator))
elif isinstance(other, VFraction):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
onumerator, odenominator = other._nd(builder)
if div:
numerator = builder.mul(numerator, odenominator)
denominator = builder.mul(denominator, onumerator)
else:
numerator = builder.mul(numerator, onumerator)
denominator = builder.mul(denominator, odenominator)
numerator, denominator = _frac_normalize(builder, numerator,
denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
else:
return NotImplemented
def o_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def o_truediv(self, other, builder):
return self._o_muldiv(other, builder, True)
def or_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def or_truediv(self, other, builder):
return self._o_muldiv(other, builder, False, True)
def o_floordiv(self, other, builder):
r = self.o_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
def or_floordiv(self, other, builder):
r = self.or_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)

View File

@ -1,31 +0,0 @@
from llvm import core as lc
from artiq.py2llvm import infer_types, ast_body, values, tools
def compile_function(module, env, funcdef, param_types):
ns = infer_types.infer_function_types(env, funcdef, param_types)
retval = ns["return"]
function_type = lc.Type.function(retval.get_llvm_type(),
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
function = module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
arg_llvm.name = arg_ast.arg
for k, v in ns.items():
v.alloca(builder, k)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
visitor = ast_body.Visitor(env, ns, builder)
visitor.visit_statements(funcdef.body)
if not tools.is_terminated(builder.basic_block):
if isinstance(retval, values.VNone):
builder.ret_void()
else:
builder.ret(retval.get_ssa_value(builder))
return function, retval

View File

@ -2,7 +2,7 @@ import ast
from copy import deepcopy from copy import deepcopy
from artiq.py2llvm.ast_body import Visitor from artiq.py2llvm.ast_body import Visitor
from artiq.py2llvm import values from artiq.py2llvm import base_types
class _TypeScanner(ast.NodeVisitor): class _TypeScanner(ast.NodeVisitor):
@ -31,7 +31,7 @@ class _TypeScanner(ast.NodeVisitor):
def visit_Return(self, node): def visit_Return(self, node):
if node.value is None: if node.value is None:
val = values.VNone() val = base_types.VNone()
else: else:
val = self.exprv.visit_expression(node.value) val = self.exprv.visit_expression(node.value)
ns = self.exprv.ns ns = self.exprv.ns
@ -51,5 +51,5 @@ def infer_function_types(env, node, param_types):
if all(v.same_type(prev_ns[k]) for k, v in ns.items()): if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
# no more promotions - completed # no more promotions - completed
if "return" not in ns: if "return" not in ns:
ns["return"] = values.VNone() ns["return"] = base_types.VNone()
return ns return ns

59
artiq/py2llvm/module.py Normal file
View File

@ -0,0 +1,59 @@
from llvm import core as lc
from llvm import passes as lp
from llvm import ee as le
from artiq.py2llvm import infer_types, ast_body, base_types, fractions, tools
class Module:
def __init__(self, env=None):
self.llvm_module = lc.Module.new("main")
self.env = env
if self.env is not None:
self.env.init_module(self.llvm_module)
fractions.init_module(self.llvm_module)
def finalize(self):
pass_manager = lp.PassManager.new()
pass_manager.add(lp.PASS_MEM2REG)
pass_manager.add(lp.PASS_INSTCOMBINE)
pass_manager.add(lp.PASS_REASSOCIATE)
pass_manager.add(lp.PASS_GVN)
pass_manager.add(lp.PASS_SIMPLIFYCFG)
pass_manager.run(self.llvm_module)
def get_ee(self):
return le.ExecutionEngine.new(self.llvm_module)
def emit_object(self):
self.finalize()
return self.env.emit_object()
def compile_function(self, funcdef, param_types):
ns = infer_types.infer_function_types(self.env, funcdef, param_types)
retval = ns["return"]
function_type = lc.Type.function(retval.get_llvm_type(),
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
function = self.llvm_module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
arg_llvm.name = arg_ast.arg
for k, v in ns.items():
v.alloca(builder, k)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
visitor = ast_body.Visitor(self.env, ns, builder)
visitor.visit_statements(funcdef.body)
if not tools.is_terminated(builder.basic_block):
if isinstance(retval, base_types.VNone):
builder.ret_void()
else:
builder.ret(retval.get_ssa_value(builder))
return function, retval

View File

@ -1,11 +1,2 @@
from llvm import passes as lp
def is_terminated(basic_block): def is_terminated(basic_block):
return basic_block.instructions and basic_block.instructions[-1].is_terminator return basic_block.instructions and basic_block.instructions[-1].is_terminator
def add_common_passes(pass_manager):
pass_manager.add(lp.PASS_MEM2REG)
pass_manager.add(lp.PASS_INSTCOMBINE)
pass_manager.add(lp.PASS_REASSOCIATE)
pass_manager.add(lp.PASS_GVN)
pass_manager.add(lp.PASS_SIMPLIFYCFG)

View File

@ -3,7 +3,7 @@ from types import SimpleNamespace
from llvm import core as lc from llvm import core as lc
class _Value: class VGeneric:
def __init__(self): def __init__(self):
self._llvm_value = None self._llvm_value = None
@ -40,361 +40,6 @@ class _Value:
return self.o_roundx(64, builder) return self.o_roundx(64, builder)
# None type
class VNone(_Value):
def __repr__(self):
return "<VNone>"
def get_llvm_type(self):
return lc.Type.void()
def same_type(self, other):
return isinstance(other, VNone)
def merge(self, other):
if not isinstance(other, VNone):
raise TypeError
def alloca(self, builder, name):
pass
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_const_value(builder, False)
return r
# Integer type
class VInt(_Value):
def __init__(self, nbits=32):
_Value.__init__(self)
self.nbits = nbits
def get_llvm_type(self):
return lc.Type.int(self.nbits)
def __repr__(self):
return "<VInt:{}>".format(self.nbits)
def same_type(self, other):
return isinstance(other, VInt) and other.nbits == self.nbits
def merge(self, other):
if isinstance(other, VInt) and not isinstance(other, VBool):
if other.nbits > self.nbits:
self.nbits = other.nbits
else:
raise TypeError
def set_value(self, builder, n):
self.set_ssa_value(
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
def set_const_value(self, builder, n):
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
def o_bool(self, builder, inv=False):
r = VBool()
if builder is not None:
r.set_ssa_value(
builder, builder.icmp(
lc.ICMP_EQ if inv else lc.ICMP_NE,
self.get_ssa_value(builder),
lc.Constant.int(self.get_llvm_type(), 0)))
return r
def o_not(self, builder):
return self.o_bool(builder, True)
def o_intx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
if self.nbits == target_bits:
r.set_ssa_value(
builder, self.get_ssa_value(builder))
if self.nbits > target_bits:
r.set_ssa_value(
builder, builder.trunc(self.get_ssa_value(builder),
r.get_llvm_type()))
if self.nbits < target_bits:
r.set_ssa_value(
builder, builder.sext(self.get_ssa_value(builder),
r.get_llvm_type()))
return r
o_roundx = o_intx
def _make_vint_binop_method(builder_name):
def binop_method(self, other, builder):
if isinstance(other, VInt):
target_bits = max(self.nbits, other.nbits)
r = VInt(target_bits)
if builder is not None:
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name)
r.set_ssa_value(
builder, bf(left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r
else:
return NotImplemented
return binop_method
for _method_name, _builder_name in (("o_add", "add"),
("o_sub", "sub"),
("o_mul", "mul"),
("o_floordiv", "sdiv"),
("o_mod", "srem"),
("o_and", "and_"),
("o_xor", "xor"),
("o_or", "or_")):
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
def _make_vint_cmp_method(icmp_val):
def cmp_method(self, other, builder):
if isinstance(other, VInt):
r = VBool()
if builder is not None:
target_bits = max(self.nbits, other.nbits)
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
r.set_ssa_value(
builder,
builder.icmp(
icmp_val, left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r
else:
return NotImplemented
return cmp_method
for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
("o_ne", lc.ICMP_NE),
("o_lt", lc.ICMP_SLT),
("o_le", lc.ICMP_SLE),
("o_gt", lc.ICMP_SGT),
("o_ge", lc.ICMP_SGE)):
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
# Boolean type
class VBool(VInt):
def __init__(self):
VInt.__init__(self, 1)
def __repr__(self):
return "<VBool>"
def same_type(self, other):
return isinstance(other, VBool)
def merge(self, other):
if not isinstance(other, VBool):
raise TypeError
def set_const_value(self, builder, b):
VInt.set_const_value(self, builder, int(b))
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_ssa_value(builder, self.get_ssa_value(builder))
return r
# Fraction type
def _gcd64(builder, a, b):
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
return builder.call(gcd_f, [a, b])
def _frac_normalize(builder, numerator, denominator):
gcd = _gcd64(builder, numerator, denominator)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd)
return numerator, denominator
def _frac_make_ssa(builder, numerator, denominator):
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
value = builder.insert_element(
value, numerator, lc.Constant.int(lc.Type.int(), 0))
value = builder.insert_element(
value, denominator, lc.Constant.int(lc.Type.int(), 1))
return value
class VFraction(_Value):
def get_llvm_type(self):
return lc.Type.vector(lc.Type.int(64), 2)
def __repr__(self):
return "<VFraction>"
def same_type(self, other):
return isinstance(other, VFraction)
def merge(self, other):
if not isinstance(other, VFraction):
raise TypeError
def _nd(self, builder, invert=False):
ssa_value = self.get_ssa_value(builder)
numerator = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 0))
denominator = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 1))
if invert:
return denominator, numerator
else:
return numerator, denominator
def set_value_nd(self, builder, numerator, denominator):
numerator = numerator.o_int64(builder).get_ssa_value(builder)
denominator = denominator.o_int64(builder).get_ssa_value(builder)
numerator, denominator = _frac_normalize(
builder, numerator, denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
def set_value(self, builder, n):
if not isinstance(n, VFraction):
raise TypeError
self.set_ssa_value(builder, n.get_ssa_value(builder))
def o_bool(self, builder):
r = VBool()
if builder is not None:
zero = lc.Constant.int(lc.Type.int(64), 0)
numerator = builder.extract_element(
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
return r
def o_intx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
numerator, denominator = self._nd(builder)
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
return r.o_intx(target_bits, builder)
def o_roundx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
numerator, denominator = self._nd(builder)
h_denominator = builder.ashr(denominator,
lc.Constant.int(lc.Type.int(), 1))
r_numerator = builder.add(numerator, h_denominator)
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
return r.o_intx(target_bits, builder)
def _o_eq_inv(self, other, builder, ne):
if isinstance(other, VFraction):
r = VBool()
if builder is not None:
ee = []
for i in range(2):
es = builder.extract_element(
self.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
eo = builder.extract_element(
other.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
ssa_r = builder.and_(ee[0], ee[1])
if ne:
ssa_r = builder.xor(ssa_r,
lc.Constant.int(lc.Type.int(1), 1))
r.set_ssa_value(builder, ssa_r)
return r
else:
return NotImplemented
def o_eq(self, other, builder):
return self._o_eq_inv(other, builder, False)
def o_ne(self, other, builder):
return self._o_eq_inv(other, builder, True)
def _o_muldiv(self, other, builder, div, invert=False):
r = VFraction()
if isinstance(other, VInt):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
i = other.get_ssa_value(builder)
if div:
gcd = _gcd64(i, numerator)
i = builder.sdiv(i, gcd)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.mul(denominator, i)
else:
gcd = _gcd64(i, denominator)
i = builder.sdiv(i, gcd)
denominator = builder.sdiv(denominator, gcd)
numerator = builder.mul(numerator, i)
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
denominator))
elif isinstance(other, VFraction):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
onumerator, odenominator = other._nd(builder)
if div:
numerator = builder.mul(numerator, odenominator)
denominator = builder.mul(denominator, onumerator)
else:
numerator = builder.mul(numerator, onumerator)
denominator = builder.mul(denominator, odenominator)
numerator, denominator = _frac_normalize(builder, numerator,
denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
else:
return NotImplemented
def o_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def o_truediv(self, other, builder):
return self._o_muldiv(other, builder, True)
def or_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def or_truediv(self, other, builder):
return self._o_muldiv(other, builder, False, True)
def o_floordiv(self, other, builder):
r = self.o_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
def or_floordiv(self, other, builder):
r = self.or_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
# Operators
def _make_unary_operator(op_name): def _make_unary_operator(op_name):
def op(x, builder): def op(x, builder):
try: try:
@ -446,9 +91,3 @@ def _make_operators():
return SimpleNamespace(**d) return SimpleNamespace(**d)
operators = _make_operators() operators = _make_operators()
def init_module(module):
func_type = lc.Type.function(
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
module.add_function(func_type, "__gcd64")

View File

@ -2,15 +2,12 @@ import unittest
import ast import ast
import inspect import inspect
from llvm import core as lc
from llvm import passes as lp
from llvm import ee as le from llvm import ee as le
from artiq.language.core import int64 from artiq.language.core import int64
from artiq.py2llvm.infer_types import infer_function_types from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import values from artiq.py2llvm import base_types
from artiq.py2llvm import compile_function from artiq.py2llvm.module import Module
from artiq.py2llvm.tools import add_common_passes
def test_types(choice): def test_types(choice):
@ -35,46 +32,40 @@ class FunctionTypesCase(unittest.TestCase):
dict()) dict())
def test_base_types(self): def test_base_types(self):
self.assertIsInstance(self.ns["foo"], values.VBool) self.assertIsInstance(self.ns["foo"], base_types.VBool)
self.assertIsInstance(self.ns["bar"], values.VNone) self.assertIsInstance(self.ns["bar"], base_types.VNone)
self.assertIsInstance(self.ns["d"], values.VInt) self.assertIsInstance(self.ns["d"], base_types.VInt)
self.assertEqual(self.ns["d"].nbits, 32) self.assertEqual(self.ns["d"].nbits, 32)
self.assertIsInstance(self.ns["x"], values.VInt) self.assertIsInstance(self.ns["x"], base_types.VInt)
self.assertEqual(self.ns["x"].nbits, 64) self.assertEqual(self.ns["x"].nbits, 64)
def test_promotion(self): def test_promotion(self):
for v in "abc": for v in "abc":
self.assertIsInstance(self.ns[v], values.VInt) self.assertIsInstance(self.ns[v], base_types.VInt)
self.assertEqual(self.ns[v].nbits, 64) self.assertEqual(self.ns[v].nbits, 64)
def test_return(self): def test_return(self):
self.assertIsInstance(self.ns["return"], values.VInt) self.assertIsInstance(self.ns["return"], base_types.VInt)
self.assertEqual(self.ns["return"].nbits, 64) self.assertEqual(self.ns["return"].nbits, 64)
class CompiledFunction: class CompiledFunction:
def __init__(self, function, param_types): def __init__(self, function, param_types):
module = lc.Module.new("main") module = Module()
values.init_module(module)
funcdef = ast.parse(inspect.getsource(function)).body[0] funcdef = ast.parse(inspect.getsource(function)).body[0]
self.function, self.retval = compile_function( self.function, self.retval = module.compile_function(
module, None, funcdef, param_types) funcdef, param_types)
self.argval = [param_types[arg.arg] for arg in funcdef.args.args] self.argval = [param_types[arg.arg] for arg in funcdef.args.args]
self.ee = module.get_ee()
self.executor = le.ExecutionEngine.new(module)
pass_manager = lp.PassManager.new()
add_common_passes(pass_manager)
pass_manager.run(module)
def __call__(self, *args): def __call__(self, *args):
args_llvm = [ args_llvm = [
le.GenericValue.int(av.get_llvm_type(), a) le.GenericValue.int(av.get_llvm_type(), a)
for av, a in zip(self.argval, args)] for av, a in zip(self.argval, args)]
result = self.executor.run_function(self.function, args_llvm) result = self.ee.run_function(self.function, args_llvm)
if isinstance(self.retval, values.VBool): if isinstance(self.retval, base_types.VBool):
return bool(result.as_int()) return bool(result.as_int())
elif isinstance(self.retval, values.VInt): elif isinstance(self.retval, base_types.VInt):
return result.as_int_signed() return result.as_int_signed()
else: else:
raise NotImplementedError raise NotImplementedError
@ -90,6 +81,6 @@ def is_prime(x):
class CodeGenCase(unittest.TestCase): class CodeGenCase(unittest.TestCase):
def test_is_prime(self): def test_is_prime(self):
is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)}) is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt(32)})
for i in range(200): for i in range(200):
self.assertEqual(is_prime_c(i), is_prime(i)) self.assertEqual(is_prime_c(i), is_prime(i))