mirror of
https://github.com/m-labs/artiq.git
synced 2025-01-13 04:18:55 +08:00
transforms.interleaver: unroll loops.
This commit is contained in:
parent
5dd1fc993e
commit
8cb7844621
@ -1 +1,2 @@
|
||||
from .inline import inline
|
||||
from .unroll import unroll
|
||||
|
97
artiq/compiler/algorithms/unroll.py
Normal file
97
artiq/compiler/algorithms/unroll.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""
|
||||
:func:`unroll` unrolls a loop instruction in ARTIQ IR.
|
||||
The loop's trip count must be constant.
|
||||
The loop body must not have any control flow instructions
|
||||
except for one branch back to the loop head.
|
||||
The loop body must be executed if the condition to which
|
||||
the instruction refers is true.
|
||||
"""
|
||||
|
||||
from .. import types, builtins, iodelay, ir
|
||||
from ..analyses import domination
|
||||
|
||||
def _get_body_blocks(root, limit):
|
||||
postorder = []
|
||||
|
||||
visited = set()
|
||||
def visit(block):
|
||||
visited.add(block)
|
||||
for next_block in block.successors():
|
||||
if next_block not in visited and next_block is not limit:
|
||||
visit(next_block)
|
||||
postorder.append(block)
|
||||
|
||||
visit(root)
|
||||
|
||||
postorder.reverse()
|
||||
return postorder
|
||||
|
||||
def unroll(loop_insn):
|
||||
loop_head = loop_insn.basic_block
|
||||
function = loop_head.function
|
||||
assert isinstance(loop_insn, ir.Loop)
|
||||
assert len(loop_head.predecessors()) == 2
|
||||
assert len(loop_insn.if_false().predecessors()) == 1
|
||||
assert iodelay.is_const(loop_insn.trip_count)
|
||||
|
||||
trip_count = loop_insn.trip_count.fold().value
|
||||
if trip_count == 0:
|
||||
loop_insn.replace_with(ir.Branch(loop_insn.if_false()))
|
||||
return
|
||||
|
||||
source_blocks = _get_body_blocks(loop_insn.if_true(), loop_head)
|
||||
source_indvar = loop_insn.induction_variable()
|
||||
source_tail = loop_insn.if_false()
|
||||
unroll_target = loop_head
|
||||
for n in range(trip_count):
|
||||
value_map = {source_indvar: ir.Constant(n, source_indvar.type)}
|
||||
|
||||
for source_block in source_blocks:
|
||||
target_block = ir.BasicBlock([], "u{}.{}".format(n, source_block.name))
|
||||
function.add(target_block)
|
||||
value_map[source_block] = target_block
|
||||
|
||||
def mapper(value):
|
||||
if isinstance(value, ir.Constant):
|
||||
return value
|
||||
elif value in value_map:
|
||||
return value_map[value]
|
||||
else:
|
||||
return value
|
||||
|
||||
for source_block in source_blocks:
|
||||
target_block = value_map[source_block]
|
||||
for source_insn in source_block.instructions:
|
||||
if isinstance(source_insn, ir.Phi):
|
||||
target_insn = ir.Phi()
|
||||
else:
|
||||
target_insn = source_insn.copy(mapper)
|
||||
target_insn.name = "u{}.{}".format(n, source_insn.name)
|
||||
target_block.append(target_insn)
|
||||
value_map[source_insn] = target_insn
|
||||
|
||||
for source_block in source_blocks:
|
||||
for source_insn in source_block.instructions:
|
||||
if isinstance(source_insn, ir.Phi):
|
||||
target_insn = value_map[source_insn]
|
||||
for block, value in source_insn.incoming():
|
||||
target_insn.add_incoming(value_map[value], value_map[block])
|
||||
|
||||
assert isinstance(unroll_target.terminator(), (ir.Branch, ir.Loop))
|
||||
unroll_target.terminator().replace_with(ir.Branch(value_map[source_blocks[0]]))
|
||||
unroll_target = value_map[source_blocks[-1]]
|
||||
|
||||
assert isinstance(unroll_target.terminator(), ir.Branch)
|
||||
assert len(source_blocks[-1].successors()) == 1
|
||||
unroll_target.terminator().replace_with(ir.Branch(source_tail))
|
||||
|
||||
for source_block in reversed(source_blocks):
|
||||
for source_insn in reversed(source_block.instructions):
|
||||
for use in set(source_insn.uses):
|
||||
if isinstance(use, ir.Phi):
|
||||
assert use.basic_block == loop_head
|
||||
use.remove_incoming_value(source_insn)
|
||||
source_insn.erase()
|
||||
|
||||
for source_block in reversed(source_blocks):
|
||||
source_block.erase()
|
@ -1371,17 +1371,21 @@ class Loop(Terminator):
|
||||
:param trip_count: (:class:`iodelay.Expr`) expression
|
||||
:param substs: (dict of str to :class:`Value`)
|
||||
SSA values corresponding to iodelay variable names
|
||||
:param indvar: (:class:`Phi`)
|
||||
phi node corresponding to the induction SSA value,
|
||||
which advances from ``0`` to ``trip_count - 1``
|
||||
:param cond: (:class:`Value`) branch condition
|
||||
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
|
||||
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
|
||||
"""
|
||||
def __init__(self, trip_count, substs, cond, if_true, if_false, name=""):
|
||||
def __init__(self, trip_count, substs, indvar, cond, if_true, if_false, name=""):
|
||||
for var_name in substs: assert isinstance(var_name, str)
|
||||
assert isinstance(indvar, Phi)
|
||||
assert isinstance(cond, Value)
|
||||
assert builtins.is_bool(cond.type)
|
||||
assert isinstance(if_true, BasicBlock)
|
||||
assert isinstance(if_false, BasicBlock)
|
||||
super().__init__([cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
|
||||
super().__init__([indvar, cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
|
||||
self.trip_count = trip_count
|
||||
self.var_names = list(substs.keys())
|
||||
|
||||
@ -1391,17 +1395,20 @@ class Loop(Terminator):
|
||||
self_copy.var_names = list(self.var_names)
|
||||
return self_copy
|
||||
|
||||
def condition(self):
|
||||
def induction_variable(self):
|
||||
return self.operands[0]
|
||||
|
||||
def if_true(self):
|
||||
def condition(self):
|
||||
return self.operands[1]
|
||||
|
||||
def if_false(self):
|
||||
def if_true(self):
|
||||
return self.operands[2]
|
||||
|
||||
def if_false(self):
|
||||
return self.operands[3]
|
||||
|
||||
def substs(self):
|
||||
return {key: value for key, value in zip(self.var_names, self.operands[3:])}
|
||||
return {key: value for key, value in zip(self.var_names, self.operands[4:])}
|
||||
|
||||
def _operands_as_string(self, type_printer):
|
||||
substs = self.substs()
|
||||
@ -1409,8 +1416,8 @@ class Loop(Terminator):
|
||||
for var_name in substs:
|
||||
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
|
||||
result = "[{}]".format(", ".join(substs_as_strings))
|
||||
result += ", {}, {}, {}".format(*list(map(lambda value: value.as_operand(type_printer),
|
||||
self.operands[0:3])))
|
||||
result += ", indvar {}, if {}, {}, {}".format(
|
||||
*list(map(lambda value: value.as_operand(type_printer), self.operands[0:4])))
|
||||
return result
|
||||
|
||||
def opcode(self):
|
||||
|
@ -73,6 +73,7 @@ class Module:
|
||||
dead_code_eliminator.process(self.artiq_ir)
|
||||
local_access_validator.process(self.artiq_ir)
|
||||
interleaver.process(self.artiq_ir)
|
||||
dead_code_eliminator.process(self.artiq_ir)
|
||||
|
||||
def build_llvm_ir(self, target):
|
||||
"""Compile the module to LLVM IR for the specified target."""
|
||||
|
@ -525,7 +525,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
if node.trip_count is not None:
|
||||
substs = {var_name: self.current_args[var_name]
|
||||
for var_name in node.trip_count.free_vars()}
|
||||
head.append(ir.Loop(node.trip_count, substs, cond, body, else_tail))
|
||||
head.append(ir.Loop(node.trip_count, substs, phi, cond, body, else_tail))
|
||||
else:
|
||||
head.append(ir.BranchIf(cond, body, else_tail))
|
||||
if not post_body.is_terminated():
|
||||
|
@ -7,7 +7,7 @@ from pythonparser import diagnostic
|
||||
|
||||
from .. import types, builtins, ir, iodelay
|
||||
from ..analyses import domination
|
||||
from ..algorithms import inline
|
||||
from ..algorithms import inline, unroll
|
||||
|
||||
def delay_free_subgraph(root, limit):
|
||||
visited = set()
|
||||
@ -152,6 +152,11 @@ class Interleaver:
|
||||
source_terminator.interval = iodelay.Const(target_time_delta)
|
||||
else:
|
||||
source_terminator.replace_with(ir.Branch(source_terminator.target()))
|
||||
elif isinstance(source_terminator, ir.Loop):
|
||||
unroll(source_terminator)
|
||||
|
||||
postdom_tree = domination.PostDominatorTree(func)
|
||||
continue
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
23
lit-test/test/interleaving/unrolling.py
Normal file
23
lit-test/test/interleaving/unrolling.py
Normal file
@ -0,0 +1,23 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.jit %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
with parallel:
|
||||
for x in range(10):
|
||||
delay_mu(1)
|
||||
print("a", x)
|
||||
with sequential:
|
||||
delay_mu(5)
|
||||
print("c")
|
||||
with sequential:
|
||||
delay_mu(3)
|
||||
print("b")
|
||||
|
||||
# CHECK-L: a 0
|
||||
# CHECK-L: a 1
|
||||
# CHECK-L: a 2
|
||||
# CHECK-L: b
|
||||
# CHECK-L: a 3
|
||||
# CHECK-L: a 4
|
||||
# CHECK-L: c
|
||||
# CHECK-L: a 5
|
||||
# CHECK-L: a 6
|
Loading…
Reference in New Issue
Block a user