forked from M-Labs/artiq
transforms/inline: support user-defined context managers
This commit is contained in:
parent
5c08423b29
commit
fdc406f062
|
@ -3,7 +3,6 @@ import textwrap
|
||||||
import ast
|
import ast
|
||||||
import types
|
import types
|
||||||
import builtins
|
import builtins
|
||||||
from copy import copy
|
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -199,8 +198,10 @@ class Function:
|
||||||
|
|
||||||
# This is ast.NodeTransformer.generic_visit from CPython, modified
|
# This is ast.NodeTransformer.generic_visit from CPython, modified
|
||||||
# to update self._insertion_point.
|
# to update self._insertion_point.
|
||||||
def code_generic_visit(self, node):
|
def code_generic_visit(self, node, exclude_fields=set()):
|
||||||
for field, old_value in ast.iter_fields(node):
|
for field, old_value in ast.iter_fields(node):
|
||||||
|
if field in exclude_fields:
|
||||||
|
continue
|
||||||
old_value = getattr(node, field, None)
|
old_value = getattr(node, field, None)
|
||||||
if isinstance(old_value, list):
|
if isinstance(old_value, list):
|
||||||
prev_insertion_point = self._insertion_point
|
prev_insertion_point = self._insertion_point
|
||||||
|
@ -378,6 +379,60 @@ class Function:
|
||||||
self.code_generic_visit(node)
|
self.code_generic_visit(node)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def get_user_ctxm(self, context_expr):
|
||||||
|
try:
|
||||||
|
ctxm = self.static_visit(context_expr)
|
||||||
|
except:
|
||||||
|
# this also catches watchdog()
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if (ctxm is core_language.sequential
|
||||||
|
or ctxm is core_language.parallel):
|
||||||
|
return None
|
||||||
|
return ctxm
|
||||||
|
|
||||||
|
def code_visit_With(self, node):
|
||||||
|
if len(node.items) != 1:
|
||||||
|
raise NotImplementedError
|
||||||
|
item = node.items[0]
|
||||||
|
if item.optional_vars is not None:
|
||||||
|
raise NotImplementedError
|
||||||
|
ctxm = self.get_user_ctxm(item.context_expr)
|
||||||
|
if ctxm is None:
|
||||||
|
self.code_generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
# user context manager
|
||||||
|
self.code_generic_visit(node, {"items"})
|
||||||
|
if (not hasattr(ctxm, "__enter__")
|
||||||
|
or not hasattr(ctxm.__enter__, "k_function_info")):
|
||||||
|
raise NotImplementedError
|
||||||
|
enter = get_inline(self.core,
|
||||||
|
self.attribute_namespace, self.in_use_names,
|
||||||
|
None, self.mappers,
|
||||||
|
ctxm.__enter__.k_function_info.k_function,
|
||||||
|
[ctxm], dict())
|
||||||
|
if (not hasattr(ctxm, "__exit__")
|
||||||
|
or not hasattr(ctxm.__exit__, "k_function_info")):
|
||||||
|
raise NotImplementedError
|
||||||
|
exit = get_inline(self.core,
|
||||||
|
self.attribute_namespace, self.in_use_names,
|
||||||
|
None, self.mappers,
|
||||||
|
ctxm.__exit__.k_function_info.k_function,
|
||||||
|
[ctxm, None, None, None], dict())
|
||||||
|
try_stmt = ast.copy_location(
|
||||||
|
ast.Try(body=node.body,
|
||||||
|
handlers=[],
|
||||||
|
orelse=[],
|
||||||
|
finalbody=exit.body), node)
|
||||||
|
return ast.copy_location(
|
||||||
|
ast.With(
|
||||||
|
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||||
|
ctx=ast.Load()),
|
||||||
|
optional_vars=None)],
|
||||||
|
body=enter.body + [try_stmt]),
|
||||||
|
node)
|
||||||
|
|
||||||
def code_visit_FunctionDef(self, node):
|
def code_visit_FunctionDef(self, node):
|
||||||
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
|
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
|
||||||
kw_defaults=[], kwarg=None, defaults=[])
|
kw_defaults=[], kwarg=None, defaults=[])
|
||||||
|
@ -470,7 +525,9 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
||||||
def inline(core, k_function, k_args, k_kwargs, with_attr_writeback):
|
def inline(core, k_function, k_args, k_kwargs, with_attr_writeback):
|
||||||
# OrderedDict prevents non-determinism in attribute init
|
# OrderedDict prevents non-determinism in attribute init
|
||||||
attribute_namespace = OrderedDict()
|
attribute_namespace = OrderedDict()
|
||||||
in_use_names = copy(embeddable_func_names)
|
# NOTE: in_use_names will be mutated. Do not mutate embeddable_func_names!
|
||||||
|
in_use_names = embeddable_func_names | {"sequential", "parallel",
|
||||||
|
"watchdog"}
|
||||||
mappers = types.SimpleNamespace(
|
mappers = types.SimpleNamespace(
|
||||||
rpc=HostObjectMapper(),
|
rpc=HostObjectMapper(),
|
||||||
exception=HostObjectMapper(core_language.first_user_eid)
|
exception=HostObjectMapper(core_language.first_user_eid)
|
||||||
|
|
Loading…
Reference in New Issue