forked from M-Labs/artiq
1
0
Fork 0

transforms.interleaver: unroll loops.

This commit is contained in:
whitequark 2015-12-17 00:52:03 +08:00
parent 5dd1fc993e
commit 8cb7844621
7 changed files with 144 additions and 10 deletions

View File

@ -1 +1,2 @@
from .inline import inline from .inline import inline
from .unroll import unroll

View File

@ -0,0 +1,97 @@
"""
:func:`unroll` unrolls a loop instruction in ARTIQ IR.
The loop's trip count must be constant.
The loop body must not have any control flow instructions
except for one branch back to the loop head.
The loop body must be executed if the condition to which
the instruction refers is true.
"""
from .. import types, builtins, iodelay, ir
from ..analyses import domination
def _get_body_blocks(root, limit):
postorder = []
visited = set()
def visit(block):
visited.add(block)
for next_block in block.successors():
if next_block not in visited and next_block is not limit:
visit(next_block)
postorder.append(block)
visit(root)
postorder.reverse()
return postorder
def unroll(loop_insn):
loop_head = loop_insn.basic_block
function = loop_head.function
assert isinstance(loop_insn, ir.Loop)
assert len(loop_head.predecessors()) == 2
assert len(loop_insn.if_false().predecessors()) == 1
assert iodelay.is_const(loop_insn.trip_count)
trip_count = loop_insn.trip_count.fold().value
if trip_count == 0:
loop_insn.replace_with(ir.Branch(loop_insn.if_false()))
return
source_blocks = _get_body_blocks(loop_insn.if_true(), loop_head)
source_indvar = loop_insn.induction_variable()
source_tail = loop_insn.if_false()
unroll_target = loop_head
for n in range(trip_count):
value_map = {source_indvar: ir.Constant(n, source_indvar.type)}
for source_block in source_blocks:
target_block = ir.BasicBlock([], "u{}.{}".format(n, source_block.name))
function.add(target_block)
value_map[source_block] = target_block
def mapper(value):
if isinstance(value, ir.Constant):
return value
elif value in value_map:
return value_map[value]
else:
return value
for source_block in source_blocks:
target_block = value_map[source_block]
for source_insn in source_block.instructions:
if isinstance(source_insn, ir.Phi):
target_insn = ir.Phi()
else:
target_insn = source_insn.copy(mapper)
target_insn.name = "u{}.{}".format(n, source_insn.name)
target_block.append(target_insn)
value_map[source_insn] = target_insn
for source_block in source_blocks:
for source_insn in source_block.instructions:
if isinstance(source_insn, ir.Phi):
target_insn = value_map[source_insn]
for block, value in source_insn.incoming():
target_insn.add_incoming(value_map[value], value_map[block])
assert isinstance(unroll_target.terminator(), (ir.Branch, ir.Loop))
unroll_target.terminator().replace_with(ir.Branch(value_map[source_blocks[0]]))
unroll_target = value_map[source_blocks[-1]]
assert isinstance(unroll_target.terminator(), ir.Branch)
assert len(source_blocks[-1].successors()) == 1
unroll_target.terminator().replace_with(ir.Branch(source_tail))
for source_block in reversed(source_blocks):
for source_insn in reversed(source_block.instructions):
for use in set(source_insn.uses):
if isinstance(use, ir.Phi):
assert use.basic_block == loop_head
use.remove_incoming_value(source_insn)
source_insn.erase()
for source_block in reversed(source_blocks):
source_block.erase()

View File

