forked from M-Labs/artiq
transforms/unroll_loop: do not unroll breakable loops
This commit is contained in:
parent
d52d641dcd
commit
5fe3cffc84
|
@ -14,6 +14,24 @@ def _count_stmts(node):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def _loop_breakable(node):
|
||||||
|
if isinstance(node, ast.Break):
|
||||||
|
return 1
|
||||||
|
elif isinstance(node, ast.Return):
|
||||||
|
return 2
|
||||||
|
elif isinstance(node, list):
|
||||||
|
return max(map(_loop_breakable, node), default=0)
|
||||||
|
elif isinstance(node, ast.If):
|
||||||
|
return max(_loop_breakable(node.body), _loop_breakable(node.orelse))
|
||||||
|
elif isinstance(node, (ast.For, ast.While)):
|
||||||
|
bb = _loop_breakable(node.body)
|
||||||
|
if bb == 1:
|
||||||
|
bb = 0
|
||||||
|
return max(bb, _loop_breakable(node.orelse))
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class _LoopUnroller(ast.NodeTransformer):
|
class _LoopUnroller(ast.NodeTransformer):
|
||||||
def __init__(self, limit):
|
def __init__(self, limit):
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
|
@ -26,8 +44,8 @@ class _LoopUnroller(ast.NodeTransformer):
|
||||||
return node
|
return node
|
||||||
l_it = len(it)
|
l_it = len(it)
|
||||||
if l_it:
|
if l_it:
|
||||||
n = l_it*_count_stmts(node.body)
|
if (not _loop_breakable(node.body)
|
||||||
if n < self.limit:
|
and l_it*_count_stmts(node.body) < self.limit):
|
||||||
replacement = []
|
replacement = []
|
||||||
for i in it:
|
for i in it:
|
||||||
if not isinstance(i, int):
|
if not isinstance(i, int):
|
||||||
|
|
Loading…
Reference in New Issue