forked from M-Labs/artiq
transforms/inferencer: add support for user-defined context manager.
This commit is contained in:
parent
5baf18ba0d
commit
dfbf55fed2
|
@ -52,7 +52,7 @@ class FunctionResolver(algorithm.Visitor):
|
||||||
self.visit(node.body)
|
self.visit(node.body)
|
||||||
self.visit(node.orelse)
|
self.visit(node.orelse)
|
||||||
|
|
||||||
def visit_withitem(self, node):
|
def visit_withitemT(self, node):
|
||||||
self.visit(node.context_expr)
|
self.visit(node.context_expr)
|
||||||
self.visit_in_assign(node.optional_vars)
|
self.visit_in_assign(node.optional_vars)
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,9 @@ class ForT(ast.For):
|
||||||
:ivar trip_interval: (:class:`iodelay.Expr`)
|
:ivar trip_interval: (:class:`iodelay.Expr`)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class withitemT(ast.withitem):
|
||||||
|
_types = ("enter_type", "exit_type")
|
||||||
|
|
||||||
class SliceT(ast.Slice, commontyped):
|
class SliceT(ast.Slice, commontyped):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -480,6 +480,14 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||||
else_loc=node.else_loc, else_colon_loc=node.else_colon_loc)
|
else_loc=node.else_loc, else_colon_loc=node.else_colon_loc)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def visit_withitem(self, node):
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
node = asttyped.withitemT(
|
||||||
|
context_expr=node.context_expr, optional_vars=node.optional_vars,
|
||||||
|
enter_type=types.TVar(), exit_type=types.TVar(),
|
||||||
|
as_loc=node.as_loc, loc=node.loc)
|
||||||
|
return node
|
||||||
|
|
||||||
# Unsupported visitors
|
# Unsupported visitors
|
||||||
#
|
#
|
||||||
def visit_unsupported(self, node):
|
def visit_unsupported(self, node):
|
||||||
|
|
|
@ -160,7 +160,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
self._unify(result_type, attr_type,
|
self._unify(result_type, attr_type,
|
||||||
loc, None)
|
loc, None)
|
||||||
else:
|
else:
|
||||||
if attr_name_loc.source_buffer == value_node.loc.source_buffer:
|
if attr_loc.source_buffer == value_node.loc.source_buffer:
|
||||||
highlights, notes = [value_node.loc], []
|
highlights, notes = [value_node.loc], []
|
||||||
else:
|
else:
|
||||||
# This happens when the object being accessed is embedded
|
# This happens when the object being accessed is embedded
|
||||||
|
@ -173,7 +173,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"type {type} does not have an attribute '{attr}'",
|
"type {type} does not have an attribute '{attr}'",
|
||||||
{"type": types.TypePrinter().name(object_type), "attr": attr_name},
|
{"type": types.TypePrinter().name(object_type), "attr": attr_name},
|
||||||
node.attr_loc, highlights, notes)
|
attr_loc, highlights, notes)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
def _unify_iterable(self, element, collection):
|
def _unify_iterable(self, element, collection):
|
||||||
|
@ -970,23 +970,104 @@ class Inferencer(algorithm.Visitor):
|
||||||
node.keyword_loc)
|
node.keyword_loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
def visit_withitem(self, node):
|
def visit_withitemT(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
|
||||||
typ = node.context_expr.type
|
typ = node.context_expr.type
|
||||||
if (types.is_builtin(typ, "parallel") or types.is_builtin(typ, "sequential") or
|
if (types.is_builtin(typ, "parallel") or types.is_builtin(typ, "sequential") or
|
||||||
(isinstance(node.context_expr, asttyped.CallT) and
|
(isinstance(node.context_expr, asttyped.CallT) and
|
||||||
types.is_builtin(node.context_expr.func.type, "watchdog"))):
|
types.is_builtin(node.context_expr.func.type, "watchdog"))):
|
||||||
|
# builtin context managers
|
||||||
|
if node.optional_vars is not None:
|
||||||
|
self._unify(node.optional_vars.type, builtins.TNone(),
|
||||||
|
node.optional_vars.loc, None)
|
||||||
|
elif types.is_instance(typ) or types.is_constructor(typ):
|
||||||
|
# user-defined context managers
|
||||||
|
self._unify_attribute(result_type=node.enter_type, value_node=node.context_expr,
|
||||||
|
attr_name='__enter__', attr_loc=None, loc=node.loc)
|
||||||
|
self._unify_attribute(result_type=node.exit_type, value_node=node.context_expr,
|
||||||
|
attr_name='__exit__', attr_loc=None, loc=node.loc)
|
||||||
|
|
||||||
|
printer = types.TypePrinter()
|
||||||
|
|
||||||
|
def check_callback(attr_name, typ, arity):
|
||||||
|
if types.is_var(typ):
|
||||||
|
return
|
||||||
|
|
||||||
|
if not (types.is_method(typ) or types.is_function(typ)):
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"attribute '{attr}' of type {manager_type} must be a function",
|
||||||
|
{"attr": attr_name,
|
||||||
|
"manager_type": printer.name(node.context_expr.type)},
|
||||||
|
node.context_expr.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
if types.is_method(typ):
|
||||||
|
typ = types.get_method_function(typ).find()
|
||||||
|
else:
|
||||||
|
typ = typ.find()
|
||||||
|
|
||||||
|
if not (len(typ.args) == arity and len(typ.optargs) == 0):
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"function '{attr}{attr_type}' must accept "
|
||||||
|
"{arity} positional argument{s} and no optional arguments",
|
||||||
|
{"attr": attr_name,
|
||||||
|
"attr_type": printer.name(typ),
|
||||||
|
"arity": arity, "s": "s" if arity > 1 else ""},
|
||||||
|
node.context_expr.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
for formal_arg_name in list(typ.args)[1:]:
|
||||||
|
formal_arg_type = typ.args[formal_arg_name]
|
||||||
|
def makenotes(printer, typea, typeb, loca, locb):
|
||||||
|
return [
|
||||||
|
diagnostic.Diagnostic("note",
|
||||||
|
"exception handling via context managers is not supported; "
|
||||||
|
"the argument '{arg}' of function '{attr}{attr_type}' "
|
||||||
|
"will always be None",
|
||||||
|
{"arg": formal_arg_name,
|
||||||
|
"attr": attr_name,
|
||||||
|
"attr_type": printer.name(typ)},
|
||||||
|
loca),
|
||||||
|
]
|
||||||
|
|
||||||
|
self._unify(formal_arg_type, builtins.TNone(),
|
||||||
|
node.context_expr.loc, None,
|
||||||
|
makenotes=makenotes)
|
||||||
|
|
||||||
|
check_callback('__enter__', node.enter_type, 1)
|
||||||
|
check_callback('__exit__', node.exit_type, 4)
|
||||||
|
|
||||||
|
if node.optional_vars is not None:
|
||||||
|
if types.is_method(node.exit_type):
|
||||||
|
var_type = types.get_method_function(node.exit_type).find().ret
|
||||||
|
else:
|
||||||
|
var_type = node.exit_type.find().ret
|
||||||
|
|
||||||
|
def makenotes(printer, typea, typeb, loca, locb):
|
||||||
|
return [
|
||||||
|
diagnostic.Diagnostic("note",
|
||||||
|
"expression of type {typea}",
|
||||||
|
{"typea": printer.name(typea)},
|
||||||
|
loca),
|
||||||
|
diagnostic.Diagnostic("note",
|
||||||
|
"context manager with an '__enter__' method returning {typeb}",
|
||||||
|
{"typeb": printer.name(typeb)},
|
||||||
|
locb)
|
||||||
|
]
|
||||||
|
|
||||||
|
self._unify(node.optional_vars.type, var_type,
|
||||||
|
node.optional_vars.loc, node.context_expr.loc,
|
||||||
|
makenotes=makenotes)
|
||||||
|
|
||||||
|
else:
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"value of type {type} cannot act as a context manager",
|
"value of type {type} cannot act as a context manager",
|
||||||
{"type": types.TypePrinter().name(typ)},
|
{"type": types.TypePrinter().name(typ)},
|
||||||
node.context_expr.loc)
|
node.context_expr.loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
if node.optional_vars is not None:
|
|
||||||
self._unify(node.optional_vars.type, node.context_expr.type,
|
|
||||||
node.optional_vars.loc, node.context_expr.loc)
|
|
||||||
|
|
||||||
def visit_With(self, node):
|
def visit_With(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
class contextmgr:
|
||||||
|
def __enter__(self, n1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, n1, n2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def foo():
|
||||||
|
# CHECK-L: ${LINE:+2}: error: function '__enter__(self:<instance contextmgr {}>, n1:'a)->NoneType delay('b)' must accept 1 positional argument and no optional arguments
|
||||||
|
# CHECK-L: ${LINE:+1}: error: function '__exit__(self:<instance contextmgr>, n1:'c, n2:'d)->NoneType delay('e)' must accept 4 positional arguments and no optional arguments
|
||||||
|
with contextmgr():
|
||||||
|
pass
|
|
@ -0,0 +1,16 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
class contextmgr:
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, n1, n2, n3):
|
||||||
|
n3 = 1
|
||||||
|
pass
|
||||||
|
|
||||||
|
def foo():
|
||||||
|
# CHECK-L: ${LINE:+2}: error: cannot unify int(width='a) with NoneType
|
||||||
|
# CHECK-L: ${LINE:+1}: note: exception handling via context managers is not supported; the argument 'n3' of function '__exit__(self:<instance contextmgr {}>, n1:NoneType, n2:NoneType, n3:int(width='a))->NoneType delay('b)' will always be None
|
||||||
|
with contextmgr():
|
||||||
|
pass
|
|
@ -0,0 +1,17 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
class contextmgr:
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, n1, n2, n3):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def foo():
|
||||||
|
contextmgr.__enter__(1)
|
||||||
|
# CHECK-L: ${LINE:+3}: error: cannot unify <instance contextmgr> with int(width='a) while inferring the type for self argument
|
||||||
|
# CHECK-L: ${LINE:+2}: note: expression of type <instance contextmgr {}>
|
||||||
|
# CHECK-L: ${LINE:+1}: note: reference to an instance with a method '__enter__(self:int(width='a))->NoneType delay('b)'
|
||||||
|
with contextmgr():
|
||||||
|
pass
|
|
@ -0,0 +1,17 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
class contextmgr:
|
||||||
|
def __enter__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def __exit__(self, n1, n2, n3):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def foo():
|
||||||
|
x = "x"
|
||||||
|
# CHECK-L: ${LINE:+3}: error: cannot unify str with NoneType
|
||||||
|
# CHECK-L: ${LINE:+2}: note: expression of type str
|
||||||
|
# CHECK-L: ${LINE:+1}: note: context manager with an '__enter__' method returning NoneType
|
||||||
|
with contextmgr() as x:
|
||||||
|
pass
|
|
@ -1,5 +1,5 @@
|
||||||
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
|
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
|
||||||
# RUN: OutputCheck %s --file-to-check=%t
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
# CHECK-L: as x:<builtin parallel>
|
# CHECK-L: as x:NoneType
|
||||||
with parallel as x: pass
|
with parallel as x: pass
|
||||||
|
|
Loading…
Reference in New Issue