forked from M-Labs/artiq
1
0
Fork 0

compiler/ir: refactor SSA/alloca management

This commit is contained in:
Sebastien Bourdeauducq 2014-08-19 17:52:05 +08:00
parent 62b872ad94
commit 0832507c26
4 changed files with 101 additions and 66 deletions

View File

@ -11,7 +11,7 @@ def compile_function(module, env, funcdef):
ns = ir_infer_types.infer_types(env, funcdef) ns = ir_infer_types.infer_types(env, funcdef)
for k, v in ns.items(): for k, v in ns.items():
v.create_alloca(builder, k) v.alloca(builder, k)
visitor = ir_ast_body.Visitor(env, ns, builder) visitor = ir_ast_body.Visitor(env, ns, builder)
visitor.visit_statements(funcdef.body) visitor.visit_statements(funcdef.body)
builder.ret_void() builder.ret_void()

View File

@ -1,7 +1,4 @@
import ast import ast
from copy import copy
from llvm import core as lc
from artiq.compiler import ir_values from artiq.compiler import ir_values
@ -25,12 +22,6 @@ class Visitor:
r = self.ns[node.id] r = self.ns[node.id]
except KeyError: except KeyError:
raise NameError("Name '{}' is not defined".format(node.id)) raise NameError("Name '{}' is not defined".format(node.id))
r = copy(r)
if self.builder is None:
r.llvm_value = None
else:
if isinstance(r.llvm_value, lc.AllocaInstruction):
r.llvm_value = self.builder.load(r.llvm_value)
return r return r
def _visit_expr_NameConstant(self, node): def _visit_expr_NameConstant(self, node):
@ -42,7 +33,7 @@ class Visitor:
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None: if self.builder is not None:
r.create_constant(v) r.set_const_value(self.builder, v)
return r return r
def _visit_expr_Num(self, node): def _visit_expr_Num(self, node):
@ -55,7 +46,7 @@ class Visitor:
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None: if self.builder is not None:
r.create_constant(n) r.set_const_value(self.builder, n)
return r return r
def _visit_expr_UnaryOp(self, node): def _visit_expr_UnaryOp(self, node):
@ -136,14 +127,14 @@ class Visitor:
val = self.visit_expression(node.value) val = self.visit_expression(node.value)
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
self.builder.store(val.llvm_value, self.ns[target.id].llvm_value) self.ns[target.id].set_value(self.builder, val)
else: else:
raise NotImplementedError raise NotImplementedError
def _visit_stmt_AugAssign(self, node): def _visit_stmt_AugAssign(self, node):
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
if isinstance(node.target, ast.Name): if isinstance(node.target, ast.Name):
self.builder.store(val.llvm_value, self.ns[node.target.id].llvm_value) self.ns[node.target.id].set_value(self.builder, val)
else: else:
raise NotImplementedError raise NotImplementedError
@ -157,7 +148,7 @@ class Visitor:
merge_block = function.append_basic_block("i_merge") merge_block = function.append_basic_block("i_merge")
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
self.builder.cbranch(condition.llvm_value, then_block, else_block) self.builder.cbranch(condition.get_ssa_value(self.builder), then_block, else_block)
self.builder.position_at_end(then_block) self.builder.position_at_end(then_block)
self.visit_statements(node.body) self.visit_statements(node.body)
@ -176,12 +167,12 @@ class Visitor:
merge_block = function.append_basic_block("w_merge") merge_block = function.append_basic_block("w_merge")
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
self.builder.cbranch(condition.llvm_value, body_block, else_block) self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, else_block)
self.builder.position_at_end(body_block) self.builder.position_at_end(body_block)
self.visit_statements(node.body) self.visit_statements(node.body)
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
self.builder.cbranch(condition.llvm_value, body_block, merge_block) self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, merge_block)
self.builder.position_at_end(else_block) self.builder.position_at_end(else_block)
self.visit_statements(node.orelse) self.visit_statements(node.orelse)

