compiler.algorithms.inline: implement.

This commit is contained in:
whitequark 2015-11-23 23:58:37 +08:00
parent a4525b21cf
commit f0fd6cd0ca
3 changed files with 158 additions and 0 deletions

View File

@ -0,0 +1 @@
from .inline import inline

View File

@ -0,0 +1,80 @@
"""
:func:`inline` inlines a call instruction in ARTIQ IR.
The call instruction must have a statically known callee,
it must be second to last in the basic block, and the basic
block must have exactly one successor.
"""
from .. import types, builtins, iodelay, ir
def inline(call_insn):
assert isinstance(call_insn, ir.Call)
assert call_insn.static_target_function is not None
assert len(call_insn.basic_block.successors()) == 1
assert call_insn.basic_block.index(call_insn) == \
len(call_insn.basic_block.instructions) - 2
value_map = {}
source_function = call_insn.static_target_function
target_function = call_insn.basic_block.function
target_predecessor = call_insn.basic_block
target_successor = call_insn.basic_block.successors()[0]
if builtins.is_none(source_function.type.ret):
target_return_phi = None
else:
target_return_phi = target_successor.prepend(ir.Phi(source_function.type.ret))
closure = target_predecessor.insert(call_insn,
ir.GetAttr(call_insn.target_function(), '__closure__'))
for actual_arg, formal_arg in zip([closure] + call_insn.arguments(),
source_function.arguments):
value_map[formal_arg] = actual_arg
for source_block in source_function.basic_blocks:
target_block = ir.BasicBlock([], "i." + source_block.name)
target_function.add(target_block)
value_map[source_block] = target_block
def mapper(value):
if isinstance(value, ir.Constant):
return value
else:
return value_map[value]
for source_insn in source_function.instructions():
target_block = value_map[source_insn.basic_block]
if isinstance(source_insn, ir.Return):
if target_return_phi is not None:
target_return_phi.add_incoming(mapper(source_insn.value()), target_block)
target_insn = ir.Branch(target_successor)
elif isinstance(source_insn, ir.Phi):
target_insn = ir.Phi()
elif isinstance(source_insn, ir.Delay):
substs = source_insn.substs()
mapped_substs = {var: value_map[substs[var]] for var in substs}
const_substs = {var: iodelay.Const(mapped_substs[var].value)
for var in mapped_substs
if isinstance(mapped_substs[var], ir.Constant)}
other_substs = {var: mapped_substs[var]
for var in mapped_substs
if not isinstance(mapped_substs[var], ir.Constant)}
target_insn = ir.Delay(source_insn.expr.fold(const_substs), other_substs,
value_map[source_insn.decomposition()],
value_map[source_insn.target()])
else:
target_insn = source_insn.copy(mapper)
target_insn.name = "i." + source_insn.name
value_map[source_insn] = target_insn
target_block.append(target_insn)
for source_insn in source_function.instructions():
if isinstance(source_insn, ir.Phi):
target_insn = value_map[source_insn]
for block, value in source_insn.incoming():
target_insn.add_incoming(value_map[value], value_map[block])
target_predecessor.terminator().replace_with(ir.Branch(value_map[source_function.entry()]))
if target_return_phi is not None:
call_insn.replace_all_uses_with(target_return_phi)
call_insn.erase()

View File

@ -163,6 +163,12 @@ class Instruction(User):
self.basic_block = None
self.loc = None
def copy(self, mapper):
self_copy = self.__class__.__new__(self.__class__)
Instruction.__init__(self_copy, list(map(mapper, self.operands)),
self.type, self.name)
return self_copy
def set_basic_block(self, new_basic_block):
self.basic_block = new_basic_block
if self.basic_block is not None:
@ -585,6 +591,11 @@ class GetLocal(Instruction):
super().__init__([env], env.type.type_of(var_name), name)
self.var_name = var_name
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.var_name = self.var_name
return self_copy
def opcode(self):
return "getlocal({})".format(repr(self.var_name))
@ -613,6 +624,11 @@ class SetLocal(Instruction):
super().__init__([env, value], builtins.TNone(), name)
self.var_name = var_name
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.var_name = self.var_name
return self_copy
def opcode(self):
return "setlocal({})".format(repr(self.var_name))
@ -643,6 +659,11 @@ class GetConstructor(Instruction):
super().__init__([env], var_type, name)
self.var_name = var_name
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.var_name = self.var_name
return self_copy
def opcode(self):
return "getconstructor({})".format(repr(self.var_name))
@ -672,6 +693,11 @@ class GetAttr(Instruction):
super().__init__([obj], typ, name)
self.attr = attr
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.attr = self.attr
return self_copy
def opcode(self):
return "getattr({})".format(repr(self.attr))
@ -701,6 +727,11 @@ class SetAttr(Instruction):
super().__init__([obj, value], builtins.TNone(), name)
self.attr = attr
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.attr = self.attr
return self_copy
def opcode(self):
return "setattr({})".format(repr(self.attr))
@ -798,6 +829,11 @@ class Arith(Instruction):
super().__init__([lhs, rhs], lhs.type, name)
self.op = op
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.op = self.op
return self_copy
def opcode(self):
return "arith({})".format(type(self.op).__name__)
@ -827,6 +863,11 @@ class Compare(Instruction):
super().__init__([lhs, rhs], builtins.TBool(), name)
self.op = op
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.op = self.op
return self_copy
def opcode(self):
return "compare({})".format(type(self.op).__name__)
@ -853,6 +894,11 @@ class Builtin(Instruction):
super().__init__(operands, typ, name)
self.op = op
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.op = self.op
return self_copy
def opcode(self):
return "builtin({})".format(self.op)
@ -874,6 +920,11 @@ class Closure(Instruction):
super().__init__([env], func.type, name)
self.target_function = func
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.target_function = self.target_function
return self_copy
def opcode(self):
return "closure({})".format(self.target_function.name)
@ -898,6 +949,11 @@ class Call(Instruction):
super().__init__([func] + args, func.type.ret, name)
self.static_target_function = None
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.static_target_function = self.static_target_function
return self_copy
def opcode(self):
return "call"
@ -957,6 +1013,11 @@ class Quote(Instruction):
super().__init__([], typ, name)
self.value = value
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.value = self.value
return self_copy
def opcode(self):
return "quote({})".format(repr(self.value))
@ -1148,6 +1209,11 @@ class Invoke(Terminator):
super().__init__([func] + args + [normal, exn], func.type.ret, name)
self.static_target_function = None
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.static_target_function = self.static_target_function
return self_copy
def opcode(self):
return "invoke"
@ -1191,6 +1257,11 @@ class LandingPad(Terminator):
super().__init__([cleanup], builtins.TException(), name)
self.types = []
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.types = list(self.types)
return self_copy
def opcode(self):
return "landingpad"
@ -1245,6 +1316,12 @@ class Delay(Terminator):
self.expr = expr
self.var_names = list(substs.keys())
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.expr = self.expr
self_copy.var_names = list(self.var_names)
return self_copy
def decomposition(self):
return self.operands[0]