forked from M-Labs/artiq
1
0
Fork 0

transforms/inline: support for function return values

This commit is contained in:
Sebastien Bourdeauducq 2014-09-13 16:17:16 +08:00
parent 3d440d5f15
commit dc9515fc62
1 changed files with 73 additions and 34 deletions

View File

@ -119,16 +119,47 @@ _embeddable_calls = {
} }
class _ReferenceReplacer(ast.NodeTransformer): class _ReferenceReplacer(ast.NodeVisitor):
def __init__(self, core, rm, obj, funcname): def __init__(self, core, rm, obj, func_name, retval_name):
self.core = core self.core = core
self.rm = rm self.rm = rm
self.obj = obj self.obj = obj
self.funcname = funcname self.func_name = func_name
self.retval_name = retval_name
self._insertion_point = None
# This is ast.NodeTransformer.generic_visit from CPython, modified
# to update self._insertion_point.
def generic_visit(self, node):
for field, old_value in ast.iter_fields(node):
old_value = getattr(node, field, None)
if isinstance(old_value, list):
prev_insertion_point = self._insertion_point
new_values = []
if field in ("body", "orelse", "finalbody"):
self._insertion_point = new_values
for value in old_value:
if isinstance(value, ast.AST):
value = self.visit(value)
if value is None:
continue
elif not isinstance(value, ast.AST):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
self._insertion_point = prev_insertion_point
elif isinstance(old_value, ast.AST):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
def visit_ref(self, node): def visit_ref(self, node):
store = isinstance(node.ctx, ast.Store) store = isinstance(node.ctx, ast.Store)
ival = self.rm.get(self.obj, self.funcname, node) ival = self.rm.get(self.obj, self.func_name, node)
if isinstance(ival, _UserVariable): if isinstance(ival, _UserVariable):
newnode = ast.Name(ival.name, node.ctx) newnode = ast.Name(ival.name, node.ctx)
elif isinstance(ival, ast.AST): elif isinstance(ival, ast.AST):
@ -149,7 +180,7 @@ class _ReferenceReplacer(ast.NodeTransformer):
visit_Attribute = visit_ref visit_Attribute = visit_ref
def visit_Call(self, node): def visit_Call(self, node):
func = self.rm.get(self.obj, self.funcname, node.func) func = self.rm.get(self.obj, self.func_name, node.func)
new_args = [self.visit(arg) for arg in node.args] new_args = [self.visit(arg) for arg in node.args]
if func in _embeddable_calls: if func in _embeddable_calls:
@ -161,15 +192,17 @@ class _ReferenceReplacer(ast.NodeTransformer):
elif (hasattr(func, "k_function_info") elif (hasattr(func, "k_function_info")
and getattr(func.__self__, func.k_function_info.core_name) and getattr(func.__self__, func.k_function_info.core_name)
is self.core): is self.core):
retval_name = self.rm.new_name(
func.k_function_info.k_function.__name__ + "_return")
args = [func.__self__] + new_args args = [func.__self__] + new_args
inlined, _ = inline(self.core, func.k_function_info.k_function, inlined, _ = inline(self.core, func.k_function_info.k_function,
args, dict(), self.rm) args, dict(), self.rm, retval_name)
r = ast.With( self._insertion_point.append(ast.With(
items=[ast.withitem(context_expr=ast.Name(id="sequential", items=[ast.withitem(context_expr=ast.Name(id="sequential",
ctx=ast.Load()), ctx=ast.Load()),
optional_vars=None)], optional_vars=None)],
body=inlined.body) body=inlined.body))
return ast.copy_location(r, node) return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
else: else:
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])] args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
args += new_args args += new_args
@ -178,16 +211,22 @@ class _ReferenceReplacer(ast.NodeTransformer):
args=args, keywords=[], starargs=None, kwargs=None), args=args, keywords=[], starargs=None, kwargs=None),
node) node)
def visit_Expr(self, node): def visit_Return(self, node):
if isinstance(node.value, ast.Call):
r = self.visit_Call(node.value)
if isinstance(r, ast.With):
return r
else:
node.value = r
return node
else:
self.generic_visit(node) self.generic_visit(node)
return ast.copy_location(
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
value=node.value),
node)
def visit_Expr(self, node):
self.generic_visit(node)
if isinstance(node.value, ast.Name):
# Remove Expr nodes that contain only a name, likely due to
# function call inlining. Such nodes that were originally in the
# code are also removed, but this does not affect the semantics of
# the code as they are nops.
return None
else:
return node return node
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
@ -213,46 +252,46 @@ class _ListReadOnlyParams(ast.NodeVisitor):
pass pass
def _list_read_only_params(funcdef): def _list_read_only_params(func_def):
lrp = _ListReadOnlyParams() lrp = _ListReadOnlyParams()
lrp.visit(funcdef) lrp.visit(func_def)
return lrp.read_only_params return lrp.read_only_params
def _initialize_function_params(funcdef, k_args, k_kwargs, rm): def _initialize_function_params(func_def, k_args, k_kwargs, rm):
obj = k_args[0] obj = k_args[0]
funcname = funcdef.name func_name = func_def.name
param_init = [] param_init = []
rop = _list_read_only_params(funcdef) rop = _list_read_only_params(func_def)
for arg_ast, arg_value in zip(funcdef.args.args, k_args): for arg_ast, arg_value in zip(func_def.args.args, k_args):
arg_name = arg_ast.arg arg_name = arg_ast.arg
if arg_name in rop: if arg_name in rop:
rm.set(obj, funcname, arg_name, arg_value) rm.set(obj, func_name, arg_name, arg_value)
else: else:
target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store())) target = rm.get(obj, func_name, ast.Name(arg_name, ast.Store()))
value = value_to_ast(arg_value) value = value_to_ast(arg_value)
param_init.append(ast.Assign(targets=[target], value=value)) param_init.append(ast.Assign(targets=[target], value=value))
return param_init return param_init
def inline(core, k_function, k_args, k_kwargs, rm=None): def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
init_kernel_attr = rm is None init_kernel_attr = rm is None
if rm is None: if rm is None:
rm = _ReferenceManager() rm = _ReferenceManager()
funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0] func_def = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm) param_init = _initialize_function_params(func_def, k_args, k_kwargs, rm)
obj = k_args[0] obj = k_args[0]
funcname = funcdef.name func_name = func_def.name
rr = _ReferenceReplacer(core, rm, obj, funcname) rr = _ReferenceReplacer(core, rm, obj, func_name, retval_name)
rr.visit(funcdef) rr.visit(func_def)
funcdef.body[0:0] = param_init func_def.body[0:0] = param_init
if init_kernel_attr: if init_kernel_attr:
funcdef.body[0:0] = rm.kernel_attr_init func_def.body[0:0] = rm.kernel_attr_init
r_rpc_map = dict((rpc_num, rpc_fun) r_rpc_map = dict((rpc_num, rpc_fun)
for rpc_fun, rpc_num in rm.rpc_map.items()) for rpc_fun, rpc_num in rm.rpc_map.items())
return funcdef, r_rpc_map return func_def, r_rpc_map