View File

@ -2,12 +2,50 @@ from types import SimpleNamespace
from llvm import core as lc from llvm import core as lc
class _Value:
def __init__(self):
self._llvm_value = None
def get_ssa_value(self, builder):
if isinstance(self._llvm_value, lc.AllocaInstruction):
return builder.load(self._llvm_value)
else:
return self._llvm_value
def set_ssa_value(self, builder, value):
if self._llvm_value is None:
self._llvm_value = value
elif isinstance(self._llvm_value, lc.AllocaInstruction):
builder.store(value, self._llvm_value)
else:
raise RuntimeError("Attempted to set LLVM SSA value multiple times")
def alloca(self, builder, name):
if self._llvm_value is not None:
raise RuntimeError("Attempted to alloca existing LLVM value")
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
def o_int(self, builder):
return self.o_intx(32, builder)
def o_int64(self, builder):
return self.o_intx(64, builder)
def o_round(self, builder):
return self.o_roundx(32, builder)
def o_round64(self, builder):
return self.o_roundx(64, builder)
# None type # None type
class VNone: class VNone(_Value):
def __repr__(self): def __repr__(self):
return "<VNone>" return "<VNone>"
def get_llvm_type(self):
return lc.Type.void()
def same_type(self, other): def same_type(self, other):
return isinstance(other, VNone) return isinstance(other, VNone)
@ -15,21 +53,24 @@ class VNone:
if not isinstance(other, VNone): if not isinstance(other, VNone):
raise TypeError raise TypeError
def create_alloca(self, builder, name): def alloca(self, builder, name):
pass pass
def o_bool(self, builder): def o_bool(self, builder):
r = VBool() r = VBool()
if builder is not None: if builder is not None:
r.create_constant(False) r.set_const_value(builder, False)
return r return r
# Integer type # Integer type
class VInt: class VInt(_Value):
def __init__(self, nbits=32, llvm_value=None): def __init__(self, nbits=32):
_Value.__init__(self)
self.nbits = nbits self.nbits = nbits
self.llvm_value = llvm_value
def get_llvm_type(self):
return lc.Type.int(self.nbits)
def __repr__(self): def __repr__(self):
return "<VInt:{}>".format(self.nbits) return "<VInt:{}>".format(self.nbits)
@ -44,47 +85,43 @@ class VInt:
else: else:
raise TypeError raise TypeError
def create_constant(self, n): def set_value(self, builder, n):
self.llvm_value = lc.Constant.int(lc.Type.int(self.nbits), n) self.set_ssa_value(builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
def create_alloca(self, builder, name): def set_const_value(self, builder, n):
self.llvm_value = builder.alloca(lc.Type.int(self.nbits), name=name) self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
def o_bool(self, builder): def o_bool(self, builder):
if builder is None: r = VBool()
return VBool() if builder is not None:
else: r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE,
zero = lc.Constant.int(lc.Type.int(self.nbits), 0) self.get_ssa_value(builder), lc.Constant.int(self.get_llvm_type(), 0)))
return VBool(llvm_value=builder.icmp(lc.ICMP_NE, self.llvm_value, zero)) return r
def _o_intx(self, target_bits, builder): def o_intx(self, target_bits, builder):
if builder is None: r = VInt(target_bits)
return VInt(target_bits) if builder is not None:
else:
if self.nbits == target_bits: if self.nbits == target_bits:
return self r.set_ssa_value(builder, self.get_ssa_value(builder))
if self.nbits > target_bits: if self.nbits > target_bits:
return VInt(target_bits, llvm_value=builder.trunc(self.llvm_value, lc.Type.int(target_bits))) r.set_ssa_value(builder, builder.trunc(self.get_ssa_value(builder), r.get_llvm_type()))
if self.nbits < target_bits: if self.nbits < target_bits:
return VInt(target_bits, llvm_value=builder.sext(self.llvm_value, lc.Type.int(target_bits))) r.set_ssa_value(builder, builder.sext(self.get_ssa_value(builder), r.get_llvm_type()))
return r
def o_int(self, builder): o_roundx = o_intx
return self._o_intx(32, builder)
o_round = o_int
def o_int64(self, builder):
return self._o_intx(64, builder)
o_round64 = o_int64
def _make_vint_binop_method(builder_name): def _make_vint_binop_method(builder_name):
def binop_method(self, other, builder): def binop_method(self, other, builder):
if isinstance(other, VInt): if isinstance(other, VInt):
nbits = max(self.nbits, other.nbits) target_bits = max(self.nbits, other.nbits)
if builder is None: r = VInt(target_bits)
return VInt(nbits) if builder is not None:
else: left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name) bf = getattr(builder, builder_name)
return VInt(nbits, llvm_value=bf(self.llvm_value, other.llvm_value)) r.set_ssa_value(builder,
bf(left.get_ssa_value(builder), right.get_ssa_value(builder)))
return r
else: else:
return NotImplemented return NotImplemented
return binop_method return binop_method
@ -103,10 +140,14 @@ for _method_name, _builder_name in (
def _make_vint_cmp_method(icmp_val): def _make_vint_cmp_method(icmp_val):
def cmp_method(self, other, builder): def cmp_method(self, other, builder):
if isinstance(other, VInt): if isinstance(other, VInt):
if builder is None: r = VBool()
return VBool() if builder is not None:
else: target_bits = max(self.nbits, other.nbits)
return VBool(llvm_value=builder.icmp(icmp_val, self.llvm_value, other.llvm_value)) left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
r.set_ssa_value(builder,
builder.icmp(icmp_val, left.get_ssa_value(builder), right.get_ssa_value(builder)))
return r
else: else:
return NotImplemented return NotImplemented
return cmp_method return cmp_method
@ -123,24 +164,27 @@ for _method_name, _icmp_val in (
# Boolean type # Boolean type
class VBool(VInt): class VBool(VInt):
def __init__(self, llvm_value=None): def __init__(self):
VInt.__init__(self, 1, llvm_value) VInt.__init__(self, 1)
def __repr__(self): def __repr__(self):
return "<VBool>" return "<VBool>"
def same_type(self, other):
return isinstance(other, VBool)
def merge(self, other): def merge(self, other):
if not isinstance(other, VBool): if not isinstance(other, VBool):
raise TypeError raise TypeError
def create_constant(self, b): def set_const_value(self, builder, b):
VInt.create_constant(self, int(b)) VInt.set_const_value(self, builder, int(b))
def o_bool(self, builder): def o_bool(self, builder):
if builder is None: r = VBool()
return VBool() if builder is not None:
else: r.set_ssa_value(builder, self.get_ssa_value(builder))
return self return r
# Operators # Operators

View File

@ -52,14 +52,14 @@ class LinkInterface:
def syscall(self, syscall_name, args, builder): def syscall(self, syscall_name, args, builder):
r = _chr_to_value[_syscalls[syscall_name][-1]]() r = _chr_to_value[_syscalls[syscall_name][-1]]()
if builder is not None: if builder is not None:
args = [arg.llvm_value for arg in args] args = [arg.get_ssa_value(builder) for arg in args]
if syscall_name in self.var_arg_fixcount: if syscall_name in self.var_arg_fixcount:
fixcount = self.var_arg_fixcount[syscall_name] fixcount = self.var_arg_fixcount[syscall_name]
args = args[:fixcount] \ args = args[:fixcount] \
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \ + [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
+ args[fixcount:] + args[fixcount:]
llvm_function = self.module.get_function_named("__syscall_"+syscall_name) llvm_function = self.module.get_function_named("__syscall_"+syscall_name)
r.llvm_value = builder.call(llvm_function, args) r.set_ssa_value(builder, builder.call(llvm_function, args))
return r return r
class Environment(LinkInterface): class Environment(LinkInterface):