forked from M-Labs/artiq
compiler/ir: refactor SSA/alloca management
This commit is contained in:
parent
62b872ad94
commit
0832507c26
|
@ -11,7 +11,7 @@ def compile_function(module, env, funcdef):
|
||||||
|
|
||||||
ns = ir_infer_types.infer_types(env, funcdef)
|
ns = ir_infer_types.infer_types(env, funcdef)
|
||||||
for k, v in ns.items():
|
for k, v in ns.items():
|
||||||
v.create_alloca(builder, k)
|
v.alloca(builder, k)
|
||||||
visitor = ir_ast_body.Visitor(env, ns, builder)
|
visitor = ir_ast_body.Visitor(env, ns, builder)
|
||||||
visitor.visit_statements(funcdef.body)
|
visitor.visit_statements(funcdef.body)
|
||||||
builder.ret_void()
|
builder.ret_void()
|
||||||
|
|
|
@ -1,7 +1,4 @@
|
||||||
import ast
|
import ast
|
||||||
from copy import copy
|
|
||||||
|
|
||||||
from llvm import core as lc
|
|
||||||
|
|
||||||
from artiq.compiler import ir_values
|
from artiq.compiler import ir_values
|
||||||
|
|
||||||
|
@ -25,12 +22,6 @@ class Visitor:
|
||||||
r = self.ns[node.id]
|
r = self.ns[node.id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise NameError("Name '{}' is not defined".format(node.id))
|
raise NameError("Name '{}' is not defined".format(node.id))
|
||||||
r = copy(r)
|
|
||||||
if self.builder is None:
|
|
||||||
r.llvm_value = None
|
|
||||||
else:
|
|
||||||
if isinstance(r.llvm_value, lc.AllocaInstruction):
|
|
||||||
r.llvm_value = self.builder.load(r.llvm_value)
|
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def _visit_expr_NameConstant(self, node):
|
def _visit_expr_NameConstant(self, node):
|
||||||
|
@ -42,7 +33,7 @@ class Visitor:
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if self.builder is not None:
|
if self.builder is not None:
|
||||||
r.create_constant(v)
|
r.set_const_value(self.builder, v)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def _visit_expr_Num(self, node):
|
def _visit_expr_Num(self, node):
|
||||||
|
@ -55,7 +46,7 @@ class Visitor:
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if self.builder is not None:
|
if self.builder is not None:
|
||||||
r.create_constant(n)
|
r.set_const_value(self.builder, n)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def _visit_expr_UnaryOp(self, node):
|
def _visit_expr_UnaryOp(self, node):
|
||||||
|
@ -136,14 +127,14 @@ class Visitor:
|
||||||
val = self.visit_expression(node.value)
|
val = self.visit_expression(node.value)
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
self.builder.store(val.llvm_value, self.ns[target.id].llvm_value)
|
self.ns[target.id].set_value(self.builder, val)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _visit_stmt_AugAssign(self, node):
|
def _visit_stmt_AugAssign(self, node):
|
||||||
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||||
if isinstance(node.target, ast.Name):
|
if isinstance(node.target, ast.Name):
|
||||||
self.builder.store(val.llvm_value, self.ns[node.target.id].llvm_value)
|
self.ns[node.target.id].set_value(self.builder, val)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -157,7 +148,7 @@ class Visitor:
|
||||||
merge_block = function.append_basic_block("i_merge")
|
merge_block = function.append_basic_block("i_merge")
|
||||||
|
|
||||||
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||||
self.builder.cbranch(condition.llvm_value, then_block, else_block)
|
self.builder.cbranch(condition.get_ssa_value(self.builder), then_block, else_block)
|
||||||
|
|
||||||
self.builder.position_at_end(then_block)
|
self.builder.position_at_end(then_block)
|
||||||
self.visit_statements(node.body)
|
self.visit_statements(node.body)
|
||||||
|
@ -176,12 +167,12 @@ class Visitor:
|
||||||
merge_block = function.append_basic_block("w_merge")
|
merge_block = function.append_basic_block("w_merge")
|
||||||
|
|
||||||
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||||
self.builder.cbranch(condition.llvm_value, body_block, else_block)
|
self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, else_block)
|
||||||
|
|
||||||
self.builder.position_at_end(body_block)
|
self.builder.position_at_end(body_block)
|
||||||
self.visit_statements(node.body)
|
self.visit_statements(node.body)
|
||||||
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||||
self.builder.cbranch(condition.llvm_value, body_block, merge_block)
|
self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, merge_block)
|
||||||
|
|
||||||
self.builder.position_at_end(else_block)
|
self.builder.position_at_end(else_block)
|
||||||
self.visit_statements(node.orelse)
|
self.visit_statements(node.orelse)
|
||||||
|
|
|
@ -2,12 +2,50 @@ from types import SimpleNamespace
|
||||||
|
|
||||||
from llvm import core as lc
|
from llvm import core as lc
|
||||||
|
|
||||||
|
class _Value:
|
||||||
|
def __init__(self):
|
||||||
|
self._llvm_value = None
|
||||||
|
|
||||||
|
def get_ssa_value(self, builder):
|
||||||
|
if isinstance(self._llvm_value, lc.AllocaInstruction):
|
||||||
|
return builder.load(self._llvm_value)
|
||||||
|
else:
|
||||||
|
return self._llvm_value
|
||||||
|
|
||||||
|
def set_ssa_value(self, builder, value):
|
||||||
|
if self._llvm_value is None:
|
||||||
|
self._llvm_value = value
|
||||||
|
elif isinstance(self._llvm_value, lc.AllocaInstruction):
|
||||||
|
builder.store(value, self._llvm_value)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Attempted to set LLVM SSA value multiple times")
|
||||||
|
|
||||||
|
def alloca(self, builder, name):
|
||||||
|
if self._llvm_value is not None:
|
||||||
|
raise RuntimeError("Attempted to alloca existing LLVM value")
|
||||||
|
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
||||||
|
|
||||||
|
def o_int(self, builder):
|
||||||
|
return self.o_intx(32, builder)
|
||||||
|
|
||||||
|
def o_int64(self, builder):
|
||||||
|
return self.o_intx(64, builder)
|
||||||
|
|
||||||
|
def o_round(self, builder):
|
||||||
|
return self.o_roundx(32, builder)
|
||||||
|
|
||||||
|
def o_round64(self, builder):
|
||||||
|
return self.o_roundx(64, builder)
|
||||||
|
|
||||||
# None type
|
# None type
|
||||||
|
|
||||||
class VNone:
|
class VNone(_Value):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<VNone>"
|
return "<VNone>"
|
||||||
|
|
||||||
|
def get_llvm_type(self):
|
||||||
|
return lc.Type.void()
|
||||||
|
|
||||||
def same_type(self, other):
|
def same_type(self, other):
|
||||||
return isinstance(other, VNone)
|
return isinstance(other, VNone)
|
||||||
|
|
||||||
|
@ -15,21 +53,24 @@ class VNone:
|
||||||
if not isinstance(other, VNone):
|
if not isinstance(other, VNone):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def create_alloca(self, builder, name):
|
def alloca(self, builder, name):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def o_bool(self, builder):
|
def o_bool(self, builder):
|
||||||
r = VBool()
|
r = VBool()
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
r.create_constant(False)
|
r.set_const_value(builder, False)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
# Integer type
|
# Integer type
|
||||||
|
|
||||||
class VInt:
|
class VInt(_Value):
|
||||||
def __init__(self, nbits=32, llvm_value=None):
|
def __init__(self, nbits=32):
|
||||||
|
_Value.__init__(self)
|
||||||
self.nbits = nbits
|
self.nbits = nbits
|
||||||
self.llvm_value = llvm_value
|
|
||||||
|
def get_llvm_type(self):
|
||||||
|
return lc.Type.int(self.nbits)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<VInt:{}>".format(self.nbits)
|
return "<VInt:{}>".format(self.nbits)
|
||||||
|
@ -44,47 +85,43 @@ class VInt:
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def create_constant(self, n):
|
def set_value(self, builder, n):
|
||||||
self.llvm_value = lc.Constant.int(lc.Type.int(self.nbits), n)
|
self.set_ssa_value(builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
|
||||||
|
|
||||||
def create_alloca(self, builder, name):
|
def set_const_value(self, builder, n):
|
||||||
self.llvm_value = builder.alloca(lc.Type.int(self.nbits), name=name)
|
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
|
||||||
|
|
||||||
def o_bool(self, builder):
|
def o_bool(self, builder):
|
||||||
if builder is None:
|
r = VBool()
|
||||||
return VBool()
|
if builder is not None:
|
||||||
else:
|
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE,
|
||||||
zero = lc.Constant.int(lc.Type.int(self.nbits), 0)
|
self.get_ssa_value(builder), lc.Constant.int(self.get_llvm_type(), 0)))
|
||||||
return VBool(llvm_value=builder.icmp(lc.ICMP_NE, self.llvm_value, zero))
|
return r
|
||||||
|
|
||||||
def _o_intx(self, target_bits, builder):
|
def o_intx(self, target_bits, builder):
|
||||||
if builder is None:
|
r = VInt(target_bits)
|
||||||
return VInt(target_bits)
|
if builder is not None:
|
||||||
else:
|
|
||||||
if self.nbits == target_bits:
|
if self.nbits == target_bits:
|
||||||
return self
|
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
||||||
if self.nbits > target_bits:
|
if self.nbits > target_bits:
|
||||||
return VInt(target_bits, llvm_value=builder.trunc(self.llvm_value, lc.Type.int(target_bits)))
|
r.set_ssa_value(builder, builder.trunc(self.get_ssa_value(builder), r.get_llvm_type()))
|
||||||
if self.nbits < target_bits:
|
if self.nbits < target_bits:
|
||||||
return VInt(target_bits, llvm_value=builder.sext(self.llvm_value, lc.Type.int(target_bits)))
|
r.set_ssa_value(builder, builder.sext(self.get_ssa_value(builder), r.get_llvm_type()))
|
||||||
|
return r
|
||||||
def o_int(self, builder):
|
o_roundx = o_intx
|
||||||
return self._o_intx(32, builder)
|
|
||||||
o_round = o_int
|
|
||||||
|
|
||||||
def o_int64(self, builder):
|
|
||||||
return self._o_intx(64, builder)
|
|
||||||
o_round64 = o_int64
|
|
||||||
|
|
||||||
def _make_vint_binop_method(builder_name):
|
def _make_vint_binop_method(builder_name):
|
||||||
def binop_method(self, other, builder):
|
def binop_method(self, other, builder):
|
||||||
if isinstance(other, VInt):
|
if isinstance(other, VInt):
|
||||||
nbits = max(self.nbits, other.nbits)
|
target_bits = max(self.nbits, other.nbits)
|
||||||
if builder is None:
|
r = VInt(target_bits)
|
||||||
return VInt(nbits)
|
if builder is not None:
|
||||||
else:
|
left = self.o_intx(target_bits, builder)
|
||||||
|
right = other.o_intx(target_bits, builder)
|
||||||
bf = getattr(builder, builder_name)
|
bf = getattr(builder, builder_name)
|
||||||
return VInt(nbits, llvm_value=bf(self.llvm_value, other.llvm_value))
|
r.set_ssa_value(builder,
|
||||||
|
bf(left.get_ssa_value(builder), right.get_ssa_value(builder)))
|
||||||
|
return r
|
||||||
else:
|
else:
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return binop_method
|
return binop_method
|
||||||
|
@ -103,10 +140,14 @@ for _method_name, _builder_name in (
|
||||||
def _make_vint_cmp_method(icmp_val):
|
def _make_vint_cmp_method(icmp_val):
|
||||||
def cmp_method(self, other, builder):
|
def cmp_method(self, other, builder):
|
||||||
if isinstance(other, VInt):
|
if isinstance(other, VInt):
|
||||||
if builder is None:
|
r = VBool()
|
||||||
return VBool()
|
if builder is not None:
|
||||||
else:
|
target_bits = max(self.nbits, other.nbits)
|
||||||
return VBool(llvm_value=builder.icmp(icmp_val, self.llvm_value, other.llvm_value))
|
left = self.o_intx(target_bits, builder)
|
||||||
|
right = other.o_intx(target_bits, builder)
|
||||||
|
r.set_ssa_value(builder,
|
||||||
|
builder.icmp(icmp_val, left.get_ssa_value(builder), right.get_ssa_value(builder)))
|
||||||
|
return r
|
||||||
else:
|
else:
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return cmp_method
|
return cmp_method
|
||||||
|
@ -123,24 +164,27 @@ for _method_name, _icmp_val in (
|
||||||
# Boolean type
|
# Boolean type
|
||||||
|
|
||||||
class VBool(VInt):
|
class VBool(VInt):
|
||||||
def __init__(self, llvm_value=None):
|
def __init__(self):
|
||||||
VInt.__init__(self, 1, llvm_value)
|
VInt.__init__(self, 1)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<VBool>"
|
return "<VBool>"
|
||||||
|
|
||||||
|
def same_type(self, other):
|
||||||
|
return isinstance(other, VBool)
|
||||||
|
|
||||||
def merge(self, other):
|
def merge(self, other):
|
||||||
if not isinstance(other, VBool):
|
if not isinstance(other, VBool):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def create_constant(self, b):
|
def set_const_value(self, builder, b):
|
||||||
VInt.create_constant(self, int(b))
|
VInt.set_const_value(self, builder, int(b))
|
||||||
|
|
||||||
def o_bool(self, builder):
|
def o_bool(self, builder):
|
||||||
if builder is None:
|
r = VBool()
|
||||||
return VBool()
|
if builder is not None:
|
||||||
else:
|
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
||||||
return self
|
return r
|
||||||
|
|
||||||
# Operators
|
# Operators
|
||||||
|
|
||||||
|
|
|
@ -52,14 +52,14 @@ class LinkInterface:
|
||||||
def syscall(self, syscall_name, args, builder):
|
def syscall(self, syscall_name, args, builder):
|
||||||
r = _chr_to_value[_syscalls[syscall_name][-1]]()
|
r = _chr_to_value[_syscalls[syscall_name][-1]]()
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
args = [arg.llvm_value for arg in args]
|
args = [arg.get_ssa_value(builder) for arg in args]
|
||||||
if syscall_name in self.var_arg_fixcount:
|
if syscall_name in self.var_arg_fixcount:
|
||||||
fixcount = self.var_arg_fixcount[syscall_name]
|
fixcount = self.var_arg_fixcount[syscall_name]
|
||||||
args = args[:fixcount] \
|
args = args[:fixcount] \
|
||||||
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
|
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
|
||||||
+ args[fixcount:]
|
+ args[fixcount:]
|
||||||
llvm_function = self.module.get_function_named("__syscall_"+syscall_name)
|
llvm_function = self.module.get_function_named("__syscall_"+syscall_name)
|
||||||
r.llvm_value = builder.call(llvm_function, args)
|
r.set_ssa_value(builder, builder.call(llvm_function, args))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
class Environment(LinkInterface):
|
class Environment(LinkInterface):
|
||||||
|
|
Loading…
Reference in New Issue