compiler: pass funcdef instead of statement list

This commit is contained in:
Sebastien Bourdeauducq 2014-08-18 21:37:30 +08:00
parent 86577ff64f
commit a5e5b5c870
7 changed files with 27 additions and 31 deletions

View File

@ -57,4 +57,5 @@ class _ConstantFolder(ast.NodeTransformer):
return node
return ast.copy_location(result, node)
fold_constants = make_stmt_transformer(_ConstantFolder)
def fold_constants(node):
_ConstantFolder().visit(node)

View File

@ -151,7 +151,7 @@ class _ReferenceReplacer(ast.NodeTransformer):
elif hasattr(func, "k_function_info") and getattr(func.__self__, func.k_function_info.core_name) is self.core:
args = [func.__self__] + new_args
inlined, _ = inline(self.core, func.k_function_info.k_function, args, dict(), self.rm)
return inlined
return inlined.body
else:
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
args += new_args
@ -173,6 +173,7 @@ class _ReferenceReplacer(ast.NodeTransformer):
return node
def visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[])
node.decorator_list = []
self.generic_visit(node)
return node
@ -230,4 +231,4 @@ def inline(core, k_function, k_args, k_kwargs, rm=None):
funcdef.body[0:0] = rm.kernel_attr_init
r_rpc_map = dict((rpc_num, rpc_fun) for rpc_fun, rpc_num in rm.rpc_map.items())
return funcdef.body, r_rpc_map
return funcdef, r_rpc_map

View File

@ -84,21 +84,21 @@ def _interleave_timelines(timelines):
return r
def interleave(stmts):
def _interleave_stmts(stmts):
replacements = []
for stmt_i, stmt in enumerate(stmts):
if isinstance(stmt, (ast.For, ast.While, ast.If)):
interleave(stmt.body)
interleave(stmt.orelse)
_interleave_stmts(stmt.body)
_interleave_stmts(stmt.orelse)
elif isinstance(stmt, ast.With):
btype = stmt.items[0].context_expr.id
if btype == "sequential":
interleave(stmt.body)
_interleave_stmts(stmt.body)
replacements.append((stmt_i, stmt.body))
elif btype == "parallel":
timelines = [[s] for s in stmt.body]
for timeline in timelines:
interleave(timeline)
_interleave_stmts(timeline)
merged = _interleave_timelines(timelines)
if merged is not None:
replacements.append((stmt_i, merged))
@ -108,3 +108,6 @@ def interleave(stmts):
for location, new_stmts in replacements:
stmts[offset+location:offset+location+1] = new_stmts
offset += len(new_stmts) - 1
def interleave(funcdef):
_interleave_stmts(funcdef.body)

View File

@ -34,10 +34,8 @@ class _TimeLowerer(ast.NodeTransformer):
else:
return node
def lower_time(stmts, ref_period, initial_time):
transformer = _TimeLowerer(ref_period)
new_stmts = [transformer.visit(stmt) for stmt in stmts]
new_stmts.insert(0, ast.copy_location(
def lower_time(funcdef, initial_time):
_TimeLowerer().visit(funcdef)
funcdef.body.insert(0, ast.copy_location(
ast.Assign(targets=[ast.Name("now", ast.Store())], value=value_to_ast(initial_time)),
stmts[0]))
stmts[:] = new_stmts
funcdef))

View File

@ -59,10 +59,3 @@ def eval_constant(node):
return units.Quantity(amount, unit)
else:
raise NotConstant
def make_stmt_transformer(transformer_class):
def stmt_transformer(stmts, *args, **kwargs):
transformer = transformer_class(*args, **kwargs)
new_stmts = [transformer.visit(stmt) for stmt in stmts]
stmts[:] = new_stmts
return stmt_transformer

View File

@ -1,6 +1,6 @@
import ast
from artiq.compiler.tools import eval_ast, make_stmt_transformer, value_to_ast
from artiq.compiler.tools import eval_ast, value_to_ast
def _count_stmts(node):
if isinstance(node, (ast.For, ast.While, ast.If)):
@ -43,4 +43,5 @@ class _LoopUnroller(ast.NodeTransformer):
else:
return node.orelse
unroll_loops = make_stmt_transformer(_LoopUnroller)
def unroll_loops(node, limit):
_LoopUnroller(limit).visit(node)

View File

@ -13,14 +13,13 @@ class Core:
self.core_com = core_com
def run(self, k_function, k_args, k_kwargs):
stmts, rpc_map = inline(self, k_function, k_args, k_kwargs)
fold_constants(stmts)
unroll_loops(stmts, 50)
interleave(stmts)
lower_time(stmts, self.runtime_env.ref_period,
getattr(self.runtime_env, "initial_time", 0))
fold_constants(stmts)
funcdef, rpc_map = inline(self, k_function, k_args, k_kwargs)
fold_constants(funcdef)
unroll_loops(funcdef, 50)
interleave(funcdef)
lower_time(funcdef, getattr(self.runtime_env, "initial_time", 0))
fold_constants(funcdef)
binary = get_runtime_binary(self.runtime_env, stmts)
binary = get_runtime_binary(self.runtime_env, funcdef)
self.core_com.run(binary)
self.core_com.serve(rpc_map)