diff --git a/artiq/transforms/unroll_loops.py b/artiq/transforms/unroll_loops.py index f289663b0..b362bd907 100644 --- a/artiq/transforms/unroll_loops.py +++ b/artiq/transforms/unroll_loops.py @@ -15,21 +15,25 @@ def _count_stmts(node): def _loop_breakable(node): - if isinstance(node, ast.Break): - return 1 - elif isinstance(node, ast.Return): - return 2 - elif isinstance(node, list): - return max(map(_loop_breakable, node), default=0) + if isinstance(node, list): + return any(map(_loop_breakable, node)) + elif isinstance(node, (ast.Break, ast.Continue)): + return True + elif isinstance(node, ast.With): + return _loop_breakable(node.body) elif isinstance(node, ast.If): - return max(_loop_breakable(node.body), _loop_breakable(node.orelse)) - elif isinstance(node, (ast.For, ast.While)): - bb = _loop_breakable(node.body) - if bb == 1: - bb = 0 - return max(bb, _loop_breakable(node.orelse)) + return _loop_breakable(node.body) or _loop_breakable(node.orelse) + elif isinstance(node, ast.Try): + if (_loop_breakable(node.body) + or _loop_breakable(node.orelse) + or _loop_breakable(node.finalbody)): + return True + for handler in node.handlers: + if _loop_breakable(handler.body): + return True + return False else: - return 0 + return False class _LoopUnroller(ast.NodeTransformer):