forked from M-Labs/artiq
transforms/unroll_loop: do not unroll breakable loops
This commit is contained in:
parent
d52d641dcd
commit
5fe3cffc84
|
@ -14,6 +14,24 @@ def _count_stmts(node):
|
|||
return 1
|
||||
|
||||
|
||||
def _loop_breakable(node):
|
||||
if isinstance(node, ast.Break):
|
||||
return 1
|
||||
elif isinstance(node, ast.Return):
|
||||
return 2
|
||||
elif isinstance(node, list):
|
||||
return max(map(_loop_breakable, node), default=0)
|
||||
elif isinstance(node, ast.If):
|
||||
return max(_loop_breakable(node.body), _loop_breakable(node.orelse))
|
||||
elif isinstance(node, (ast.For, ast.While)):
|
||||
bb = _loop_breakable(node.body)
|
||||
if bb == 1:
|
||||
bb = 0
|
||||
return max(bb, _loop_breakable(node.orelse))
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
class _LoopUnroller(ast.NodeTransformer):
|
||||
def __init__(self, limit):
|
||||
self.limit = limit
|
||||
|
@ -26,8 +44,8 @@ class _LoopUnroller(ast.NodeTransformer):
|
|||
return node
|
||||
l_it = len(it)
|
||||
if l_it:
|
||||
n = l_it*_count_stmts(node.body)
|
||||
if n < self.limit:
|
||||
if (not _loop_breakable(node.body)
|
||||
and l_it*_count_stmts(node.body) < self.limit):
|
||||
replacement = []
|
||||
for i in it:
|
||||
if not isinstance(i, int):
|
||||
|
|
Loading…
Reference in New Issue