From 0832507c26db927baf61e008dff478d4d859d436 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Tue, 19 Aug 2014 17:52:05 +0800 Subject: [PATCH] compiler/ir: refactor SSA/alloca management --- artiq/compiler/ir.py | 2 +- artiq/compiler/ir_ast_body.py | 23 ++---- artiq/compiler/ir_values.py | 138 ++++++++++++++++++++++------------ artiq/devices/runtime.py | 4 +- 4 files changed, 101 insertions(+), 66 deletions(-) diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index d4e3e12ef..c2d69740c 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -11,7 +11,7 @@ def compile_function(module, env, funcdef): ns = ir_infer_types.infer_types(env, funcdef) for k, v in ns.items(): - v.create_alloca(builder, k) + v.alloca(builder, k) visitor = ir_ast_body.Visitor(env, ns, builder) visitor.visit_statements(funcdef.body) builder.ret_void() diff --git a/artiq/compiler/ir_ast_body.py b/artiq/compiler/ir_ast_body.py index 20e3b6c8f..57f3b1018 100644 --- a/artiq/compiler/ir_ast_body.py +++ b/artiq/compiler/ir_ast_body.py @@ -1,7 +1,4 @@ import ast -from copy import copy - -from llvm import core as lc from artiq.compiler import ir_values @@ -25,12 +22,6 @@ class Visitor: r = self.ns[node.id] except KeyError: raise NameError("Name '{}' is not defined".format(node.id)) - r = copy(r) - if self.builder is None: - r.llvm_value = None - else: - if isinstance(r.llvm_value, lc.AllocaInstruction): - r.llvm_value = self.builder.load(r.llvm_value) return r def _visit_expr_NameConstant(self, node): @@ -42,7 +33,7 @@ class Visitor: else: raise NotImplementedError if self.builder is not None: - r.create_constant(v) + r.set_const_value(self.builder, v) return r def _visit_expr_Num(self, node): @@ -55,7 +46,7 @@ class Visitor: else: raise NotImplementedError if self.builder is not None: - r.create_constant(n) + r.set_const_value(self.builder, n) return r def _visit_expr_UnaryOp(self, node): @@ -136,14 +127,14 @@ class Visitor: val = self.visit_expression(node.value) for target in node.targets: if isinstance(target, ast.Name): - self.builder.store(val.llvm_value, self.ns[target.id].llvm_value) + self.ns[target.id].set_value(self.builder, val) else: raise NotImplementedError def _visit_stmt_AugAssign(self, node): val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) if isinstance(node.target, ast.Name): - self.builder.store(val.llvm_value, self.ns[node.target.id].llvm_value) + self.ns[node.target.id].set_value(self.builder, val) else: raise NotImplementedError @@ -157,7 +148,7 @@ class Visitor: merge_block = function.append_basic_block("i_merge") condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) - self.builder.cbranch(condition.llvm_value, then_block, else_block) + self.builder.cbranch(condition.get_ssa_value(self.builder), then_block, else_block) self.builder.position_at_end(then_block) self.visit_statements(node.body) @@ -176,12 +167,12 @@ class Visitor: merge_block = function.append_basic_block("w_merge") condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) - self.builder.cbranch(condition.llvm_value, body_block, else_block) + self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, else_block) self.builder.position_at_end(body_block) self.visit_statements(node.body) condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) - self.builder.cbranch(condition.llvm_value, body_block, merge_block) + self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, merge_block) self.builder.position_at_end(else_block) self.visit_statements(node.orelse) diff --git a/artiq/compiler/ir_values.py b/artiq/compiler/ir_values.py index 7f925f012..156b4aac9 100644 --- a/artiq/compiler/ir_values.py +++ b/artiq/compiler/ir_values.py @@ -2,12 +2,50 @@ from types import SimpleNamespace from llvm import core as lc +class _Value: + def __init__(self): + self._llvm_value = None + + def get_ssa_value(self, builder): + if isinstance(self._llvm_value, lc.AllocaInstruction): + return builder.load(self._llvm_value) + else: + return self._llvm_value + + def set_ssa_value(self, builder, value): + if self._llvm_value is None: + self._llvm_value = value + elif isinstance(self._llvm_value, lc.AllocaInstruction): + builder.store(value, self._llvm_value) + else: + raise RuntimeError("Attempted to set LLVM SSA value multiple times") + + def alloca(self, builder, name): + if self._llvm_value is not None: + raise RuntimeError("Attempted to alloca existing LLVM value") + self._llvm_value = builder.alloca(self.get_llvm_type(), name=name) + + def o_int(self, builder): + return self.o_intx(32, builder) + + def o_int64(self, builder): + return self.o_intx(64, builder) + + def o_round(self, builder): + return self.o_roundx(32, builder) + + def o_round64(self, builder): + return self.o_roundx(64, builder) + # None type -class VNone: +class VNone(_Value): def __repr__(self): return "" + def get_llvm_type(self): + return lc.Type.void() + def same_type(self, other): return isinstance(other, VNone) @@ -15,21 +53,24 @@ class VNone: if not isinstance(other, VNone): raise TypeError - def create_alloca(self, builder, name): + def alloca(self, builder, name): pass def o_bool(self, builder): r = VBool() if builder is not None: - r.create_constant(False) + r.set_const_value(builder, False) return r # Integer type -class VInt: - def __init__(self, nbits=32, llvm_value=None): +class VInt(_Value): + def __init__(self, nbits=32): + _Value.__init__(self) self.nbits = nbits - self.llvm_value = llvm_value + + def get_llvm_type(self): + return lc.Type.int(self.nbits) def __repr__(self): return "".format(self.nbits) @@ -44,47 +85,43 @@ class VInt: else: raise TypeError - def create_constant(self, n): - self.llvm_value = lc.Constant.int(lc.Type.int(self.nbits), n) + def set_value(self, builder, n): + self.set_ssa_value(builder, n.o_intx(self.nbits, builder).get_ssa_value(builder)) - def create_alloca(self, builder, name): - self.llvm_value = builder.alloca(lc.Type.int(self.nbits), name=name) + def set_const_value(self, builder, n): + self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n)) def o_bool(self, builder): - if builder is None: - return VBool() - else: - zero = lc.Constant.int(lc.Type.int(self.nbits), 0) - return VBool(llvm_value=builder.icmp(lc.ICMP_NE, self.llvm_value, zero)) + r = VBool() + if builder is not None: + r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, + self.get_ssa_value(builder), lc.Constant.int(self.get_llvm_type(), 0))) + return r - def _o_intx(self, target_bits, builder): - if builder is None: - return VInt(target_bits) - else: + def o_intx(self, target_bits, builder): + r = VInt(target_bits) + if builder is not None: if self.nbits == target_bits: - return self + r.set_ssa_value(builder, self.get_ssa_value(builder)) if self.nbits > target_bits: - return VInt(target_bits, llvm_value=builder.trunc(self.llvm_value, lc.Type.int(target_bits))) + r.set_ssa_value(builder, builder.trunc(self.get_ssa_value(builder), r.get_llvm_type())) if self.nbits < target_bits: - return VInt(target_bits, llvm_value=builder.sext(self.llvm_value, lc.Type.int(target_bits))) - - def o_int(self, builder): - return self._o_intx(32, builder) - o_round = o_int - - def o_int64(self, builder): - return self._o_intx(64, builder) - o_round64 = o_int64 + r.set_ssa_value(builder, builder.sext(self.get_ssa_value(builder), r.get_llvm_type())) + return r + o_roundx = o_intx def _make_vint_binop_method(builder_name): def binop_method(self, other, builder): if isinstance(other, VInt): - nbits = max(self.nbits, other.nbits) - if builder is None: - return VInt(nbits) - else: + target_bits = max(self.nbits, other.nbits) + r = VInt(target_bits) + if builder is not None: + left = self.o_intx(target_bits, builder) + right = other.o_intx(target_bits, builder) bf = getattr(builder, builder_name) - return VInt(nbits, llvm_value=bf(self.llvm_value, other.llvm_value)) + r.set_ssa_value(builder, + bf(left.get_ssa_value(builder), right.get_ssa_value(builder))) + return r else: return NotImplemented return binop_method @@ -103,10 +140,14 @@ for _method_name, _builder_name in ( def _make_vint_cmp_method(icmp_val): def cmp_method(self, other, builder): if isinstance(other, VInt): - if builder is None: - return VBool() - else: - return VBool(llvm_value=builder.icmp(icmp_val, self.llvm_value, other.llvm_value)) + r = VBool() + if builder is not None: + target_bits = max(self.nbits, other.nbits) + left = self.o_intx(target_bits, builder) + right = other.o_intx(target_bits, builder) + r.set_ssa_value(builder, + builder.icmp(icmp_val, left.get_ssa_value(builder), right.get_ssa_value(builder))) + return r else: return NotImplemented return cmp_method @@ -123,24 +164,27 @@ for _method_name, _icmp_val in ( # Boolean type class VBool(VInt): - def __init__(self, llvm_value=None): - VInt.__init__(self, 1, llvm_value) + def __init__(self): + VInt.__init__(self, 1) def __repr__(self): return "" + def same_type(self, other): + return isinstance(other, VBool) + def merge(self, other): if not isinstance(other, VBool): raise TypeError - def create_constant(self, b): - VInt.create_constant(self, int(b)) + def set_const_value(self, builder, b): + VInt.set_const_value(self, builder, int(b)) def o_bool(self, builder): - if builder is None: - return VBool() - else: - return self + r = VBool() + if builder is not None: + r.set_ssa_value(builder, self.get_ssa_value(builder)) + return r # Operators diff --git a/artiq/devices/runtime.py b/artiq/devices/runtime.py index 8bc4b3ce8..da690e3ff 100644 --- a/artiq/devices/runtime.py +++ b/artiq/devices/runtime.py @@ -52,14 +52,14 @@ class LinkInterface: def syscall(self, syscall_name, args, builder): r = _chr_to_value[_syscalls[syscall_name][-1]]() if builder is not None: - args = [arg.llvm_value for arg in args] + args = [arg.get_ssa_value(builder) for arg in args] if syscall_name in self.var_arg_fixcount: fixcount = self.var_arg_fixcount[syscall_name] args = args[:fixcount] \ + [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \ + args[fixcount:] llvm_function = self.module.get_function_named("__syscall_"+syscall_name) - r.llvm_value = builder.call(llvm_function, args) + r.set_ssa_value(builder, builder.call(llvm_function, args)) return r class Environment(LinkInterface):