forked from M-Labs/artiq
transforms/inline: support for function return values
This commit is contained in:
parent
3d440d5f15
commit
dc9515fc62
@ -119,16 +119,47 @@ _embeddable_calls = {
|
||||
}
|
||||
|
||||
|
||||
class _ReferenceReplacer(ast.NodeTransformer):
|
||||
def __init__(self, core, rm, obj, funcname):
|
||||
class _ReferenceReplacer(ast.NodeVisitor):
|
||||
def __init__(self, core, rm, obj, func_name, retval_name):
|
||||
self.core = core
|
||||
self.rm = rm
|
||||
self.obj = obj
|
||||
self.funcname = funcname
|
||||
self.func_name = func_name
|
||||
self.retval_name = retval_name
|
||||
self._insertion_point = None
|
||||
|
||||
# This is ast.NodeTransformer.generic_visit from CPython, modified
|
||||
# to update self._insertion_point.
|
||||
def generic_visit(self, node):
|
||||
for field, old_value in ast.iter_fields(node):
|
||||
old_value = getattr(node, field, None)
|
||||
if isinstance(old_value, list):
|
||||
prev_insertion_point = self._insertion_point
|
||||
new_values = []
|
||||
if field in ("body", "orelse", "finalbody"):
|
||||
self._insertion_point = new_values
|
||||
for value in old_value:
|
||||
if isinstance(value, ast.AST):
|
||||
value = self.visit(value)
|
||||
if value is None:
|
||||
continue
|
||||
elif not isinstance(value, ast.AST):
|
||||
new_values.extend(value)
|
||||
continue
|
||||
new_values.append(value)
|
||||
old_value[:] = new_values
|
||||
self._insertion_point = prev_insertion_point
|
||||
elif isinstance(old_value, ast.AST):
|
||||
new_node = self.visit(old_value)
|
||||
if new_node is None:
|
||||
delattr(node, field)
|
||||
else:
|
||||
setattr(node, field, new_node)
|
||||
return node
|
||||
|
||||
def visit_ref(self, node):
|
||||
store = isinstance(node.ctx, ast.Store)
|
||||
ival = self.rm.get(self.obj, self.funcname, node)
|
||||
ival = self.rm.get(self.obj, self.func_name, node)
|
||||
if isinstance(ival, _UserVariable):
|
||||
newnode = ast.Name(ival.name, node.ctx)
|
||||
elif isinstance(ival, ast.AST):
|
||||
@ -149,7 +180,7 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||
visit_Attribute = visit_ref
|
||||
|
||||
def visit_Call(self, node):
|
||||
func = self.rm.get(self.obj, self.funcname, node.func)
|
||||
func = self.rm.get(self.obj, self.func_name, node.func)
|
||||
new_args = [self.visit(arg) for arg in node.args]
|
||||
|
||||
if func in _embeddable_calls:
|
||||
@ -161,15 +192,17 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||
elif (hasattr(func, "k_function_info")
|
||||
and getattr(func.__self__, func.k_function_info.core_name)
|
||||
is self.core):
|
||||
retval_name = self.rm.new_name(
|
||||
func.k_function_info.k_function.__name__ + "_return")
|
||||
args = [func.__self__] + new_args
|
||||
inlined, _ = inline(self.core, func.k_function_info.k_function,
|
||||
args, dict(), self.rm)
|
||||
r = ast.With(
|
||||
args, dict(), self.rm, retval_name)
|
||||
self._insertion_point.append(ast.With(
|
||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||
ctx=ast.Load()),
|
||||
optional_vars=None)],
|
||||
body=inlined.body)
|
||||
return ast.copy_location(r, node)
|
||||
body=inlined.body))
|
||||
return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
|
||||
else:
|
||||
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
|
||||
args += new_args
|
||||
@ -178,16 +211,22 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||
args=args, keywords=[], starargs=None, kwargs=None),
|
||||
node)
|
||||
|
||||
def visit_Return(self, node):
|
||||
self.generic_visit(node)
|
||||
return ast.copy_location(
|
||||
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
|
||||
value=node.value),
|
||||
node)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
if isinstance(node.value, ast.Call):
|
||||
r = self.visit_Call(node.value)
|
||||
if isinstance(r, ast.With):
|
||||
return r
|
||||
else:
|
||||
node.value = r
|
||||
return node
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.value, ast.Name):
|
||||
# Remove Expr nodes that contain only a name, likely due to
|
||||
# function call inlining. Such nodes that were originally in the
|
||||
# code are also removed, but this does not affect the semantics of
|
||||
# the code as they are nops.
|
||||
return None
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
@ -213,46 +252,46 @@ class _ListReadOnlyParams(ast.NodeVisitor):
|
||||
pass
|
||||
|
||||
|
||||
def _list_read_only_params(funcdef):
|
||||
def _list_read_only_params(func_def):
|
||||
lrp = _ListReadOnlyParams()
|
||||
lrp.visit(funcdef)
|
||||
lrp.visit(func_def)
|
||||
return lrp.read_only_params
|
||||
|
||||
|
||||
def _initialize_function_params(funcdef, k_args, k_kwargs, rm):
|
||||
def _initialize_function_params(func_def, k_args, k_kwargs, rm):
|
||||
obj = k_args[0]
|
||||
funcname = funcdef.name
|
||||
func_name = func_def.name
|
||||
param_init = []
|
||||
rop = _list_read_only_params(funcdef)
|
||||
for arg_ast, arg_value in zip(funcdef.args.args, k_args):
|
||||
rop = _list_read_only_params(func_def)
|
||||
for arg_ast, arg_value in zip(func_def.args.args, k_args):
|
||||
arg_name = arg_ast.arg
|
||||
if arg_name in rop:
|
||||
rm.set(obj, funcname, arg_name, arg_value)
|
||||
rm.set(obj, func_name, arg_name, arg_value)
|
||||
else:
|
||||
target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store()))
|
||||
target = rm.get(obj, func_name, ast.Name(arg_name, ast.Store()))
|
||||
value = value_to_ast(arg_value)
|
||||
param_init.append(ast.Assign(targets=[target], value=value))
|
||||
return param_init
|
||||
|
||||
|
||||
def inline(core, k_function, k_args, k_kwargs, rm=None):
|
||||
def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
|
||||
init_kernel_attr = rm is None
|
||||
if rm is None:
|
||||
rm = _ReferenceManager()
|
||||
|
||||
funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
|
||||
func_def = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
|
||||
|
||||
param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm)
|
||||
param_init = _initialize_function_params(func_def, k_args, k_kwargs, rm)
|
||||
|
||||
obj = k_args[0]
|
||||
funcname = funcdef.name
|
||||
rr = _ReferenceReplacer(core, rm, obj, funcname)
|
||||
rr.visit(funcdef)
|
||||
func_name = func_def.name
|
||||
rr = _ReferenceReplacer(core, rm, obj, func_name, retval_name)
|
||||
rr.visit(func_def)
|
||||
|
||||
funcdef.body[0:0] = param_init
|
||||
func_def.body[0:0] = param_init
|
||||
if init_kernel_attr:
|
||||
funcdef.body[0:0] = rm.kernel_attr_init
|
||||
func_def.body[0:0] = rm.kernel_attr_init
|
||||
|
||||
r_rpc_map = dict((rpc_num, rpc_fun)
|
||||
for rpc_fun, rpc_num in rm.rpc_map.items())
|
||||
return funcdef, r_rpc_map
|
||||
return func_def, r_rpc_map
|
||||
|
Loading…
Reference in New Issue
Block a user