From 3c8b5419395f1ae97c8bf9705b27714d0ea94a46 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sun, 7 Sep 2014 14:09:03 +0800 Subject: [PATCH] py2llvm: reorganize, split 'values' module, factor LLVM module/pass management --- artiq/devices/runtime.py | 8 +- artiq/py2llvm/__init__.py | 22 +-- artiq/py2llvm/ast_body.py | 16 +- artiq/py2llvm/base_types.py | 165 ++++++++++++++++ artiq/py2llvm/fractions.py | 192 ++++++++++++++++++ artiq/py2llvm/functions.py | 31 --- artiq/py2llvm/infer_types.py | 6 +- artiq/py2llvm/module.py | 59 ++++++ artiq/py2llvm/tools.py | 9 - artiq/py2llvm/values.py | 363 +---------------------------------- test/py2llvm.py | 41 ++-- 11 files changed, 452 insertions(+), 460 deletions(-) create mode 100644 artiq/py2llvm/base_types.py create mode 100644 artiq/py2llvm/fractions.py delete mode 100644 artiq/py2llvm/functions.py create mode 100644 artiq/py2llvm/module.py diff --git a/artiq/devices/runtime.py b/artiq/devices/runtime.py index cfc0beb71..2554fced9 100644 --- a/artiq/devices/runtime.py +++ b/artiq/devices/runtime.py @@ -1,7 +1,7 @@ from llvm import core as lc from llvm import target as lt -from artiq.py2llvm import values +from artiq.py2llvm import base_types lt.initialize_all() @@ -21,9 +21,9 @@ _chr_to_type = { } _chr_to_value = { - "n": lambda: values.VNone(), - "i": lambda: values.VInt(), - "I": lambda: values.VInt(64) + "n": lambda: base_types.VNone(), + "i": lambda: base_types.VInt(), + "I": lambda: base_types.VInt(64) } diff --git a/artiq/py2llvm/__init__.py b/artiq/py2llvm/__init__.py index 6f23f0fc0..88dd35808 100644 --- a/artiq/py2llvm/__init__.py +++ b/artiq/py2llvm/__init__.py @@ -1,20 +1,6 @@ -from llvm import core as lc -from llvm import passes as lp - -from artiq.py2llvm import values -from artiq.py2llvm.functions import compile_function -from artiq.py2llvm.tools import add_common_passes - +from artiq.py2llvm.module import Module def get_runtime_binary(env, funcdef): - module = lc.Module.new("main") - env.init_module(module) - values.init_module(module) - - compile_function(module, env, funcdef, dict()) - - pass_manager = lp.PassManager.new() - add_common_passes(pass_manager) - pass_manager.run(module) - - return env.emit_object() + module = Module(env) + module.compile_function(funcdef, dict()) + return module.emit_object() diff --git a/artiq/py2llvm/ast_body.py b/artiq/py2llvm/ast_body.py index 86d731251..c14b2dfdd 100644 --- a/artiq/py2llvm/ast_body.py +++ b/artiq/py2llvm/ast_body.py @@ -1,6 +1,6 @@ import ast -from artiq.py2llvm import values +from artiq.py2llvm import values, base_types, fractions from artiq.py2llvm.tools import is_terminated @@ -30,9 +30,9 @@ class Visitor: def _visit_expr_NameConstant(self, node): v = node.value if v is None: - r = values.VNone() + r = base_types.VNone() elif isinstance(v, bool): - r = values.VBool() + r = base_types.VBool() else: raise NotImplementedError if self.builder is not None: @@ -43,9 +43,9 @@ class Visitor: n = node.n if isinstance(n, int): if abs(n) < 2**31: - r = values.VInt() + r = base_types.VInt() else: - r = values.VInt(64) + r = base_types.VInt(64) else: raise NotImplementedError if self.builder is not None: @@ -116,7 +116,7 @@ class Visitor: return ast_unfuns[fn](self.visit_expression(node.args[0]), self.builder) elif fn == "Fraction": - r = values.VFraction() + r = fractions.VFraction() if self.builder is not None: numerator = self.visit_expression(node.args[0]) denominator = self.visit_expression(node.args[1]) @@ -213,10 +213,10 @@ class Visitor: def _visit_stmt_Return(self, node): if node.value is None: - val = values.VNone() + val = base_types.VNone() else: val = self.visit_expression(node.value) - if isinstance(val, values.VNone): + if isinstance(val, base_types.VNone): self.builder.ret_void() else: self.builder.ret(val.get_ssa_value(self.builder)) diff --git a/artiq/py2llvm/base_types.py b/artiq/py2llvm/base_types.py new file mode 100644 index 000000000..a534795ca --- /dev/null +++ b/artiq/py2llvm/base_types.py @@ -0,0 +1,165 @@ +from llvm import core as lc + +from artiq.py2llvm.values import VGeneric + + +class VNone(VGeneric): + def __repr__(self): + return "" + + def get_llvm_type(self): + return lc.Type.void() + + def same_type(self, other): + return isinstance(other, VNone) + + def merge(self, other): + if not isinstance(other, VNone): + raise TypeError + + def alloca(self, builder, name): + pass + + def o_bool(self, builder): + r = VBool() + if builder is not None: + r.set_const_value(builder, False) + return r + + +class VInt(VGeneric): + def __init__(self, nbits=32): + VGeneric.__init__(self) + self.nbits = nbits + + def get_llvm_type(self): + return lc.Type.int(self.nbits) + + def __repr__(self): + return "".format(self.nbits) + + def same_type(self, other): + return isinstance(other, VInt) and other.nbits == self.nbits + + def merge(self, other): + if isinstance(other, VInt) and not isinstance(other, VBool): + if other.nbits > self.nbits: + self.nbits = other.nbits + else: + raise TypeError + + def set_value(self, builder, n): + self.set_ssa_value( + builder, n.o_intx(self.nbits, builder).get_ssa_value(builder)) + + def set_const_value(self, builder, n): + self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n)) + + def o_bool(self, builder, inv=False): + r = VBool() + if builder is not None: + r.set_ssa_value( + builder, builder.icmp( + lc.ICMP_EQ if inv else lc.ICMP_NE, + self.get_ssa_value(builder), + lc.Constant.int(self.get_llvm_type(), 0))) + return r + + def o_not(self, builder): + return self.o_bool(builder, True) + + def o_intx(self, target_bits, builder): + r = VInt(target_bits) + if builder is not None: + if self.nbits == target_bits: + r.set_ssa_value( + builder, self.get_ssa_value(builder)) + if self.nbits > target_bits: + r.set_ssa_value( + builder, builder.trunc(self.get_ssa_value(builder), + r.get_llvm_type())) + if self.nbits < target_bits: + r.set_ssa_value( + builder, builder.sext(self.get_ssa_value(builder), + r.get_llvm_type())) + return r + o_roundx = o_intx + + +def _make_vint_binop_method(builder_name): + def binop_method(self, other, builder): + if isinstance(other, VInt): + target_bits = max(self.nbits, other.nbits) + r = VInt(target_bits) + if builder is not None: + left = self.o_intx(target_bits, builder) + right = other.o_intx(target_bits, builder) + bf = getattr(builder, builder_name) + r.set_ssa_value( + builder, bf(left.get_ssa_value(builder), + right.get_ssa_value(builder))) + return r + else: + return NotImplemented + return binop_method + +for _method_name, _builder_name in (("o_add", "add"), + ("o_sub", "sub"), + ("o_mul", "mul"), + ("o_floordiv", "sdiv"), + ("o_mod", "srem"), + ("o_and", "and_"), + ("o_xor", "xor"), + ("o_or", "or_")): + setattr(VInt, _method_name, _make_vint_binop_method(_builder_name)) + + +def _make_vint_cmp_method(icmp_val): + def cmp_method(self, other, builder): + if isinstance(other, VInt): + r = VBool() + if builder is not None: + target_bits = max(self.nbits, other.nbits) + left = self.o_intx(target_bits, builder) + right = other.o_intx(target_bits, builder) + r.set_ssa_value( + builder, + builder.icmp( + icmp_val, left.get_ssa_value(builder), + right.get_ssa_value(builder))) + return r + else: + return NotImplemented + return cmp_method + +for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ), + ("o_ne", lc.ICMP_NE), + ("o_lt", lc.ICMP_SLT), + ("o_le", lc.ICMP_SLE), + ("o_gt", lc.ICMP_SGT), + ("o_ge", lc.ICMP_SGE)): + setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val)) + + +class VBool(VInt): + def __init__(self): + VInt.__init__(self, 1) + + def __repr__(self): + return "" + + def same_type(self, other): + return isinstance(other, VBool) + + def merge(self, other): + if not isinstance(other, VBool): + raise TypeError + + def set_const_value(self, builder, b): + VInt.set_const_value(self, builder, int(b)) + + def o_bool(self, builder): + r = VBool() + if builder is not None: + r.set_ssa_value(builder, self.get_ssa_value(builder)) + return r diff --git a/artiq/py2llvm/fractions.py b/artiq/py2llvm/fractions.py new file mode 100644 index 000000000..ff46e90ae --- /dev/null +++ b/artiq/py2llvm/fractions.py @@ -0,0 +1,192 @@ +from llvm import core as lc + +from artiq.py2llvm.values import VGeneric +from artiq.py2llvm.base_types import VBool, VInt + + +def _gcd64(builder, a, b): + gcd_f = builder.basic_block.function.module.get_function_named("__gcd64") + return builder.call(gcd_f, [a, b]) + +def init_module(module): + func_type = lc.Type.function( + lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)]) + module.add_function(func_type, "__gcd64") + + +def _frac_normalize(builder, numerator, denominator): + gcd = _gcd64(builder, numerator, denominator) + numerator = builder.sdiv(numerator, gcd) + denominator = builder.sdiv(denominator, gcd) + return numerator, denominator + + +def _frac_make_ssa(builder, numerator, denominator): + value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) + value = builder.insert_element( + value, numerator, lc.Constant.int(lc.Type.int(), 0)) + value = builder.insert_element( + value, denominator, lc.Constant.int(lc.Type.int(), 1)) + return value + + +class VFraction(VGeneric): + def get_llvm_type(self): + return lc.Type.vector(lc.Type.int(64), 2) + + def __repr__(self): + return "" + + def same_type(self, other): + return isinstance(other, VFraction) + + def merge(self, other): + if not isinstance(other, VFraction): + raise TypeError + + def _nd(self, builder, invert=False): + ssa_value = self.get_ssa_value(builder) + numerator = builder.extract_element( + ssa_value, lc.Constant.int(lc.Type.int(), 0)) + denominator = builder.extract_element( + ssa_value, lc.Constant.int(lc.Type.int(), 1)) + if invert: + return denominator, numerator + else: + return numerator, denominator + + def set_value_nd(self, builder, numerator, denominator): + numerator = numerator.o_int64(builder).get_ssa_value(builder) + denominator = denominator.o_int64(builder).get_ssa_value(builder) + numerator, denominator = _frac_normalize( + builder, numerator, denominator) + self.set_ssa_value( + builder, _frac_make_ssa(builder, numerator, denominator)) + + def set_value(self, builder, n): + if not isinstance(n, VFraction): + raise TypeError + self.set_ssa_value(builder, n.get_ssa_value(builder)) + + def o_bool(self, builder): + r = VBool() + if builder is not None: + zero = lc.Constant.int(lc.Type.int(64), 0) + numerator = builder.extract_element( + self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) + r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero)) + return r + + def o_intx(self, target_bits, builder): + if builder is None: + return VInt(target_bits) + else: + r = VInt(64) + numerator, denominator = self._nd(builder) + r.set_ssa_value(builder, builder.sdiv(numerator, denominator)) + return r.o_intx(target_bits, builder) + + def o_roundx(self, target_bits, builder): + if builder is None: + return VInt(target_bits) + else: + r = VInt(64) + numerator, denominator = self._nd(builder) + h_denominator = builder.ashr(denominator, + lc.Constant.int(lc.Type.int(), 1)) + r_numerator = builder.add(numerator, h_denominator) + r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) + return r.o_intx(target_bits, builder) + + def _o_eq_inv(self, other, builder, ne): + if isinstance(other, VFraction): + r = VBool() + if builder is not None: + ee = [] + for i in range(2): + es = builder.extract_element( + self.get_ssa_value(builder), + lc.Constant.int(lc.Type.int(), i)) + eo = builder.extract_element( + other.get_ssa_value(builder), + lc.Constant.int(lc.Type.int(), i)) + ee.append(builder.icmp(lc.ICMP_EQ, es, eo)) + ssa_r = builder.and_(ee[0], ee[1]) + if ne: + ssa_r = builder.xor(ssa_r, + lc.Constant.int(lc.Type.int(1), 1)) + r.set_ssa_value(builder, ssa_r) + return r + else: + return NotImplemented + + def o_eq(self, other, builder): + return self._o_eq_inv(other, builder, False) + + def o_ne(self, other, builder): + return self._o_eq_inv(other, builder, True) + + def _o_muldiv(self, other, builder, div, invert=False): + r = VFraction() + if isinstance(other, VInt): + if builder is None: + return r + else: + numerator, denominator = self._nd(builder, invert) + i = other.get_ssa_value(builder) + if div: + gcd = _gcd64(i, numerator) + i = builder.sdiv(i, gcd) + numerator = builder.sdiv(numerator, gcd) + denominator = builder.mul(denominator, i) + else: + gcd = _gcd64(i, denominator) + i = builder.sdiv(i, gcd) + denominator = builder.sdiv(denominator, gcd) + numerator = builder.mul(numerator, i) + self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, + denominator)) + elif isinstance(other, VFraction): + if builder is None: + return r + else: + numerator, denominator = self._nd(builder, invert) + onumerator, odenominator = other._nd(builder) + if div: + numerator = builder.mul(numerator, odenominator) + denominator = builder.mul(denominator, onumerator) + else: + numerator = builder.mul(numerator, onumerator) + denominator = builder.mul(denominator, odenominator) + numerator, denominator = _frac_normalize(builder, numerator, + denominator) + self.set_ssa_value( + builder, _frac_make_ssa(builder, numerator, denominator)) + else: + return NotImplemented + + def o_mul(self, other, builder): + return self._o_muldiv(other, builder, False) + + def o_truediv(self, other, builder): + return self._o_muldiv(other, builder, True) + + def or_mul(self, other, builder): + return self._o_muldiv(other, builder, False) + + def or_truediv(self, other, builder): + return self._o_muldiv(other, builder, False, True) + + def o_floordiv(self, other, builder): + r = self.o_truediv(other, builder) + if r is NotImplemented: + return r + else: + return r.o_int(builder) + + def or_floordiv(self, other, builder): + r = self.or_truediv(other, builder) + if r is NotImplemented: + return r + else: + return r.o_int(builder) diff --git a/artiq/py2llvm/functions.py b/artiq/py2llvm/functions.py deleted file mode 100644 index cab173105..000000000 --- a/artiq/py2llvm/functions.py +++ /dev/null @@ -1,31 +0,0 @@ -from llvm import core as lc - -from artiq.py2llvm import infer_types, ast_body, values, tools - -def compile_function(module, env, funcdef, param_types): - ns = infer_types.infer_function_types(env, funcdef, param_types) - retval = ns["return"] - - function_type = lc.Type.function(retval.get_llvm_type(), - [ns[arg.arg].get_llvm_type() for arg in funcdef.args.args]) - function = module.add_function(function_type, funcdef.name) - bb = function.append_basic_block("entry") - builder = lc.Builder.new(bb) - - for arg_ast, arg_llvm in zip(funcdef.args.args, function.args): - arg_llvm.name = arg_ast.arg - for k, v in ns.items(): - v.alloca(builder, k) - for arg_ast, arg_llvm in zip(funcdef.args.args, function.args): - ns[arg_ast.arg].set_ssa_value(builder, arg_llvm) - - visitor = ast_body.Visitor(env, ns, builder) - visitor.visit_statements(funcdef.body) - - if not tools.is_terminated(builder.basic_block): - if isinstance(retval, values.VNone): - builder.ret_void() - else: - builder.ret(retval.get_ssa_value(builder)) - - return function, retval diff --git a/artiq/py2llvm/infer_types.py b/artiq/py2llvm/infer_types.py index 5e542db34..9cd830735 100644 --- a/artiq/py2llvm/infer_types.py +++ b/artiq/py2llvm/infer_types.py @@ -2,7 +2,7 @@ import ast from copy import deepcopy from artiq.py2llvm.ast_body import Visitor -from artiq.py2llvm import values +from artiq.py2llvm import base_types class _TypeScanner(ast.NodeVisitor): @@ -31,7 +31,7 @@ class _TypeScanner(ast.NodeVisitor): def visit_Return(self, node): if node.value is None: - val = values.VNone() + val = base_types.VNone() else: val = self.exprv.visit_expression(node.value) ns = self.exprv.ns @@ -51,5 +51,5 @@ def infer_function_types(env, node, param_types): if all(v.same_type(prev_ns[k]) for k, v in ns.items()): # no more promotions - completed if "return" not in ns: - ns["return"] = values.VNone() + ns["return"] = base_types.VNone() return ns diff --git a/artiq/py2llvm/module.py b/artiq/py2llvm/module.py new file mode 100644 index 000000000..4dfd646ba --- /dev/null +++ b/artiq/py2llvm/module.py @@ -0,0 +1,59 @@ +from llvm import core as lc +from llvm import passes as lp +from llvm import ee as le + +from artiq.py2llvm import infer_types, ast_body, base_types, fractions, tools + + +class Module: + def __init__(self, env=None): + self.llvm_module = lc.Module.new("main") + self.env = env + + if self.env is not None: + self.env.init_module(self.llvm_module) + fractions.init_module(self.llvm_module) + + def finalize(self): + pass_manager = lp.PassManager.new() + pass_manager.add(lp.PASS_MEM2REG) + pass_manager.add(lp.PASS_INSTCOMBINE) + pass_manager.add(lp.PASS_REASSOCIATE) + pass_manager.add(lp.PASS_GVN) + pass_manager.add(lp.PASS_SIMPLIFYCFG) + pass_manager.run(self.llvm_module) + + def get_ee(self): + return le.ExecutionEngine.new(self.llvm_module) + + def emit_object(self): + self.finalize() + return self.env.emit_object() + + def compile_function(self, funcdef, param_types): + ns = infer_types.infer_function_types(self.env, funcdef, param_types) + retval = ns["return"] + + function_type = lc.Type.function(retval.get_llvm_type(), + [ns[arg.arg].get_llvm_type() for arg in funcdef.args.args]) + function = self.llvm_module.add_function(function_type, funcdef.name) + bb = function.append_basic_block("entry") + builder = lc.Builder.new(bb) + + for arg_ast, arg_llvm in zip(funcdef.args.args, function.args): + arg_llvm.name = arg_ast.arg + for k, v in ns.items(): + v.alloca(builder, k) + for arg_ast, arg_llvm in zip(funcdef.args.args, function.args): + ns[arg_ast.arg].set_ssa_value(builder, arg_llvm) + + visitor = ast_body.Visitor(self.env, ns, builder) + visitor.visit_statements(funcdef.body) + + if not tools.is_terminated(builder.basic_block): + if isinstance(retval, base_types.VNone): + builder.ret_void() + else: + builder.ret(retval.get_ssa_value(builder)) + + return function, retval diff --git a/artiq/py2llvm/tools.py b/artiq/py2llvm/tools.py index bdc3a0791..067a55b61 100644 --- a/artiq/py2llvm/tools.py +++ b/artiq/py2llvm/tools.py @@ -1,11 +1,2 @@ -from llvm import passes as lp - def is_terminated(basic_block): return basic_block.instructions and basic_block.instructions[-1].is_terminator - -def add_common_passes(pass_manager): - pass_manager.add(lp.PASS_MEM2REG) - pass_manager.add(lp.PASS_INSTCOMBINE) - pass_manager.add(lp.PASS_REASSOCIATE) - pass_manager.add(lp.PASS_GVN) - pass_manager.add(lp.PASS_SIMPLIFYCFG) diff --git a/artiq/py2llvm/values.py b/artiq/py2llvm/values.py index 1760cb8b5..8adbd72a7 100644 --- a/artiq/py2llvm/values.py +++ b/artiq/py2llvm/values.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from llvm import core as lc -class _Value: +class VGeneric: def __init__(self): self._llvm_value = None @@ -40,361 +40,6 @@ class _Value: return self.o_roundx(64, builder) -# None type - -class VNone(_Value): - def __repr__(self): - return "" - - def get_llvm_type(self): - return lc.Type.void() - - def same_type(self, other): - return isinstance(other, VNone) - - def merge(self, other): - if not isinstance(other, VNone): - raise TypeError - - def alloca(self, builder, name): - pass - - def o_bool(self, builder): - r = VBool() - if builder is not None: - r.set_const_value(builder, False) - return r - - -# Integer type - -class VInt(_Value): - def __init__(self, nbits=32): - _Value.__init__(self) - self.nbits = nbits - - def get_llvm_type(self): - return lc.Type.int(self.nbits) - - def __repr__(self): - return "".format(self.nbits) - - def same_type(self, other): - return isinstance(other, VInt) and other.nbits == self.nbits - - def merge(self, other): - if isinstance(other, VInt) and not isinstance(other, VBool): - if other.nbits > self.nbits: - self.nbits = other.nbits - else: - raise TypeError - - def set_value(self, builder, n): - self.set_ssa_value( - builder, n.o_intx(self.nbits, builder).get_ssa_value(builder)) - - def set_const_value(self, builder, n): - self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n)) - - def o_bool(self, builder, inv=False): - r = VBool() - if builder is not None: - r.set_ssa_value( - builder, builder.icmp( - lc.ICMP_EQ if inv else lc.ICMP_NE, - self.get_ssa_value(builder), - lc.Constant.int(self.get_llvm_type(), 0))) - return r - - def o_not(self, builder): - return self.o_bool(builder, True) - - def o_intx(self, target_bits, builder): - r = VInt(target_bits) - if builder is not None: - if self.nbits == target_bits: - r.set_ssa_value( - builder, self.get_ssa_value(builder)) - if self.nbits > target_bits: - r.set_ssa_value( - builder, builder.trunc(self.get_ssa_value(builder), - r.get_llvm_type())) - if self.nbits < target_bits: - r.set_ssa_value( - builder, builder.sext(self.get_ssa_value(builder), - r.get_llvm_type())) - return r - o_roundx = o_intx - - -def _make_vint_binop_method(builder_name): - def binop_method(self, other, builder): - if isinstance(other, VInt): - target_bits = max(self.nbits, other.nbits) - r = VInt(target_bits) - if builder is not None: - left = self.o_intx(target_bits, builder) - right = other.o_intx(target_bits, builder) - bf = getattr(builder, builder_name) - r.set_ssa_value( - builder, bf(left.get_ssa_value(builder), - right.get_ssa_value(builder))) - return r - else: - return NotImplemented - return binop_method - -for _method_name, _builder_name in (("o_add", "add"), - ("o_sub", "sub"), - ("o_mul", "mul"), - ("o_floordiv", "sdiv"), - ("o_mod", "srem"), - ("o_and", "and_"), - ("o_xor", "xor"), - ("o_or", "or_")): - setattr(VInt, _method_name, _make_vint_binop_method(_builder_name)) - - -def _make_vint_cmp_method(icmp_val): - def cmp_method(self, other, builder): - if isinstance(other, VInt): - r = VBool() - if builder is not None: - target_bits = max(self.nbits, other.nbits) - left = self.o_intx(target_bits, builder) - right = other.o_intx(target_bits, builder) - r.set_ssa_value( - builder, - builder.icmp( - icmp_val, left.get_ssa_value(builder), - right.get_ssa_value(builder))) - return r - else: - return NotImplemented - return cmp_method - -for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ), - ("o_ne", lc.ICMP_NE), - ("o_lt", lc.ICMP_SLT), - ("o_le", lc.ICMP_SLE), - ("o_gt", lc.ICMP_SGT), - ("o_ge", lc.ICMP_SGE)): - setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val)) - - -# Boolean type - -class VBool(VInt): - def __init__(self): - VInt.__init__(self, 1) - - def __repr__(self): - return "" - - def same_type(self, other): - return isinstance(other, VBool) - - def merge(self, other): - if not isinstance(other, VBool): - raise TypeError - - def set_const_value(self, builder, b): - VInt.set_const_value(self, builder, int(b)) - - def o_bool(self, builder): - r = VBool() - if builder is not None: - r.set_ssa_value(builder, self.get_ssa_value(builder)) - return r - - -# Fraction type - -def _gcd64(builder, a, b): - gcd_f = builder.basic_block.function.module.get_function_named("__gcd64") - return builder.call(gcd_f, [a, b]) - - -def _frac_normalize(builder, numerator, denominator): - gcd = _gcd64(builder, numerator, denominator) - numerator = builder.sdiv(numerator, gcd) - denominator = builder.sdiv(denominator, gcd) - return numerator, denominator - - -def _frac_make_ssa(builder, numerator, denominator): - value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) - value = builder.insert_element( - value, numerator, lc.Constant.int(lc.Type.int(), 0)) - value = builder.insert_element( - value, denominator, lc.Constant.int(lc.Type.int(), 1)) - return value - - -class VFraction(_Value): - def get_llvm_type(self): - return lc.Type.vector(lc.Type.int(64), 2) - - def __repr__(self): - return "" - - def same_type(self, other): - return isinstance(other, VFraction) - - def merge(self, other): - if not isinstance(other, VFraction): - raise TypeError - - def _nd(self, builder, invert=False): - ssa_value = self.get_ssa_value(builder) - numerator = builder.extract_element( - ssa_value, lc.Constant.int(lc.Type.int(), 0)) - denominator = builder.extract_element( - ssa_value, lc.Constant.int(lc.Type.int(), 1)) - if invert: - return denominator, numerator - else: - return numerator, denominator - - def set_value_nd(self, builder, numerator, denominator): - numerator = numerator.o_int64(builder).get_ssa_value(builder) - denominator = denominator.o_int64(builder).get_ssa_value(builder) - numerator, denominator = _frac_normalize( - builder, numerator, denominator) - self.set_ssa_value( - builder, _frac_make_ssa(builder, numerator, denominator)) - - def set_value(self, builder, n): - if not isinstance(n, VFraction): - raise TypeError - self.set_ssa_value(builder, n.get_ssa_value(builder)) - - def o_bool(self, builder): - r = VBool() - if builder is not None: - zero = lc.Constant.int(lc.Type.int(64), 0) - numerator = builder.extract_element( - self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) - r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero)) - return r - - def o_intx(self, target_bits, builder): - if builder is None: - return VInt(target_bits) - else: - r = VInt(64) - numerator, denominator = self._nd(builder) - r.set_ssa_value(builder, builder.sdiv(numerator, denominator)) - return r.o_intx(target_bits, builder) - - def o_roundx(self, target_bits, builder): - if builder is None: - return VInt(target_bits) - else: - r = VInt(64) - numerator, denominator = self._nd(builder) - h_denominator = builder.ashr(denominator, - lc.Constant.int(lc.Type.int(), 1)) - r_numerator = builder.add(numerator, h_denominator) - r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) - return r.o_intx(target_bits, builder) - - def _o_eq_inv(self, other, builder, ne): - if isinstance(other, VFraction): - r = VBool() - if builder is not None: - ee = [] - for i in range(2): - es = builder.extract_element( - self.get_ssa_value(builder), - lc.Constant.int(lc.Type.int(), i)) - eo = builder.extract_element( - other.get_ssa_value(builder), - lc.Constant.int(lc.Type.int(), i)) - ee.append(builder.icmp(lc.ICMP_EQ, es, eo)) - ssa_r = builder.and_(ee[0], ee[1]) - if ne: - ssa_r = builder.xor(ssa_r, - lc.Constant.int(lc.Type.int(1), 1)) - r.set_ssa_value(builder, ssa_r) - return r - else: - return NotImplemented - - def o_eq(self, other, builder): - return self._o_eq_inv(other, builder, False) - - def o_ne(self, other, builder): - return self._o_eq_inv(other, builder, True) - - def _o_muldiv(self, other, builder, div, invert=False): - r = VFraction() - if isinstance(other, VInt): - if builder is None: - return r - else: - numerator, denominator = self._nd(builder, invert) - i = other.get_ssa_value(builder) - if div: - gcd = _gcd64(i, numerator) - i = builder.sdiv(i, gcd) - numerator = builder.sdiv(numerator, gcd) - denominator = builder.mul(denominator, i) - else: - gcd = _gcd64(i, denominator) - i = builder.sdiv(i, gcd) - denominator = builder.sdiv(denominator, gcd) - numerator = builder.mul(numerator, i) - self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, - denominator)) - elif isinstance(other, VFraction): - if builder is None: - return r - else: - numerator, denominator = self._nd(builder, invert) - onumerator, odenominator = other._nd(builder) - if div: - numerator = builder.mul(numerator, odenominator) - denominator = builder.mul(denominator, onumerator) - else: - numerator = builder.mul(numerator, onumerator) - denominator = builder.mul(denominator, odenominator) - numerator, denominator = _frac_normalize(builder, numerator, - denominator) - self.set_ssa_value( - builder, _frac_make_ssa(builder, numerator, denominator)) - else: - return NotImplemented - - def o_mul(self, other, builder): - return self._o_muldiv(other, builder, False) - - def o_truediv(self, other, builder): - return self._o_muldiv(other, builder, True) - - def or_mul(self, other, builder): - return self._o_muldiv(other, builder, False) - - def or_truediv(self, other, builder): - return self._o_muldiv(other, builder, False, True) - - def o_floordiv(self, other, builder): - r = self.o_truediv(other, builder) - if r is NotImplemented: - return r - else: - return r.o_int(builder) - - def or_floordiv(self, other, builder): - r = self.or_truediv(other, builder) - if r is NotImplemented: - return r - else: - return r.o_int(builder) - - -# Operators - def _make_unary_operator(op_name): def op(x, builder): try: @@ -446,9 +91,3 @@ def _make_operators(): return SimpleNamespace(**d) operators = _make_operators() - - -def init_module(module): - func_type = lc.Type.function( - lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)]) - module.add_function(func_type, "__gcd64") diff --git a/test/py2llvm.py b/test/py2llvm.py index e33daac86..be81364fd 100644 --- a/test/py2llvm.py +++ b/test/py2llvm.py @@ -2,15 +2,12 @@ import unittest import ast import inspect -from llvm import core as lc -from llvm import passes as lp from llvm import ee as le from artiq.language.core import int64 from artiq.py2llvm.infer_types import infer_function_types -from artiq.py2llvm import values -from artiq.py2llvm import compile_function -from artiq.py2llvm.tools import add_common_passes +from artiq.py2llvm import base_types +from artiq.py2llvm.module import Module def test_types(choice): @@ -35,46 +32,40 @@ class FunctionTypesCase(unittest.TestCase): dict()) def test_base_types(self): - self.assertIsInstance(self.ns["foo"], values.VBool) - self.assertIsInstance(self.ns["bar"], values.VNone) - self.assertIsInstance(self.ns["d"], values.VInt) + self.assertIsInstance(self.ns["foo"], base_types.VBool) + self.assertIsInstance(self.ns["bar"], base_types.VNone) + self.assertIsInstance(self.ns["d"], base_types.VInt) self.assertEqual(self.ns["d"].nbits, 32) - self.assertIsInstance(self.ns["x"], values.VInt) + self.assertIsInstance(self.ns["x"], base_types.VInt) self.assertEqual(self.ns["x"].nbits, 64) def test_promotion(self): for v in "abc": - self.assertIsInstance(self.ns[v], values.VInt) + self.assertIsInstance(self.ns[v], base_types.VInt) self.assertEqual(self.ns[v].nbits, 64) def test_return(self): - self.assertIsInstance(self.ns["return"], values.VInt) + self.assertIsInstance(self.ns["return"], base_types.VInt) self.assertEqual(self.ns["return"].nbits, 64) class CompiledFunction: def __init__(self, function, param_types): - module = lc.Module.new("main") - values.init_module(module) - + module = Module() funcdef = ast.parse(inspect.getsource(function)).body[0] - self.function, self.retval = compile_function( - module, None, funcdef, param_types) + self.function, self.retval = module.compile_function( + funcdef, param_types) self.argval = [param_types[arg.arg] for arg in funcdef.args.args] - - self.executor = le.ExecutionEngine.new(module) - pass_manager = lp.PassManager.new() - add_common_passes(pass_manager) - pass_manager.run(module) + self.ee = module.get_ee() def __call__(self, *args): args_llvm = [ le.GenericValue.int(av.get_llvm_type(), a) for av, a in zip(self.argval, args)] - result = self.executor.run_function(self.function, args_llvm) - if isinstance(self.retval, values.VBool): + result = self.ee.run_function(self.function, args_llvm) + if isinstance(self.retval, base_types.VBool): return bool(result.as_int()) - elif isinstance(self.retval, values.VInt): + elif isinstance(self.retval, base_types.VInt): return result.as_int_signed() else: raise NotImplementedError @@ -90,6 +81,6 @@ def is_prime(x): class CodeGenCase(unittest.TestCase): def test_is_prime(self): - is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)}) + is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt(32)}) for i in range(200): self.assertEqual(is_prime_c(i), is_prime(i))