forked from M-Labs/artiq
transforms.interleaver: unroll loops.
This commit is contained in:
parent
5dd1fc993e
commit
8cb7844621
@ -1 +1,2 @@
|
|||||||
from .inline import inline
|
from .inline import inline
|
||||||
|
from .unroll import unroll
|
||||||
|
97
artiq/compiler/algorithms/unroll.py
Normal file
97
artiq/compiler/algorithms/unroll.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
:func:`unroll` unrolls a loop instruction in ARTIQ IR.
|
||||||
|
The loop's trip count must be constant.
|
||||||
|
The loop body must not have any control flow instructions
|
||||||
|
except for one branch back to the loop head.
|
||||||
|
The loop body must be executed if the condition to which
|
||||||
|
the instruction refers is true.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .. import types, builtins, iodelay, ir
|
||||||
|
from ..analyses import domination
|
||||||
|
|
||||||
|
def _get_body_blocks(root, limit):
|
||||||
|
postorder = []
|
||||||
|
|
||||||
|
visited = set()
|
||||||
|
def visit(block):
|
||||||
|
visited.add(block)
|
||||||
|
for next_block in block.successors():
|
||||||
|
if next_block not in visited and next_block is not limit:
|
||||||
|
visit(next_block)
|
||||||
|
postorder.append(block)
|
||||||
|
|
||||||
|
visit(root)
|
||||||
|
|
||||||
|
postorder.reverse()
|
||||||
|
return postorder
|
||||||
|
|
||||||
|
def unroll(loop_insn):
|
||||||
|
loop_head = loop_insn.basic_block
|
||||||
|
function = loop_head.function
|
||||||
|
assert isinstance(loop_insn, ir.Loop)
|
||||||
|
assert len(loop_head.predecessors()) == 2
|
||||||
|
assert len(loop_insn.if_false().predecessors()) == 1
|
||||||
|
assert iodelay.is_const(loop_insn.trip_count)
|
||||||
|
|
||||||
|
trip_count = loop_insn.trip_count.fold().value
|
||||||
|
if trip_count == 0:
|
||||||
|
loop_insn.replace_with(ir.Branch(loop_insn.if_false()))
|
||||||
|
return
|
||||||
|
|
||||||
|
source_blocks = _get_body_blocks(loop_insn.if_true(), loop_head)
|
||||||
|
source_indvar = loop_insn.induction_variable()
|
||||||
|
source_tail = loop_insn.if_false()
|
||||||
|
unroll_target = loop_head
|
||||||
|
for n in range(trip_count):
|
||||||
|
value_map = {source_indvar: ir.Constant(n, source_indvar.type)}
|
||||||
|
|
||||||
|
for source_block in source_blocks:
|
||||||
|
target_block = ir.BasicBlock([], "u{}.{}".format(n, source_block.name))
|
||||||
|
function.add(target_block)
|
||||||
|
value_map[source_block] = target_block
|
||||||
|
|
||||||
|
def mapper(value):
|
||||||
|
if isinstance(value, ir.Constant):
|
||||||
|
return value
|
||||||
|
elif value in value_map:
|
||||||
|
return value_map[value]
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
for source_block in source_blocks:
|
||||||
|
target_block = value_map[source_block]
|
||||||
|
for source_insn in source_block.instructions:
|
||||||
|
if isinstance(source_insn, ir.Phi):
|
||||||
|
target_insn = ir.Phi()
|
||||||
|
else:
|
||||||
|
target_insn = source_insn.copy(mapper)
|
||||||
|
target_insn.name = "u{}.{}".format(n, source_insn.name)
|
||||||
|
target_block.append(target_insn)
|
||||||
|
value_map[source_insn] = target_insn
|
||||||
|
|
||||||
|
for source_block in source_blocks:
|
||||||
|
for source_insn in source_block.instructions:
|
||||||
|
if isinstance(source_insn, ir.Phi):
|
||||||
|
target_insn = value_map[source_insn]
|
||||||
|
for block, value in source_insn.incoming():
|
||||||
|
target_insn.add_incoming(value_map[value], value_map[block])
|
||||||
|
|
||||||
|
assert isinstance(unroll_target.terminator(), (ir.Branch, ir.Loop))
|
||||||
|
unroll_target.terminator().replace_with(ir.Branch(value_map[source_blocks[0]]))
|
||||||
|
unroll_target = value_map[source_blocks[-1]]
|
||||||
|
|
||||||
|
assert isinstance(unroll_target.terminator(), ir.Branch)
|
||||||
|
assert len(source_blocks[-1].successors()) == 1
|
||||||
|
unroll_target.terminator().replace_with(ir.Branch(source_tail))
|
||||||
|
|
||||||
|
for source_block in reversed(source_blocks):
|
||||||
|
for source_insn in reversed(source_block.instructions):
|
||||||
|
for use in set(source_insn.uses):
|
||||||
|
if isinstance(use, ir.Phi):
|
||||||
|
assert use.basic_block == loop_head
|
||||||
|
use.remove_incoming_value(source_insn)
|
||||||
|
source_insn.erase()
|
||||||
|
|
||||||
|
for source_block in reversed(source_blocks):
|
||||||
|
source_block.erase()
|
@ -1371,17 +1371,21 @@ class Loop(Terminator):
|
|||||||
:param trip_count: (:class:`iodelay.Expr`) expression
|
:param trip_count: (:class:`iodelay.Expr`) expression
|
||||||
:param substs: (dict of str to :class:`Value`)
|
:param substs: (dict of str to :class:`Value`)
|
||||||
SSA values corresponding to iodelay variable names
|
SSA values corresponding to iodelay variable names
|
||||||
|
:param indvar: (:class:`Phi`)
|
||||||
|
phi node corresponding to the induction SSA value,
|
||||||
|
which advances from ``0`` to ``trip_count - 1``
|
||||||
:param cond: (:class:`Value`) branch condition
|
:param cond: (:class:`Value`) branch condition
|
||||||
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
|
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
|
||||||
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
|
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
|
||||||
"""
|
"""
|
||||||
def __init__(self, trip_count, substs, cond, if_true, if_false, name=""):
|
def __init__(self, trip_count, substs, indvar, cond, if_true, if_false, name=""):
|
||||||
for var_name in substs: assert isinstance(var_name, str)
|
for var_name in substs: assert isinstance(var_name, str)
|
||||||
|
assert isinstance(indvar, Phi)
|
||||||
assert isinstance(cond, Value)
|
assert isinstance(cond, Value)
|
||||||
assert builtins.is_bool(cond.type)
|
assert builtins.is_bool(cond.type)
|
||||||
assert isinstance(if_true, BasicBlock)
|
assert isinstance(if_true, BasicBlock)
|
||||||
assert isinstance(if_false, BasicBlock)
|
assert isinstance(if_false, BasicBlock)
|
||||||
super().__init__([cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
|
super().__init__([indvar, cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
|
||||||
self.trip_count = trip_count
|
self.trip_count = trip_count
|
||||||
self.var_names = list(substs.keys())
|
self.var_names = list(substs.keys())
|
||||||
|
|
||||||
@ -1391,17 +1395,20 @@ class Loop(Terminator):
|
|||||||
self_copy.var_names = list(self.var_names)
|
self_copy.var_names = list(self.var_names)
|
||||||
return self_copy
|
return self_copy
|
||||||
|
|
||||||
def condition(self):
|
def induction_variable(self):
|
||||||
return self.operands[0]
|
return self.operands[0]
|
||||||
|
|
||||||
def if_true(self):
|
def condition(self):
|
||||||
return self.operands[1]
|
return self.operands[1]
|
||||||
|
|
||||||
def if_false(self):
|
def if_true(self):
|
||||||
return self.operands[2]
|
return self.operands[2]
|
||||||
|
|
||||||
|
def if_false(self):
|
||||||
|
return self.operands[3]
|
||||||
|
|
||||||
def substs(self):
|
def substs(self):
|
||||||
return {key: value for key, value in zip(self.var_names, self.operands[3:])}
|
return {key: value for key, value in zip(self.var_names, self.operands[4:])}
|
||||||
|
|
||||||
def _operands_as_string(self, type_printer):
|
def _operands_as_string(self, type_printer):
|
||||||
substs = self.substs()
|
substs = self.substs()
|
||||||
@ -1409,8 +1416,8 @@ class Loop(Terminator):
|
|||||||
for var_name in substs:
|
for var_name in substs:
|
||||||
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
|
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
|
||||||
result = "[{}]".format(", ".join(substs_as_strings))
|
result = "[{}]".format(", ".join(substs_as_strings))
|
||||||
result += ", {}, {}, {}".format(*list(map(lambda value: value.as_operand(type_printer),
|
result += ", indvar {}, if {}, {}, {}".format(
|
||||||
self.operands[0:3])))
|
*list(map(lambda value: value.as_operand(type_printer), self.operands[0:4])))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def opcode(self):
|
def opcode(self):
|
||||||
|
@ -73,6 +73,7 @@ class Module:
|
|||||||
dead_code_eliminator.process(self.artiq_ir)
|
dead_code_eliminator.process(self.artiq_ir)
|
||||||
local_access_validator.process(self.artiq_ir)
|
local_access_validator.process(self.artiq_ir)
|
||||||
interleaver.process(self.artiq_ir)
|
interleaver.process(self.artiq_ir)
|
||||||
|
dead_code_eliminator.process(self.artiq_ir)
|
||||||
|
|
||||||
def build_llvm_ir(self, target):
|
def build_llvm_ir(self, target):
|
||||||
"""Compile the module to LLVM IR for the specified target."""
|
"""Compile the module to LLVM IR for the specified target."""
|
||||||
|
@ -525,7 +525,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
if node.trip_count is not None:
|
if node.trip_count is not None:
|
||||||
substs = {var_name: self.current_args[var_name]
|
substs = {var_name: self.current_args[var_name]
|
||||||
for var_name in node.trip_count.free_vars()}
|
for var_name in node.trip_count.free_vars()}
|
||||||
head.append(ir.Loop(node.trip_count, substs, cond, body, else_tail))
|
head.append(ir.Loop(node.trip_count, substs, phi, cond, body, else_tail))
|
||||||
else:
|
else:
|
||||||
head.append(ir.BranchIf(cond, body, else_tail))
|
head.append(ir.BranchIf(cond, body, else_tail))
|
||||||
if not post_body.is_terminated():
|
if not post_body.is_terminated():
|
||||||
|
@ -7,7 +7,7 @@ from pythonparser import diagnostic
|
|||||||
|
|
||||||
from .. import types, builtins, ir, iodelay
|
from .. import types, builtins, ir, iodelay
|
||||||
from ..analyses import domination
|
from ..analyses import domination
|
||||||
from ..algorithms import inline
|
from ..algorithms import inline, unroll
|
||||||
|
|
||||||
def delay_free_subgraph(root, limit):
|
def delay_free_subgraph(root, limit):
|
||||||
visited = set()
|
visited = set()
|
||||||
@ -152,6 +152,11 @@ class Interleaver:
|
|||||||
source_terminator.interval = iodelay.Const(target_time_delta)
|
source_terminator.interval = iodelay.Const(target_time_delta)
|
||||||
else:
|
else:
|
||||||
source_terminator.replace_with(ir.Branch(source_terminator.target()))
|
source_terminator.replace_with(ir.Branch(source_terminator.target()))
|
||||||
|
elif isinstance(source_terminator, ir.Loop):
|
||||||
|
unroll(source_terminator)
|
||||||
|
|
||||||
|
postdom_tree = domination.PostDominatorTree(func)
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
23
lit-test/test/interleaving/unrolling.py
Normal file
23
lit-test/test/interleaving/unrolling.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# RUN: %python -m artiq.compiler.testbench.jit %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
with parallel:
|
||||||
|
for x in range(10):
|
||||||
|
delay_mu(1)
|
||||||
|
print("a", x)
|
||||||
|
with sequential:
|
||||||
|
delay_mu(5)
|
||||||
|
print("c")
|
||||||
|
with sequential:
|
||||||
|
delay_mu(3)
|
||||||
|
print("b")
|
||||||
|
|
||||||
|
# CHECK-L: a 0
|
||||||
|
# CHECK-L: a 1
|
||||||
|
# CHECK-L: a 2
|
||||||
|
# CHECK-L: b
|
||||||
|
# CHECK-L: a 3
|
||||||
|
# CHECK-L: a 4
|
||||||
|
# CHECK-L: c
|
||||||
|
# CHECK-L: a 5
|
||||||
|
# CHECK-L: a 6
|
Loading…
Reference in New Issue
Block a user