inline: basic function inlining

pull/231/head
Sebastien Bourdeauducq 2014-06-17 18:37:51 +02:00
parent d87b207b8a
commit 08ab99d33e
1 changed files with 29 additions and 9 deletions

View File

@ -31,7 +31,7 @@ _UserVariable = namedtuple("_UserVariable", "name")
class _ReferenceManager:
def __init__(self):
# (id(obj), funcname, local) -> _UserVariable(name) / constant_object
# (id(obj), funcname, local) -> _UserVariable(name) / ast / constant_object
self.to_inlined = dict()
# inlined_name -> use_count
self.use_count = dict()
@ -71,6 +71,9 @@ class _ReferenceManager:
else:
if isinstance(ival, _UserVariable):
return ast.Name(ival.name, ref.ctx)
elif isinstance(ival, ast.AST):
assert(not store)
return ival
else:
if store:
raise NotImplementedError("Cannot turn object into user variable")
@ -95,7 +98,7 @@ class _ReferenceManager:
in self.to_inlined.items()
if objid == id(r_obj)
and funcname == r_funcname
and not isinstance(v, _UserVariable)}
and not isinstance(v, (_UserVariable, ast.AST))}
_embeddable_calls = {
units.Quantity,
@ -122,21 +125,39 @@ class _ReferenceReplacer(ast.NodeTransformer):
calldict.update(self.module.__dict__)
func = eval_ast(node.func, calldict)
new_args = [self.visit(arg) for arg in node.args]
if func in _embeddable_calls:
new_func = ast.Name(func.__name__, ast.Load())
new_args = [self.visit(arg) for arg in node.args]
return ast.Call(func=new_func, args=new_args,
keywords=[], starargs=None, kwargs=None)
elif hasattr(func, "k_function_info"):
print(func.k_function_info)
# TODO: inline called kernel
return node
args = [func.__self__] + new_args
inlined, _ = inline(func.k_function_info.k_function, args, dict(), self.rm)
return inlined
else:
args = [ast.Str("rpc"), ast.Num(self.rm.rpc_map[func])]
args += [self.visit(arg) for arg in node.args]
args += new_args
return ast.Call(func=ast.Name("syscall", ast.Load()),
args=args, keywords=[], starargs=None, kwargs=None)
def visit_Expr(self, node):
if isinstance(node.value, ast.Call):
r = self.visit_Call(node.value)
if isinstance(r, list):
return r
else:
node.value = r
return node
else:
self.generic_visit(node)
return node
def visit_FunctionDef(self, node):
node.decorator_list = []
self.generic_visit(node)
return node
class _ListReadOnlyParams(ast.NodeVisitor):
def visit_FunctionDef(self, node):
if hasattr(self, "read_only_params"):
@ -182,8 +203,7 @@ def inline(k_function, k_args, k_kwargs, rm=None):
obj = k_args[0]
funcname = funcdef.name
rr = _ReferenceReplacer(rm, obj, funcname)
for stmt in funcdef.body:
rr.visit(stmt)
rr.visit(funcdef)
funcdef.body[0:0] = param_init