transforms/unroll_loop: do not unroll breakable loops

This commit is contained in:
Sebastien Bourdeauducq 2014-09-18 09:53:08 +08:00
parent d52d641dcd
commit 5fe3cffc84

View File

@ -14,6 +14,24 @@ def _count_stmts(node):
return 1 return 1
def _loop_breakable(node):
if isinstance(node, ast.Break):
return 1
elif isinstance(node, ast.Return):
return 2
elif isinstance(node, list):
return max(map(_loop_breakable, node), default=0)
elif isinstance(node, ast.If):
return max(_loop_breakable(node.body), _loop_breakable(node.orelse))
elif isinstance(node, (ast.For, ast.While)):
bb = _loop_breakable(node.body)
if bb == 1:
bb = 0
return max(bb, _loop_breakable(node.orelse))
else:
return 0
class _LoopUnroller(ast.NodeTransformer): class _LoopUnroller(ast.NodeTransformer):
def __init__(self, limit): def __init__(self, limit):
self.limit = limit self.limit = limit
@ -26,8 +44,8 @@ class _LoopUnroller(ast.NodeTransformer):
return node return node
l_it = len(it) l_it = len(it)
if l_it: if l_it:
n = l_it*_count_stmts(node.body) if (not _loop_breakable(node.body)
if n < self.limit: and l_it*_count_stmts(node.body) < self.limit):
replacement = [] replacement = []
for i in it: for i in it:
if not isinstance(i, int): if not isinstance(i, int):