mirror of https://github.com/m-labs/artiq.git
47 lines
1.1 KiB
Python
47 lines
1.1 KiB
Python
import ast
|
|
|
|
from artiq.compiler.tools import eval_ast, make_stmt_transformer
|
|
|
|
def _count_stmts(node):
|
|
if isinstance(node, (ast.For, ast.While, ast.If)):
|
|
return 1 + _count_stmts(node.body) + _count_stmts(node.orelse)
|
|
elif isinstance(node, ast.With):
|
|
return 1 + _count_stmts(node.body)
|
|
elif isinstance(node, list):
|
|
return sum(map(_count_stmts, node))
|
|
else:
|
|
return 1
|
|
|
|
class _LoopUnroller(ast.NodeTransformer):
|
|
def __init__(self, limit):
|
|
self.limit = limit
|
|
|
|
def visit_For(self, node):
|
|
self.generic_visit(node)
|
|
try:
|
|
it = eval_ast(node.iter)
|
|
except:
|
|
return node
|
|
l_it = len(it)
|
|
if l_it:
|
|
n = l_it*_count_stmts(node.body)
|
|
if n < self.limit:
|
|
replacement = []
|
|
for i in it:
|
|
if not isinstance(i, int):
|
|
replacement = None
|
|
break
|
|
replacement.append(ast.copy_location(
|
|
ast.Assign(targets=[node.target], value=ast.Num(n=i)), node))
|
|
replacement += node.body
|
|
if replacement is not None:
|
|
return replacement
|
|
else:
|
|
return node
|
|
else:
|
|
return node
|
|
else:
|
|
return node.orelse
|
|
|
|
unroll_loops = make_stmt_transformer(_LoopUnroller)
|