forked from M-Labs/artiq
transforms/inline: support user-defined context managers
This commit is contained in:
parent
5c08423b29
commit
fdc406f062
|
@ -3,7 +3,6 @@ import textwrap
|
|||
import ast
|
||||
import types
|
||||
import builtins
|
||||
from copy import copy
|
||||
from fractions import Fraction
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
@ -199,8 +198,10 @@ class Function:
|
|||
|
||||
# This is ast.NodeTransformer.generic_visit from CPython, modified
|
||||
# to update self._insertion_point.
|
||||
def code_generic_visit(self, node):
|
||||
def code_generic_visit(self, node, exclude_fields=set()):
|
||||
for field, old_value in ast.iter_fields(node):
|
||||
if field in exclude_fields:
|
||||
continue
|
||||
old_value = getattr(node, field, None)
|
||||
if isinstance(old_value, list):
|
||||
prev_insertion_point = self._insertion_point
|
||||
|
@ -378,6 +379,60 @@ class Function:
|
|||
self.code_generic_visit(node)
|
||||
return node
|
||||
|
||||
def get_user_ctxm(self, context_expr):
|
||||
try:
|
||||
ctxm = self.static_visit(context_expr)
|
||||
except:
|
||||
# this also catches watchdog()
|
||||
return None
|
||||
else:
|
||||
if (ctxm is core_language.sequential
|
||||
or ctxm is core_language.parallel):
|
||||
return None
|
||||
return ctxm
|
||||
|
||||
def code_visit_With(self, node):
|
||||
if len(node.items) != 1:
|
||||
raise NotImplementedError
|
||||
item = node.items[0]
|
||||
if item.optional_vars is not None:
|
||||
raise NotImplementedError
|
||||
ctxm = self.get_user_ctxm(item.context_expr)
|
||||
if ctxm is None:
|
||||
self.code_generic_visit(node)
|
||||
return node
|
||||
|
||||
# user context manager
|
||||
self.code_generic_visit(node, {"items"})
|
||||
if (not hasattr(ctxm, "__enter__")
|
||||
or not hasattr(ctxm.__enter__, "k_function_info")):
|
||||
raise NotImplementedError
|
||||
enter = get_inline(self.core,
|
||||
self.attribute_namespace, self.in_use_names,
|
||||
None, self.mappers,
|
||||
ctxm.__enter__.k_function_info.k_function,
|
||||
[ctxm], dict())
|
||||
if (not hasattr(ctxm, "__exit__")
|
||||
or not hasattr(ctxm.__exit__, "k_function_info")):
|
||||
raise NotImplementedError
|
||||
exit = get_inline(self.core,
|
||||
self.attribute_namespace, self.in_use_names,
|
||||
None, self.mappers,
|
||||
ctxm.__exit__.k_function_info.k_function,
|
||||
[ctxm, None, None, None], dict())
|
||||
try_stmt = ast.copy_location(
|
||||
ast.Try(body=node.body,
|
||||
handlers=[],
|
||||
orelse=[],
|
||||
finalbody=exit.body), node)
|
||||
return ast.copy_location(
|
||||
ast.With(
|
||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||
ctx=ast.Load()),
|
||||
optional_vars=None)],
|
||||
body=enter.body + [try_stmt]),
|
||||
node)
|
||||
|
||||
def code_visit_FunctionDef(self, node):
|
||||
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
|
||||
kw_defaults=[], kwarg=None, defaults=[])
|
||||
|
@ -470,7 +525,9 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
|||
def inline(core, k_function, k_args, k_kwargs, with_attr_writeback):
|
||||
# OrderedDict prevents non-determinism in attribute init
|
||||
attribute_namespace = OrderedDict()
|
||||
in_use_names = copy(embeddable_func_names)
|
||||
# NOTE: in_use_names will be mutated. Do not mutate embeddable_func_names!
|
||||
in_use_names = embeddable_func_names | {"sequential", "parallel",
|
||||
"watchdog"}
|
||||
mappers = types.SimpleNamespace(
|
||||
rpc=HostObjectMapper(),
|
||||
exception=HostObjectMapper(core_language.first_user_eid)
|
||||
|
|
Loading…
Reference in New Issue