mirror of
https://github.com/m-labs/artiq.git
synced 2024-12-28 12:48:26 +08:00
compiler/ir: refactor SSA/alloca management
This commit is contained in:
parent
62b872ad94
commit
0832507c26
@ -11,7 +11,7 @@ def compile_function(module, env, funcdef):
|
||||
|
||||
ns = ir_infer_types.infer_types(env, funcdef)
|
||||
for k, v in ns.items():
|
||||
v.create_alloca(builder, k)
|
||||
v.alloca(builder, k)
|
||||
visitor = ir_ast_body.Visitor(env, ns, builder)
|
||||
visitor.visit_statements(funcdef.body)
|
||||
builder.ret_void()
|
||||
|
@ -1,7 +1,4 @@
|
||||
import ast
|
||||
from copy import copy
|
||||
|
||||
from llvm import core as lc
|
||||
|
||||
from artiq.compiler import ir_values
|
||||
|
||||
@ -25,12 +22,6 @@ class Visitor:
|
||||
r = self.ns[node.id]
|
||||
except KeyError:
|
||||
raise NameError("Name '{}' is not defined".format(node.id))
|
||||
r = copy(r)
|
||||
if self.builder is None:
|
||||
r.llvm_value = None
|
||||
else:
|
||||
if isinstance(r.llvm_value, lc.AllocaInstruction):
|
||||
r.llvm_value = self.builder.load(r.llvm_value)
|
||||
return r
|
||||
|
||||
def _visit_expr_NameConstant(self, node):
|
||||
@ -42,7 +33,7 @@ class Visitor:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.create_constant(v)
|
||||
r.set_const_value(self.builder, v)
|
||||
return r
|
||||
|
||||
def _visit_expr_Num(self, node):
|
||||
@ -55,7 +46,7 @@ class Visitor:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.create_constant(n)
|
||||
r.set_const_value(self.builder, n)
|
||||
return r
|
||||
|
||||
def _visit_expr_UnaryOp(self, node):
|
||||
@ -136,14 +127,14 @@ class Visitor:
|
||||
val = self.visit_expression(node.value)
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.builder.store(val.llvm_value, self.ns[target.id].llvm_value)
|
||||
self.ns[target.id].set_value(self.builder, val)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _visit_stmt_AugAssign(self, node):
|
||||
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.builder.store(val.llvm_value, self.ns[node.target.id].llvm_value)
|
||||
self.ns[node.target.id].set_value(self.builder, val)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -157,7 +148,7 @@ class Visitor:
|
||||
merge_block = function.append_basic_block("i_merge")
|
||||
|
||||
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||
self.builder.cbranch(condition.llvm_value, then_block, else_block)
|
||||
self.builder.cbranch(condition.get_ssa_value(self.builder), then_block, else_block)
|
||||
|
||||
self.builder.position_at_end(then_block)
|
||||
self.visit_statements(node.body)
|
||||
@ -176,12 +167,12 @@ class Visitor:
|
||||
merge_block = function.append_basic_block("w_merge")
|
||||
|
||||
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||
self.builder.cbranch(condition.llvm_value, body_block, else_block)
|
||||
self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, else_block)
|
||||
|
||||
self.builder.position_at_end(body_block)
|
||||
self.visit_statements(node.body)
|
||||
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||
self.builder.cbranch(condition.llvm_value, body_block, merge_block)
|
||||
self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
|
@ -2,12 +2,50 @@ from types import SimpleNamespace
|
||||
|
||||
from llvm import core as lc
|
||||
|
||||
class _Value:
|
||||
def __init__(self):
|
||||
self._llvm_value = None
|
||||
|
||||
def get_ssa_value(self, builder):
|
||||
if isinstance(self._llvm_value, lc.AllocaInstruction):
|
||||
return builder.load(self._llvm_value)
|
||||
else:
|
||||
return self._llvm_value
|
||||
|
||||
def set_ssa_value(self, builder, value):
|
||||
if self._llvm_value is None:
|
||||
self._llvm_value = value
|
||||
elif isinstance(self._llvm_value, lc.AllocaInstruction):
|
||||
builder.store(value, self._llvm_value)
|
||||
else:
|
||||
raise RuntimeError("Attempted to set LLVM SSA value multiple times")
|
||||
|
||||
def alloca(self, builder, name):
|
||||
if self._llvm_value is not None:
|
||||
raise RuntimeError("Attempted to alloca existing LLVM value")
|
||||
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
||||
|
||||
def o_int(self, builder):
|
||||
return self.o_intx(32, builder)
|
||||
|
||||
def o_int64(self, builder):
|
||||
return self.o_intx(64, builder)
|
||||
|
||||
def o_round(self, builder):
|
||||
return self.o_roundx(32, builder)
|
||||
|
||||
def o_round64(self, builder):
|
||||
return self.o_roundx(64, builder)
|
||||
|
||||
# None type
|
||||
|
||||
class VNone:
|
||||
class VNone(_Value):
|
||||
def __repr__(self):
|
||||
return "<VNone>"
|
||||
|
||||
def get_llvm_type(self):
|
||||
return lc.Type.void()
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VNone)
|
||||
|
||||
@ -15,21 +53,24 @@ class VNone:
|
||||
if not isinstance(other, VNone):
|
||||
raise TypeError
|
||||
|
||||
def create_alloca(self, builder, name):
|
||||
def alloca(self, builder, name):
|
||||
pass
|
||||
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.create_constant(False)
|
||||
r.set_const_value(builder, False)
|
||||
return r
|
||||
|
||||
# Integer type
|
||||
|
||||
class VInt:
|
||||
def __init__(self, nbits=32, llvm_value=None):
|
||||
class VInt(_Value):
|
||||
def __init__(self, nbits=32):
|
||||
_Value.__init__(self)
|
||||
self.nbits = nbits
|
||||
self.llvm_value = llvm_value
|
||||
|
||||
def get_llvm_type(self):
|
||||
return lc.Type.int(self.nbits)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VInt:{}>".format(self.nbits)
|
||||
@ -44,47 +85,43 @@ class VInt:
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def create_constant(self, n):
|
||||
self.llvm_value = lc.Constant.int(lc.Type.int(self.nbits), n)
|
||||
def set_value(self, builder, n):
|
||||
self.set_ssa_value(builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
|
||||
|
||||
def create_alloca(self, builder, name):
|
||||
self.llvm_value = builder.alloca(lc.Type.int(self.nbits), name=name)
|
||||
def set_const_value(self, builder, n):
|
||||
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
|
||||
|
||||
def o_bool(self, builder):
|
||||
if builder is None:
|
||||
return VBool()
|
||||
else:
|
||||
zero = lc.Constant.int(lc.Type.int(self.nbits), 0)
|
||||
return VBool(llvm_value=builder.icmp(lc.ICMP_NE, self.llvm_value, zero))
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE,
|
||||
self.get_ssa_value(builder), lc.Constant.int(self.get_llvm_type(), 0)))
|
||||
return r
|
||||
|
||||
def _o_intx(self, target_bits, builder):
|
||||
if builder is None:
|
||||
return VInt(target_bits)
|
||||
else:
|
||||
def o_intx(self, target_bits, builder):
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
if self.nbits == target_bits:
|
||||
return self
|
||||
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
||||
if self.nbits > target_bits:
|
||||
return VInt(target_bits, llvm_value=builder.trunc(self.llvm_value, lc.Type.int(target_bits)))
|
||||
r.set_ssa_value(builder, builder.trunc(self.get_ssa_value(builder), r.get_llvm_type()))
|
||||
if self.nbits < target_bits:
|
||||
return VInt(target_bits, llvm_value=builder.sext(self.llvm_value, lc.Type.int(target_bits)))
|
||||
|
||||
def o_int(self, builder):
|
||||
return self._o_intx(32, builder)
|
||||
o_round = o_int
|
||||
|
||||
def o_int64(self, builder):
|
||||
return self._o_intx(64, builder)
|
||||
o_round64 = o_int64
|
||||
r.set_ssa_value(builder, builder.sext(self.get_ssa_value(builder), r.get_llvm_type()))
|
||||
return r
|
||||
o_roundx = o_intx
|
||||
|
||||
def _make_vint_binop_method(builder_name):
|
||||
def binop_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
nbits = max(self.nbits, other.nbits)
|
||||
if builder is None:
|
||||
return VInt(nbits)
|
||||
else:
|
||||
target_bits = max(self.nbits, other.nbits)
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
bf = getattr(builder, builder_name)
|
||||
return VInt(nbits, llvm_value=bf(self.llvm_value, other.llvm_value))
|
||||
r.set_ssa_value(builder,
|
||||
bf(left.get_ssa_value(builder), right.get_ssa_value(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
return binop_method
|
||||
@ -103,10 +140,14 @@ for _method_name, _builder_name in (
|
||||
def _make_vint_cmp_method(icmp_val):
|
||||
def cmp_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
if builder is None:
|
||||
return VBool()
|
||||
else:
|
||||
return VBool(llvm_value=builder.icmp(icmp_val, self.llvm_value, other.llvm_value))
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
target_bits = max(self.nbits, other.nbits)
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
r.set_ssa_value(builder,
|
||||
builder.icmp(icmp_val, left.get_ssa_value(builder), right.get_ssa_value(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
return cmp_method
|
||||
@ -123,24 +164,27 @@ for _method_name, _icmp_val in (
|
||||
# Boolean type
|
||||
|
||||
class VBool(VInt):
|
||||
def __init__(self, llvm_value=None):
|
||||
VInt.__init__(self, 1, llvm_value)
|
||||
def __init__(self):
|
||||
VInt.__init__(self, 1)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VBool>"
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VBool)
|
||||
|
||||
def merge(self, other):
|
||||
if not isinstance(other, VBool):
|
||||
raise TypeError
|
||||
|
||||
def create_constant(self, b):
|
||||
VInt.create_constant(self, int(b))
|
||||
def set_const_value(self, builder, b):
|
||||
VInt.set_const_value(self, builder, int(b))
|
||||
|
||||
def o_bool(self, builder):
|
||||
if builder is None:
|
||||
return VBool()
|
||||
else:
|
||||
return self
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
||||
return r
|
||||
|
||||
# Operators
|
||||
|
||||
|
@ -52,14 +52,14 @@ class LinkInterface:
|
||||
def syscall(self, syscall_name, args, builder):
|
||||
r = _chr_to_value[_syscalls[syscall_name][-1]]()
|
||||
if builder is not None:
|
||||
args = [arg.llvm_value for arg in args]
|
||||
args = [arg.get_ssa_value(builder) for arg in args]
|
||||
if syscall_name in self.var_arg_fixcount:
|
||||
fixcount = self.var_arg_fixcount[syscall_name]
|
||||
args = args[:fixcount] \
|
||||
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
|
||||
+ args[fixcount:]
|
||||
llvm_function = self.module.get_function_named("__syscall_"+syscall_name)
|
||||
r.llvm_value = builder.call(llvm_function, args)
|
||||
r.set_ssa_value(builder, builder.call(llvm_function, args))
|
||||
return r
|
||||
|
||||
class Environment(LinkInterface):
|
||||
|
Loading…
Reference in New Issue
Block a user