unroll_loops: handle for/else

This commit is contained in:
Sebastien Bourdeauducq 2014-05-25 13:37:12 +02:00
parent 08ff5154a8
commit 956f75d168
1 changed files with 18 additions and 13 deletions

View File

@ -75,18 +75,23 @@ def unroll_loops(stmts, limit):
pass pass
else: else:
unroll_loops(stmt.body, limit) unroll_loops(stmt.body, limit)
n = len(it)*_count_stmts(stmt.body) unroll_loops(stmt.orelse, limit)
if n < limit: l_it = len(it)
replacement = [] if l_it:
for i in it: n = l_it*_count_stmts(stmt.body)
if not isinstance(i, int): if n < limit:
replacement = None replacement = []
break for i in it:
replacement.append(ast.copy_location( if not isinstance(i, int):
ast.Assign(targets=[stmt.target], value=ast.Num(n=i)), stmt)) replacement = None
replacement += stmt.body break
if replacement is not None: replacement.append(ast.copy_location(
replacements.append((stmt_i, replacement)) ast.Assign(targets=[stmt.target], value=ast.Num(n=i)), stmt))
replacement += stmt.body
if replacement is not None:
replacements.append((stmt_i, replacement))
else:
replacements.append((stmt_i, stmt.orelse))
if isinstance(stmt, (ast.While, ast.If)): if isinstance(stmt, (ast.While, ast.If)):
unroll_loops(stmt.body, limit) unroll_loops(stmt.body, limit)
unroll_loops(stmt.orelse, limit) unroll_loops(stmt.orelse, limit)