forked from M-Labs/artiq
inline: basic function inlining
This commit is contained in:
parent
d87b207b8a
commit
08ab99d33e
|
@ -31,7 +31,7 @@ _UserVariable = namedtuple("_UserVariable", "name")
|
|||
|
||||
class _ReferenceManager:
|
||||
def __init__(self):
|
||||
# (id(obj), funcname, local) -> _UserVariable(name) / constant_object
|
||||
# (id(obj), funcname, local) -> _UserVariable(name) / ast / constant_object
|
||||
self.to_inlined = dict()
|
||||
# inlined_name -> use_count
|
||||
self.use_count = dict()
|
||||
|
@ -71,6 +71,9 @@ class _ReferenceManager:
|
|||
else:
|
||||
if isinstance(ival, _UserVariable):
|
||||
return ast.Name(ival.name, ref.ctx)
|
||||
elif isinstance(ival, ast.AST):
|
||||
assert(not store)
|
||||
return ival
|
||||
else:
|
||||
if store:
|
||||
raise NotImplementedError("Cannot turn object into user variable")
|
||||
|
@ -95,7 +98,7 @@ class _ReferenceManager:
|
|||
in self.to_inlined.items()
|
||||
if objid == id(r_obj)
|
||||
and funcname == r_funcname
|
||||
and not isinstance(v, _UserVariable)}
|
||||
and not isinstance(v, (_UserVariable, ast.AST))}
|
||||
|
||||
_embeddable_calls = {
|
||||
units.Quantity,
|
||||
|
@ -122,21 +125,39 @@ class _ReferenceReplacer(ast.NodeTransformer):
|
|||
calldict.update(self.module.__dict__)
|
||||
func = eval_ast(node.func, calldict)
|
||||
|
||||
new_args = [self.visit(arg) for arg in node.args]
|
||||
|
||||
if func in _embeddable_calls:
|
||||
new_func = ast.Name(func.__name__, ast.Load())
|
||||
new_args = [self.visit(arg) for arg in node.args]
|
||||
return ast.Call(func=new_func, args=new_args,
|
||||
keywords=[], starargs=None, kwargs=None)
|
||||
elif hasattr(func, "k_function_info"):
|
||||
print(func.k_function_info)
|
||||
# TODO: inline called kernel
|
||||
return node
|
||||
args = [func.__self__] + new_args
|
||||
inlined, _ = inline(func.k_function_info.k_function, args, dict(), self.rm)
|
||||
return inlined
|
||||
else:
|
||||
args = [ast.Str("rpc"), ast.Num(self.rm.rpc_map[func])]
|
||||
args += [self.visit(arg) for arg in node.args]
|
||||
args += new_args
|
||||
return ast.Call(func=ast.Name("syscall", ast.Load()),
|
||||
args=args, keywords=[], starargs=None, kwargs=None)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
if isinstance(node.value, ast.Call):
|
||||
r = self.visit_Call(node.value)
|
||||
if isinstance(r, list):
|
||||
return r
|
||||
else:
|
||||
node.value = r
|
||||
return node
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
node.decorator_list = []
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
class _ListReadOnlyParams(ast.NodeVisitor):
|
||||
def visit_FunctionDef(self, node):
|
||||
if hasattr(self, "read_only_params"):
|
||||
|
@ -182,8 +203,7 @@ def inline(k_function, k_args, k_kwargs, rm=None):
|
|||
obj = k_args[0]
|
||||
funcname = funcdef.name
|
||||
rr = _ReferenceReplacer(rm, obj, funcname)
|
||||
for stmt in funcdef.body:
|
||||
rr.visit(stmt)
|
||||
rr.visit(funcdef)
|
||||
|
||||
funcdef.body[0:0] = param_init
|
||||
|
||||
|
|
Loading…
Reference in New Issue