@ -1371,17 +1371,21 @@ class Loop(Terminator):
:param trip_count: (:class:`iodelay.Expr`) expression :param trip_count: (:class:`iodelay.Expr`) expression
:param substs: (dict of str to :class:`Value`) :param substs: (dict of str to :class:`Value`)
SSA values corresponding to iodelay variable names SSA values corresponding to iodelay variable names
:param indvar: (:class:`Phi`)
phi node corresponding to the induction SSA value,
which advances from ``0`` to ``trip_count - 1``
:param cond: (:class:`Value`) branch condition :param cond: (:class:`Value`) branch condition
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful :param if_true: (:class:`BasicBlock`) branch target if condition is truthful
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful :param if_false: (:class:`BasicBlock`) branch target if condition is falseful
""" """
def __init__(self, trip_count, substs, cond, if_true, if_false, name=""): def __init__(self, trip_count, substs, indvar, cond, if_true, if_false, name=""):
for var_name in substs: assert isinstance(var_name, str) for var_name in substs: assert isinstance(var_name, str)
assert isinstance(indvar, Phi)
assert isinstance(cond, Value) assert isinstance(cond, Value)
assert builtins.is_bool(cond.type) assert builtins.is_bool(cond.type)
assert isinstance(if_true, BasicBlock) assert isinstance(if_true, BasicBlock)
assert isinstance(if_false, BasicBlock) assert isinstance(if_false, BasicBlock)
super().__init__([cond, if_true, if_false, *substs.values()], builtins.TNone(), name) super().__init__([indvar, cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
self.trip_count = trip_count self.trip_count = trip_count
self.var_names = list(substs.keys()) self.var_names = list(substs.keys())
@ -1391,17 +1395,20 @@ class Loop(Terminator):
self_copy.var_names = list(self.var_names) self_copy.var_names = list(self.var_names)
return self_copy return self_copy
def condition(self): def induction_variable(self):
return self.operands[0] return self.operands[0]
def if_true(self): def condition(self):
return self.operands[1] return self.operands[1]
def if_false(self): def if_true(self):
return self.operands[2] return self.operands[2]
def if_false(self):
return self.operands[3]
def substs(self): def substs(self):
return {key: value for key, value in zip(self.var_names, self.operands[3:])} return {key: value for key, value in zip(self.var_names, self.operands[4:])}
def _operands_as_string(self, type_printer): def _operands_as_string(self, type_printer):
substs = self.substs() substs = self.substs()
@ -1409,8 +1416,8 @@ class Loop(Terminator):
for var_name in substs: for var_name in substs:
substs_as_strings.append("{} = {}".format(var_name, substs[var_name])) substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
result = "[{}]".format(", ".join(substs_as_strings)) result = "[{}]".format(", ".join(substs_as_strings))
result += ", {}, {}, {}".format(*list(map(lambda value: value.as_operand(type_printer), result += ", indvar {}, if {}, {}, {}".format(
self.operands[0:3]))) *list(map(lambda value: value.as_operand(type_printer), self.operands[0:4])))
return result return result
def opcode(self): def opcode(self):

View File

@ -73,6 +73,7 @@ class Module:
dead_code_eliminator.process(self.artiq_ir) dead_code_eliminator.process(self.artiq_ir)
local_access_validator.process(self.artiq_ir) local_access_validator.process(self.artiq_ir)
interleaver.process(self.artiq_ir) interleaver.process(self.artiq_ir)
dead_code_eliminator.process(self.artiq_ir)
def build_llvm_ir(self, target): def build_llvm_ir(self, target):
"""Compile the module to LLVM IR for the specified target.""" """Compile the module to LLVM IR for the specified target."""

View File

@ -525,7 +525,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
if node.trip_count is not None: if node.trip_count is not None:
substs = {var_name: self.current_args[var_name] substs = {var_name: self.current_args[var_name]
for var_name in node.trip_count.free_vars()} for var_name in node.trip_count.free_vars()}
head.append(ir.Loop(node.trip_count, substs, cond, body, else_tail)) head.append(ir.Loop(node.trip_count, substs, phi, cond, body, else_tail))
else: else:
head.append(ir.BranchIf(cond, body, else_tail)) head.append(ir.BranchIf(cond, body, else_tail))
if not post_body.is_terminated(): if not post_body.is_terminated():

View File

@ -7,7 +7,7 @@ from pythonparser import diagnostic
from .. import types, builtins, ir, iodelay from .. import types, builtins, ir, iodelay
from ..analyses import domination from ..analyses import domination
from ..algorithms import inline from ..algorithms import inline, unroll
def delay_free_subgraph(root, limit): def delay_free_subgraph(root, limit):
visited = set() visited = set()
@ -152,6 +152,11 @@ class Interleaver:
source_terminator.interval = iodelay.Const(target_time_delta) source_terminator.interval = iodelay.Const(target_time_delta)
else: else:
source_terminator.replace_with(ir.Branch(source_terminator.target())) source_terminator.replace_with(ir.Branch(source_terminator.target()))
elif isinstance(source_terminator, ir.Loop):
unroll(source_terminator)
postdom_tree = domination.PostDominatorTree(func)
continue
else: else:
assert False assert False

View File

@ -0,0 +1,23 @@
# RUN: %python -m artiq.compiler.testbench.jit %s >%t
# RUN: OutputCheck %s --file-to-check=%t
with parallel:
for x in range(10):
delay_mu(1)
print("a", x)
with sequential:
delay_mu(5)
print("c")
with sequential:
delay_mu(3)
print("b")
# CHECK-L: a 0
# CHECK-L: a 1
# CHECK-L: a 2
# CHECK-L: b
# CHECK-L: a 3
# CHECK-L: a 4
# CHECK-L: c
# CHECK-L: a 5
# CHECK-L: a 6