transforms/inferencer: add support for user-defined context manager.

This commit is contained in:
whitequark 2016-01-05 00:11:03 +08:00
parent 5baf18ba0d
commit dfbf55fed2
9 changed files with 166 additions and 9 deletions

View File

@ -52,7 +52,7 @@ class FunctionResolver(algorithm.Visitor):
self.visit(node.body)
self.visit(node.orelse)
def visit_withitem(self, node):
def visit_withitemT(self, node):
self.visit(node.context_expr)
self.visit_in_assign(node.optional_vars)

View File

@ -42,6 +42,9 @@ class ForT(ast.For):
:ivar trip_interval: (:class:`iodelay.Expr`)
"""
class withitemT(ast.withitem):
_types = ("enter_type", "exit_type")
class SliceT(ast.Slice, commontyped):
pass

View File

@ -480,6 +480,14 @@ class ASTTypedRewriter(algorithm.Transformer):
else_loc=node.else_loc, else_colon_loc=node.else_colon_loc)
return node
def visit_withitem(self, node):
node = self.generic_visit(node)
node = asttyped.withitemT(
context_expr=node.context_expr, optional_vars=node.optional_vars,
enter_type=types.TVar(), exit_type=types.TVar(),
as_loc=node.as_loc, loc=node.loc)
return node
# Unsupported visitors
#
def visit_unsupported(self, node):

View File

@ -160,7 +160,7 @@ class Inferencer(algorithm.Visitor):
self._unify(result_type, attr_type,
loc, None)
else:
if attr_name_loc.source_buffer == value_node.loc.source_buffer:
if attr_loc.source_buffer == value_node.loc.source_buffer:
highlights, notes = [value_node.loc], []
else:
# This happens when the object being accessed is embedded
@ -173,7 +173,7 @@ class Inferencer(algorithm.Visitor):
diag = diagnostic.Diagnostic("error",
"type {type} does not have an attribute '{attr}'",
{"type": types.TypePrinter().name(object_type), "attr": attr_name},
node.attr_loc, highlights, notes)
attr_loc, highlights, notes)
self.engine.process(diag)
def _unify_iterable(self, element, collection):
@ -970,23 +970,104 @@ class Inferencer(algorithm.Visitor):
node.keyword_loc)
self.engine.process(diag)
def visit_withitem(self, node):
def visit_withitemT(self, node):
self.generic_visit(node)
typ = node.context_expr.type
if (types.is_builtin(typ, "parallel") or types.is_builtin(typ, "sequential") or
(isinstance(node.context_expr, asttyped.CallT) and
types.is_builtin(node.context_expr.func.type, "watchdog"))):
# builtin context managers
if node.optional_vars is not None:
self._unify(node.optional_vars.type, builtins.TNone(),
node.optional_vars.loc, None)
elif types.is_instance(typ) or types.is_constructor(typ):
# user-defined context managers
self._unify_attribute(result_type=node.enter_type, value_node=node.context_expr,
attr_name='__enter__', attr_loc=None, loc=node.loc)
self._unify_attribute(result_type=node.exit_type, value_node=node.context_expr,
attr_name='__exit__', attr_loc=None, loc=node.loc)
printer = types.TypePrinter()
def check_callback(attr_name, typ, arity):
if types.is_var(typ):
return
if not (types.is_method(typ) or types.is_function(typ)):
diag = diagnostic.Diagnostic("error",
"attribute '{attr}' of type {manager_type} must be a function",
{"attr": attr_name,
"manager_type": printer.name(node.context_expr.type)},
node.context_expr.loc)
self.engine.process(diag)
return
if types.is_method(typ):
typ = types.get_method_function(typ).find()
else:
typ = typ.find()
if not (len(typ.args) == arity and len(typ.optargs) == 0):
diag = diagnostic.Diagnostic("error",
"function '{attr}{attr_type}' must accept "
"{arity} positional argument{s} and no optional arguments",
{"attr": attr_name,
"attr_type": printer.name(typ),
"arity": arity, "s": "s" if arity > 1 else ""},
node.context_expr.loc)
self.engine.process(diag)
for formal_arg_name in list(typ.args)[1:]:
formal_arg_type = typ.args[formal_arg_name]
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"exception handling via context managers is not supported; "
"the argument '{arg}' of function '{attr}{attr_type}' "
"will always be None",
{"arg": formal_arg_name,
"attr": attr_name,
"attr_type": printer.name(typ)},
loca),
]
self._unify(formal_arg_type, builtins.TNone(),
node.context_expr.loc, None,
makenotes=makenotes)
check_callback('__enter__', node.enter_type, 1)
check_callback('__exit__', node.exit_type, 4)
if node.optional_vars is not None:
if types.is_method(node.exit_type):
var_type = types.get_method_function(node.exit_type).find().ret
else:
var_type = node.exit_type.find().ret
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"expression of type {typea}",
{"typea": printer.name(typea)},
loca),
diagnostic.Diagnostic("note",
"context manager with an '__enter__' method returning {typeb}",
{"typeb": printer.name(typeb)},
locb)
]
self._unify(node.optional_vars.type, var_type,
node.optional_vars.loc, node.context_expr.loc,
makenotes=makenotes)
else:
diag = diagnostic.Diagnostic("error",
"value of type {type} cannot act as a context manager",
{"type": types.TypePrinter().name(typ)},
node.context_expr.loc)
self.engine.process(diag)
if node.optional_vars is not None:
self._unify(node.optional_vars.type, node.context_expr.type,
node.optional_vars.loc, node.context_expr.loc)
def visit_With(self, node):
self.generic_visit(node)

View File

@ -0,0 +1,15 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
class contextmgr:
def __enter__(self, n1):
pass
def __exit__(self, n1, n2):
pass
def foo():
# CHECK-L: ${LINE:+2}: error: function '__enter__(self:<instance contextmgr {}>, n1:'a)->NoneType delay('b)' must accept 1 positional argument and no optional arguments
# CHECK-L: ${LINE:+1}: error: function '__exit__(self:<instance contextmgr>, n1:'c, n2:'d)->NoneType delay('e)' must accept 4 positional arguments and no optional arguments
with contextmgr():
pass

View File

@ -0,0 +1,16 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
class contextmgr:
def __enter__(self):
pass
def __exit__(self, n1, n2, n3):
n3 = 1
pass
def foo():
# CHECK-L: ${LINE:+2}: error: cannot unify int(width='a) with NoneType
# CHECK-L: ${LINE:+1}: note: exception handling via context managers is not supported; the argument 'n3' of function '__exit__(self:<instance contextmgr {}>, n1:NoneType, n2:NoneType, n3:int(width='a))->NoneType delay('b)' will always be None
with contextmgr():
pass

View File

@ -0,0 +1,17 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
class contextmgr:
def __enter__(self):
pass
def __exit__(self, n1, n2, n3):
pass
def foo():
contextmgr.__enter__(1)
# CHECK-L: ${LINE:+3}: error: cannot unify <instance contextmgr> with int(width='a) while inferring the type for self argument
# CHECK-L: ${LINE:+2}: note: expression of type <instance contextmgr {}>
# CHECK-L: ${LINE:+1}: note: reference to an instance with a method '__enter__(self:int(width='a))->NoneType delay('b)'
with contextmgr():
pass

View File

@ -0,0 +1,17 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
class contextmgr:
def __enter__(self):
return 1
def __exit__(self, n1, n2, n3):
pass
def foo():
x = "x"
# CHECK-L: ${LINE:+3}: error: cannot unify str with NoneType
# CHECK-L: ${LINE:+2}: note: expression of type str
# CHECK-L: ${LINE:+1}: note: context manager with an '__enter__' method returning NoneType
with contextmgr() as x:
pass

View File

@ -1,5 +1,5 @@
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: as x:<builtin parallel>
# CHECK-L: as x:NoneType
with parallel as x: pass