From c021b2ef412c3735d1318285d9e186038dc43e65 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Mon, 16 Jun 2014 21:52:38 +0200 Subject: [PATCH] compiler/inline: handle function params and builtin calls --- artiq/compiler/inline.py | 60 +++++++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/artiq/compiler/inline.py b/artiq/compiler/inline.py index a051ad395..4b49176ce 100644 --- a/artiq/compiler/inline.py +++ b/artiq/compiler/inline.py @@ -1,5 +1,5 @@ -from collections import namedtuple -import inspect, textwrap, ast +from collections import namedtuple, defaultdict +import inspect, textwrap, ast, builtins from artiq.compiler import unparse from artiq.compiler.tools import eval_ast @@ -44,6 +44,8 @@ class _ReferenceManager: self.use_count["base_Hz_unit"] = 1 for kg in experiment.kernel_globals: self.use_count[kg] = 1 + for bi in dir(builtins): + self.use_count[bi] = 1 def new_name(self, base_name): if base_name[-1].isdigit(): @@ -91,11 +93,19 @@ class _ReferenceManager: def set(self, obj, funcname, name, value): self.to_inlined[(id(obj), funcname, name)] = value + def get_constants(self, r_obj, r_funcname): + return {local: v for (objid, funcname, local), v + in self.to_inlined.items() + if objid == id(r_obj) + and funcname == r_funcname + and not isinstance(v, _UserVariable)} + class _ReferenceReplacer(ast.NodeTransformer): def __init__(self, rm, obj, funcname): self.rm = rm self.obj = obj self.funcname = funcname + self.module = inspect.getmodule(self.obj) def visit_ref(self, node): return self.rm.get(self.obj, self.funcname, node) @@ -104,11 +114,29 @@ class _ReferenceReplacer(ast.NodeTransformer): visit_Attribute = visit_ref visit_Subscript = visit_ref + def visit_Call(self, node): + calldict = self.rm.get_constants(self.obj, self.funcname) + calldict.update(self.module.__dict__) + func = eval_ast(node.func, calldict) + + if inspect.getmodule(func) is builtins: + new_func = ast.Name(func.__name__, ast.Load()) + new_args = [self.visit(arg) for arg in node.args] + return ast.Call(func=new_func, args=new_args, + keywords=[], starargs=None, kwargs=None) + elif hasattr(func, "k_function_info"): + print(func.k_function_info) + # TODO: inline called kernel + + self.generic_visit(node) + return node + class _ListReadOnlyParams(ast.NodeVisitor): def visit_FunctionDef(self, node): if hasattr(self, "read_only_params"): raise ValueError("More than one function definition") self.read_only_params = {arg.arg for arg in node.args.args} + self.generic_visit(node) def visit_Name(self, node): if isinstance(node.ctx, ast.Store): @@ -122,23 +150,35 @@ def _list_read_only_params(funcdef): lrp.visit(funcdef) return lrp.read_only_params -def inline(k_function, k_args, k_kwargs, rm=None): - if rm is None: - rm = _ReferenceManager() - - funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0] +def _initialize_function_params(funcdef, k_args, k_kwargs, rm): obj = k_args[0] funcname = funcdef.name - + param_init = [] rop = _list_read_only_params(funcdef) for arg_ast, arg_value in zip(funcdef.args.args, k_args): arg_name = arg_ast.arg if arg_name in rop: rm.set(obj, funcname, arg_name, arg_value) + else: + target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store())) + value = _value_to_ast(arg_value) + param_init.append(ast.Assign(targets=[target], value=value)) + return param_init +def inline(k_function, k_args, k_kwargs, rm=None): + if rm is None: + rm = _ReferenceManager() + + funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0] + + param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm) + + obj = k_args[0] + funcname = funcdef.name rr = _ReferenceReplacer(rm, obj, funcname) for stmt in funcdef.body: rr.visit(stmt) - print(ast.dump(funcdef)) - unparse.Unparser(funcdef) + funcdef.body[0:0] = param_init + + unparse.Unparser(funcdef.body)