diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index ed9fab890..4c3b661e5 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -33,71 +33,91 @@ class _HostObjectMapper: _UserVariable = namedtuple("_UserVariable", "name") +def _is_kernel_attr(value, attr): + return hasattr(value, "kernel_attr") and attr in value.kernel_attr.split() + + class _ReferenceManager: def __init__(self): - # (id(obj), func_name, local_name) or (id(obj), kernel_attr_name) - # -> _UserVariable(name) / ast / constant_object - self.to_inlined = dict() - # inlined_name -> use_count - self.use_count = dict() self.rpc_mapper = _HostObjectMapper() self.exception_mapper = _HostObjectMapper(core_language.first_user_eid) self.kernel_attr_init = [] + # (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name) + # -> _UserVariable(name) / ast / constant_object + self._to_inlined = dict() + # inlined_name -> use_count + self._use_count = dict() # reserved names for kg in core_language.kernel_globals: - self.use_count[kg] = 1 + self._use_count[kg] = 1 for name in ("int", "round", "int64", "round64", "float", "array", "range", "Fraction", "Quantity", "EncodedException"): - self.use_count[name] = 1 + self._use_count[name] = 1 + + # node_or_value can be a AST node, used to inline function parameter values + # that can be simplified later through constant folding. + def register_replace(self, obj, func_name, ref_name, node_or_value): + self._to_inlined[(id(obj), func_name, ref_name)] = node_or_value def new_name(self, base_name): if base_name[-1].isdigit(): base_name += "_" - if base_name in self.use_count: - r = base_name + str(self.use_count[base_name]) - self.use_count[base_name] += 1 + if base_name in self._use_count: + r = base_name + str(self._use_count[base_name]) + self._use_count[base_name] += 1 return r else: - self.use_count[base_name] = 1 + self._use_count[base_name] = 1 return base_name - def get(self, obj, func_name, ref): - if isinstance(ref, ast.Name): - key = (id(obj), func_name, ref.id) - try: - return self.to_inlined[key] - except KeyError: - if isinstance(ref.ctx, ast.Store): - ival = _UserVariable(self.new_name(ref.id)) - self.to_inlined[key] = ival - return ival - else: - try: - return inspect.getmodule(obj).__dict__[ref.id] - except KeyError: - return getattr(builtins, ref.id) - elif isinstance(ref, ast.Attribute): - target = self.get(obj, func_name, ref.value) - if hasattr(target, "kernel_attr") and ref.attr in target.kernel_attr.split(): - key = (id(target), ref.attr) - try: - ival = self.to_inlined[key] - assert(isinstance(ival, _UserVariable)) - except KeyError: - iname = self.new_name(ref.attr) - ival = _UserVariable(iname) - self.to_inlined[key] = ival - a = value_to_ast(getattr(target, ref.attr)) - if a is None: - raise NotImplementedError( - "Cannot represent initial value" - " of kernel attribute") - self.kernel_attr_init.append(ast.Assign( - [ast.Name(iname, ast.Store())], a)) + def resolve_name(self, obj, func_name, ref_name, store): + key = (id(obj), func_name, ref_name) + try: + return self._to_inlined[key] + except KeyError: + if store: + ival = _UserVariable(self.new_name(ref_name)) + self._to_inlined[key] = ival return ival else: - return getattr(target, ref.attr) + try: + return inspect.getmodule(obj).__dict__[ref_name] + except KeyError: + return getattr(builtins, ref_name) + + def resolve_attr(self, value, attr): + if _is_kernel_attr(value, attr): + key = (id(value), attr) + try: + ival = self._to_inlined[key] + assert(isinstance(ival, _UserVariable)) + except KeyError: + iname = self.new_name(attr) + ival = _UserVariable(iname) + self._to_inlined[key] = ival + a = value_to_ast(getattr(value, attr)) + if a is None: + raise NotImplementedError( + "Cannot represent initial value" + " of kernel attribute") + self.kernel_attr_init.append(ast.Assign( + [ast.Name(iname, ast.Store())], a)) + return ival + else: + return getattr(value, attr) + + def resolve_constant(self, obj, func_name, node): + if isinstance(node, ast.Name): + c = self.resolve_name(obj, func_name, node.id, False) + if isinstance(c, (_UserVariable, ast.AST)): + raise ValueError("Not a constant") + return c + elif isinstance(node, ast.Attribute): + value = self.resolve_constant(obj, func_name, node.value) + if _is_kernel_attr(value, node.attr): + raise ValueError("Not a constant") + return getattr(value, node.attr) else: raise NotImplementedError @@ -156,9 +176,9 @@ class _ReferenceReplacer(ast.NodeVisitor): setattr(node, field, new_node) return node - def visit_ref(self, node): + def visit_Name(self, node): store = isinstance(node.ctx, ast.Store) - ival = self.rm.get(self.obj, self.func_name, node) + ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store) if isinstance(ival, _UserVariable): newnode = ast.Name(ival.name, node.ctx) elif isinstance(ival, ast.AST): @@ -175,11 +195,34 @@ class _ReferenceReplacer(ast.NodeVisitor): "Cannot represent inlined value") return ast.copy_location(newnode, node) - visit_Name = visit_ref - visit_Attribute = visit_ref + def _resolve_attribute(self, node): + if isinstance(node, ast.Name): + ival = self.rm.resolve_name(self.obj, self.func_name, node.id, False) + if isinstance(ival, _UserVariable): + return ast.copy_location(ast.Name(ival.name, ast.Load()), node) + else: + return ival + elif isinstance(node, ast.Attribute): + value = self._resolve_attribute(node.value) + if isinstance(value, ast.AST): + node.value = value + return node + else: + return self.rm.resolve_attr(value, node.attr) + else: + return self.visit(node) + + def visit_Attribute(self, node): + ival = self._resolve_attribute(node) + if isinstance(ival, ast.AST): + return ival + elif isinstance(ival, _UserVariable): + return ast.copy_location(ast.Name(ival.name, ast.Load()), node) + else: + return value_to_ast(ival) def visit_Call(self, node): - func = self.rm.get(self.obj, self.func_name, node.func) + func = self.rm.resolve_constant(self.obj, self.func_name, node.func) new_args = [self.visit(arg) for arg in node.args] if _is_embeddable(func): @@ -240,7 +283,7 @@ class _ReferenceReplacer(ast.NodeVisitor): return node def _encode_exception(self, e): - exception_class = self.rm.get(self.obj, self.func_name, e) + exception_class = self.rm.resolve_constant(self.obj, self.func_name, e) if not inspect.isclass(exception_class): raise NotImplementedError("Exception type must be a class") if issubclass(exception_class, core_language.RuntimeException): @@ -301,9 +344,10 @@ def _initialize_function_params(func_def, k_args, k_kwargs, rm): for arg_ast, arg_value in zip(func_def.args.args, k_args): arg_name = arg_ast.arg if arg_name in rop: - rm.to_inlined[(id(obj), func_name, arg_name)] = arg_value + rm.register_replace(obj, func_name, arg_name, arg_value) else: - target = rm.get(obj, func_name, ast.Name(arg_name, ast.Store())) + uservar = rm.resolve_name(obj, func_name, arg_name, True) + target = ast.Name(uservar.name, ast.Store()) value = value_to_ast(arg_value) param_init.append(ast.Assign(targets=[target], value=value)) return param_init