forked from M-Labs/artiq
compiler: explicitly represent loops in IR.
This commit is contained in:
parent
33860820b9
commit
f8eaeaa43f
@ -46,7 +46,7 @@ class FunctionResolver(algorithm.Visitor):
|
|||||||
self.visit(node.value)
|
self.visit(node.value)
|
||||||
self.visit_in_assign(node.targets)
|
self.visit_in_assign(node.targets)
|
||||||
|
|
||||||
def visit_For(self, node):
|
def visit_ForT(self, node):
|
||||||
self.visit(node.iter)
|
self.visit(node.iter)
|
||||||
self.visit_in_assign(node.target)
|
self.visit_in_assign(node.target)
|
||||||
self.visit(node.body)
|
self.visit(node.body)
|
||||||
|
@ -36,6 +36,12 @@ class ExceptHandlerT(ast.ExceptHandler):
|
|||||||
_fields = ("filter", "name", "body") # rename ast.ExceptHandler.type to filter
|
_fields = ("filter", "name", "body") # rename ast.ExceptHandler.type to filter
|
||||||
_types = ("name_type",)
|
_types = ("name_type",)
|
||||||
|
|
||||||
|
class ForT(ast.For):
|
||||||
|
"""
|
||||||
|
:ivar trip_count: (:class:`iodelay.Expr`)
|
||||||
|
:ivar trip_interval: (:class:`iodelay.Expr`)
|
||||||
|
"""
|
||||||
|
|
||||||
class SliceT(ast.Slice, commontyped):
|
class SliceT(ast.Slice, commontyped):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1296,7 +1296,7 @@ class Delay(Terminator):
|
|||||||
|
|
||||||
:ivar interval: (:class:`iodelay.Expr`) expression
|
:ivar interval: (:class:`iodelay.Expr`) expression
|
||||||
:ivar var_names: (list of string)
|
:ivar var_names: (list of string)
|
||||||
iodelay variable names corresponding to operands
|
iodelay variable names corresponding to SSA values
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -1354,6 +1354,68 @@ class Delay(Terminator):
|
|||||||
def opcode(self):
|
def opcode(self):
|
||||||
return "delay({})".format(self.interval)
|
return "delay({})".format(self.interval)
|
||||||
|
|
||||||
|
class Loop(Terminator):
|
||||||
|
"""
|
||||||
|
A terminator for loop headers that carries metadata useful
|
||||||
|
for unrolling. It includes an :class:`iodelay.Expr` specifying
|
||||||
|
the trip count, tied to SSA values so that inlining could lead
|
||||||
|
to the expression folding to a constant.
|
||||||
|
|
||||||
|
:ivar trip_count: (:class:`iodelay.Expr`)
|
||||||
|
expression for trip count
|
||||||
|
:ivar var_names: (list of string)
|
||||||
|
iodelay variable names corresponding to ``trip_count`` operands
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
:param trip_count: (:class:`iodelay.Expr`) expression
|
||||||
|
:param substs: (dict of str to :class:`Value`)
|
||||||
|
SSA values corresponding to iodelay variable names
|
||||||
|
:param cond: (:class:`Value`) branch condition
|
||||||
|
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
|
||||||
|
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
|
||||||
|
"""
|
||||||
|
def __init__(self, trip_count, substs, cond, if_true, if_false, name=""):
|
||||||
|
for var_name in substs: assert isinstance(var_name, str)
|
||||||
|
assert isinstance(cond, Value)
|
||||||
|
assert builtins.is_bool(cond.type)
|
||||||
|
assert isinstance(if_true, BasicBlock)
|
||||||
|
assert isinstance(if_false, BasicBlock)
|
||||||
|
super().__init__([cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
|
||||||
|
self.trip_count = trip_count
|
||||||
|
self.var_names = list(substs.keys())
|
||||||
|
|
||||||
|
def copy(self, mapper):
|
||||||
|
self_copy = super().copy(mapper)
|
||||||
|
self_copy.trip_count = self.trip_count
|
||||||
|
self_copy.var_names = list(self.var_names)
|
||||||
|
return self_copy
|
||||||
|
|
||||||
|
def condition(self):
|
||||||
|
return self.operands[0]
|
||||||
|
|
||||||
|
def if_true(self):
|
||||||
|
return self.operands[1]
|
||||||
|
|
||||||
|
def if_false(self):
|
||||||
|
return self.operands[2]
|
||||||
|
|
||||||
|
def substs(self):
|
||||||
|
return {key: value for key, value in zip(self.var_names, self.operands[3:])}
|
||||||
|
|
||||||
|
def _operands_as_string(self, type_printer):
|
||||||
|
substs = self.substs()
|
||||||
|
substs_as_strings = []
|
||||||
|
for var_name in substs:
|
||||||
|
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
|
||||||
|
result = "[{}]".format(", ".join(substs_as_strings))
|
||||||
|
result += ", {}, {}, {}".format(*list(map(lambda value: value.as_operand(type_printer),
|
||||||
|
self.operands[0:3])))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def opcode(self):
|
||||||
|
return "loop({} times)".format(self.trip_count)
|
||||||
|
|
||||||
class Parallel(Terminator):
|
class Parallel(Terminator):
|
||||||
"""
|
"""
|
||||||
An instruction that schedules several threads of execution
|
An instruction that schedules several threads of execution
|
||||||
|
@ -2,6 +2,7 @@ import sys, fileinput, os
|
|||||||
from pythonparser import source, diagnostic, algorithm, parse_buffer
|
from pythonparser import source, diagnostic, algorithm, parse_buffer
|
||||||
from .. import prelude, types
|
from .. import prelude, types
|
||||||
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
||||||
|
from ..transforms import IODelayEstimator
|
||||||
|
|
||||||
class Printer(algorithm.Visitor):
|
class Printer(algorithm.Visitor):
|
||||||
"""
|
"""
|
||||||
@ -34,6 +35,14 @@ class Printer(algorithm.Visitor):
|
|||||||
self.rewriter.insert_after(node.name_loc,
|
self.rewriter.insert_after(node.name_loc,
|
||||||
":{}".format(self.type_printer.name(node.name_type)))
|
":{}".format(self.type_printer.name(node.name_type)))
|
||||||
|
|
||||||
|
def visit_ForT(self, node):
|
||||||
|
super().generic_visit(node)
|
||||||
|
|
||||||
|
if node.trip_count is not None and node.trip_interval is not None:
|
||||||
|
self.rewriter.insert_after(node.keyword_loc,
|
||||||
|
"[{} x {} mu]".format(node.trip_count.fold(),
|
||||||
|
node.trip_interval.fold()))
|
||||||
|
|
||||||
def generic_visit(self, node):
|
def generic_visit(self, node):
|
||||||
super().generic_visit(node)
|
super().generic_visit(node)
|
||||||
|
|
||||||
@ -48,6 +57,12 @@ def main():
|
|||||||
else:
|
else:
|
||||||
monomorphize = False
|
monomorphize = False
|
||||||
|
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] == "+iodelay":
|
||||||
|
del sys.argv[1]
|
||||||
|
iodelay = True
|
||||||
|
else:
|
||||||
|
iodelay = False
|
||||||
|
|
||||||
if len(sys.argv) > 1 and sys.argv[1] == "+diag":
|
if len(sys.argv) > 1 and sys.argv[1] == "+diag":
|
||||||
del sys.argv[1]
|
del sys.argv[1]
|
||||||
def process_diagnostic(diag):
|
def process_diagnostic(diag):
|
||||||
@ -71,6 +86,8 @@ def main():
|
|||||||
if monomorphize:
|
if monomorphize:
|
||||||
IntMonomorphizer(engine=engine).visit(typed)
|
IntMonomorphizer(engine=engine).visit(typed)
|
||||||
Inferencer(engine=engine).visit(typed)
|
Inferencer(engine=engine).visit(typed)
|
||||||
|
if iodelay:
|
||||||
|
IODelayEstimator(engine=engine, ref_period=1e6).visit_fixpoint(typed)
|
||||||
|
|
||||||
printer = Printer(buf)
|
printer = Printer(buf)
|
||||||
printer.visit(typed)
|
printer.visit(typed)
|
||||||
|
@ -472,7 +472,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
def visit_For(self, node):
|
def visit_ForT(self, node):
|
||||||
try:
|
try:
|
||||||
iterable = self.visit(node.iter)
|
iterable = self.visit(node.iter)
|
||||||
length = self.iterable_len(iterable)
|
length = self.iterable_len(iterable)
|
||||||
@ -522,6 +522,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
else:
|
else:
|
||||||
else_tail = tail
|
else_tail = tail
|
||||||
|
|
||||||
|
if node.trip_count is not None:
|
||||||
|
substs = {var_name: self.current_args[var_name]
|
||||||
|
for var_name in node.trip_count.free_vars()}
|
||||||
|
head.append(ir.Loop(node.trip_count, substs, cond, body, else_tail))
|
||||||
|
else:
|
||||||
head.append(ir.BranchIf(cond, body, else_tail))
|
head.append(ir.BranchIf(cond, body, else_tail))
|
||||||
if not post_body.is_terminated():
|
if not post_body.is_terminated():
|
||||||
post_body.append(ir.Branch(continue_block))
|
post_body.append(ir.Branch(continue_block))
|
||||||
|
@ -471,6 +471,15 @@ class ASTTypedRewriter(algorithm.Transformer):
|
|||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def visit_For(self, node):
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
node = asttyped.ForT(
|
||||||
|
target=node.target, iter=node.iter, body=node.body, orelse=node.orelse,
|
||||||
|
trip_count=None, trip_interval=None,
|
||||||
|
keyword_loc=node.keyword_loc, in_loc=node.in_loc, for_colon_loc=node.for_colon_loc,
|
||||||
|
else_loc=node.else_loc, else_colon_loc=node.else_colon_loc)
|
||||||
|
return node
|
||||||
|
|
||||||
# Unsupported visitors
|
# Unsupported visitors
|
||||||
#
|
#
|
||||||
def visit_unsupported(self, node):
|
def visit_unsupported(self, node):
|
||||||
|
@ -916,7 +916,7 @@ class Inferencer(algorithm.Visitor):
|
|||||||
|
|
||||||
node.value = self._coerce_one(value_type, node.value, other_node=node.target)
|
node.value = self._coerce_one(value_type, node.value, other_node=node.target)
|
||||||
|
|
||||||
def visit_For(self, node):
|
def visit_ForT(self, node):
|
||||||
old_in_loop, self.in_loop = self.in_loop, True
|
old_in_loop, self.in_loop = self.in_loop, True
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
self.in_loop = old_in_loop
|
self.in_loop = old_in_loop
|
||||||
|
@ -167,7 +167,7 @@ class IODelayEstimator(algorithm.Visitor):
|
|||||||
node.loc)
|
node.loc)
|
||||||
abort([note])
|
abort([note])
|
||||||
|
|
||||||
def visit_For(self, node):
|
def visit_ForT(self, node):
|
||||||
self.visit(node.iter)
|
self.visit(node.iter)
|
||||||
|
|
||||||
old_goto, self.current_goto = self.current_goto, None
|
old_goto, self.current_goto = self.current_goto, None
|
||||||
@ -180,8 +180,9 @@ class IODelayEstimator(algorithm.Visitor):
|
|||||||
self.abort("loop trip count is indeterminate because of control flow",
|
self.abort("loop trip count is indeterminate because of control flow",
|
||||||
self.current_goto.loc)
|
self.current_goto.loc)
|
||||||
|
|
||||||
trip_count = self.get_iterable_length(node.iter)
|
node.trip_count = self.get_iterable_length(node.iter).fold()
|
||||||
self.current_delay = old_delay + self.current_delay * trip_count
|
node.trip_interval = self.current_delay.fold()
|
||||||
|
self.current_delay = old_delay + node.trip_interval * node.trip_count
|
||||||
self.current_goto = old_goto
|
self.current_goto = old_goto
|
||||||
|
|
||||||
self.visit(node.orelse)
|
self.visit(node.orelse)
|
||||||
|
Loading…
Reference in New Issue
Block a user