diff --git a/artiq/compiler/algorithms/__init__.py b/artiq/compiler/algorithms/__init__.py new file mode 100644 index 000000000..47fcd2dbf --- /dev/null +++ b/artiq/compiler/algorithms/__init__.py @@ -0,0 +1 @@ +from .inline import inline diff --git a/artiq/compiler/algorithms/inline.py b/artiq/compiler/algorithms/inline.py new file mode 100644 index 000000000..5030f23fa --- /dev/null +++ b/artiq/compiler/algorithms/inline.py @@ -0,0 +1,80 @@ +""" +:func:`inline` inlines a call instruction in ARTIQ IR. +The call instruction must have a statically known callee, +it must be second to last in the basic block, and the basic +block must have exactly one successor. +""" + +from .. import types, builtins, iodelay, ir + +def inline(call_insn): + assert isinstance(call_insn, ir.Call) + assert call_insn.static_target_function is not None + assert len(call_insn.basic_block.successors()) == 1 + assert call_insn.basic_block.index(call_insn) == \ + len(call_insn.basic_block.instructions) - 2 + + value_map = {} + source_function = call_insn.static_target_function + target_function = call_insn.basic_block.function + target_predecessor = call_insn.basic_block + target_successor = call_insn.basic_block.successors()[0] + + if builtins.is_none(source_function.type.ret): + target_return_phi = None + else: + target_return_phi = target_successor.prepend(ir.Phi(source_function.type.ret)) + + closure = target_predecessor.insert(call_insn, + ir.GetAttr(call_insn.target_function(), '__closure__')) + for actual_arg, formal_arg in zip([closure] + call_insn.arguments(), + source_function.arguments): + value_map[formal_arg] = actual_arg + + for source_block in source_function.basic_blocks: + target_block = ir.BasicBlock([], "i." + source_block.name) + target_function.add(target_block) + value_map[source_block] = target_block + + def mapper(value): + if isinstance(value, ir.Constant): + return value + else: + return value_map[value] + + for source_insn in source_function.instructions(): + target_block = value_map[source_insn.basic_block] + if isinstance(source_insn, ir.Return): + if target_return_phi is not None: + target_return_phi.add_incoming(mapper(source_insn.value()), target_block) + target_insn = ir.Branch(target_successor) + elif isinstance(source_insn, ir.Phi): + target_insn = ir.Phi() + elif isinstance(source_insn, ir.Delay): + substs = source_insn.substs() + mapped_substs = {var: value_map[substs[var]] for var in substs} + const_substs = {var: iodelay.Const(mapped_substs[var].value) + for var in mapped_substs + if isinstance(mapped_substs[var], ir.Constant)} + other_substs = {var: mapped_substs[var] + for var in mapped_substs + if not isinstance(mapped_substs[var], ir.Constant)} + target_insn = ir.Delay(source_insn.expr.fold(const_substs), other_substs, + value_map[source_insn.decomposition()], + value_map[source_insn.target()]) + else: + target_insn = source_insn.copy(mapper) + target_insn.name = "i." + source_insn.name + value_map[source_insn] = target_insn + target_block.append(target_insn) + + for source_insn in source_function.instructions(): + if isinstance(source_insn, ir.Phi): + target_insn = value_map[source_insn] + for block, value in source_insn.incoming(): + target_insn.add_incoming(value_map[value], value_map[block]) + + target_predecessor.terminator().replace_with(ir.Branch(value_map[source_function.entry()])) + if target_return_phi is not None: + call_insn.replace_all_uses_with(target_return_phi) + call_insn.erase() diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 6a363ef43..1586c5bd4 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -163,6 +163,12 @@ class Instruction(User): self.basic_block = None self.loc = None + def copy(self, mapper): + self_copy = self.__class__.__new__(self.__class__) + Instruction.__init__(self_copy, list(map(mapper, self.operands)), + self.type, self.name) + return self_copy + def set_basic_block(self, new_basic_block): self.basic_block = new_basic_block if self.basic_block is not None: @@ -585,6 +591,11 @@ class GetLocal(Instruction): super().__init__([env], env.type.type_of(var_name), name) self.var_name = var_name + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.var_name = self.var_name + return self_copy + def opcode(self): return "getlocal({})".format(repr(self.var_name)) @@ -613,6 +624,11 @@ class SetLocal(Instruction): super().__init__([env, value], builtins.TNone(), name) self.var_name = var_name + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.var_name = self.var_name + return self_copy + def opcode(self): return "setlocal({})".format(repr(self.var_name)) @@ -643,6 +659,11 @@ class GetConstructor(Instruction): super().__init__([env], var_type, name) self.var_name = var_name + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.var_name = self.var_name + return self_copy + def opcode(self): return "getconstructor({})".format(repr(self.var_name)) @@ -672,6 +693,11 @@ class GetAttr(Instruction): super().__init__([obj], typ, name) self.attr = attr + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.attr = self.attr + return self_copy + def opcode(self): return "getattr({})".format(repr(self.attr)) @@ -701,6 +727,11 @@ class SetAttr(Instruction): super().__init__([obj, value], builtins.TNone(), name) self.attr = attr + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.attr = self.attr + return self_copy + def opcode(self): return "setattr({})".format(repr(self.attr)) @@ -798,6 +829,11 @@ class Arith(Instruction): super().__init__([lhs, rhs], lhs.type, name) self.op = op + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.op = self.op + return self_copy + def opcode(self): return "arith({})".format(type(self.op).__name__) @@ -827,6 +863,11 @@ class Compare(Instruction): super().__init__([lhs, rhs], builtins.TBool(), name) self.op = op + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.op = self.op + return self_copy + def opcode(self): return "compare({})".format(type(self.op).__name__) @@ -853,6 +894,11 @@ class Builtin(Instruction): super().__init__(operands, typ, name) self.op = op + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.op = self.op + return self_copy + def opcode(self): return "builtin({})".format(self.op) @@ -874,6 +920,11 @@ class Closure(Instruction): super().__init__([env], func.type, name) self.target_function = func + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.target_function = self.target_function + return self_copy + def opcode(self): return "closure({})".format(self.target_function.name) @@ -898,6 +949,11 @@ class Call(Instruction): super().__init__([func] + args, func.type.ret, name) self.static_target_function = None + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.static_target_function = self.static_target_function + return self_copy + def opcode(self): return "call" @@ -957,6 +1013,11 @@ class Quote(Instruction): super().__init__([], typ, name) self.value = value + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.value = self.value + return self_copy + def opcode(self): return "quote({})".format(repr(self.value)) @@ -1148,6 +1209,11 @@ class Invoke(Terminator): super().__init__([func] + args + [normal, exn], func.type.ret, name) self.static_target_function = None + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.static_target_function = self.static_target_function + return self_copy + def opcode(self): return "invoke" @@ -1191,6 +1257,11 @@ class LandingPad(Terminator): super().__init__([cleanup], builtins.TException(), name) self.types = [] + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.types = list(self.types) + return self_copy + def opcode(self): return "landingpad" @@ -1245,6 +1316,12 @@ class Delay(Terminator): self.expr = expr self.var_names = list(substs.keys()) + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.expr = self.expr + self_copy.var_names = list(self.var_names) + return self_copy + def decomposition(self): return self.operands[0]