forked from M-Labs/artiq
compiler: explicitly represent loops in IR.
This commit is contained in:
parent
33860820b9
commit
f8eaeaa43f
|
@ -46,7 +46,7 @@ class FunctionResolver(algorithm.Visitor):
|
|||
self.visit(node.value)
|
||||
self.visit_in_assign(node.targets)
|
||||
|
||||
def visit_For(self, node):
|
||||
def visit_ForT(self, node):
|
||||
self.visit(node.iter)
|
||||
self.visit_in_assign(node.target)
|
||||
self.visit(node.body)
|
||||
|
|
|
@ -36,6 +36,12 @@ class ExceptHandlerT(ast.ExceptHandler):
|
|||
_fields = ("filter", "name", "body") # rename ast.ExceptHandler.type to filter
|
||||
_types = ("name_type",)
|
||||
|
||||
class ForT(ast.For):
|
||||
"""
|
||||
:ivar trip_count: (:class:`iodelay.Expr`)
|
||||
:ivar trip_interval: (:class:`iodelay.Expr`)
|
||||
"""
|
||||
|
||||
class SliceT(ast.Slice, commontyped):
|
||||
pass
|
||||
|
||||
|
|
|
@ -1296,7 +1296,7 @@ class Delay(Terminator):
|
|||
|
||||
:ivar interval: (:class:`iodelay.Expr`) expression
|
||||
:ivar var_names: (list of string)
|
||||
iodelay variable names corresponding to operands
|
||||
iodelay variable names corresponding to SSA values
|
||||
"""
|
||||
|
||||
"""
|
||||
|
@ -1354,6 +1354,68 @@ class Delay(Terminator):
|
|||
def opcode(self):
|
||||
return "delay({})".format(self.interval)
|
||||
|
||||
class Loop(Terminator):
|
||||
"""
|
||||
A terminator for loop headers that carries metadata useful
|
||||
for unrolling. It includes an :class:`iodelay.Expr` specifying
|
||||
the trip count, tied to SSA values so that inlining could lead
|
||||
to the expression folding to a constant.
|
||||
|
||||
:ivar trip_count: (:class:`iodelay.Expr`)
|
||||
expression for trip count
|
||||
:ivar var_names: (list of string)
|
||||
iodelay variable names corresponding to ``trip_count`` operands
|
||||
"""
|
||||
|
||||
"""
|
||||
:param trip_count: (:class:`iodelay.Expr`) expression
|
||||
:param substs: (dict of str to :class:`Value`)
|
||||
SSA values corresponding to iodelay variable names
|
||||
:param cond: (:class:`Value`) branch condition
|
||||
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
|
||||
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
|
||||
"""
|
||||
def __init__(self, trip_count, substs, cond, if_true, if_false, name=""):
|
||||
for var_name in substs: assert isinstance(var_name, str)
|
||||
assert isinstance(cond, Value)
|
||||
assert builtins.is_bool(cond.type)
|
||||
assert isinstance(if_true, BasicBlock)
|
||||
assert isinstance(if_false, BasicBlock)
|
||||
super().__init__([cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
|
||||
self.trip_count = trip_count
|
||||
self.var_names = list(substs.keys())
|
||||
|
||||
def copy(self, mapper):
|
||||
self_copy = super().copy(mapper)
|
||||
self_copy.trip_count = self.trip_count
|
||||
self_copy.var_names = list(self.var_names)
|
||||
return self_copy
|
||||
|
||||
def condition(self):
|
||||
return self.operands[0]
|
||||
|
||||
def if_true(self):
|
||||
return self.operands[1]
|
||||
|
||||
def if_false(self):
|
||||
return self.operands[2]
|
||||
|
||||
def substs(self):
|
||||
return {key: value for key, value in zip(self.var_names, self.operands[3:])}
|
||||
|
||||
def _operands_as_string(self, type_printer):
|
||||
substs = self.substs()
|
||||
substs_as_strings = []
|
||||
for var_name in substs:
|
||||
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
|
||||
result = "[{}]".format(", ".join(substs_as_strings))
|
||||
result += ", {}, {}, {}".format(*list(map(lambda value: value.as_operand(type_printer),
|
||||
self.operands[0:3])))
|
||||
return result
|
||||
|
||||
def opcode(self):
|
||||
return "loop({} times)".format(self.trip_count)
|
||||
|
||||
class Parallel(Terminator):
|
||||
"""
|
||||
An instruction that schedules several threads of execution
|
||||
|
|
|
@ -2,6 +2,7 @@ import sys, fileinput, os
|
|||
from pythonparser import source, diagnostic, algorithm, parse_buffer
|
||||
from .. import prelude, types
|
||||
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
||||
from ..transforms import IODelayEstimator
|
||||
|
||||
class Printer(algorithm.Visitor):
|
||||
"""
|
||||
|
@ -34,6 +35,14 @@ class Printer(algorithm.Visitor):
|
|||
self.rewriter.insert_after(node.name_loc,
|
||||
":{}".format(self.type_printer.name(node.name_type)))
|
||||
|
||||
def visit_ForT(self, node):
|
||||
super().generic_visit(node)
|
||||
|
||||
if node.trip_count is not None and node.trip_interval is not None:
|
||||
self.rewriter.insert_after(node.keyword_loc,
|
||||
"[{} x {} mu]".format(node.trip_count.fold(),
|
||||
node.trip_interval.fold()))
|
||||
|
||||
def generic_visit(self, node):
|
||||
super().generic_visit(node)
|
||||
|
||||
|
@ -48,6 +57,12 @@ def main():
|
|||
else:
|
||||
monomorphize = False
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "+iodelay":
|
||||
del sys.argv[1]
|
||||
iodelay = True
|
||||
else:
|
||||
iodelay = False
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "+diag":
|
||||
del sys.argv[1]
|
||||
def process_diagnostic(diag):
|
||||
|
@ -71,6 +86,8 @@ def main():
|
|||
if monomorphize:
|
||||
IntMonomorphizer(engine=engine).visit(typed)
|
||||
Inferencer(engine=engine).visit(typed)
|
||||
if iodelay:
|
||||
IODelayEstimator(engine=engine, ref_period=1e6).visit_fixpoint(typed)
|
||||
|
||||
printer = Printer(buf)
|
||||
printer.visit(typed)
|
||||
|
|
|
@ -472,7 +472,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
else:
|
||||
assert False
|
||||
|
||||
def visit_For(self, node):
|
||||
def visit_ForT(self, node):
|
||||
try:
|
||||
iterable = self.visit(node.iter)
|
||||
length = self.iterable_len(iterable)
|
||||
|
@ -522,6 +522,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
else:
|
||||
else_tail = tail
|
||||
|
||||
if node.trip_count is not None:
|
||||
substs = {var_name: self.current_args[var_name]
|
||||
for var_name in node.trip_count.free_vars()}
|
||||
head.append(ir.Loop(node.trip_count, substs, cond, body, else_tail))
|
||||
else:
|
||||
head.append(ir.BranchIf(cond, body, else_tail))
|
||||
if not post_body.is_terminated():
|
||||
post_body.append(ir.Branch(continue_block))
|
||||
|
|
|
@ -471,6 +471,15 @@ class ASTTypedRewriter(algorithm.Transformer):
|
|||
self.engine.process(diag)
|
||||
return node
|
||||
|
||||
def visit_For(self, node):
|
||||
node = self.generic_visit(node)
|
||||
node = asttyped.ForT(
|
||||
target=node.target, iter=node.iter, body=node.body, orelse=node.orelse,
|
||||
trip_count=None, trip_interval=None,
|
||||
keyword_loc=node.keyword_loc, in_loc=node.in_loc, for_colon_loc=node.for_colon_loc,
|
||||
else_loc=node.else_loc, else_colon_loc=node.else_colon_loc)
|
||||
return node
|
||||
|
||||
# Unsupported visitors
|
||||
#
|
||||
def visit_unsupported(self, node):
|
||||
|
|
|
@ -916,7 +916,7 @@ class Inferencer(algorithm.Visitor):
|
|||
|
||||
node.value = self._coerce_one(value_type, node.value, other_node=node.target)
|
||||
|
||||
def visit_For(self, node):
|
||||
def visit_ForT(self, node):
|
||||
old_in_loop, self.in_loop = self.in_loop, True
|
||||
self.generic_visit(node)
|
||||
self.in_loop = old_in_loop
|
||||
|
|
|
@ -167,7 +167,7 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
node.loc)
|
||||
abort([note])
|
||||
|
||||
def visit_For(self, node):
|
||||
def visit_ForT(self, node):
|
||||
self.visit(node.iter)
|
||||
|
||||
old_goto, self.current_goto = self.current_goto, None
|
||||
|
@ -180,8 +180,9 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
self.abort("loop trip count is indeterminate because of control flow",
|
||||
self.current_goto.loc)
|
||||
|
||||
trip_count = self.get_iterable_length(node.iter)
|
||||
self.current_delay = old_delay + self.current_delay * trip_count
|
||||
node.trip_count = self.get_iterable_length(node.iter).fold()
|
||||
node.trip_interval = self.current_delay.fold()
|
||||
self.current_delay = old_delay + node.trip_interval * node.trip_count
|
||||
self.current_goto = old_goto
|
||||
|
||||
self.visit(node.orelse)
|
||||
|
|
Loading…
Reference in New Issue