compiler.algorithms.inline: implement.

This commit is contained in:
whitequark 2015-11-23 23:58:37 +08:00
parent a4525b21cf
commit f0fd6cd0ca
3 changed files with 158 additions and 0 deletions

View File

@ -0,0 +1 @@
from .inline import inline

View File

@ -0,0 +1,80 @@
"""
:func:`inline` inlines a call instruction in ARTIQ IR.
The call instruction must have a statically known callee,
it must be second to last in the basic block, and the basic
block must have exactly one successor.
"""
from .. import types, builtins, iodelay, ir
def inline(call_insn):
assert isinstance(call_insn, ir.Call)
assert call_insn.static_target_function is not None
assert len(call_insn.basic_block.successors()) == 1
assert call_insn.basic_block.index(call_insn) == \
len(call_insn.basic_block.instructions) - 2
value_map = {}
source_function = call_insn.static_target_function
target_function = call_insn.basic_block.function
target_predecessor = call_insn.basic_block
target_successor = call_insn.basic_block.successors()[0]
if builtins.is_none(source_function.type.ret):
target_return_phi = None
else:
target_return_phi = target_successor.prepend(ir.Phi(source_function.type.ret))
closure = target_predecessor.insert(call_insn,
ir.GetAttr(call_insn.target_function(), '__closure__'))
for actual_arg, formal_arg in zip([closure] + call_insn.arguments(),
source_function.arguments):
value_map[formal_arg] = actual_arg
for source_block in source_function.basic_blocks:
target_block = ir.BasicBlock([], "i." + source_block.name)
target_function.add(target_block)
value_map[source_block] = target_block
def mapper(value):
if isinstance(value, ir.Constant):
return value
else:
return value_map[value]
for source_insn in source_function.instructions():
target_block = value_map[source_insn.basic_block]
if isinstance(source_insn, ir.Return):
if target_return_phi is not None:
target_return_phi.add_incoming(mapper(source_insn.value()), target_block)
target_insn = ir.Branch(target_successor)
elif isinstance(source_insn, ir.Phi):
target_insn = ir.Phi()
elif isinstance(source_insn, ir.Delay):
substs = source_insn.substs()
mapped_substs = {var: value_map[substs[var]] for var in substs}
const_substs = {var: iodelay.Const(mapped_substs[var].value)
for var in mapped_substs
if isinstance(mapped_substs[var], ir.Constant)}
other_substs = {var: mapped_substs[var]
for var in mapped_substs
if not isinstance(mapped_substs[var], ir.Constant)}
target_insn = ir.Delay(source_insn.expr.fold(const_substs), other_substs,
value_map[source_insn.decomposition()],
value_map[source_insn.target()])
else:
target_insn = source_insn.copy(mapper)
target_insn.name = "i." + source_insn.name
value_map[source_insn] = target_insn
target_block.append(target_insn)
for source_insn in source_function.instructions():
if isinstance(source_insn, ir.Phi):
target_insn = value_map[source_insn]
for block, value in source_insn.incoming():
target_insn.add_incoming(value_map[value], value_map[block])
target_predecessor.terminator().replace_with(ir.Branch(value_map[source_function.entry()]))
if target_return_phi is not None:
call_insn.replace_all_uses_with(target_return_phi)
call_insn.erase()

View File

