diff --git a/artiq/compiler/transforms/interleaver.py b/artiq/compiler/transforms/interleaver.py index 282c84534..873e55eac 100644 --- a/artiq/compiler/transforms/interleaver.py +++ b/artiq/compiler/transforms/interleaver.py @@ -5,6 +5,7 @@ the timestamp would always monotonically nondecrease. from .. import types, builtins, ir, iodelay from ..analyses import domination +from ..algorithms import inline def delay_free_subgraph(root, limit): visited = set() @@ -25,9 +26,24 @@ def delay_free_subgraph(root, limit): return True +def iodelay_of_block(block): + terminator = block.terminator() + if isinstance(terminator, ir.Delay): + # We should be able to fold everything without free variables. + folded_expr = terminator.expr.fold() + assert iodelay.is_const(folded_expr) + return folded_expr.value + else: + return 0 + def is_pure_delay(insn): return isinstance(insn, ir.Builtin) and insn.op in ("delay", "delay_mu") +def is_impure_delay_block(block): + terminator = block.terminator() + return isinstance(terminator, ir.Delay) and \ + not is_pure_delay(terminator.decomposition()) + class Interleaver: def __init__(self, engine): self.engine = engine @@ -64,25 +80,24 @@ class Interleaver: source_times = [0 for _ in source_blocks] while len(source_blocks) > 0: - def iodelay_of_block(block): - terminator = block.terminator() - if isinstance(terminator, ir.Delay): - # We should be able to fold everything without free variables. - folded_expr = terminator.expr.fold() - assert iodelay.is_const(folded_expr) - return folded_expr.value - else: - return 0 - def time_after_block(pair): index, block = pair return source_times[index] + iodelay_of_block(block) - index, source_block = min(enumerate(source_blocks), key=time_after_block) + # Always prefer impure blocks (with calls) to pure blocks, because + # impure blocks may expand with smaller delays appearing, and in + # case of a tie, if a pure block is preferred, this would violate + # the timeline monotonicity. + available_source_blocks = list(filter(is_impure_delay_block, source_blocks)) + if not any(available_source_blocks): + available_source_blocks = source_blocks + + index, source_block = min(enumerate(available_source_blocks), key=time_after_block) source_block_delay = iodelay_of_block(source_block) new_target_time = source_times[index] + source_block_delay target_time_delta = new_target_time - target_time + assert target_time_delta >= 0 target_terminator = target_block.terminator() if isinstance(target_terminator, ir.Parallel): @@ -93,42 +108,32 @@ class Interleaver: source_terminator = source_block.terminator() - if isinstance(source_terminator, ir.Delay): - old_decomp = source_terminator.decomposition() - else: - old_decomp = None - - if target_time_delta > 0: - assert isinstance(source_terminator, ir.Delay) - - if is_pure_delay(old_decomp): - new_decomp_expr = ir.Constant(int(target_time_delta), builtins.TInt64()) - new_decomp = ir.Builtin("delay_mu", [new_decomp_expr], builtins.TNone()) - new_decomp.loc = old_decomp.loc - source_terminator.basic_block.insert(source_terminator, new_decomp) - else: # It's a call. - need_to_inline = False - for other_source_block in filter(lambda block: block != source_block, - source_blocks): - other_source_terminator = other_source_block.terminator() - if not (is_pure_delay(other_source_terminator.decomposition()) and \ - iodelay.is_const(other_source_terminator.expr) and \ - other_source_terminator.expr.fold().value >= source_block_delay): - need_to_inline = True - break - - if need_to_inline: - assert False - else: - old_decomp, new_decomp = None, old_decomp - - source_terminator.replace_with(ir.Delay(iodelay.Const(target_time_delta), {}, - new_decomp, source_terminator.target())) - else: + if not isinstance(source_terminator, ir.Delay): source_terminator.replace_with(ir.Branch(source_terminator.target())) + else: + old_decomp = source_terminator.decomposition() + if is_pure_delay(old_decomp): + if target_time_delta > 0: + new_decomp_expr = ir.Constant(int(target_time_delta), builtins.TInt64()) + new_decomp = ir.Builtin("delay_mu", [new_decomp_expr], builtins.TNone()) + new_decomp.loc = old_decomp.loc - if old_decomp is not None: - old_decomp.erase() + source_terminator.basic_block.insert(new_decomp, before=source_terminator) + source_terminator.expr = iodelay.Const(target_time_delta) + source_terminator.set_decomposition(new_decomp) + else: + source_terminator.replace_with(ir.Branch(source_terminator.target())) + old_decomp.erase() + else: # It's a call. + need_to_inline = len(source_blocks) > 1 + if need_to_inline: + inline(old_decomp) + postdom_tree = domination.PostDominatorTree(func) + continue + elif target_time_delta > 0: + source_terminator.expr = iodelay.Const(target_time_delta) + else: + source_terminator.replace_with(ir.Branch(source_terminator.target())) target_block = source_block target_time = new_target_time