diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index abc5a06f1..c3151b54c 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -6,7 +6,7 @@ import ast import builtins from copy import deepcopy -from artiq.transforms.tools import eval_ast, value_to_ast +from artiq.transforms.tools import eval_ast, value_to_ast, NotASTRepresentable from artiq.language import core as core_language from artiq.language import units @@ -45,7 +45,7 @@ class _ReferenceManager: self.kernel_attr_init = [] # (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name) - # -> _UserVariable(name) / ast / constant_object + # -> _UserVariable(name) / complex object self._to_inlined = dict() # inlined_name -> use_count self._use_count = dict() @@ -57,10 +57,15 @@ class _ReferenceManager: "range"): self._use_count[name] = 1 - # node_or_value can be a AST node, used to inline function parameter values - # that can be simplified later through constant folding. - def register_replace(self, obj, func_name, ref_name, node_or_value): - self._to_inlined[(id(obj), func_name, ref_name)] = node_or_value + # Complex objects in the namespace of functions can be used in two ways: + # 1. Calling a method on them (which gets inlined or RPCd) + # 2. Getting or setting attributes (which are turned into local variables) + # They are needed to implement "self", which is the only supported use + # case. + def register_complex_object(self, obj, func_name, ref_name, + complex_object): + assert(not isinstance(complex_object, ast.AST)) + self._to_inlined[(id(obj), func_name, ref_name)] = complex_object def new_name(self, base_name): if base_name[-1].isdigit(): @@ -112,7 +117,7 @@ class _ReferenceManager: def resolve_constant(self, obj, func_name, node): if isinstance(node, ast.Name): c = self.resolve_name(obj, func_name, node.id, False) - if isinstance(c, (_UserVariable, ast.AST)): + if isinstance(c, _UserVariable): raise ValueError("Not a constant") return c elif isinstance(node, ast.Attribute): @@ -192,18 +197,10 @@ class _ReferenceReplacer(ast.NodeVisitor): ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store) if isinstance(ival, _UserVariable): newnode = ast.Name(ival.name, node.ctx) - elif isinstance(ival, ast.AST): - assert(not store) - newnode = deepcopy(ival) else: if store: - raise NotImplementedError( - "Cannot turn object into user variable") - else: - newnode = value_to_ast(ival) - if newnode is None: - raise NotImplementedError( - "Cannot represent inlined value") + raise NotImplementedError("Cannot assign to this object") + newnode = value_to_ast(ival) return ast.copy_location(newnode, node) def _resolve_attribute(self, node): @@ -325,40 +322,23 @@ class _ReferenceReplacer(ast.NodeVisitor): return node -class _ListReadOnlyParams(ast.NodeVisitor): - def visit_FunctionDef(self, node): - if hasattr(self, "read_only_params"): - raise ValueError("More than one function definition") - self.read_only_params = {arg.arg for arg in node.args.args} - self.generic_visit(node) - - def visit_Name(self, node): - if isinstance(node.ctx, ast.Store): - try: - self.read_only_params.remove(node.id) - except KeyError: - pass - - -def _list_read_only_params(func_def): - lrp = _ListReadOnlyParams() - lrp.visit(func_def) - return lrp.read_only_params - - def _initialize_function_params(func_def, k_args, k_kwargs, rm): obj = k_args[0] func_name = func_def.name param_init = [] - rop = _list_read_only_params(func_def) for arg_ast, arg_value in zip(func_def.args.args, k_args): - arg_name = arg_ast.arg - if arg_name in rop: - rm.register_replace(obj, func_name, arg_name, arg_value) + if isinstance(arg_value, ast.AST): + value = arg_value else: - uservar = rm.resolve_name(obj, func_name, arg_name, True) + try: + value = value_to_ast(arg_value) + except NotASTRepresentable: + value = None + if value is None: + rm.register_complex_object(obj, func_name, arg_ast.arg, arg_value) + else: + uservar = rm.resolve_name(obj, func_name, arg_ast.arg, True) target = ast.Name(uservar.name, ast.Store()) - value = value_to_ast(arg_value) param_init.append(ast.Assign(targets=[target], value=value)) return param_init diff --git a/artiq/transforms/tools.py b/artiq/transforms/tools.py index e6a8baa7e..8a7b4d08e 100644 --- a/artiq/transforms/tools.py +++ b/artiq/transforms/tools.py @@ -13,6 +13,10 @@ def eval_ast(expr, symdict=dict()): return eval(code, symdict) +class NotASTRepresentable(Exception): + pass + + def value_to_ast(value): if isinstance(value, core_language.int64): # must be before int return ast.Call( @@ -41,7 +45,7 @@ def value_to_ast(value): func=ast.Name("Quantity", ast.Load()), args=[value_to_ast(value.amount), ast.Str(value.unit)], keywords=[], starargs=None, kwargs=None) - return None + raise NotASTRepresentable(str(value)) class NotConstant(Exception):