transforms/inline: support user-defined context managers

This commit is contained in:
Sebastien Bourdeauducq 2015-05-09 14:47:08 +08:00
parent 5c08423b29
commit fdc406f062
1 changed files with 60 additions and 3 deletions

View File

@ -3,7 +3,6 @@ import textwrap
import ast
import types
import builtins
from copy import copy
from fractions import Fraction
from collections import OrderedDict
from functools import partial
@ -199,8 +198,10 @@ class Function:
# This is ast.NodeTransformer.generic_visit from CPython, modified
# to update self._insertion_point.
def code_generic_visit(self, node):
def code_generic_visit(self, node, exclude_fields=set()):
for field, old_value in ast.iter_fields(node):
if field in exclude_fields:
continue
old_value = getattr(node, field, None)
if isinstance(old_value, list):
prev_insertion_point = self._insertion_point
@ -378,6 +379,60 @@ class Function:
self.code_generic_visit(node)
return node
def get_user_ctxm(self, context_expr):
try:
ctxm = self.static_visit(context_expr)
except:
# this also catches watchdog()
return None
else:
if (ctxm is core_language.sequential
or ctxm is core_language.parallel):
return None
return ctxm
def code_visit_With(self, node):
if len(node.items) != 1:
raise NotImplementedError
item = node.items[0]
if item.optional_vars is not None:
raise NotImplementedError
ctxm = self.get_user_ctxm(item.context_expr)
if ctxm is None:
self.code_generic_visit(node)
return node
# user context manager
self.code_generic_visit(node, {"items"})
if (not hasattr(ctxm, "__enter__")
or not hasattr(ctxm.__enter__, "k_function_info")):
raise NotImplementedError
enter = get_inline(self.core,
self.attribute_namespace, self.in_use_names,
None, self.mappers,
ctxm.__enter__.k_function_info.k_function,
[ctxm], dict())
if (not hasattr(ctxm, "__exit__")
or not hasattr(ctxm.__exit__, "k_function_info")):
raise NotImplementedError
exit = get_inline(self.core,
self.attribute_namespace, self.in_use_names,
None, self.mappers,
ctxm.__exit__.k_function_info.k_function,
[ctxm, None, None, None], dict())
try_stmt = ast.copy_location(
ast.Try(body=node.body,
handlers=[],
orelse=[],
finalbody=exit.body), node)
return ast.copy_location(
ast.With(
items=[ast.withitem(context_expr=ast.Name(id="sequential",
ctx=ast.Load()),
optional_vars=None)],
body=enter.body + [try_stmt]),
node)
def code_visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
kw_defaults=[], kwarg=None, defaults=[])
@ -470,7 +525,9 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
def inline(core, k_function, k_args, k_kwargs, with_attr_writeback):
# OrderedDict prevents non-determinism in attribute init
attribute_namespace = OrderedDict()
in_use_names = copy(embeddable_func_names)
# NOTE: in_use_names will be mutated. Do not mutate embeddable_func_names!
in_use_names = embeddable_func_names | {"sequential", "parallel",
"watchdog"}
mappers = types.SimpleNamespace(
rpc=HostObjectMapper(),
exception=HostObjectMapper(core_language.first_user_eid)