transforms/inline: support user-defined context managers

This commit is contained in:
Sebastien Bourdeauducq 2015-05-09 14:47:08 +08:00
parent 5c08423b29
commit fdc406f062
1 changed files with 60 additions and 3 deletions

View File

@ -3,7 +3,6 @@ import textwrap
import ast import ast
import types import types
import builtins import builtins
from copy import copy
from fractions import Fraction from fractions import Fraction
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
@ -199,8 +198,10 @@ class Function:
# This is ast.NodeTransformer.generic_visit from CPython, modified # This is ast.NodeTransformer.generic_visit from CPython, modified
# to update self._insertion_point. # to update self._insertion_point.
def code_generic_visit(self, node): def code_generic_visit(self, node, exclude_fields=set()):
for field, old_value in ast.iter_fields(node): for field, old_value in ast.iter_fields(node):
if field in exclude_fields:
continue
old_value = getattr(node, field, None) old_value = getattr(node, field, None)
if isinstance(old_value, list): if isinstance(old_value, list):
prev_insertion_point = self._insertion_point prev_insertion_point = self._insertion_point
@ -378,6 +379,60 @@ class Function:
self.code_generic_visit(node) self.code_generic_visit(node)
return node return node
def get_user_ctxm(self, context_expr):
try:
ctxm = self.static_visit(context_expr)
except:
# this also catches watchdog()
return None
else:
if (ctxm is core_language.sequential
or ctxm is core_language.parallel):
return None
return ctxm
def code_visit_With(self, node):
if len(node.items) != 1:
raise NotImplementedError
item = node.items[0]
if item.optional_vars is not None:
raise NotImplementedError
ctxm = self.get_user_ctxm(item.context_expr)
if ctxm is None:
self.code_generic_visit(node)
return node
# user context manager
self.code_generic_visit(node, {"items"})
if (not hasattr(ctxm, "__enter__")
or not hasattr(ctxm.__enter__, "k_function_info")):
raise NotImplementedError
enter = get_inline(self.core,
self.attribute_namespace, self.in_use_names,
None, self.mappers,
ctxm.__enter__.k_function_info.k_function,
[ctxm], dict())
if (not hasattr(ctxm, "__exit__")
or not hasattr(ctxm.__exit__, "k_function_info")):
raise NotImplementedError
exit = get_inline(self.core,
self.attribute_namespace, self.in_use_names,
None, self.mappers,
ctxm.__exit__.k_function_info.k_function,
[ctxm, None, None, None], dict())
try_stmt = ast.copy_location(
ast.Try(body=node.body,
handlers=[],
orelse=[],
finalbody=exit.body), node)
return ast.copy_location(
ast.With(
items=[ast.withitem(context_expr=ast.Name(id="sequential",
ctx=ast.Load()),
optional_vars=None)],
body=enter.body + [try_stmt]),
node)
def code_visit_FunctionDef(self, node): def code_visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
kw_defaults=[], kwarg=None, defaults=[]) kw_defaults=[], kwarg=None, defaults=[])
@ -470,7 +525,9 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
def inline(core, k_function, k_args, k_kwargs, with_attr_writeback): def inline(core, k_function, k_args, k_kwargs, with_attr_writeback):
# OrderedDict prevents non-determinism in attribute init # OrderedDict prevents non-determinism in attribute init
attribute_namespace = OrderedDict() attribute_namespace = OrderedDict()
in_use_names = copy(embeddable_func_names) # NOTE: in_use_names will be mutated. Do not mutate embeddable_func_names!
in_use_names = embeddable_func_names | {"sequential", "parallel",
"watchdog"}
mappers = types.SimpleNamespace( mappers = types.SimpleNamespace(
rpc=HostObjectMapper(), rpc=HostObjectMapper(),
exception=HostObjectMapper(core_language.first_user_eid) exception=HostObjectMapper(core_language.first_user_eid)