diff --git a/artiq/compiler/algorithms/__init__.py b/artiq/compiler/algorithms/__init__.py index 47fcd2dbf..50fc4304b 100644 --- a/artiq/compiler/algorithms/__init__.py +++ b/artiq/compiler/algorithms/__init__.py @@ -1 +1,2 @@ from .inline import inline +from .unroll import unroll diff --git a/artiq/compiler/algorithms/unroll.py b/artiq/compiler/algorithms/unroll.py new file mode 100644 index 000000000..392a31a8c --- /dev/null +++ b/artiq/compiler/algorithms/unroll.py @@ -0,0 +1,97 @@ +""" +:func:`unroll` unrolls a loop instruction in ARTIQ IR. +The loop's trip count must be constant. +The loop body must not have any control flow instructions +except for one branch back to the loop head. +The loop body must be executed if the condition to which +the instruction refers is true. +""" + +from .. import types, builtins, iodelay, ir +from ..analyses import domination + +def _get_body_blocks(root, limit): + postorder = [] + + visited = set() + def visit(block): + visited.add(block) + for next_block in block.successors(): + if next_block not in visited and next_block is not limit: + visit(next_block) + postorder.append(block) + + visit(root) + + postorder.reverse() + return postorder + +def unroll(loop_insn): + loop_head = loop_insn.basic_block + function = loop_head.function + assert isinstance(loop_insn, ir.Loop) + assert len(loop_head.predecessors()) == 2 + assert len(loop_insn.if_false().predecessors()) == 1 + assert iodelay.is_const(loop_insn.trip_count) + + trip_count = loop_insn.trip_count.fold().value + if trip_count == 0: + loop_insn.replace_with(ir.Branch(loop_insn.if_false())) + return + + source_blocks = _get_body_blocks(loop_insn.if_true(), loop_head) + source_indvar = loop_insn.induction_variable() + source_tail = loop_insn.if_false() + unroll_target = loop_head + for n in range(trip_count): + value_map = {source_indvar: ir.Constant(n, source_indvar.type)} + + for source_block in source_blocks: + target_block = ir.BasicBlock([], "u{}.{}".format(n, source_block.name)) + function.add(target_block) + value_map[source_block] = target_block + + def mapper(value): + if isinstance(value, ir.Constant): + return value + elif value in value_map: + return value_map[value] + else: + return value + + for source_block in source_blocks: + target_block = value_map[source_block] + for source_insn in source_block.instructions: + if isinstance(source_insn, ir.Phi): + target_insn = ir.Phi() + else: + target_insn = source_insn.copy(mapper) + target_insn.name = "u{}.{}".format(n, source_insn.name) + target_block.append(target_insn) + value_map[source_insn] = target_insn + + for source_block in source_blocks: + for source_insn in source_block.instructions: + if isinstance(source_insn, ir.Phi): + target_insn = value_map[source_insn] + for block, value in source_insn.incoming(): + target_insn.add_incoming(value_map[value], value_map[block]) + + assert isinstance(unroll_target.terminator(), (ir.Branch, ir.Loop)) + unroll_target.terminator().replace_with(ir.Branch(value_map[source_blocks[0]])) + unroll_target = value_map[source_blocks[-1]] + + assert isinstance(unroll_target.terminator(), ir.Branch) + assert len(source_blocks[-1].successors()) == 1 + unroll_target.terminator().replace_with(ir.Branch(source_tail)) + + for source_block in reversed(source_blocks): + for source_insn in reversed(source_block.instructions): + for use in set(source_insn.uses): + if isinstance(use, ir.Phi): + assert use.basic_block == loop_head + use.remove_incoming_value(source_insn) + source_insn.erase() + + for source_block in reversed(source_blocks): + source_block.erase() diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 6ef6dbf35..b7fb18402 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -1371,17 +1371,21 @@ class Loop(Terminator): :param trip_count: (:class:`iodelay.Expr`) expression :param substs: (dict of str to :class:`Value`) SSA values corresponding to iodelay variable names + :param indvar: (:class:`Phi`) + phi node corresponding to the induction SSA value, + which advances from ``0`` to ``trip_count - 1`` :param cond: (:class:`Value`) branch condition :param if_true: (:class:`BasicBlock`) branch target if condition is truthful :param if_false: (:class:`BasicBlock`) branch target if condition is falseful """ - def __init__(self, trip_count, substs, cond, if_true, if_false, name=""): + def __init__(self, trip_count, substs, indvar, cond, if_true, if_false, name=""): for var_name in substs: assert isinstance(var_name, str) + assert isinstance(indvar, Phi) assert isinstance(cond, Value) assert builtins.is_bool(cond.type) assert isinstance(if_true, BasicBlock) assert isinstance(if_false, BasicBlock) - super().__init__([cond, if_true, if_false, *substs.values()], builtins.TNone(), name) + super().__init__([indvar, cond, if_true, if_false, *substs.values()], builtins.TNone(), name) self.trip_count = trip_count self.var_names = list(substs.keys()) @@ -1391,17 +1395,20 @@ class Loop(Terminator): self_copy.var_names = list(self.var_names) return self_copy - def condition(self): + def induction_variable(self): return self.operands[0] - def if_true(self): + def condition(self): return self.operands[1] - def if_false(self): + def if_true(self): return self.operands[2] + def if_false(self): + return self.operands[3] + def substs(self): - return {key: value for key, value in zip(self.var_names, self.operands[3:])} + return {key: value for key, value in zip(self.var_names, self.operands[4:])} def _operands_as_string(self, type_printer): substs = self.substs() @@ -1409,8 +1416,8 @@ class Loop(Terminator): for var_name in substs: substs_as_strings.append("{} = {}".format(var_name, substs[var_name])) result = "[{}]".format(", ".join(substs_as_strings)) - result += ", {}, {}, {}".format(*list(map(lambda value: value.as_operand(type_printer), - self.operands[0:3]))) + result += ", indvar {}, if {}, {}, {}".format( + *list(map(lambda value: value.as_operand(type_printer), self.operands[0:4]))) return result def opcode(self): diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index 03ddaae91..8f7a1cb52 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -73,6 +73,7 @@ class Module: dead_code_eliminator.process(self.artiq_ir) local_access_validator.process(self.artiq_ir) interleaver.process(self.artiq_ir) + dead_code_eliminator.process(self.artiq_ir) def build_llvm_ir(self, target): """Compile the module to LLVM IR for the specified target.""" diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index de5ad63cb..e3a6ae000 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -525,7 +525,7 @@ class ARTIQIRGenerator(algorithm.Visitor): if node.trip_count is not None: substs = {var_name: self.current_args[var_name] for var_name in node.trip_count.free_vars()} - head.append(ir.Loop(node.trip_count, substs, cond, body, else_tail)) + head.append(ir.Loop(node.trip_count, substs, phi, cond, body, else_tail)) else: head.append(ir.BranchIf(cond, body, else_tail)) if not post_body.is_terminated(): diff --git a/artiq/compiler/transforms/interleaver.py b/artiq/compiler/transforms/interleaver.py index e26acd54b..d1f292979 100644 --- a/artiq/compiler/transforms/interleaver.py +++ b/artiq/compiler/transforms/interleaver.py @@ -7,7 +7,7 @@ from pythonparser import diagnostic from .. import types, builtins, ir, iodelay from ..analyses import domination -from ..algorithms import inline +from ..algorithms import inline, unroll def delay_free_subgraph(root, limit): visited = set() @@ -152,6 +152,11 @@ class Interleaver: source_terminator.interval = iodelay.Const(target_time_delta) else: source_terminator.replace_with(ir.Branch(source_terminator.target())) + elif isinstance(source_terminator, ir.Loop): + unroll(source_terminator) + + postdom_tree = domination.PostDominatorTree(func) + continue else: assert False diff --git a/lit-test/test/interleaving/unrolling.py b/lit-test/test/interleaving/unrolling.py new file mode 100644 index 000000000..f83a9255b --- /dev/null +++ b/lit-test/test/interleaving/unrolling.py @@ -0,0 +1,23 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +with parallel: + for x in range(10): + delay_mu(1) + print("a", x) + with sequential: + delay_mu(5) + print("c") + with sequential: + delay_mu(3) + print("b") + +# CHECK-L: a 0 +# CHECK-L: a 1 +# CHECK-L: a 2 +# CHECK-L: b +# CHECK-L: a 3 +# CHECK-L: a 4 +# CHECK-L: c +# CHECK-L: a 5 +# CHECK-L: a 6