transforms.interleaver: inline calls.

This commit is contained in:
whitequark 2015-11-24 00:02:07 +08:00
parent 2a82eb7219
commit 178ff74da2
1 changed files with 50 additions and 45 deletions

View File

@ -5,6 +5,7 @@ the timestamp would always monotonically nondecrease.
from .. import types, builtins, ir, iodelay from .. import types, builtins, ir, iodelay
from ..analyses import domination from ..analyses import domination
from ..algorithms import inline
def delay_free_subgraph(root, limit): def delay_free_subgraph(root, limit):
visited = set() visited = set()
@ -25,9 +26,24 @@ def delay_free_subgraph(root, limit):
return True return True
def iodelay_of_block(block):
terminator = block.terminator()
if isinstance(terminator, ir.Delay):
# We should be able to fold everything without free variables.
folded_expr = terminator.expr.fold()
assert iodelay.is_const(folded_expr)
return folded_expr.value
else:
return 0
def is_pure_delay(insn): def is_pure_delay(insn):
return isinstance(insn, ir.Builtin) and insn.op in ("delay", "delay_mu") return isinstance(insn, ir.Builtin) and insn.op in ("delay", "delay_mu")
def is_impure_delay_block(block):
terminator = block.terminator()
return isinstance(terminator, ir.Delay) and \
not is_pure_delay(terminator.decomposition())
class Interleaver: class Interleaver:
def __init__(self, engine): def __init__(self, engine):
self.engine = engine self.engine = engine
@ -64,25 +80,24 @@ class Interleaver:
source_times = [0 for _ in source_blocks] source_times = [0 for _ in source_blocks]
while len(source_blocks) > 0: while len(source_blocks) > 0:
def iodelay_of_block(block):
terminator = block.terminator()
if isinstance(terminator, ir.Delay):
# We should be able to fold everything without free variables.
folded_expr = terminator.expr.fold()
assert iodelay.is_const(folded_expr)
return folded_expr.value
else:
return 0
def time_after_block(pair): def time_after_block(pair):
index, block = pair index, block = pair
return source_times[index] + iodelay_of_block(block) return source_times[index] + iodelay_of_block(block)
index, source_block = min(enumerate(source_blocks), key=time_after_block) # Always prefer impure blocks (with calls) to pure blocks, because
# impure blocks may expand with smaller delays appearing, and in
# case of a tie, if a pure block is preferred, this would violate
# the timeline monotonicity.
available_source_blocks = list(filter(is_impure_delay_block, source_blocks))
if not any(available_source_blocks):
available_source_blocks = source_blocks
index, source_block = min(enumerate(available_source_blocks), key=time_after_block)
source_block_delay = iodelay_of_block(source_block) source_block_delay = iodelay_of_block(source_block)
new_target_time = source_times[index] + source_block_delay new_target_time = source_times[index] + source_block_delay
target_time_delta = new_target_time - target_time target_time_delta = new_target_time - target_time
assert target_time_delta >= 0
target_terminator = target_block.terminator() target_terminator = target_block.terminator()
if isinstance(target_terminator, ir.Parallel): if isinstance(target_terminator, ir.Parallel):
@ -93,42 +108,32 @@ class Interleaver:
source_terminator = source_block.terminator() source_terminator = source_block.terminator()
if isinstance(source_terminator, ir.Delay): if not isinstance(source_terminator, ir.Delay):
old_decomp = source_terminator.decomposition()
else:
old_decomp = None
if target_time_delta > 0:
assert isinstance(source_terminator, ir.Delay)
if is_pure_delay(old_decomp):
new_decomp_expr = ir.Constant(int(target_time_delta), builtins.TInt64())
new_decomp = ir.Builtin("delay_mu", [new_decomp_expr], builtins.TNone())
new_decomp.loc = old_decomp.loc
source_terminator.basic_block.insert(source_terminator, new_decomp)
else: # It's a call.
need_to_inline = False
for other_source_block in filter(lambda block: block != source_block,
source_blocks):
other_source_terminator = other_source_block.terminator()
if not (is_pure_delay(other_source_terminator.decomposition()) and \
iodelay.is_const(other_source_terminator.expr) and \
other_source_terminator.expr.fold().value >= source_block_delay):
need_to_inline = True
break
if need_to_inline:
assert False
else:
old_decomp, new_decomp = None, old_decomp
source_terminator.replace_with(ir.Delay(iodelay.Const(target_time_delta), {},
new_decomp, source_terminator.target()))
else:
source_terminator.replace_with(ir.Branch(source_terminator.target())) source_terminator.replace_with(ir.Branch(source_terminator.target()))
else:
old_decomp = source_terminator.decomposition()
if is_pure_delay(old_decomp):
if target_time_delta > 0:
new_decomp_expr = ir.Constant(int(target_time_delta), builtins.TInt64())
new_decomp = ir.Builtin("delay_mu", [new_decomp_expr], builtins.TNone())
new_decomp.loc = old_decomp.loc
if old_decomp is not None: source_terminator.basic_block.insert(new_decomp, before=source_terminator)
old_decomp.erase() source_terminator.expr = iodelay.Const(target_time_delta)
source_terminator.set_decomposition(new_decomp)
else:
source_terminator.replace_with(ir.Branch(source_terminator.target()))
old_decomp.erase()
else: # It's a call.
need_to_inline = len(source_blocks) > 1
if need_to_inline:
inline(old_decomp)
postdom_tree = domination.PostDominatorTree(func)
continue
elif target_time_delta > 0:
source_terminator.expr = iodelay.Const(target_time_delta)
else:
source_terminator.replace_with(ir.Branch(source_terminator.target()))
target_block = source_block target_block = source_block
target_time = new_target_time target_time = new_target_time