@ -163,6 +163,12 @@ class Instruction(User):
self.basic_block = None self.basic_block = None
self.loc = None self.loc = None
def copy(self, mapper):
self_copy = self.__class__.__new__(self.__class__)
Instruction.__init__(self_copy, list(map(mapper, self.operands)),
self.type, self.name)
return self_copy
def set_basic_block(self, new_basic_block): def set_basic_block(self, new_basic_block):
self.basic_block = new_basic_block self.basic_block = new_basic_block
if self.basic_block is not None: if self.basic_block is not None:
@ -585,6 +591,11 @@ class GetLocal(Instruction):
super().__init__([env], env.type.type_of(var_name), name) super().__init__([env], env.type.type_of(var_name), name)
self.var_name = var_name self.var_name = var_name
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.var_name = self.var_name
return self_copy
def opcode(self): def opcode(self):
return "getlocal({})".format(repr(self.var_name)) return "getlocal({})".format(repr(self.var_name))
@ -613,6 +624,11 @@ class SetLocal(Instruction):
super().__init__([env, value], builtins.TNone(), name) super().__init__([env, value], builtins.TNone(), name)
self.var_name = var_name self.var_name = var_name
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.var_name = self.var_name
return self_copy
def opcode(self): def opcode(self):
return "setlocal({})".format(repr(self.var_name)) return "setlocal({})".format(repr(self.var_name))
@ -643,6 +659,11 @@ class GetConstructor(Instruction):
super().__init__([env], var_type, name) super().__init__([env], var_type, name)
self.var_name = var_name self.var_name = var_name
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.var_name = self.var_name
return self_copy
def opcode(self): def opcode(self):
return "getconstructor({})".format(repr(self.var_name)) return "getconstructor({})".format(repr(self.var_name))
@ -672,6 +693,11 @@ class GetAttr(Instruction):
super().__init__([obj], typ, name) super().__init__([obj], typ, name)
self.attr = attr self.attr = attr
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.attr = self.attr
return self_copy
def opcode(self): def opcode(self):
return "getattr({})".format(repr(self.attr)) return "getattr({})".format(repr(self.attr))
@ -701,6 +727,11 @@ class SetAttr(Instruction):
super().__init__([obj, value], builtins.TNone(), name) super().__init__([obj, value], builtins.TNone(), name)
self.attr = attr self.attr = attr
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.attr = self.attr
return self_copy
def opcode(self): def opcode(self):
return "setattr({})".format(repr(self.attr)) return "setattr({})".format(repr(self.attr))
@ -798,6 +829,11 @@ class Arith(Instruction):
super().__init__([lhs, rhs], lhs.type, name) super().__init__([lhs, rhs], lhs.type, name)
self.op = op self.op = op
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.op = self.op
return self_copy
def opcode(self): def opcode(self):
return "arith({})".format(type(self.op).__name__) return "arith({})".format(type(self.op).__name__)
@ -827,6 +863,11 @@ class Compare(Instruction):
super().__init__([lhs, rhs], builtins.TBool(), name) super().__init__([lhs, rhs], builtins.TBool(), name)
self.op = op self.op = op
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.op = self.op
return self_copy
def opcode(self): def opcode(self):
return "compare({})".format(type(self.op).__name__) return "compare({})".format(type(self.op).__name__)
@ -853,6 +894,11 @@ class Builtin(Instruction):
super().__init__(operands, typ, name) super().__init__(operands, typ, name)
self.op = op self.op = op
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.op = self.op
return self_copy
def opcode(self): def opcode(self):
return "builtin({})".format(self.op) return "builtin({})".format(self.op)
@ -874,6 +920,11 @@ class Closure(Instruction):
super().__init__([env], func.type, name) super().__init__([env], func.type, name)
self.target_function = func self.target_function = func
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.target_function = self.target_function
return self_copy
def opcode(self): def opcode(self):
return "closure({})".format(self.target_function.name) return "closure({})".format(self.target_function.name)
@ -898,6 +949,11 @@ class Call(Instruction):
super().__init__([func] + args, func.type.ret, name) super().__init__([func] + args, func.type.ret, name)
self.static_target_function = None self.static_target_function = None
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.static_target_function = self.static_target_function
return self_copy
def opcode(self): def opcode(self):
return "call" return "call"
@ -957,6 +1013,11 @@ class Quote(Instruction):
super().__init__([], typ, name) super().__init__([], typ, name)
self.value = value self.value = value
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.value = self.value
return self_copy
def opcode(self): def opcode(self):
return "quote({})".format(repr(self.value)) return "quote({})".format(repr(self.value))
@ -1148,6 +1209,11 @@ class Invoke(Terminator):
super().__init__([func] + args + [normal, exn], func.type.ret, name) super().__init__([func] + args + [normal, exn], func.type.ret, name)
self.static_target_function = None self.static_target_function = None
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.static_target_function = self.static_target_function
return self_copy
def opcode(self): def opcode(self):
return "invoke" return "invoke"
@ -1191,6 +1257,11 @@ class LandingPad(Terminator):
super().__init__([cleanup], builtins.TException(), name) super().__init__([cleanup], builtins.TException(), name)
self.types = [] self.types = []
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.types = list(self.types)
return self_copy
def opcode(self): def opcode(self):
return "landingpad" return "landingpad"
@ -1245,6 +1316,12 @@ class Delay(Terminator):
self.expr = expr self.expr = expr
self.var_names = list(substs.keys()) self.var_names = list(substs.keys())
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.expr = self.expr
self_copy.var_names = list(self.var_names)
return self_copy
def decomposition(self): def decomposition(self):
return self.operands[0] return self.operands[0]