diff --git a/artiq/compiler/inline.py b/artiq/compiler/inline.py index 991861a8d..516a0447e 100644 --- a/artiq/compiler/inline.py +++ b/artiq/compiler/inline.py @@ -31,7 +31,7 @@ _UserVariable = namedtuple("_UserVariable", "name") class _ReferenceManager: def __init__(self): - # (id(obj), funcname, local) -> _UserVariable(name) / constant_object + # (id(obj), funcname, local) -> _UserVariable(name) / ast / constant_object self.to_inlined = dict() # inlined_name -> use_count self.use_count = dict() @@ -71,6 +71,9 @@ class _ReferenceManager: else: if isinstance(ival, _UserVariable): return ast.Name(ival.name, ref.ctx) + elif isinstance(ival, ast.AST): + assert(not store) + return ival else: if store: raise NotImplementedError("Cannot turn object into user variable") @@ -95,7 +98,7 @@ class _ReferenceManager: in self.to_inlined.items() if objid == id(r_obj) and funcname == r_funcname - and not isinstance(v, _UserVariable)} + and not isinstance(v, (_UserVariable, ast.AST))} _embeddable_calls = { units.Quantity, @@ -122,21 +125,39 @@ class _ReferenceReplacer(ast.NodeTransformer): calldict.update(self.module.__dict__) func = eval_ast(node.func, calldict) + new_args = [self.visit(arg) for arg in node.args] + if func in _embeddable_calls: new_func = ast.Name(func.__name__, ast.Load()) - new_args = [self.visit(arg) for arg in node.args] return ast.Call(func=new_func, args=new_args, keywords=[], starargs=None, kwargs=None) elif hasattr(func, "k_function_info"): - print(func.k_function_info) - # TODO: inline called kernel - return node + args = [func.__self__] + new_args + inlined, _ = inline(func.k_function_info.k_function, args, dict(), self.rm) + return inlined else: args = [ast.Str("rpc"), ast.Num(self.rm.rpc_map[func])] - args += [self.visit(arg) for arg in node.args] + args += new_args return ast.Call(func=ast.Name("syscall", ast.Load()), args=args, keywords=[], starargs=None, kwargs=None) + def visit_Expr(self, node): + if isinstance(node.value, ast.Call): + r = self.visit_Call(node.value) + if isinstance(r, list): + return r + else: + node.value = r + return node + else: + self.generic_visit(node) + return node + + def visit_FunctionDef(self, node): + node.decorator_list = [] + self.generic_visit(node) + return node + class _ListReadOnlyParams(ast.NodeVisitor): def visit_FunctionDef(self, node): if hasattr(self, "read_only_params"): @@ -182,8 +203,7 @@ def inline(k_function, k_args, k_kwargs, rm=None): obj = k_args[0] funcname = funcdef.name rr = _ReferenceReplacer(rm, obj, funcname) - for stmt in funcdef.body: - rr.visit(stmt) + rr.visit(funcdef) funcdef.body[0:0] = param_init