compiler: explicitly represent loops in IR.

This commit is contained in:
whitequark 2015-12-16 15:33:15 +08:00
parent 33860820b9
commit f8eaeaa43f
8 changed files with 109 additions and 9 deletions

View File

@ -46,7 +46,7 @@ class FunctionResolver(algorithm.Visitor):
self.visit(node.value)
self.visit_in_assign(node.targets)
def visit_For(self, node):
def visit_ForT(self, node):
self.visit(node.iter)
self.visit_in_assign(node.target)
self.visit(node.body)

View File

@ -36,6 +36,12 @@ class ExceptHandlerT(ast.ExceptHandler):
_fields = ("filter", "name", "body") # rename ast.ExceptHandler.type to filter
_types = ("name_type",)
class ForT(ast.For):
"""
:ivar trip_count: (:class:`iodelay.Expr`)
:ivar trip_interval: (:class:`iodelay.Expr`)
"""
class SliceT(ast.Slice, commontyped):
pass

View File

@ -1296,7 +1296,7 @@ class Delay(Terminator):
:ivar interval: (:class:`iodelay.Expr`) expression
:ivar var_names: (list of string)
iodelay variable names corresponding to operands
iodelay variable names corresponding to SSA values
"""
"""
@ -1354,6 +1354,68 @@ class Delay(Terminator):
def opcode(self):
return "delay({})".format(self.interval)
class Loop(Terminator):
"""
A terminator for loop headers that carries metadata useful
for unrolling. It includes an :class:`iodelay.Expr` specifying
the trip count, tied to SSA values so that inlining could lead
to the expression folding to a constant.
:ivar trip_count: (:class:`iodelay.Expr`)
expression for trip count
:ivar var_names: (list of string)
iodelay variable names corresponding to ``trip_count`` operands
"""
"""
:param trip_count: (:class:`iodelay.Expr`) expression
:param substs: (dict of str to :class:`Value`)
SSA values corresponding to iodelay variable names
:param cond: (:class:`Value`) branch condition
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
"""
def __init__(self, trip_count, substs, cond, if_true, if_false, name=""):
for var_name in substs: assert isinstance(var_name, str)
assert isinstance(cond, Value)
assert builtins.is_bool(cond.type)
assert isinstance(if_true, BasicBlock)
assert isinstance(if_false, BasicBlock)
super().__init__([cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
self.trip_count = trip_count
self.var_names = list(substs.keys())
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.trip_count = self.trip_count
self_copy.var_names = list(self.var_names)
return self_copy
def condition(self):
return self.operands[0]
def if_true(self):
return self.operands[1]
def if_false(self):
return self.operands[2]
def substs(self):
return {key: value for key, value in zip(self.var_names, self.operands[3:])}
def _operands_as_string(self, type_printer):
substs = self.substs()
substs_as_strings = []
for var_name in substs:
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
result = "[{}]".format(", ".join(substs_as_strings))
result += ", {}, {}, {}".format(*list(map(lambda value: value.as_operand(type_printer),
self.operands[0:3])))
return result
def opcode(self):
return "loop({} times)".format(self.trip_count)
class Parallel(Terminator):
"""
An instruction that schedules several threads of execution

View File

@ -2,6 +2,7 @@ import sys, fileinput, os
from pythonparser import source, diagnostic, algorithm, parse_buffer
from .. import prelude, types
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
from ..transforms import IODelayEstimator
class Printer(algorithm.Visitor):
"""
@ -34,6 +35,14 @@ class Printer(algorithm.Visitor):
self.rewriter.insert_after(node.name_loc,
":{}".format(self.type_printer.name(node.name_type)))
def visit_ForT(self, node):
super().generic_visit(node)
if node.trip_count is not None and node.trip_interval is not None:
self.rewriter.insert_after(node.keyword_loc,
"[{} x {} mu]".format(node.trip_count.fold(),
node.trip_interval.fold()))
def generic_visit(self, node):
super().generic_visit(node)
@ -48,6 +57,12 @@ def main():
else:
monomorphize = False
if len(sys.argv) > 1 and sys.argv[1] == "+iodelay":
del sys.argv[1]
iodelay = True
else:
iodelay = False
if len(sys.argv) > 1 and sys.argv[1] == "+diag":
del sys.argv[1]
def process_diagnostic(diag):
@ -71,6 +86,8 @@ def main():
if monomorphize:
IntMonomorphizer(engine=engine).visit(typed)
Inferencer(engine=engine).visit(typed)
if iodelay:
IODelayEstimator(engine=engine, ref_period=1e6).visit_fixpoint(typed)
printer = Printer(buf)
printer.visit(typed)

View File

@ -472,7 +472,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else:
assert False
def visit_For(self, node):
def visit_ForT(self, node):
try:
iterable = self.visit(node.iter)
length = self.iterable_len(iterable)
@ -522,6 +522,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
else:
else_tail = tail
if node.trip_count is not None:
substs = {var_name: self.current_args[var_name]
for var_name in node.trip_count.free_vars()}
head.append(ir.Loop(node.trip_count, substs, cond, body, else_tail))
else:
head.append(ir.BranchIf(cond, body, else_tail))
if not post_body.is_terminated():
post_body.append(ir.Branch(continue_block))

View File

@ -471,6 +471,15 @@ class ASTTypedRewriter(algorithm.Transformer):
self.engine.process(diag)
return node
def visit_For(self, node):
node = self.generic_visit(node)
node = asttyped.ForT(
target=node.target, iter=node.iter, body=node.body, orelse=node.orelse,
trip_count=None, trip_interval=None,
keyword_loc=node.keyword_loc, in_loc=node.in_loc, for_colon_loc=node.for_colon_loc,
else_loc=node.else_loc, else_colon_loc=node.else_colon_loc)
return node
# Unsupported visitors
#
def visit_unsupported(self, node):

View File

@ -916,7 +916,7 @@ class Inferencer(algorithm.Visitor):
node.value = self._coerce_one(value_type, node.value, other_node=node.target)
def visit_For(self, node):
def visit_ForT(self, node):
old_in_loop, self.in_loop = self.in_loop, True
self.generic_visit(node)
self.in_loop = old_in_loop

View File

@ -167,7 +167,7 @@ class IODelayEstimator(algorithm.Visitor):
node.loc)
abort([note])
def visit_For(self, node):
def visit_ForT(self, node):
self.visit(node.iter)
old_goto, self.current_goto = self.current_goto, None
@ -180,8 +180,9 @@ class IODelayEstimator(algorithm.Visitor):
self.abort("loop trip count is indeterminate because of control flow",
self.current_goto.loc)
trip_count = self.get_iterable_length(node.iter)
self.current_delay = old_delay + self.current_delay * trip_count
node.trip_count = self.get_iterable_length(node.iter).fold()
node.trip_interval = self.current_delay.fold()
self.current_delay = old_delay + node.trip_interval * node.trip_count
self.current_goto = old_goto
self.visit(node.orelse)