forked from M-Labs/artiq
compiler: pass funcdef instead of statement list
This commit is contained in:
parent
86577ff64f
commit
a5e5b5c870
|
@ -57,4 +57,5 @@ class _ConstantFolder(ast.NodeTransformer):
|
|||
return node
|
||||
return ast.copy_location(result, node)
|
||||
|
||||
fold_constants = make_stmt_transformer(_ConstantFolder)
|
||||
def fold_constants(node):
|
||||
_ConstantFolder().visit(node)
|
||||
|
|
|
@ -151,7 +151,7 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
|||
elif hasattr(func, "k_function_info") and getattr(func.__self__, func.k_function_info.core_name) is self.core:
|
||||
args = [func.__self__] + new_args
|
||||
inlined, _ = inline(self.core, func.k_function_info.k_function, args, dict(), self.rm)
|
||||
return inlined
|
||||
return inlined.body
|
||||
else:
|
||||
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
|
||||
args += new_args
|
||||
|
@ -173,6 +173,7 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
|||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[])
|
||||
node.decorator_list = []
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
@ -230,4 +231,4 @@ def inline(core, k_function, k_args, k_kwargs, rm=None):
|
|||
funcdef.body[0:0] = rm.kernel_attr_init
|
||||
|
||||
r_rpc_map = dict((rpc_num, rpc_fun) for rpc_fun, rpc_num in rm.rpc_map.items())
|
||||
return funcdef.body, r_rpc_map
|
||||
return funcdef, r_rpc_map
|
||||
|
|
|
@ -84,21 +84,21 @@ def _interleave_timelines(timelines):
|
|||
|
||||
return r
|
||||
|
||||
def interleave(stmts):
|
||||
def _interleave_stmts(stmts):
|
||||
replacements = []
|
||||
for stmt_i, stmt in enumerate(stmts):
|
||||
if isinstance(stmt, (ast.For, ast.While, ast.If)):
|
||||
interleave(stmt.body)
|
||||
interleave(stmt.orelse)
|
||||
_interleave_stmts(stmt.body)
|
||||
_interleave_stmts(stmt.orelse)
|
||||
elif isinstance(stmt, ast.With):
|
||||
btype = stmt.items[0].context_expr.id
|
||||
if btype == "sequential":
|
||||
interleave(stmt.body)
|
||||
_interleave_stmts(stmt.body)
|
||||
replacements.append((stmt_i, stmt.body))
|
||||
elif btype == "parallel":
|
||||
timelines = [[s] for s in stmt.body]
|
||||
for timeline in timelines:
|
||||
interleave(timeline)
|
||||
_interleave_stmts(timeline)
|
||||
merged = _interleave_timelines(timelines)
|
||||
if merged is not None:
|
||||
replacements.append((stmt_i, merged))
|
||||
|
@ -108,3 +108,6 @@ def interleave(stmts):
|
|||
for location, new_stmts in replacements:
|
||||
stmts[offset+location:offset+location+1] = new_stmts
|
||||
offset += len(new_stmts) - 1
|
||||
|
||||
def interleave(funcdef):
|
||||
_interleave_stmts(funcdef.body)
|
||||
|
|
|
@ -34,10 +34,8 @@ class _TimeLowerer(ast.NodeTransformer):
|
|||
else:
|
||||
return node
|
||||
|
||||
def lower_time(stmts, ref_period, initial_time):
|
||||
transformer = _TimeLowerer(ref_period)
|
||||
new_stmts = [transformer.visit(stmt) for stmt in stmts]
|
||||
new_stmts.insert(0, ast.copy_location(
|
||||
def lower_time(funcdef, initial_time):
|
||||
_TimeLowerer().visit(funcdef)
|
||||
funcdef.body.insert(0, ast.copy_location(
|
||||
ast.Assign(targets=[ast.Name("now", ast.Store())], value=value_to_ast(initial_time)),
|
||||
stmts[0]))
|
||||
stmts[:] = new_stmts
|
||||
funcdef))
|
||||
|
|
|
@ -59,10 +59,3 @@ def eval_constant(node):
|
|||
return units.Quantity(amount, unit)
|
||||
else:
|
||||
raise NotConstant
|
||||
|
||||
def make_stmt_transformer(transformer_class):
|
||||
def stmt_transformer(stmts, *args, **kwargs):
|
||||
transformer = transformer_class(*args, **kwargs)
|
||||
new_stmts = [transformer.visit(stmt) for stmt in stmts]
|
||||
stmts[:] = new_stmts
|
||||
return stmt_transformer
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import ast
|
||||
|
||||
from artiq.compiler.tools import eval_ast, make_stmt_transformer, value_to_ast
|
||||
from artiq.compiler.tools import eval_ast, value_to_ast
|
||||
|
||||
def _count_stmts(node):
|
||||
if isinstance(node, (ast.For, ast.While, ast.If)):
|
||||
|
@ -43,4 +43,5 @@ class _LoopUnroller(ast.NodeTransformer):
|
|||
else:
|
||||
return node.orelse
|
||||
|
||||
unroll_loops = make_stmt_transformer(_LoopUnroller)
|
||||
def unroll_loops(node, limit):
|
||||
_LoopUnroller(limit).visit(node)
|
||||
|
|
|
@ -13,14 +13,13 @@ class Core:
|
|||
self.core_com = core_com
|
||||
|
||||
def run(self, k_function, k_args, k_kwargs):
|
||||
stmts, rpc_map = inline(self, k_function, k_args, k_kwargs)
|
||||
fold_constants(stmts)
|
||||
unroll_loops(stmts, 50)
|
||||
interleave(stmts)
|
||||
lower_time(stmts, self.runtime_env.ref_period,
|
||||
getattr(self.runtime_env, "initial_time", 0))
|
||||
fold_constants(stmts)
|
||||
funcdef, rpc_map = inline(self, k_function, k_args, k_kwargs)
|
||||
fold_constants(funcdef)
|
||||
unroll_loops(funcdef, 50)
|
||||
interleave(funcdef)
|
||||
lower_time(funcdef, getattr(self.runtime_env, "initial_time", 0))
|
||||
fold_constants(funcdef)
|
||||
|
||||
binary = get_runtime_binary(self.runtime_env, stmts)
|
||||
binary = get_runtime_binary(self.runtime_env, funcdef)
|
||||
self.core_com.run(binary)
|
||||
self.core_com.serve(rpc_map)
|
||||
|
|
Loading…
Reference in New Issue