forked from M-Labs/artiq
1
0
Fork 0

compiler: add unroll_loops transform

This commit is contained in:
Sebastien Bourdeauducq 2014-06-21 15:06:15 +02:00
parent 5a8074a12f
commit b28fdf5fb0
2 changed files with 48 additions and 1 deletions

View File

@ -0,0 +1,46 @@
import ast
from artiq.compiler.tools import eval_ast, make_stmt_transformer
def _count_stmts(node):
if isinstance(node, (ast.For, ast.While, ast.If)):
return 1 + _count_stmts(node.body) + _count_stmts(node.orelse)
elif isinstance(node, ast.With):
return 1 + _count_stmts(node.body)
elif isinstance(node, list):
return sum(map(_count_stmts, node))
else:
return 1
class _LoopUnroller(ast.NodeTransformer):
def __init__(self, limit):
self.limit = limit
def visit_For(self, node):
self.generic_visit(node)
try:
it = eval_ast(node.iter)
except:
return node
l_it = len(it)
if l_it:
n = l_it*_count_stmts(node.body)
if n < self.limit:
replacement = []
for i in it:
if not isinstance(i, int):
replacement = None
break
replacement.append(ast.copy_location(
ast.Assign(targets=[node.target], value=ast.Num(n=i)), node))
replacement += node.body
if replacement is not None:
return replacement
else:
return node
else:
return node
else:
return node.orelse
unroll_loops = make_stmt_transformer(_LoopUnroller)

View File

@ -2,13 +2,14 @@ from operator import itemgetter
from artiq.compiler.inline import inline from artiq.compiler.inline import inline
from artiq.compiler.fold_constants import fold_constants from artiq.compiler.fold_constants import fold_constants
from artiq.compiler.unroll_loops import unroll_loops
from artiq.compiler.unparse import Unparser from artiq.compiler.unparse import Unparser
class Core: class Core:
def run(self, k_function, k_args, k_kwargs): def run(self, k_function, k_args, k_kwargs):
stmts, rpc_map = inline(self, k_function, k_args, k_kwargs) stmts, rpc_map = inline(self, k_function, k_args, k_kwargs)
fold_constants(stmts) fold_constants(stmts)
unroll_loops(stmts, 50)
print("=========================") print("=========================")
print(" Inlined") print(" Inlined")