From 5fe3cffc848651256018965da2eae2490981f956 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Thu, 18 Sep 2014 09:53:08 +0800 Subject: [PATCH] transforms/unroll_loop: do not unroll breakable loops --- artiq/transforms/unroll_loops.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/artiq/transforms/unroll_loops.py b/artiq/transforms/unroll_loops.py index d000429ee..f289663b0 100644 --- a/artiq/transforms/unroll_loops.py +++ b/artiq/transforms/unroll_loops.py @@ -14,6 +14,24 @@ def _count_stmts(node): return 1 +def _loop_breakable(node): + if isinstance(node, ast.Break): + return 1 + elif isinstance(node, ast.Return): + return 2 + elif isinstance(node, list): + return max(map(_loop_breakable, node), default=0) + elif isinstance(node, ast.If): + return max(_loop_breakable(node.body), _loop_breakable(node.orelse)) + elif isinstance(node, (ast.For, ast.While)): + bb = _loop_breakable(node.body) + if bb == 1: + bb = 0 + return max(bb, _loop_breakable(node.orelse)) + else: + return 0 + + class _LoopUnroller(ast.NodeTransformer): def __init__(self, limit): self.limit = limit @@ -26,8 +44,8 @@ class _LoopUnroller(ast.NodeTransformer): return node l_it = len(it) if l_it: - n = l_it*_count_stmts(node.body) - if n < self.limit: + if (not _loop_breakable(node.body) + and l_it*_count_stmts(node.body) < self.limit): replacement = [] for i in it: if not isinstance(i, int):