forked from M-Labs/artiq
compiler/inline: handle function params and builtin calls
This commit is contained in:
parent
f035507bac
commit
c021b2ef41
|
@ -1,5 +1,5 @@
|
||||||
from collections import namedtuple
|
from collections import namedtuple, defaultdict
|
||||||
import inspect, textwrap, ast
|
import inspect, textwrap, ast, builtins
|
||||||
|
|
||||||
from artiq.compiler import unparse
|
from artiq.compiler import unparse
|
||||||
from artiq.compiler.tools import eval_ast
|
from artiq.compiler.tools import eval_ast
|
||||||
|
@ -44,6 +44,8 @@ class _ReferenceManager:
|
||||||
self.use_count["base_Hz_unit"] = 1
|
self.use_count["base_Hz_unit"] = 1
|
||||||
for kg in experiment.kernel_globals:
|
for kg in experiment.kernel_globals:
|
||||||
self.use_count[kg] = 1
|
self.use_count[kg] = 1
|
||||||
|
for bi in dir(builtins):
|
||||||
|
self.use_count[bi] = 1
|
||||||
|
|
||||||
def new_name(self, base_name):
|
def new_name(self, base_name):
|
||||||
if base_name[-1].isdigit():
|
if base_name[-1].isdigit():
|
||||||
|
@ -91,11 +93,19 @@ class _ReferenceManager:
|
||||||
def set(self, obj, funcname, name, value):
|
def set(self, obj, funcname, name, value):
|
||||||
self.to_inlined[(id(obj), funcname, name)] = value
|
self.to_inlined[(id(obj), funcname, name)] = value
|
||||||
|
|
||||||
|
def get_constants(self, r_obj, r_funcname):
|
||||||
|
return {local: v for (objid, funcname, local), v
|
||||||
|
in self.to_inlined.items()
|
||||||
|
if objid == id(r_obj)
|
||||||
|
and funcname == r_funcname
|
||||||
|
and not isinstance(v, _UserVariable)}
|
||||||
|
|
||||||
class _ReferenceReplacer(ast.NodeTransformer):
|
class _ReferenceReplacer(ast.NodeTransformer):
|
||||||
def __init__(self, rm, obj, funcname):
|
def __init__(self, rm, obj, funcname):
|
||||||
self.rm = rm
|
self.rm = rm
|
||||||
self.obj = obj
|
self.obj = obj
|
||||||
self.funcname = funcname
|
self.funcname = funcname
|
||||||
|
self.module = inspect.getmodule(self.obj)
|
||||||
|
|
||||||
def visit_ref(self, node):
|
def visit_ref(self, node):
|
||||||
return self.rm.get(self.obj, self.funcname, node)
|
return self.rm.get(self.obj, self.funcname, node)
|
||||||
|
@ -104,11 +114,29 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||||
visit_Attribute = visit_ref
|
visit_Attribute = visit_ref
|
||||||
visit_Subscript = visit_ref
|
visit_Subscript = visit_ref
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
calldict = self.rm.get_constants(self.obj, self.funcname)
|
||||||
|
calldict.update(self.module.__dict__)
|
||||||
|
func = eval_ast(node.func, calldict)
|
||||||
|
|
||||||
|
if inspect.getmodule(func) is builtins:
|
||||||
|
new_func = ast.Name(func.__name__, ast.Load())
|
||||||
|
new_args = [self.visit(arg) for arg in node.args]
|
||||||
|
return ast.Call(func=new_func, args=new_args,
|
||||||
|
keywords=[], starargs=None, kwargs=None)
|
||||||
|
elif hasattr(func, "k_function_info"):
|
||||||
|
print(func.k_function_info)
|
||||||
|
# TODO: inline called kernel
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
class _ListReadOnlyParams(ast.NodeVisitor):
|
class _ListReadOnlyParams(ast.NodeVisitor):
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
if hasattr(self, "read_only_params"):
|
if hasattr(self, "read_only_params"):
|
||||||
raise ValueError("More than one function definition")
|
raise ValueError("More than one function definition")
|
||||||
self.read_only_params = {arg.arg for arg in node.args.args}
|
self.read_only_params = {arg.arg for arg in node.args.args}
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
def visit_Name(self, node):
|
def visit_Name(self, node):
|
||||||
if isinstance(node.ctx, ast.Store):
|
if isinstance(node.ctx, ast.Store):
|
||||||
|
@ -122,23 +150,35 @@ def _list_read_only_params(funcdef):
|
||||||
lrp.visit(funcdef)
|
lrp.visit(funcdef)
|
||||||
return lrp.read_only_params
|
return lrp.read_only_params
|
||||||
|
|
||||||
def inline(k_function, k_args, k_kwargs, rm=None):
|
def _initialize_function_params(funcdef, k_args, k_kwargs, rm):
|
||||||
if rm is None:
|
|
||||||
rm = _ReferenceManager()
|
|
||||||
|
|
||||||
funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
|
|
||||||
obj = k_args[0]
|
obj = k_args[0]
|
||||||
funcname = funcdef.name
|
funcname = funcdef.name
|
||||||
|
param_init = []
|
||||||
rop = _list_read_only_params(funcdef)
|
rop = _list_read_only_params(funcdef)
|
||||||
for arg_ast, arg_value in zip(funcdef.args.args, k_args):
|
for arg_ast, arg_value in zip(funcdef.args.args, k_args):
|
||||||
arg_name = arg_ast.arg
|
arg_name = arg_ast.arg
|
||||||
if arg_name in rop:
|
if arg_name in rop:
|
||||||
rm.set(obj, funcname, arg_name, arg_value)
|
rm.set(obj, funcname, arg_name, arg_value)
|
||||||
|
else:
|
||||||
|
target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store()))
|
||||||
|
value = _value_to_ast(arg_value)
|
||||||
|
param_init.append(ast.Assign(targets=[target], value=value))
|
||||||
|
return param_init
|
||||||
|
|
||||||
|
def inline(k_function, k_args, k_kwargs, rm=None):
|
||||||
|
if rm is None:
|
||||||
|
rm = _ReferenceManager()
|
||||||
|
|
||||||
|
funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
|
||||||
|
|
||||||
|
param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm)
|
||||||
|
|
||||||
|
obj = k_args[0]
|
||||||
|
funcname = funcdef.name
|
||||||
rr = _ReferenceReplacer(rm, obj, funcname)
|
rr = _ReferenceReplacer(rm, obj, funcname)
|
||||||
for stmt in funcdef.body:
|
for stmt in funcdef.body:
|
||||||
rr.visit(stmt)
|
rr.visit(stmt)
|
||||||
|
|
||||||
print(ast.dump(funcdef))
|
funcdef.body[0:0] = param_init
|
||||||
unparse.Unparser(funcdef)
|
|
||||||
|
unparse.Unparser(funcdef.body)
|
||||||
|
|
Loading…
Reference in New Issue