forked from M-Labs/artiq
transforms/inline: support for function return values
This commit is contained in:
parent
3d440d5f15
commit
dc9515fc62
|
@ -119,16 +119,47 @@ _embeddable_calls = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class _ReferenceReplacer(ast.NodeTransformer):
|
class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
def __init__(self, core, rm, obj, funcname):
|
def __init__(self, core, rm, obj, func_name, retval_name):
|
||||||
self.core = core
|
self.core = core
|
||||||
self.rm = rm
|
self.rm = rm
|
||||||
self.obj = obj
|
self.obj = obj
|
||||||
self.funcname = funcname
|
self.func_name = func_name
|
||||||
|
self.retval_name = retval_name
|
||||||
|
self._insertion_point = None
|
||||||
|
|
||||||
|
# This is ast.NodeTransformer.generic_visit from CPython, modified
|
||||||
|
# to update self._insertion_point.
|
||||||
|
def generic_visit(self, node):
|
||||||
|
for field, old_value in ast.iter_fields(node):
|
||||||
|
old_value = getattr(node, field, None)
|
||||||
|
if isinstance(old_value, list):
|
||||||
|
prev_insertion_point = self._insertion_point
|
||||||
|
new_values = []
|
||||||
|
if field in ("body", "orelse", "finalbody"):
|
||||||
|
self._insertion_point = new_values
|
||||||
|
for value in old_value:
|
||||||
|
if isinstance(value, ast.AST):
|
||||||
|
value = self.visit(value)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
elif not isinstance(value, ast.AST):
|
||||||
|
new_values.extend(value)
|
||||||
|
continue
|
||||||
|
new_values.append(value)
|
||||||
|
old_value[:] = new_values
|
||||||
|
self._insertion_point = prev_insertion_point
|
||||||
|
elif isinstance(old_value, ast.AST):
|
||||||
|
new_node = self.visit(old_value)
|
||||||
|
if new_node is None:
|
||||||
|
delattr(node, field)
|
||||||
|
else:
|
||||||
|
setattr(node, field, new_node)
|
||||||
|
return node
|
||||||
|
|
||||||
def visit_ref(self, node):
|
def visit_ref(self, node):
|
||||||
store = isinstance(node.ctx, ast.Store)
|
store = isinstance(node.ctx, ast.Store)
|
||||||
ival = self.rm.get(self.obj, self.funcname, node)
|
ival = self.rm.get(self.obj, self.func_name, node)
|
||||||
if isinstance(ival, _UserVariable):
|
if isinstance(ival, _UserVariable):
|
||||||
newnode = ast.Name(ival.name, node.ctx)
|
newnode = ast.Name(ival.name, node.ctx)
|
||||||
elif isinstance(ival, ast.AST):
|
elif isinstance(ival, ast.AST):
|
||||||
|
@ -149,7 +180,7 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||||
visit_Attribute = visit_ref
|
visit_Attribute = visit_ref
|
||||||
|
|
||||||
def visit_Call(self, node):
|
def visit_Call(self, node):
|
||||||
func = self.rm.get(self.obj, self.funcname, node.func)
|
func = self.rm.get(self.obj, self.func_name, node.func)
|
||||||
new_args = [self.visit(arg) for arg in node.args]
|
new_args = [self.visit(arg) for arg in node.args]
|
||||||
|
|
||||||
if func in _embeddable_calls:
|
if func in _embeddable_calls:
|
||||||
|
@ -161,15 +192,17 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||||
elif (hasattr(func, "k_function_info")
|
elif (hasattr(func, "k_function_info")
|
||||||
and getattr(func.__self__, func.k_function_info.core_name)
|
and getattr(func.__self__, func.k_function_info.core_name)
|
||||||
is self.core):
|
is self.core):
|
||||||
|
retval_name = self.rm.new_name(
|
||||||
|
func.k_function_info.k_function.__name__ + "_return")
|
||||||
args = [func.__self__] + new_args
|
args = [func.__self__] + new_args
|
||||||
inlined, _ = inline(self.core, func.k_function_info.k_function,
|
inlined, _ = inline(self.core, func.k_function_info.k_function,
|
||||||
args, dict(), self.rm)
|
args, dict(), self.rm, retval_name)
|
||||||
r = ast.With(
|
self._insertion_point.append(ast.With(
|
||||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||||
ctx=ast.Load()),
|
ctx=ast.Load()),
|
||||||
optional_vars=None)],
|
optional_vars=None)],
|
||||||
body=inlined.body)
|
body=inlined.body))
|
||||||
return ast.copy_location(r, node)
|
return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
|
||||||
else:
|
else:
|
||||||
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
|
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
|
||||||
args += new_args
|
args += new_args
|
||||||
|
@ -178,16 +211,22 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||||
args=args, keywords=[], starargs=None, kwargs=None),
|
args=args, keywords=[], starargs=None, kwargs=None),
|
||||||
node)
|
node)
|
||||||
|
|
||||||
|
def visit_Return(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
return ast.copy_location(
|
||||||
|
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
|
||||||
|
value=node.value),
|
||||||
|
node)
|
||||||
|
|
||||||
def visit_Expr(self, node):
|
def visit_Expr(self, node):
|
||||||
if isinstance(node.value, ast.Call):
|
self.generic_visit(node)
|
||||||
r = self.visit_Call(node.value)
|
if isinstance(node.value, ast.Name):
|
||||||
if isinstance(r, ast.With):
|
# Remove Expr nodes that contain only a name, likely due to
|
||||||
return r
|
# function call inlining. Such nodes that were originally in the
|
||||||
else:
|
# code are also removed, but this does not affect the semantics of
|
||||||
node.value = r
|
# the code as they are nops.
|
||||||
return node
|
return None
|
||||||
else:
|
else:
|
||||||
self.generic_visit(node)
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
|
@ -213,46 +252,46 @@ class _ListReadOnlyParams(ast.NodeVisitor):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _list_read_only_params(funcdef):
|
def _list_read_only_params(func_def):
|
||||||
lrp = _ListReadOnlyParams()
|
lrp = _ListReadOnlyParams()
|
||||||
lrp.visit(funcdef)
|
lrp.visit(func_def)
|
||||||
return lrp.read_only_params
|
return lrp.read_only_params
|
||||||
|
|
||||||
|
|
||||||
def _initialize_function_params(funcdef, k_args, k_kwargs, rm):
|
def _initialize_function_params(func_def, k_args, k_kwargs, rm):
|
||||||
obj = k_args[0]
|
obj = k_args[0]
|
||||||
funcname = funcdef.name
|
func_name = func_def.name
|
||||||
param_init = []
|
param_init = []
|
||||||
rop = _list_read_only_params(funcdef)
|
rop = _list_read_only_params(func_def)
|
||||||
for arg_ast, arg_value in zip(funcdef.args.args, k_args):
|
for arg_ast, arg_value in zip(func_def.args.args, k_args):
|
||||||
arg_name = arg_ast.arg
|
arg_name = arg_ast.arg
|
||||||
if arg_name in rop:
|
if arg_name in rop:
|
||||||
rm.set(obj, funcname, arg_name, arg_value)
|
rm.set(obj, func_name, arg_name, arg_value)
|
||||||
else:
|
else:
|
||||||
target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store()))
|
target = rm.get(obj, func_name, ast.Name(arg_name, ast.Store()))
|
||||||
value = value_to_ast(arg_value)
|
value = value_to_ast(arg_value)
|
||||||
param_init.append(ast.Assign(targets=[target], value=value))
|
param_init.append(ast.Assign(targets=[target], value=value))
|
||||||
return param_init
|
return param_init
|
||||||
|
|
||||||
|
|
||||||
def inline(core, k_function, k_args, k_kwargs, rm=None):
|
def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
|
||||||
init_kernel_attr = rm is None
|
init_kernel_attr = rm is None
|
||||||
if rm is None:
|
if rm is None:
|
||||||
rm = _ReferenceManager()
|
rm = _ReferenceManager()
|
||||||
|
|
||||||
funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
|
func_def = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
|
||||||
|
|
||||||
param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm)
|
param_init = _initialize_function_params(func_def, k_args, k_kwargs, rm)
|
||||||
|
|
||||||
obj = k_args[0]
|
obj = k_args[0]
|
||||||
funcname = funcdef.name
|
func_name = func_def.name
|
||||||
rr = _ReferenceReplacer(core, rm, obj, funcname)
|
rr = _ReferenceReplacer(core, rm, obj, func_name, retval_name)
|
||||||
rr.visit(funcdef)
|
rr.visit(func_def)
|
||||||
|
|
||||||
funcdef.body[0:0] = param_init
|
func_def.body[0:0] = param_init
|
||||||
if init_kernel_attr:
|
if init_kernel_attr:
|
||||||
funcdef.body[0:0] = rm.kernel_attr_init
|
func_def.body[0:0] = rm.kernel_attr_init
|
||||||
|
|
||||||
r_rpc_map = dict((rpc_num, rpc_fun)
|
r_rpc_map = dict((rpc_num, rpc_fun)
|
||||||
for rpc_fun, rpc_num in rm.rpc_map.items())
|
for rpc_fun, rpc_num in rm.rpc_map.items())
|
||||||
return funcdef, r_rpc_map
|
return func_def, r_rpc_map
|
||||||
|
|
Loading…
Reference in New Issue