diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index 0804578d3..dd67e0c69 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -3,7 +3,6 @@ import textwrap import ast import types import builtins -from copy import copy from fractions import Fraction from collections import OrderedDict from functools import partial @@ -199,8 +198,10 @@ class Function: # This is ast.NodeTransformer.generic_visit from CPython, modified # to update self._insertion_point. - def code_generic_visit(self, node): + def code_generic_visit(self, node, exclude_fields=set()): for field, old_value in ast.iter_fields(node): + if field in exclude_fields: + continue old_value = getattr(node, field, None) if isinstance(old_value, list): prev_insertion_point = self._insertion_point @@ -378,6 +379,60 @@ class Function: self.code_generic_visit(node) return node + def get_user_ctxm(self, context_expr): + try: + ctxm = self.static_visit(context_expr) + except: + # this also catches watchdog() + return None + else: + if (ctxm is core_language.sequential + or ctxm is core_language.parallel): + return None + return ctxm + + def code_visit_With(self, node): + if len(node.items) != 1: + raise NotImplementedError + item = node.items[0] + if item.optional_vars is not None: + raise NotImplementedError + ctxm = self.get_user_ctxm(item.context_expr) + if ctxm is None: + self.code_generic_visit(node) + return node + + # user context manager + self.code_generic_visit(node, {"items"}) + if (not hasattr(ctxm, "__enter__") + or not hasattr(ctxm.__enter__, "k_function_info")): + raise NotImplementedError + enter = get_inline(self.core, + self.attribute_namespace, self.in_use_names, + None, self.mappers, + ctxm.__enter__.k_function_info.k_function, + [ctxm], dict()) + if (not hasattr(ctxm, "__exit__") + or not hasattr(ctxm.__exit__, "k_function_info")): + raise NotImplementedError + exit = get_inline(self.core, + self.attribute_namespace, self.in_use_names, + None, self.mappers, + ctxm.__exit__.k_function_info.k_function, + [ctxm, None, None, None], dict()) + try_stmt = ast.copy_location( + ast.Try(body=node.body, + handlers=[], + orelse=[], + finalbody=exit.body), node) + return ast.copy_location( + ast.With( + items=[ast.withitem(context_expr=ast.Name(id="sequential", + ctx=ast.Load()), + optional_vars=None)], + body=enter.body + [try_stmt]), + node) + def code_visit_FunctionDef(self, node): node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]) @@ -470,7 +525,9 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node): def inline(core, k_function, k_args, k_kwargs, with_attr_writeback): # OrderedDict prevents non-determinism in attribute init attribute_namespace = OrderedDict() - in_use_names = copy(embeddable_func_names) + # NOTE: in_use_names will be mutated. Do not mutate embeddable_func_names! + in_use_names = embeddable_func_names | {"sequential", "parallel", + "watchdog"} mappers = types.SimpleNamespace( rpc=HostObjectMapper(), exception=HostObjectMapper(core_language.first_user_eid)