From 6a6d7dab19f320d5ff2c1287eff6cd3ee3001916 Mon Sep 17 00:00:00 2001 From: whitequark Date: Tue, 5 Jan 2016 04:10:31 +0000 Subject: [PATCH] transforms.artiq_ir_generator: add support for user-defined context managers. --- .../compiler/transforms/artiq_ir_generator.py | 240 ++++++++++-------- lit-test/test/integration/with.py | 33 +++ 2 files changed, 171 insertions(+), 102 deletions(-) create mode 100644 lit-test/test/integration/with.py diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index a27afadbd..eb80f14fc 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -555,7 +555,7 @@ class ARTIQIRGenerator(algorithm.Visitor): def visit_Continue(self, node): self.append(ir.Branch(self.continue_target)) - def raise_exn(self, exn, loc=None): + def raise_exn(self, exn=None, loc=None): if self.final_branch is not None: raise_proxy = self.add_block("try.raise") self.final_branch(raise_proxy, self.current_block) @@ -718,19 +718,34 @@ class ARTIQIRGenerator(algorithm.Visitor): if not post_handler.is_terminated(): post_handler.append(ir.Branch(tail)) - def visit_With(self, node): - if len(node.items) != 1: - diag = diagnostic.Diagnostic("fatal", - "only one expression per 'with' statement is supported", - {"type": types.TypePrinter().name(typ)}, - node.context_expr.loc) - self.engine.process(diag) + def _try_finally(self, body_gen, finally_gen, name): + dispatcher = self.add_block("{}.dispatch".format(name)) + try: + old_unwind, self.unwind_target = self.unwind_target, dispatcher + body_gen() + finally: + self.unwind_target = old_unwind + + if not self.current_block.is_terminated(): + finally_gen() + + self.post_body = self.current_block + + self.current_block = self.add_block("{}.cleanup".format(name)) + dispatcher.append(ir.LandingPad(self.current_block)) + finally_gen() + self.raise_exn() + + self.current_block = self.post_body + + def visit_With(self, node): context_expr_node = node.items[0].context_expr optional_vars_node = node.items[0].optional_vars if types.is_builtin(context_expr_node.type, "sequential"): self.visit(node.body) + return elif types.is_builtin(context_expr_node.type, "parallel"): parallel = self.append(ir.Parallel([])) @@ -748,32 +763,44 @@ class ARTIQIRGenerator(algorithm.Visitor): for tail in tails: if not tail.is_terminated(): tail.append(ir.Branch(self.current_block)) - elif isinstance(context_expr_node, asttyped.CallT) and \ - types.is_builtin(context_expr_node.func.type, "watchdog"): - timeout = self.visit(context_expr_node.args[0]) - timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout, - ir.Constant(1000, builtins.TFloat()))) - timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32())) + return - watchdog = self.append(ir.Builtin("watchdog_set", [timeout_ms_int], builtins.TInt32())) + cleanup = [] + for item_node in node.items: + context_expr_node = item_node.context_expr + optional_vars_node = item_node.optional_vars - dispatcher = self.add_block("watchdog.dispatch") + if isinstance(context_expr_node, asttyped.CallT) and \ + types.is_builtin(context_expr_node.func.type, "watchdog"): + timeout = self.visit(context_expr_node.args[0]) + timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout, + ir.Constant(1000, builtins.TFloat()))) + timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32())) - try: - old_unwind, self.unwind_target = self.unwind_target, dispatcher - self.visit(node.body) - finally: - self.unwind_target = old_unwind + watchdog_id = self.append(ir.Builtin("watchdog_set", [timeout_ms_int], + builtins.TInt32())) + cleanup.append(lambda: + self.append(ir.Builtin("watchdog_clear", [watchdog_id], builtins.TNone()))) + else: # user-defined context manager + context_mgr = self.visit(context_expr_node) + enter_fn = self._get_attribute(context_mgr, '__enter__') + exit_fn = self._get_attribute(context_mgr, '__exit__') - cleanup = self.add_block('watchdog.cleanup') - landingpad = dispatcher.append(ir.LandingPad(cleanup)) - cleanup.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone())) - cleanup.append(ir.Reraise(self.unwind_target)) + try: + self.current_assign = self._user_call(enter_fn, [], {}) + if optional_vars_node is not None: + self.visit(optional_vars_node) + finally: + self.current_assign = None - if not self.current_block.is_terminated(): - self.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone())) - else: - assert False + none = self.append(ir.Alloc([], builtins.TNone())) + cleanup.append(lambda: + self._user_call(exit_fn, [none, none, none], {})) + + self._try_finally( + body_gen=lambda: self.visit(node.body), + finally_gen=lambda: [thunk() for thunk in cleanup], + name="with") # Expression visitors # These visitors return a node in addition to mutating @@ -850,14 +877,8 @@ class ARTIQIRGenerator(algorithm.Visitor): else: return self._set_local(node.id, self.current_assign) - def visit_AttributeT(self, node): - try: - old_assign, self.current_assign = self.current_assign, None - obj = self.visit(node.value) - finally: - self.current_assign = old_assign - - if node.attr not in obj.type.find().attributes: + def _get_attribute(self, obj, attr_name): + if attr_name not in obj.type.find().attributes: # A class attribute. Get the constructor (class object) and # extract the attribute from it. constr_type = obj.type.constructor @@ -865,16 +886,26 @@ class ARTIQIRGenerator(algorithm.Visitor): constr_type.name, constr_type, name="constructor." + constr_type.name)) - if types.is_function(constr.type.attributes[node.attr]): + if types.is_function(constr.type.attributes[attr_name]): # A method. Construct a method object instead. - func = self.append(ir.GetAttr(constr, node.attr)) - return self.append(ir.Alloc([func, obj], node.type)) + func = self.append(ir.GetAttr(constr, attr_name)) + return self.append(ir.Alloc([func, obj], + types.TMethod(obj.type, func.type))) else: obj = constr + return self.append(ir.GetAttr(obj, attr_name, + name="{}.{}".format(_readable_name(obj), attr_name))) + + def visit_AttributeT(self, node): + try: + old_assign, self.current_assign = self.current_assign, None + obj = self.visit(node.value) + finally: + self.current_assign = old_assign + if self.current_assign is None: - return self.append(ir.GetAttr(obj, node.attr, - name="{}.{}".format(_readable_name(obj), node.attr))) + return self._get_attribute(obj, node.attr) elif types.is_rpc_function(self.current_assign.type): # RPC functions are just type-level markers return self.append(ir.Builtin("nop", [], builtins.TNone())) @@ -1624,8 +1655,70 @@ class ARTIQIRGenerator(algorithm.Visitor): node.loc) self.engine.process(diag) + def _user_call(self, callee, positional, keywords, arg_exprs={}): + if types.is_function(callee.type): + func = callee + self_arg = None + fn_typ = callee.type + offset = 0 + elif types.is_method(callee.type): + func = self.append(ir.GetAttr(callee, "__func__")) + self_arg = self.append(ir.GetAttr(callee, "__self__")) + fn_typ = types.get_method_function(callee.type) + offset = 1 + else: + assert False + + args = [None] * (len(fn_typ.args) + len(fn_typ.optargs)) + + for index, arg in enumerate(positional): + if index + offset < len(fn_typ.args): + args[index + offset] = arg + else: + args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type))) + + for keyword in keywords: + arg = keywords[keyword] + if keyword in fn_typ.args: + for index, arg_name in enumerate(fn_typ.args): + if keyword == arg_name: + assert args[index] is None + args[index] = arg + break + elif keyword in fn_typ.optargs: + for index, optarg_name in enumerate(fn_typ.optargs): + if keyword == optarg_name: + assert args[len(fn_typ.args) + index] is None + args[len(fn_typ.args) + index] = \ + self.append(ir.Alloc([arg], ir.TOption(arg.type))) + break + + for index, optarg_name in enumerate(fn_typ.optargs): + if args[len(fn_typ.args) + index] is None: + args[len(fn_typ.args) + index] = \ + self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name]))) + + if self_arg is not None: + assert args[0] is None + args[0] = self_arg + + assert None not in args + + if self.unwind_target is None: + insn = self.append(ir.Call(func, args, arg_exprs)) + else: + after_invoke = self.add_block() + insn = self.append(ir.Invoke(func, args, arg_exprs, + after_invoke, self.unwind_target)) + self.current_block = after_invoke + + return insn + def visit_CallT(self, node): - typ = node.func.type.find() + if not types.is_builtin(node.func.type): + callee = self.visit(node.func) + args = [self.visit(arg_node) for arg_node in node.args] + keywords = {kw_node.arg: self.visit(kw_node.value) for kw_node in node.keywords} if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0): before_delay = self.current_block @@ -1633,68 +1726,11 @@ class ARTIQIRGenerator(algorithm.Visitor): before_delay.append(ir.Branch(during_delay)) self.current_block = during_delay - if types.is_builtin(typ): + if types.is_builtin(node.func.type): insn = self.visit_builtin_call(node) else: - if types.is_function(typ): - func = self.visit(node.func) - self_arg = None - fn_typ = typ - offset = 0 - elif types.is_method(typ): - method = self.visit(node.func) - func = self.append(ir.GetAttr(method, "__func__")) - self_arg = self.append(ir.GetAttr(method, "__self__")) - fn_typ = types.get_method_function(typ) - offset = 1 - else: - assert False + insn = self._user_call(callee, args, keywords, node.arg_exprs) - args = [None] * (len(fn_typ.args) + len(fn_typ.optargs)) - - for index, arg_node in enumerate(node.args): - arg = self.visit(arg_node) - if index + offset < len(fn_typ.args): - args[index + offset] = arg - else: - args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type))) - - for keyword in node.keywords: - arg = self.visit(keyword.value) - if keyword.arg in fn_typ.args: - for index, arg_name in enumerate(fn_typ.args): - if keyword.arg == arg_name: - assert args[index] is None - args[index] = arg - break - elif keyword.arg in fn_typ.optargs: - for index, optarg_name in enumerate(fn_typ.optargs): - if keyword.arg == optarg_name: - assert args[len(fn_typ.args) + index] is None - args[len(fn_typ.args) + index] = \ - self.append(ir.Alloc([arg], ir.TOption(arg.type))) - break - - for index, optarg_name in enumerate(fn_typ.optargs): - if args[len(fn_typ.args) + index] is None: - args[len(fn_typ.args) + index] = \ - self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name]))) - - if self_arg is not None: - assert args[0] is None - args[0] = self_arg - - assert None not in args - - if self.unwind_target is None: - insn = self.append(ir.Call(func, args, node.arg_exprs)) - else: - after_invoke = self.add_block() - insn = self.append(ir.Invoke(func, args, node.arg_exprs, - after_invoke, self.unwind_target)) - self.current_block = after_invoke - - method_key = None if isinstance(node.func, asttyped.AttributeT): attr_node = node.func self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn) diff --git a/lit-test/test/integration/with.py b/lit-test/test/integration/with.py new file mode 100644 index 000000000..8f2b9b9a8 --- /dev/null +++ b/lit-test/test/integration/with.py @@ -0,0 +1,33 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +class contextmgr: + def __enter__(self): + print(2) + + def __exit__(self, n1, n2, n3): + print(4) + +# CHECK-L: a 1 +# CHECK-L: 2 +# CHECK-L: a 3 +# CHECK-L: 4 +# CHECK-L: a 5 +print("a", 1) +with contextmgr(): + print("a", 3) +print("a", 5) + +# CHECK-L: b 1 +# CHECK-L: 2 +# CHECK-L: 4 +# CHECK-L: b 6 +try: + print("b", 1) + with contextmgr(): + [0][1] + print("b", 3) + print("b", 5) +except: + pass +print("b", 6)