forked from M-Labs/artiq
inline: basic function inlining
This commit is contained in:
parent
d87b207b8a
commit
08ab99d33e
|
@ -31,7 +31,7 @@ _UserVariable = namedtuple("_UserVariable", "name")
|
||||||
|
|
||||||
class _ReferenceManager:
|
class _ReferenceManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# (id(obj), funcname, local) -> _UserVariable(name) / constant_object
|
# (id(obj), funcname, local) -> _UserVariable(name) / ast / constant_object
|
||||||
self.to_inlined = dict()
|
self.to_inlined = dict()
|
||||||
# inlined_name -> use_count
|
# inlined_name -> use_count
|
||||||
self.use_count = dict()
|
self.use_count = dict()
|
||||||
|
@ -71,6 +71,9 @@ class _ReferenceManager:
|
||||||
else:
|
else:
|
||||||
if isinstance(ival, _UserVariable):
|
if isinstance(ival, _UserVariable):
|
||||||
return ast.Name(ival.name, ref.ctx)
|
return ast.Name(ival.name, ref.ctx)
|
||||||
|
elif isinstance(ival, ast.AST):
|
||||||
|
assert(not store)
|
||||||
|
return ival
|
||||||
else:
|
else:
|
||||||
if store:
|
if store:
|
||||||
raise NotImplementedError("Cannot turn object into user variable")
|
raise NotImplementedError("Cannot turn object into user variable")
|
||||||
|
@ -95,7 +98,7 @@ class _ReferenceManager:
|
||||||
in self.to_inlined.items()
|
in self.to_inlined.items()
|
||||||
if objid == id(r_obj)
|
if objid == id(r_obj)
|
||||||
and funcname == r_funcname
|
and funcname == r_funcname
|
||||||
and not isinstance(v, _UserVariable)}
|
and not isinstance(v, (_UserVariable, ast.AST))}
|
||||||
|
|
||||||
_embeddable_calls = {
|
_embeddable_calls = {
|
||||||
units.Quantity,
|
units.Quantity,
|
||||||
|
@ -122,21 +125,39 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
||||||
calldict.update(self.module.__dict__)
|
calldict.update(self.module.__dict__)
|
||||||
func = eval_ast(node.func, calldict)
|
func = eval_ast(node.func, calldict)
|
||||||
|
|
||||||
|
new_args = [self.visit(arg) for arg in node.args]
|
||||||
|
|
||||||
if func in _embeddable_calls:
|
if func in _embeddable_calls:
|
||||||
new_func = ast.Name(func.__name__, ast.Load())
|
new_func = ast.Name(func.__name__, ast.Load())
|
||||||
new_args = [self.visit(arg) for arg in node.args]
|
|
||||||
return ast.Call(func=new_func, args=new_args,
|
return ast.Call(func=new_func, args=new_args,
|
||||||
keywords=[], starargs=None, kwargs=None)
|
keywords=[], starargs=None, kwargs=None)
|
||||||
elif hasattr(func, "k_function_info"):
|
elif hasattr(func, "k_function_info"):
|
||||||
print(func.k_function_info)
|
args = [func.__self__] + new_args
|
||||||
# TODO: inline called kernel
|
inlined, _ = inline(func.k_function_info.k_function, args, dict(), self.rm)
|
||||||
return node
|
return inlined
|
||||||
else:
|
else:
|
||||||
args = [ast.Str("rpc"), ast.Num(self.rm.rpc_map[func])]
|
args = [ast.Str("rpc"), ast.Num(self.rm.rpc_map[func])]
|
||||||
args += [self.visit(arg) for arg in node.args]
|
args += new_args
|
||||||
return ast.Call(func=ast.Name("syscall", ast.Load()),
|
return ast.Call(func=ast.Name("syscall", ast.Load()),
|
||||||
args=args, keywords=[], starargs=None, kwargs=None)
|
args=args, keywords=[], starargs=None, kwargs=None)
|
||||||
|
|
||||||
|
def visit_Expr(self, node):
|
||||||
|
if isinstance(node.value, ast.Call):
|
||||||
|
r = self.visit_Call(node.value)
|
||||||
|
if isinstance(r, list):
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
node.value = r
|
||||||
|
return node
|
||||||
|
else:
|
||||||
|
self.generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
node.decorator_list = []
|
||||||
|
self.generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
class _ListReadOnlyParams(ast.NodeVisitor):
|
class _ListReadOnlyParams(ast.NodeVisitor):
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
if hasattr(self, "read_only_params"):
|
if hasattr(self, "read_only_params"):
|
||||||
|
@ -182,8 +203,7 @@ def inline(k_function, k_args, k_kwargs, rm=None):
|
||||||
obj = k_args[0]
|
obj = k_args[0]
|
||||||
funcname = funcdef.name
|
funcname = funcdef.name
|
||||||
rr = _ReferenceReplacer(rm, obj, funcname)
|
rr = _ReferenceReplacer(rm, obj, funcname)
|
||||||
for stmt in funcdef.body:
|
rr.visit(funcdef)
|
||||||
rr.visit(stmt)
|
|
||||||
|
|
||||||
funcdef.body[0:0] = param_init
|
funcdef.body[0:0] = param_init
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue