diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index b3f56a7b6..0d5f435ba 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -119,16 +119,47 @@ _embeddable_calls = { } -class _ReferenceReplacer(ast.NodeTransformer): - def __init__(self, core, rm, obj, funcname): +class _ReferenceReplacer(ast.NodeVisitor): + def __init__(self, core, rm, obj, func_name, retval_name): self.core = core self.rm = rm self.obj = obj - self.funcname = funcname + self.func_name = func_name + self.retval_name = retval_name + self._insertion_point = None + + # This is ast.NodeTransformer.generic_visit from CPython, modified + # to update self._insertion_point. + def generic_visit(self, node): + for field, old_value in ast.iter_fields(node): + old_value = getattr(node, field, None) + if isinstance(old_value, list): + prev_insertion_point = self._insertion_point + new_values = [] + if field in ("body", "orelse", "finalbody"): + self._insertion_point = new_values + for value in old_value: + if isinstance(value, ast.AST): + value = self.visit(value) + if value is None: + continue + elif not isinstance(value, ast.AST): + new_values.extend(value) + continue + new_values.append(value) + old_value[:] = new_values + self._insertion_point = prev_insertion_point + elif isinstance(old_value, ast.AST): + new_node = self.visit(old_value) + if new_node is None: + delattr(node, field) + else: + setattr(node, field, new_node) + return node def visit_ref(self, node): store = isinstance(node.ctx, ast.Store) - ival = self.rm.get(self.obj, self.funcname, node) + ival = self.rm.get(self.obj, self.func_name, node) if isinstance(ival, _UserVariable): newnode = ast.Name(ival.name, node.ctx) elif isinstance(ival, ast.AST): @@ -149,7 +180,7 @@ class _ReferenceReplacer(ast.NodeTransformer): visit_Attribute = visit_ref def visit_Call(self, node): - func = self.rm.get(self.obj, self.funcname, node.func) + func = self.rm.get(self.obj, self.func_name, node.func) new_args = [self.visit(arg) for arg in node.args] if func in _embeddable_calls: @@ -161,15 +192,17 @@ class _ReferenceReplacer(ast.NodeTransformer): elif (hasattr(func, "k_function_info") and getattr(func.__self__, func.k_function_info.core_name) is self.core): + retval_name = self.rm.new_name( + func.k_function_info.k_function.__name__ + "_return") args = [func.__self__] + new_args inlined, _ = inline(self.core, func.k_function_info.k_function, - args, dict(), self.rm) - r = ast.With( + args, dict(), self.rm, retval_name) + self._insertion_point.append(ast.With( items=[ast.withitem(context_expr=ast.Name(id="sequential", ctx=ast.Load()), optional_vars=None)], - body=inlined.body) - return ast.copy_location(r, node) + body=inlined.body)) + return ast.copy_location(ast.Name(retval_name, ast.Load()), node) else: args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])] args += new_args @@ -178,16 +211,22 @@ class _ReferenceReplacer(ast.NodeTransformer): args=args, keywords=[], starargs=None, kwargs=None), node) + def visit_Return(self, node): + self.generic_visit(node) + return ast.copy_location( + ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())], + value=node.value), + node) + def visit_Expr(self, node): - if isinstance(node.value, ast.Call): - r = self.visit_Call(node.value) - if isinstance(r, ast.With): - return r - else: - node.value = r - return node + self.generic_visit(node) + if isinstance(node.value, ast.Name): + # Remove Expr nodes that contain only a name, likely due to + # function call inlining. Such nodes that were originally in the + # code are also removed, but this does not affect the semantics of + # the code as they are nops. + return None else: - self.generic_visit(node) return node def visit_FunctionDef(self, node): @@ -213,46 +252,46 @@ class _ListReadOnlyParams(ast.NodeVisitor): pass -def _list_read_only_params(funcdef): +def _list_read_only_params(func_def): lrp = _ListReadOnlyParams() - lrp.visit(funcdef) + lrp.visit(func_def) return lrp.read_only_params -def _initialize_function_params(funcdef, k_args, k_kwargs, rm): +def _initialize_function_params(func_def, k_args, k_kwargs, rm): obj = k_args[0] - funcname = funcdef.name + func_name = func_def.name param_init = [] - rop = _list_read_only_params(funcdef) - for arg_ast, arg_value in zip(funcdef.args.args, k_args): + rop = _list_read_only_params(func_def) + for arg_ast, arg_value in zip(func_def.args.args, k_args): arg_name = arg_ast.arg if arg_name in rop: - rm.set(obj, funcname, arg_name, arg_value) + rm.set(obj, func_name, arg_name, arg_value) else: - target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store())) + target = rm.get(obj, func_name, ast.Name(arg_name, ast.Store())) value = value_to_ast(arg_value) param_init.append(ast.Assign(targets=[target], value=value)) return param_init -def inline(core, k_function, k_args, k_kwargs, rm=None): +def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None): init_kernel_attr = rm is None if rm is None: rm = _ReferenceManager() - funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0] + func_def = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0] - param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm) + param_init = _initialize_function_params(func_def, k_args, k_kwargs, rm) obj = k_args[0] - funcname = funcdef.name - rr = _ReferenceReplacer(core, rm, obj, funcname) - rr.visit(funcdef) + func_name = func_def.name + rr = _ReferenceReplacer(core, rm, obj, func_name, retval_name) + rr.visit(func_def) - funcdef.body[0:0] = param_init + func_def.body[0:0] = param_init if init_kernel_attr: - funcdef.body[0:0] = rm.kernel_attr_init + func_def.body[0:0] = rm.kernel_attr_init r_rpc_map = dict((rpc_num, rpc_fun) for rpc_fun, rpc_num in rm.rpc_map.items()) - return funcdef, r_rpc_map + return func_def, r_rpc_map