transforms.artiq_ir_generator: add support for user-defined context managers.

This commit is contained in:
whitequark 2016-01-05 04:10:31 +00:00
parent d633c8e1f8
commit 6a6d7dab19
2 changed files with 171 additions and 102 deletions

View File

@ -555,7 +555,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
def visit_Continue(self, node):
self.append(ir.Branch(self.continue_target))
def raise_exn(self, exn, loc=None):
def raise_exn(self, exn=None, loc=None):
if self.final_branch is not None:
raise_proxy = self.add_block("try.raise")
self.final_branch(raise_proxy, self.current_block)
@ -718,19 +718,34 @@ class ARTIQIRGenerator(algorithm.Visitor):
if not post_handler.is_terminated():
post_handler.append(ir.Branch(tail))
def visit_With(self, node):
if len(node.items) != 1:
diag = diagnostic.Diagnostic("fatal",
"only one expression per 'with' statement is supported",
{"type": types.TypePrinter().name(typ)},
node.context_expr.loc)
self.engine.process(diag)
def _try_finally(self, body_gen, finally_gen, name):
dispatcher = self.add_block("{}.dispatch".format(name))
try:
old_unwind, self.unwind_target = self.unwind_target, dispatcher
body_gen()
finally:
self.unwind_target = old_unwind
if not self.current_block.is_terminated():
finally_gen()
self.post_body = self.current_block
self.current_block = self.add_block("{}.cleanup".format(name))
dispatcher.append(ir.LandingPad(self.current_block))
finally_gen()
self.raise_exn()
self.current_block = self.post_body
def visit_With(self, node):
context_expr_node = node.items[0].context_expr
optional_vars_node = node.items[0].optional_vars
if types.is_builtin(context_expr_node.type, "sequential"):
self.visit(node.body)
return
elif types.is_builtin(context_expr_node.type, "parallel"):
parallel = self.append(ir.Parallel([]))
@ -748,32 +763,44 @@ class ARTIQIRGenerator(algorithm.Visitor):
for tail in tails:
if not tail.is_terminated():
tail.append(ir.Branch(self.current_block))
elif isinstance(context_expr_node, asttyped.CallT) and \
types.is_builtin(context_expr_node.func.type, "watchdog"):
timeout = self.visit(context_expr_node.args[0])
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
ir.Constant(1000, builtins.TFloat())))
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
return
watchdog = self.append(ir.Builtin("watchdog_set", [timeout_ms_int], builtins.TInt32()))
cleanup = []
for item_node in node.items:
context_expr_node = item_node.context_expr
optional_vars_node = item_node.optional_vars
dispatcher = self.add_block("watchdog.dispatch")
if isinstance(context_expr_node, asttyped.CallT) and \
types.is_builtin(context_expr_node.func.type, "watchdog"):
timeout = self.visit(context_expr_node.args[0])
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
ir.Constant(1000, builtins.TFloat())))
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
try:
old_unwind, self.unwind_target = self.unwind_target, dispatcher
self.visit(node.body)
finally:
self.unwind_target = old_unwind
watchdog_id = self.append(ir.Builtin("watchdog_set", [timeout_ms_int],
builtins.TInt32()))
cleanup.append(lambda:
self.append(ir.Builtin("watchdog_clear", [watchdog_id], builtins.TNone())))
else: # user-defined context manager
context_mgr = self.visit(context_expr_node)
enter_fn = self._get_attribute(context_mgr, '__enter__')
exit_fn = self._get_attribute(context_mgr, '__exit__')
cleanup = self.add_block('watchdog.cleanup')
landingpad = dispatcher.append(ir.LandingPad(cleanup))
cleanup.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
cleanup.append(ir.Reraise(self.unwind_target))
try:
self.current_assign = self._user_call(enter_fn, [], {})
if optional_vars_node is not None:
self.visit(optional_vars_node)
finally:
self.current_assign = None
if not self.current_block.is_terminated():
self.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
else:
assert False
none = self.append(ir.Alloc([], builtins.TNone()))
cleanup.append(lambda:
self._user_call(exit_fn, [none, none, none], {}))
self._try_finally(
body_gen=lambda: self.visit(node.body),
finally_gen=lambda: [thunk() for thunk in cleanup],
name="with")
# Expression visitors
# These visitors return a node in addition to mutating
@ -850,14 +877,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
else:
return self._set_local(node.id, self.current_assign)
def visit_AttributeT(self, node):
try:
old_assign, self.current_assign = self.current_assign, None
obj = self.visit(node.value)
finally:
self.current_assign = old_assign
if node.attr not in obj.type.find().attributes:
def _get_attribute(self, obj, attr_name):
if attr_name not in obj.type.find().attributes:
# A class attribute. Get the constructor (class object) and
# extract the attribute from it.
constr_type = obj.type.constructor
@ -865,16 +886,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
constr_type.name, constr_type,
name="constructor." + constr_type.name))
if types.is_function(constr.type.attributes[node.attr]):
if types.is_function(constr.type.attributes[attr_name]):
# A method. Construct a method object instead.
func = self.append(ir.GetAttr(constr, node.attr))
return self.append(ir.Alloc([func, obj], node.type))
func = self.append(ir.GetAttr(constr, attr_name))
return self.append(ir.Alloc([func, obj],
types.TMethod(obj.type, func.type)))
else:
obj = constr
return self.append(ir.GetAttr(obj, attr_name,
name="{}.{}".format(_readable_name(obj), attr_name)))
def visit_AttributeT(self, node):
try:
old_assign, self.current_assign = self.current_assign, None
obj = self.visit(node.value)
finally:
self.current_assign = old_assign
if self.current_assign is None:
return self.append(ir.GetAttr(obj, node.attr,
name="{}.{}".format(_readable_name(obj), node.attr)))
return self._get_attribute(obj, node.attr)
elif types.is_rpc_function(self.current_assign.type):
# RPC functions are just type-level markers
return self.append(ir.Builtin("nop", [], builtins.TNone()))
@ -1624,8 +1655,70 @@ class ARTIQIRGenerator(algorithm.Visitor):
node.loc)
self.engine.process(diag)
def _user_call(self, callee, positional, keywords, arg_exprs={}):
if types.is_function(callee.type):
func = callee
self_arg = None
fn_typ = callee.type
offset = 0
elif types.is_method(callee.type):
func = self.append(ir.GetAttr(callee, "__func__"))
self_arg = self.append(ir.GetAttr(callee, "__self__"))
fn_typ = types.get_method_function(callee.type)
offset = 1
else:
assert False
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
for index, arg in enumerate(positional):
if index + offset < len(fn_typ.args):
args[index + offset] = arg
else:
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
for keyword in keywords:
arg = keywords[keyword]
if keyword in fn_typ.args:
for index, arg_name in enumerate(fn_typ.args):
if keyword == arg_name:
assert args[index] is None
args[index] = arg
break
elif keyword in fn_typ.optargs:
for index, optarg_name in enumerate(fn_typ.optargs):
if keyword == optarg_name:
assert args[len(fn_typ.args) + index] is None
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
break
for index, optarg_name in enumerate(fn_typ.optargs):
if args[len(fn_typ.args) + index] is None:
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))
if self_arg is not None:
assert args[0] is None
args[0] = self_arg
assert None not in args
if self.unwind_target is None:
insn = self.append(ir.Call(func, args, arg_exprs))
else:
after_invoke = self.add_block()
insn = self.append(ir.Invoke(func, args, arg_exprs,
after_invoke, self.unwind_target))
self.current_block = after_invoke
return insn
def visit_CallT(self, node):
typ = node.func.type.find()
if not types.is_builtin(node.func.type):
callee = self.visit(node.func)
args = [self.visit(arg_node) for arg_node in node.args]
keywords = {kw_node.arg: self.visit(kw_node.value) for kw_node in node.keywords}
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
before_delay = self.current_block
@ -1633,68 +1726,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
before_delay.append(ir.Branch(during_delay))
self.current_block = during_delay
if types.is_builtin(typ):
if types.is_builtin(node.func.type):
insn = self.visit_builtin_call(node)
else:
if types.is_function(typ):
func = self.visit(node.func)
self_arg = None
fn_typ = typ
offset = 0
elif types.is_method(typ):
method = self.visit(node.func)
func = self.append(ir.GetAttr(method, "__func__"))
self_arg = self.append(ir.GetAttr(method, "__self__"))
fn_typ = types.get_method_function(typ)
offset = 1
else:
assert False
insn = self._user_call(callee, args, keywords, node.arg_exprs)
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
for index, arg_node in enumerate(node.args):
arg = self.visit(arg_node)
if index + offset < len(fn_typ.args):
args[index + offset] = arg
else:
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
for keyword in node.keywords:
arg = self.visit(keyword.value)
if keyword.arg in fn_typ.args:
for index, arg_name in enumerate(fn_typ.args):
if keyword.arg == arg_name:
assert args[index] is None
args[index] = arg
break
elif keyword.arg in fn_typ.optargs:
for index, optarg_name in enumerate(fn_typ.optargs):
if keyword.arg == optarg_name:
assert args[len(fn_typ.args) + index] is None
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
break
for index, optarg_name in enumerate(fn_typ.optargs):
if args[len(fn_typ.args) + index] is None:
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))
if self_arg is not None:
assert args[0] is None
args[0] = self_arg
assert None not in args
if self.unwind_target is None:
insn = self.append(ir.Call(func, args, node.arg_exprs))
else:
after_invoke = self.add_block()
insn = self.append(ir.Invoke(func, args, node.arg_exprs,
after_invoke, self.unwind_target))
self.current_block = after_invoke
method_key = None
if isinstance(node.func, asttyped.AttributeT):
attr_node = node.func
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn)

View File

@ -0,0 +1,33 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
# RUN: %python %s
class contextmgr:
def __enter__(self):
print(2)
def __exit__(self, n1, n2, n3):
print(4)
# CHECK-L: a 1
# CHECK-L: 2
# CHECK-L: a 3
# CHECK-L: 4
# CHECK-L: a 5
print("a", 1)
with contextmgr():
print("a", 3)
print("a", 5)
# CHECK-L: b 1
# CHECK-L: 2
# CHECK-L: 4
# CHECK-L: b 6
try:
print("b", 1)
with contextmgr():
[0][1]
print("b", 3)
print("b", 5)
except:
pass
print("b", 6)