From dfbf55fed290aab7607e2370c7e85f757c41418f Mon Sep 17 00:00:00 2001 From: whitequark Date: Tue, 5 Jan 2016 00:11:03 +0800 Subject: [PATCH] transforms/inferencer: add support for user-defined context manager. --- artiq/compiler/analyses/devirtualization.py | 2 +- artiq/compiler/asttyped.py | 3 + .../compiler/transforms/asttyped_rewriter.py | 8 ++ artiq/compiler/transforms/inferencer.py | 95 +++++++++++++++++-- lit-test/test/inferencer/error_with_arity.py | 15 +++ lit-test/test/inferencer/error_with_exn.py | 16 ++++ lit-test/test/inferencer/error_with_self.py | 17 ++++ lit-test/test/inferencer/error_with_var.py | 17 ++++ lit-test/test/inferencer/with.py | 2 +- 9 files changed, 166 insertions(+), 9 deletions(-) create mode 100644 lit-test/test/inferencer/error_with_arity.py create mode 100644 lit-test/test/inferencer/error_with_exn.py create mode 100644 lit-test/test/inferencer/error_with_self.py create mode 100644 lit-test/test/inferencer/error_with_var.py diff --git a/artiq/compiler/analyses/devirtualization.py b/artiq/compiler/analyses/devirtualization.py index 0fef44ddf..3a35639f2 100644 --- a/artiq/compiler/analyses/devirtualization.py +++ b/artiq/compiler/analyses/devirtualization.py @@ -52,7 +52,7 @@ class FunctionResolver(algorithm.Visitor): self.visit(node.body) self.visit(node.orelse) - def visit_withitem(self, node): + def visit_withitemT(self, node): self.visit(node.context_expr) self.visit_in_assign(node.optional_vars) diff --git a/artiq/compiler/asttyped.py b/artiq/compiler/asttyped.py index 9d4274470..85e86d59d 100644 --- a/artiq/compiler/asttyped.py +++ b/artiq/compiler/asttyped.py @@ -42,6 +42,9 @@ class ForT(ast.For): :ivar trip_interval: (:class:`iodelay.Expr`) """ +class withitemT(ast.withitem): + _types = ("enter_type", "exit_type") + class SliceT(ast.Slice, commontyped): pass diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index a14d55c5d..70790a613 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -480,6 +480,14 @@ class ASTTypedRewriter(algorithm.Transformer): else_loc=node.else_loc, else_colon_loc=node.else_colon_loc) return node + def visit_withitem(self, node): + node = self.generic_visit(node) + node = asttyped.withitemT( + context_expr=node.context_expr, optional_vars=node.optional_vars, + enter_type=types.TVar(), exit_type=types.TVar(), + as_loc=node.as_loc, loc=node.loc) + return node + # Unsupported visitors # def visit_unsupported(self, node): diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 1cc2a30da..07ee75198 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -160,7 +160,7 @@ class Inferencer(algorithm.Visitor): self._unify(result_type, attr_type, loc, None) else: - if attr_name_loc.source_buffer == value_node.loc.source_buffer: + if attr_loc.source_buffer == value_node.loc.source_buffer: highlights, notes = [value_node.loc], [] else: # This happens when the object being accessed is embedded @@ -173,7 +173,7 @@ class Inferencer(algorithm.Visitor): diag = diagnostic.Diagnostic("error", "type {type} does not have an attribute '{attr}'", {"type": types.TypePrinter().name(object_type), "attr": attr_name}, - node.attr_loc, highlights, notes) + attr_loc, highlights, notes) self.engine.process(diag) def _unify_iterable(self, element, collection): @@ -970,23 +970,104 @@ class Inferencer(algorithm.Visitor): node.keyword_loc) self.engine.process(diag) - def visit_withitem(self, node): + def visit_withitemT(self, node): self.generic_visit(node) typ = node.context_expr.type if (types.is_builtin(typ, "parallel") or types.is_builtin(typ, "sequential") or (isinstance(node.context_expr, asttyped.CallT) and types.is_builtin(node.context_expr.func.type, "watchdog"))): + # builtin context managers + if node.optional_vars is not None: + self._unify(node.optional_vars.type, builtins.TNone(), + node.optional_vars.loc, None) + elif types.is_instance(typ) or types.is_constructor(typ): + # user-defined context managers + self._unify_attribute(result_type=node.enter_type, value_node=node.context_expr, + attr_name='__enter__', attr_loc=None, loc=node.loc) + self._unify_attribute(result_type=node.exit_type, value_node=node.context_expr, + attr_name='__exit__', attr_loc=None, loc=node.loc) + + printer = types.TypePrinter() + + def check_callback(attr_name, typ, arity): + if types.is_var(typ): + return + + if not (types.is_method(typ) or types.is_function(typ)): + diag = diagnostic.Diagnostic("error", + "attribute '{attr}' of type {manager_type} must be a function", + {"attr": attr_name, + "manager_type": printer.name(node.context_expr.type)}, + node.context_expr.loc) + self.engine.process(diag) + return + + if types.is_method(typ): + typ = types.get_method_function(typ).find() + else: + typ = typ.find() + + if not (len(typ.args) == arity and len(typ.optargs) == 0): + diag = diagnostic.Diagnostic("error", + "function '{attr}{attr_type}' must accept " + "{arity} positional argument{s} and no optional arguments", + {"attr": attr_name, + "attr_type": printer.name(typ), + "arity": arity, "s": "s" if arity > 1 else ""}, + node.context_expr.loc) + self.engine.process(diag) + + for formal_arg_name in list(typ.args)[1:]: + formal_arg_type = typ.args[formal_arg_name] + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "exception handling via context managers is not supported; " + "the argument '{arg}' of function '{attr}{attr_type}' " + "will always be None", + {"arg": formal_arg_name, + "attr": attr_name, + "attr_type": printer.name(typ)}, + loca), + ] + + self._unify(formal_arg_type, builtins.TNone(), + node.context_expr.loc, None, + makenotes=makenotes) + + check_callback('__enter__', node.enter_type, 1) + check_callback('__exit__', node.exit_type, 4) + + if node.optional_vars is not None: + if types.is_method(node.exit_type): + var_type = types.get_method_function(node.exit_type).find().ret + else: + var_type = node.exit_type.find().ret + + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "expression of type {typea}", + {"typea": printer.name(typea)}, + loca), + diagnostic.Diagnostic("note", + "context manager with an '__enter__' method returning {typeb}", + {"typeb": printer.name(typeb)}, + locb) + ] + + self._unify(node.optional_vars.type, var_type, + node.optional_vars.loc, node.context_expr.loc, + makenotes=makenotes) + + else: diag = diagnostic.Diagnostic("error", "value of type {type} cannot act as a context manager", {"type": types.TypePrinter().name(typ)}, node.context_expr.loc) self.engine.process(diag) - if node.optional_vars is not None: - self._unify(node.optional_vars.type, node.context_expr.type, - node.optional_vars.loc, node.context_expr.loc) - def visit_With(self, node): self.generic_visit(node) diff --git a/lit-test/test/inferencer/error_with_arity.py b/lit-test/test/inferencer/error_with_arity.py new file mode 100644 index 000000000..b056073ae --- /dev/null +++ b/lit-test/test/inferencer/error_with_arity.py @@ -0,0 +1,15 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +class contextmgr: + def __enter__(self, n1): + pass + + def __exit__(self, n1, n2): + pass + +def foo(): + # CHECK-L: ${LINE:+2}: error: function '__enter__(self:, n1:'a)->NoneType delay('b)' must accept 1 positional argument and no optional arguments + # CHECK-L: ${LINE:+1}: error: function '__exit__(self:, n1:'c, n2:'d)->NoneType delay('e)' must accept 4 positional arguments and no optional arguments + with contextmgr(): + pass diff --git a/lit-test/test/inferencer/error_with_exn.py b/lit-test/test/inferencer/error_with_exn.py new file mode 100644 index 000000000..5aa8ed1ec --- /dev/null +++ b/lit-test/test/inferencer/error_with_exn.py @@ -0,0 +1,16 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +class contextmgr: + def __enter__(self): + pass + + def __exit__(self, n1, n2, n3): + n3 = 1 + pass + +def foo(): + # CHECK-L: ${LINE:+2}: error: cannot unify int(width='a) with NoneType + # CHECK-L: ${LINE:+1}: note: exception handling via context managers is not supported; the argument 'n3' of function '__exit__(self:, n1:NoneType, n2:NoneType, n3:int(width='a))->NoneType delay('b)' will always be None + with contextmgr(): + pass diff --git a/lit-test/test/inferencer/error_with_self.py b/lit-test/test/inferencer/error_with_self.py new file mode 100644 index 000000000..afe53e531 --- /dev/null +++ b/lit-test/test/inferencer/error_with_self.py @@ -0,0 +1,17 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +class contextmgr: + def __enter__(self): + pass + + def __exit__(self, n1, n2, n3): + pass + +def foo(): + contextmgr.__enter__(1) + # CHECK-L: ${LINE:+3}: error: cannot unify with int(width='a) while inferring the type for self argument + # CHECK-L: ${LINE:+2}: note: expression of type + # CHECK-L: ${LINE:+1}: note: reference to an instance with a method '__enter__(self:int(width='a))->NoneType delay('b)' + with contextmgr(): + pass diff --git a/lit-test/test/inferencer/error_with_var.py b/lit-test/test/inferencer/error_with_var.py new file mode 100644 index 000000000..97b1b345a --- /dev/null +++ b/lit-test/test/inferencer/error_with_var.py @@ -0,0 +1,17 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +class contextmgr: + def __enter__(self): + return 1 + + def __exit__(self, n1, n2, n3): + pass + +def foo(): + x = "x" + # CHECK-L: ${LINE:+3}: error: cannot unify str with NoneType + # CHECK-L: ${LINE:+2}: note: expression of type str + # CHECK-L: ${LINE:+1}: note: context manager with an '__enter__' method returning NoneType + with contextmgr() as x: + pass diff --git a/lit-test/test/inferencer/with.py b/lit-test/test/inferencer/with.py index 2228b6204..a48dae4b7 100644 --- a/lit-test/test/inferencer/with.py +++ b/lit-test/test/inferencer/with.py @@ -1,5 +1,5 @@ # RUN: %python -m artiq.compiler.testbench.inferencer %s >%t # RUN: OutputCheck %s --file-to-check=%t -# CHECK-L: as x: +# CHECK-L: as x:NoneType with parallel as x: pass