From f8eaeaa43f9d87c8e573fca73ecc0c95ca889f73 Mon Sep 17 00:00:00 2001 From: whitequark Date: Wed, 16 Dec 2015 15:33:15 +0800 Subject: [PATCH] compiler: explicitly represent loops in IR. --- artiq/compiler/analyses/devirtualization.py | 2 +- artiq/compiler/asttyped.py | 6 ++ artiq/compiler/ir.py | 64 ++++++++++++++++++- artiq/compiler/testbench/inferencer.py | 19 +++++- .../compiler/transforms/artiq_ir_generator.py | 9 ++- .../compiler/transforms/asttyped_rewriter.py | 9 +++ artiq/compiler/transforms/inferencer.py | 2 +- .../compiler/transforms/iodelay_estimator.py | 7 +- 8 files changed, 109 insertions(+), 9 deletions(-) diff --git a/artiq/compiler/analyses/devirtualization.py b/artiq/compiler/analyses/devirtualization.py index 14f93c882..0fef44ddf 100644 --- a/artiq/compiler/analyses/devirtualization.py +++ b/artiq/compiler/analyses/devirtualization.py @@ -46,7 +46,7 @@ class FunctionResolver(algorithm.Visitor): self.visit(node.value) self.visit_in_assign(node.targets) - def visit_For(self, node): + def visit_ForT(self, node): self.visit(node.iter) self.visit_in_assign(node.target) self.visit(node.body) diff --git a/artiq/compiler/asttyped.py b/artiq/compiler/asttyped.py index 6d7908d8d..26df00ad4 100644 --- a/artiq/compiler/asttyped.py +++ b/artiq/compiler/asttyped.py @@ -36,6 +36,12 @@ class ExceptHandlerT(ast.ExceptHandler): _fields = ("filter", "name", "body") # rename ast.ExceptHandler.type to filter _types = ("name_type",) +class ForT(ast.For): + """ + :ivar trip_count: (:class:`iodelay.Expr`) + :ivar trip_interval: (:class:`iodelay.Expr`) + """ + class SliceT(ast.Slice, commontyped): pass diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 5f20164d3..6ef6dbf35 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -1296,7 +1296,7 @@ class Delay(Terminator): :ivar interval: (:class:`iodelay.Expr`) expression :ivar var_names: (list of string) - iodelay variable names corresponding to operands + iodelay variable names corresponding to SSA values """ """ @@ -1354,6 +1354,68 @@ class Delay(Terminator): def opcode(self): return "delay({})".format(self.interval) +class Loop(Terminator): + """ + A terminator for loop headers that carries metadata useful + for unrolling. It includes an :class:`iodelay.Expr` specifying + the trip count, tied to SSA values so that inlining could lead + to the expression folding to a constant. + + :ivar trip_count: (:class:`iodelay.Expr`) + expression for trip count + :ivar var_names: (list of string) + iodelay variable names corresponding to ``trip_count`` operands + """ + + """ + :param trip_count: (:class:`iodelay.Expr`) expression + :param substs: (dict of str to :class:`Value`) + SSA values corresponding to iodelay variable names + :param cond: (:class:`Value`) branch condition + :param if_true: (:class:`BasicBlock`) branch target if condition is truthful + :param if_false: (:class:`BasicBlock`) branch target if condition is falseful + """ + def __init__(self, trip_count, substs, cond, if_true, if_false, name=""): + for var_name in substs: assert isinstance(var_name, str) + assert isinstance(cond, Value) + assert builtins.is_bool(cond.type) + assert isinstance(if_true, BasicBlock) + assert isinstance(if_false, BasicBlock) + super().__init__([cond, if_true, if_false, *substs.values()], builtins.TNone(), name) + self.trip_count = trip_count + self.var_names = list(substs.keys()) + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.trip_count = self.trip_count + self_copy.var_names = list(self.var_names) + return self_copy + + def condition(self): + return self.operands[0] + + def if_true(self): + return self.operands[1] + + def if_false(self): + return self.operands[2] + + def substs(self): + return {key: value for key, value in zip(self.var_names, self.operands[3:])} + + def _operands_as_string(self, type_printer): + substs = self.substs() + substs_as_strings = [] + for var_name in substs: + substs_as_strings.append("{} = {}".format(var_name, substs[var_name])) + result = "[{}]".format(", ".join(substs_as_strings)) + result += ", {}, {}, {}".format(*list(map(lambda value: value.as_operand(type_printer), + self.operands[0:3]))) + return result + + def opcode(self): + return "loop({} times)".format(self.trip_count) + class Parallel(Terminator): """ An instruction that schedules several threads of execution diff --git a/artiq/compiler/testbench/inferencer.py b/artiq/compiler/testbench/inferencer.py index 174baf8f4..4179cc777 100644 --- a/artiq/compiler/testbench/inferencer.py +++ b/artiq/compiler/testbench/inferencer.py @@ -2,6 +2,7 @@ import sys, fileinput, os from pythonparser import source, diagnostic, algorithm, parse_buffer from .. import prelude, types from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer +from ..transforms import IODelayEstimator class Printer(algorithm.Visitor): """ @@ -32,7 +33,15 @@ class Printer(algorithm.Visitor): if node.name_loc: self.rewriter.insert_after(node.name_loc, - ":{}".format(self.type_printer.name(node.name_type))) + ":{}".format(self.type_printer.name(node.name_type))) + + def visit_ForT(self, node): + super().generic_visit(node) + + if node.trip_count is not None and node.trip_interval is not None: + self.rewriter.insert_after(node.keyword_loc, + "[{} x {} mu]".format(node.trip_count.fold(), + node.trip_interval.fold())) def generic_visit(self, node): super().generic_visit(node) @@ -48,6 +57,12 @@ def main(): else: monomorphize = False + if len(sys.argv) > 1 and sys.argv[1] == "+iodelay": + del sys.argv[1] + iodelay = True + else: + iodelay = False + if len(sys.argv) > 1 and sys.argv[1] == "+diag": del sys.argv[1] def process_diagnostic(diag): @@ -71,6 +86,8 @@ def main(): if monomorphize: IntMonomorphizer(engine=engine).visit(typed) Inferencer(engine=engine).visit(typed) + if iodelay: + IODelayEstimator(engine=engine, ref_period=1e6).visit_fixpoint(typed) printer = Printer(buf) printer.visit(typed) diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 55f6bfadb..de5ad63cb 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -472,7 +472,7 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False - def visit_For(self, node): + def visit_ForT(self, node): try: iterable = self.visit(node.iter) length = self.iterable_len(iterable) @@ -522,7 +522,12 @@ class ARTIQIRGenerator(algorithm.Visitor): else: else_tail = tail - head.append(ir.BranchIf(cond, body, else_tail)) + if node.trip_count is not None: + substs = {var_name: self.current_args[var_name] + for var_name in node.trip_count.free_vars()} + head.append(ir.Loop(node.trip_count, substs, cond, body, else_tail)) + else: + head.append(ir.BranchIf(cond, body, else_tail)) if not post_body.is_terminated(): post_body.append(ir.Branch(continue_block)) break_block.append(ir.Branch(tail)) diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index e8f811423..b7fa6f7a1 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -471,6 +471,15 @@ class ASTTypedRewriter(algorithm.Transformer): self.engine.process(diag) return node + def visit_For(self, node): + node = self.generic_visit(node) + node = asttyped.ForT( + target=node.target, iter=node.iter, body=node.body, orelse=node.orelse, + trip_count=None, trip_interval=None, + keyword_loc=node.keyword_loc, in_loc=node.in_loc, for_colon_loc=node.for_colon_loc, + else_loc=node.else_loc, else_colon_loc=node.else_colon_loc) + return node + # Unsupported visitors # def visit_unsupported(self, node): diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index b9bb7554e..21bc7c62a 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -916,7 +916,7 @@ class Inferencer(algorithm.Visitor): node.value = self._coerce_one(value_type, node.value, other_node=node.target) - def visit_For(self, node): + def visit_ForT(self, node): old_in_loop, self.in_loop = self.in_loop, True self.generic_visit(node) self.in_loop = old_in_loop diff --git a/artiq/compiler/transforms/iodelay_estimator.py b/artiq/compiler/transforms/iodelay_estimator.py index 4170094c0..378933029 100644 --- a/artiq/compiler/transforms/iodelay_estimator.py +++ b/artiq/compiler/transforms/iodelay_estimator.py @@ -167,7 +167,7 @@ class IODelayEstimator(algorithm.Visitor): node.loc) abort([note]) - def visit_For(self, node): + def visit_ForT(self, node): self.visit(node.iter) old_goto, self.current_goto = self.current_goto, None @@ -180,8 +180,9 @@ class IODelayEstimator(algorithm.Visitor): self.abort("loop trip count is indeterminate because of control flow", self.current_goto.loc) - trip_count = self.get_iterable_length(node.iter) - self.current_delay = old_delay + self.current_delay * trip_count + node.trip_count = self.get_iterable_length(node.iter).fold() + node.trip_interval = self.current_delay.fold() + self.current_delay = old_delay + node.trip_interval * node.trip_count self.current_goto = old_goto self.visit(node.orelse)