forked from M-Labs/artiq
compiler: pass funcdef instead of statement list
This commit is contained in:
parent
86577ff64f
commit
a5e5b5c870
|
@ -57,4 +57,5 @@ class _ConstantFolder(ast.NodeTransformer):
|
||||||
return node
|
return node
|
||||||
return ast.copy_location(result, node)
|
return ast.copy_location(result, node)
|
||||||
|
|
||||||
fold_constants = make_stmt_transformer(_ConstantFolder)
|
def fold_constants(node):
|
||||||
|
_ConstantFolder().visit(node)
|
||||||
|
|
|
@ -151,7 +151,7 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||||
elif hasattr(func, "k_function_info") and getattr(func.__self__, func.k_function_info.core_name) is self.core:
|
elif hasattr(func, "k_function_info") and getattr(func.__self__, func.k_function_info.core_name) is self.core:
|
||||||
args = [func.__self__] + new_args
|
args = [func.__self__] + new_args
|
||||||
inlined, _ = inline(self.core, func.k_function_info.k_function, args, dict(), self.rm)
|
inlined, _ = inline(self.core, func.k_function_info.k_function, args, dict(), self.rm)
|
||||||
return inlined
|
return inlined.body
|
||||||
else:
|
else:
|
||||||
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
|
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
|
||||||
args += new_args
|
args += new_args
|
||||||
|
@ -173,6 +173,7 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
|
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[])
|
||||||
node.decorator_list = []
|
node.decorator_list = []
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
return node
|
return node
|
||||||
|
@ -230,4 +231,4 @@ def inline(core, k_function, k_args, k_kwargs, rm=None):
|
||||||
funcdef.body[0:0] = rm.kernel_attr_init
|
funcdef.body[0:0] = rm.kernel_attr_init
|
||||||
|
|
||||||
r_rpc_map = dict((rpc_num, rpc_fun) for rpc_fun, rpc_num in rm.rpc_map.items())
|
r_rpc_map = dict((rpc_num, rpc_fun) for rpc_fun, rpc_num in rm.rpc_map.items())
|
||||||
return funcdef.body, r_rpc_map
|
return funcdef, r_rpc_map
|
||||||
|
|
|
@ -84,21 +84,21 @@ def _interleave_timelines(timelines):
|
||||||
|
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def interleave(stmts):
|
def _interleave_stmts(stmts):
|
||||||
replacements = []
|
replacements = []
|
||||||
for stmt_i, stmt in enumerate(stmts):
|
for stmt_i, stmt in enumerate(stmts):
|
||||||
if isinstance(stmt, (ast.For, ast.While, ast.If)):
|
if isinstance(stmt, (ast.For, ast.While, ast.If)):
|
||||||
interleave(stmt.body)
|
_interleave_stmts(stmt.body)
|
||||||
interleave(stmt.orelse)
|
_interleave_stmts(stmt.orelse)
|
||||||
elif isinstance(stmt, ast.With):
|
elif isinstance(stmt, ast.With):
|
||||||
btype = stmt.items[0].context_expr.id
|
btype = stmt.items[0].context_expr.id
|
||||||
if btype == "sequential":
|
if btype == "sequential":
|
||||||
interleave(stmt.body)
|
_interleave_stmts(stmt.body)
|
||||||
replacements.append((stmt_i, stmt.body))
|
replacements.append((stmt_i, stmt.body))
|
||||||
elif btype == "parallel":
|
elif btype == "parallel":
|
||||||
timelines = [[s] for s in stmt.body]
|
timelines = [[s] for s in stmt.body]
|
||||||
for timeline in timelines:
|
for timeline in timelines:
|
||||||
interleave(timeline)
|
_interleave_stmts(timeline)
|
||||||
merged = _interleave_timelines(timelines)
|
merged = _interleave_timelines(timelines)
|
||||||
if merged is not None:
|
if merged is not None:
|
||||||
replacements.append((stmt_i, merged))
|
replacements.append((stmt_i, merged))
|
||||||
|
@ -108,3 +108,6 @@ def interleave(stmts):
|
||||||
for location, new_stmts in replacements:
|
for location, new_stmts in replacements:
|
||||||
stmts[offset+location:offset+location+1] = new_stmts
|
stmts[offset+location:offset+location+1] = new_stmts
|
||||||
offset += len(new_stmts) - 1
|
offset += len(new_stmts) - 1
|
||||||
|
|
||||||
|
def interleave(funcdef):
|
||||||
|
_interleave_stmts(funcdef.body)
|
||||||
|
|
|
@ -34,10 +34,8 @@ class _TimeLowerer(ast.NodeTransformer):
|
||||||
else:
|
else:
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def lower_time(stmts, ref_period, initial_time):
|
def lower_time(funcdef, initial_time):
|
||||||
transformer = _TimeLowerer(ref_period)
|
_TimeLowerer().visit(funcdef)
|
||||||
new_stmts = [transformer.visit(stmt) for stmt in stmts]
|
funcdef.body.insert(0, ast.copy_location(
|
||||||
new_stmts.insert(0, ast.copy_location(
|
|
||||||
ast.Assign(targets=[ast.Name("now", ast.Store())], value=value_to_ast(initial_time)),
|
ast.Assign(targets=[ast.Name("now", ast.Store())], value=value_to_ast(initial_time)),
|
||||||
stmts[0]))
|
funcdef))
|
||||||
stmts[:] = new_stmts
|
|
||||||
|
|
|
@ -59,10 +59,3 @@ def eval_constant(node):
|
||||||
return units.Quantity(amount, unit)
|
return units.Quantity(amount, unit)
|
||||||
else:
|
else:
|
||||||
raise NotConstant
|
raise NotConstant
|
||||||
|
|
||||||
def make_stmt_transformer(transformer_class):
|
|
||||||
def stmt_transformer(stmts, *args, **kwargs):
|
|
||||||
transformer = transformer_class(*args, **kwargs)
|
|
||||||
new_stmts = [transformer.visit(stmt) for stmt in stmts]
|
|
||||||
stmts[:] = new_stmts
|
|
||||||
return stmt_transformer
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
from artiq.compiler.tools import eval_ast, make_stmt_transformer, value_to_ast
|
from artiq.compiler.tools import eval_ast, value_to_ast
|
||||||
|
|
||||||
def _count_stmts(node):
|
def _count_stmts(node):
|
||||||
if isinstance(node, (ast.For, ast.While, ast.If)):
|
if isinstance(node, (ast.For, ast.While, ast.If)):
|
||||||
|
@ -43,4 +43,5 @@ class _LoopUnroller(ast.NodeTransformer):
|
||||||
else:
|
else:
|
||||||
return node.orelse
|
return node.orelse
|
||||||
|
|
||||||
unroll_loops = make_stmt_transformer(_LoopUnroller)
|
def unroll_loops(node, limit):
|
||||||
|
_LoopUnroller(limit).visit(node)
|
||||||
|
|
|
@ -13,14 +13,13 @@ class Core:
|
||||||
self.core_com = core_com
|
self.core_com = core_com
|
||||||
|
|
||||||
def run(self, k_function, k_args, k_kwargs):
|
def run(self, k_function, k_args, k_kwargs):
|
||||||
stmts, rpc_map = inline(self, k_function, k_args, k_kwargs)
|
funcdef, rpc_map = inline(self, k_function, k_args, k_kwargs)
|
||||||
fold_constants(stmts)
|
fold_constants(funcdef)
|
||||||
unroll_loops(stmts, 50)
|
unroll_loops(funcdef, 50)
|
||||||
interleave(stmts)
|
interleave(funcdef)
|
||||||
lower_time(stmts, self.runtime_env.ref_period,
|
lower_time(funcdef, getattr(self.runtime_env, "initial_time", 0))
|
||||||
getattr(self.runtime_env, "initial_time", 0))
|
fold_constants(funcdef)
|
||||||
fold_constants(stmts)
|
|
||||||
|
|
||||||
binary = get_runtime_binary(self.runtime_env, stmts)
|
binary = get_runtime_binary(self.runtime_env, funcdef)
|
||||||
self.core_com.run(binary)
|
self.core_com.run(binary)
|
||||||
self.core_com.serve(rpc_map)
|
self.core_com.serve(rpc_map)
|
||||||
|
|
Loading…
Reference in New Issue