mirror of
https://github.com/m-labs/artiq.git
synced 2025-01-27 02:48:12 +08:00
compiler: add unroll_loops transform
This commit is contained in:
parent
5a8074a12f
commit
b28fdf5fb0
46
artiq/compiler/unroll_loops.py
Normal file
46
artiq/compiler/unroll_loops.py
Normal file
@ -0,0 +1,46 @@
|
||||
import ast
|
||||
|
||||
from artiq.compiler.tools import eval_ast, make_stmt_transformer
|
||||
|
||||
def _count_stmts(node):
|
||||
if isinstance(node, (ast.For, ast.While, ast.If)):
|
||||
return 1 + _count_stmts(node.body) + _count_stmts(node.orelse)
|
||||
elif isinstance(node, ast.With):
|
||||
return 1 + _count_stmts(node.body)
|
||||
elif isinstance(node, list):
|
||||
return sum(map(_count_stmts, node))
|
||||
else:
|
||||
return 1
|
||||
|
||||
class _LoopUnroller(ast.NodeTransformer):
|
||||
def __init__(self, limit):
|
||||
self.limit = limit
|
||||
|
||||
def visit_For(self, node):
|
||||
self.generic_visit(node)
|
||||
try:
|
||||
it = eval_ast(node.iter)
|
||||
except:
|
||||
return node
|
||||
l_it = len(it)
|
||||
if l_it:
|
||||
n = l_it*_count_stmts(node.body)
|
||||
if n < self.limit:
|
||||
replacement = []
|
||||
for i in it:
|
||||
if not isinstance(i, int):
|
||||
replacement = None
|
||||
break
|
||||
replacement.append(ast.copy_location(
|
||||
ast.Assign(targets=[node.target], value=ast.Num(n=i)), node))
|
||||
replacement += node.body
|
||||
if replacement is not None:
|
||||
return replacement
|
||||
else:
|
||||
return node
|
||||
else:
|
||||
return node
|
||||
else:
|
||||
return node.orelse
|
||||
|
||||
unroll_loops = make_stmt_transformer(_LoopUnroller)
|
@ -2,13 +2,14 @@ from operator import itemgetter
|
||||
|
||||
from artiq.compiler.inline import inline
|
||||
from artiq.compiler.fold_constants import fold_constants
|
||||
from artiq.compiler.unroll_loops import unroll_loops
|
||||
from artiq.compiler.unparse import Unparser
|
||||
|
||||
class Core:
|
||||
def run(self, k_function, k_args, k_kwargs):
|
||||
stmts, rpc_map = inline(self, k_function, k_args, k_kwargs)
|
||||
fold_constants(stmts)
|
||||
|
||||
unroll_loops(stmts, 50)
|
||||
|
||||
print("=========================")
|
||||
print(" Inlined")
|
||||
|
Loading…
Reference in New Issue
Block a user