forked from M-Labs/artiq
transforms.artiq_ir_generator: add support for user-defined context managers.
This commit is contained in:
parent
d633c8e1f8
commit
6a6d7dab19
|
@ -555,7 +555,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
def visit_Continue(self, node):
|
||||
self.append(ir.Branch(self.continue_target))
|
||||
|
||||
def raise_exn(self, exn, loc=None):
|
||||
def raise_exn(self, exn=None, loc=None):
|
||||
if self.final_branch is not None:
|
||||
raise_proxy = self.add_block("try.raise")
|
||||
self.final_branch(raise_proxy, self.current_block)
|
||||
|
@ -718,19 +718,34 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
if not post_handler.is_terminated():
|
||||
post_handler.append(ir.Branch(tail))
|
||||
|
||||
def visit_With(self, node):
|
||||
if len(node.items) != 1:
|
||||
diag = diagnostic.Diagnostic("fatal",
|
||||
"only one expression per 'with' statement is supported",
|
||||
{"type": types.TypePrinter().name(typ)},
|
||||
node.context_expr.loc)
|
||||
self.engine.process(diag)
|
||||
def _try_finally(self, body_gen, finally_gen, name):
|
||||
dispatcher = self.add_block("{}.dispatch".format(name))
|
||||
|
||||
try:
|
||||
old_unwind, self.unwind_target = self.unwind_target, dispatcher
|
||||
body_gen()
|
||||
finally:
|
||||
self.unwind_target = old_unwind
|
||||
|
||||
if not self.current_block.is_terminated():
|
||||
finally_gen()
|
||||
|
||||
self.post_body = self.current_block
|
||||
|
||||
self.current_block = self.add_block("{}.cleanup".format(name))
|
||||
dispatcher.append(ir.LandingPad(self.current_block))
|
||||
finally_gen()
|
||||
self.raise_exn()
|
||||
|
||||
self.current_block = self.post_body
|
||||
|
||||
def visit_With(self, node):
|
||||
context_expr_node = node.items[0].context_expr
|
||||
optional_vars_node = node.items[0].optional_vars
|
||||
|
||||
if types.is_builtin(context_expr_node.type, "sequential"):
|
||||
self.visit(node.body)
|
||||
return
|
||||
elif types.is_builtin(context_expr_node.type, "parallel"):
|
||||
parallel = self.append(ir.Parallel([]))
|
||||
|
||||
|
@ -748,32 +763,44 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
for tail in tails:
|
||||
if not tail.is_terminated():
|
||||
tail.append(ir.Branch(self.current_block))
|
||||
elif isinstance(context_expr_node, asttyped.CallT) and \
|
||||
types.is_builtin(context_expr_node.func.type, "watchdog"):
|
||||
timeout = self.visit(context_expr_node.args[0])
|
||||
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
|
||||
ir.Constant(1000, builtins.TFloat())))
|
||||
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
|
||||
return
|
||||
|
||||
watchdog = self.append(ir.Builtin("watchdog_set", [timeout_ms_int], builtins.TInt32()))
|
||||
cleanup = []
|
||||
for item_node in node.items:
|
||||
context_expr_node = item_node.context_expr
|
||||
optional_vars_node = item_node.optional_vars
|
||||
|
||||
dispatcher = self.add_block("watchdog.dispatch")
|
||||
if isinstance(context_expr_node, asttyped.CallT) and \
|
||||
types.is_builtin(context_expr_node.func.type, "watchdog"):
|
||||
timeout = self.visit(context_expr_node.args[0])
|
||||
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
|
||||
ir.Constant(1000, builtins.TFloat())))
|
||||
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
|
||||
|
||||
try:
|
||||
old_unwind, self.unwind_target = self.unwind_target, dispatcher
|
||||
self.visit(node.body)
|
||||
finally:
|
||||
self.unwind_target = old_unwind
|
||||
watchdog_id = self.append(ir.Builtin("watchdog_set", [timeout_ms_int],
|
||||
builtins.TInt32()))
|
||||
cleanup.append(lambda:
|
||||
self.append(ir.Builtin("watchdog_clear", [watchdog_id], builtins.TNone())))
|
||||
else: # user-defined context manager
|
||||
context_mgr = self.visit(context_expr_node)
|
||||
enter_fn = self._get_attribute(context_mgr, '__enter__')
|
||||
exit_fn = self._get_attribute(context_mgr, '__exit__')
|
||||
|
||||
cleanup = self.add_block('watchdog.cleanup')
|
||||
landingpad = dispatcher.append(ir.LandingPad(cleanup))
|
||||
cleanup.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
|
||||
cleanup.append(ir.Reraise(self.unwind_target))
|
||||
try:
|
||||
self.current_assign = self._user_call(enter_fn, [], {})
|
||||
if optional_vars_node is not None:
|
||||
self.visit(optional_vars_node)
|
||||
finally:
|
||||
self.current_assign = None
|
||||
|
||||
if not self.current_block.is_terminated():
|
||||
self.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
|
||||
else:
|
||||
assert False
|
||||
none = self.append(ir.Alloc([], builtins.TNone()))
|
||||
cleanup.append(lambda:
|
||||
self._user_call(exit_fn, [none, none, none], {}))
|
||||
|
||||
self._try_finally(
|
||||
body_gen=lambda: self.visit(node.body),
|
||||
finally_gen=lambda: [thunk() for thunk in cleanup],
|
||||
name="with")
|
||||
|
||||
# Expression visitors
|
||||
# These visitors return a node in addition to mutating
|
||||
|
@ -850,14 +877,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
else:
|
||||
return self._set_local(node.id, self.current_assign)
|
||||
|
||||
def visit_AttributeT(self, node):
|
||||
try:
|
||||
old_assign, self.current_assign = self.current_assign, None
|
||||
obj = self.visit(node.value)
|
||||
finally:
|
||||
self.current_assign = old_assign
|
||||
|
||||
if node.attr not in obj.type.find().attributes:
|
||||
def _get_attribute(self, obj, attr_name):
|
||||
if attr_name not in obj.type.find().attributes:
|
||||
# A class attribute. Get the constructor (class object) and
|
||||
# extract the attribute from it.
|
||||
constr_type = obj.type.constructor
|
||||
|
@ -865,16 +886,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
constr_type.name, constr_type,
|
||||
name="constructor." + constr_type.name))
|
||||
|
||||
if types.is_function(constr.type.attributes[node.attr]):
|
||||
if types.is_function(constr.type.attributes[attr_name]):
|
||||
# A method. Construct a method object instead.
|
||||
func = self.append(ir.GetAttr(constr, node.attr))
|
||||
return self.append(ir.Alloc([func, obj], node.type))
|
||||
func = self.append(ir.GetAttr(constr, attr_name))
|
||||
return self.append(ir.Alloc([func, obj],
|
||||
types.TMethod(obj.type, func.type)))
|
||||
else:
|
||||
obj = constr
|
||||
|
||||
return self.append(ir.GetAttr(obj, attr_name,
|
||||
name="{}.{}".format(_readable_name(obj), attr_name)))
|
||||
|
||||
def visit_AttributeT(self, node):
|
||||
try:
|
||||
old_assign, self.current_assign = self.current_assign, None
|
||||
obj = self.visit(node.value)
|
||||
finally:
|
||||
self.current_assign = old_assign
|
||||
|
||||
if self.current_assign is None:
|
||||
return self.append(ir.GetAttr(obj, node.attr,
|
||||
name="{}.{}".format(_readable_name(obj), node.attr)))
|
||||
return self._get_attribute(obj, node.attr)
|
||||
elif types.is_rpc_function(self.current_assign.type):
|
||||
# RPC functions are just type-level markers
|
||||
return self.append(ir.Builtin("nop", [], builtins.TNone()))
|
||||
|
@ -1624,8 +1655,70 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
node.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
def _user_call(self, callee, positional, keywords, arg_exprs={}):
|
||||
if types.is_function(callee.type):
|
||||
func = callee
|
||||
self_arg = None
|
||||
fn_typ = callee.type
|
||||
offset = 0
|
||||
elif types.is_method(callee.type):
|
||||
func = self.append(ir.GetAttr(callee, "__func__"))
|
||||
self_arg = self.append(ir.GetAttr(callee, "__self__"))
|
||||
fn_typ = types.get_method_function(callee.type)
|
||||
offset = 1
|
||||
else:
|
||||
assert False
|
||||
|
||||
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
|
||||
|
||||
for index, arg in enumerate(positional):
|
||||
if index + offset < len(fn_typ.args):
|
||||
args[index + offset] = arg
|
||||
else:
|
||||
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
|
||||
|
||||
for keyword in keywords:
|
||||
arg = keywords[keyword]
|
||||
if keyword in fn_typ.args:
|
||||
for index, arg_name in enumerate(fn_typ.args):
|
||||
if keyword == arg_name:
|
||||
assert args[index] is None
|
||||
args[index] = arg
|
||||
break
|
||||
elif keyword in fn_typ.optargs:
|
||||
for index, optarg_name in enumerate(fn_typ.optargs):
|
||||
if keyword == optarg_name:
|
||||
assert args[len(fn_typ.args) + index] is None
|
||||
args[len(fn_typ.args) + index] = \
|
||||
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
|
||||
break
|
||||
|
||||
for index, optarg_name in enumerate(fn_typ.optargs):
|
||||
if args[len(fn_typ.args) + index] is None:
|
||||
args[len(fn_typ.args) + index] = \
|
||||
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))
|
||||
|
||||
if self_arg is not None:
|
||||
assert args[0] is None
|
||||
args[0] = self_arg
|
||||
|
||||
assert None not in args
|
||||
|
||||
if self.unwind_target is None:
|
||||
insn = self.append(ir.Call(func, args, arg_exprs))
|
||||
else:
|
||||
after_invoke = self.add_block()
|
||||
insn = self.append(ir.Invoke(func, args, arg_exprs,
|
||||
after_invoke, self.unwind_target))
|
||||
self.current_block = after_invoke
|
||||
|
||||
return insn
|
||||
|
||||
def visit_CallT(self, node):
|
||||
typ = node.func.type.find()
|
||||
if not types.is_builtin(node.func.type):
|
||||
callee = self.visit(node.func)
|
||||
args = [self.visit(arg_node) for arg_node in node.args]
|
||||
keywords = {kw_node.arg: self.visit(kw_node.value) for kw_node in node.keywords}
|
||||
|
||||
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
|
||||
before_delay = self.current_block
|
||||
|
@ -1633,68 +1726,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
before_delay.append(ir.Branch(during_delay))
|
||||
self.current_block = during_delay
|
||||
|
||||
if types.is_builtin(typ):
|
||||
if types.is_builtin(node.func.type):
|
||||
insn = self.visit_builtin_call(node)
|
||||
else:
|
||||
if types.is_function(typ):
|
||||
func = self.visit(node.func)
|
||||
self_arg = None
|
||||
fn_typ = typ
|
||||
offset = 0
|
||||
elif types.is_method(typ):
|
||||
method = self.visit(node.func)
|
||||
func = self.append(ir.GetAttr(method, "__func__"))
|
||||
self_arg = self.append(ir.GetAttr(method, "__self__"))
|
||||
fn_typ = types.get_method_function(typ)
|
||||
offset = 1
|
||||
else:
|
||||
assert False
|
||||
insn = self._user_call(callee, args, keywords, node.arg_exprs)
|
||||
|
||||
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
|
||||
|
||||
for index, arg_node in enumerate(node.args):
|
||||
arg = self.visit(arg_node)
|
||||
if index + offset < len(fn_typ.args):
|
||||
args[index + offset] = arg
|
||||
else:
|
||||
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
|
||||
|
||||
for keyword in node.keywords:
|
||||
arg = self.visit(keyword.value)
|
||||
if keyword.arg in fn_typ.args:
|
||||
for index, arg_name in enumerate(fn_typ.args):
|
||||
if keyword.arg == arg_name:
|
||||
assert args[index] is None
|
||||
args[index] = arg
|
||||
break
|
||||
elif keyword.arg in fn_typ.optargs:
|
||||
for index, optarg_name in enumerate(fn_typ.optargs):
|
||||
if keyword.arg == optarg_name:
|
||||
assert args[len(fn_typ.args) + index] is None
|
||||
args[len(fn_typ.args) + index] = \
|
||||
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
|
||||
break
|
||||
|
||||
for index, optarg_name in enumerate(fn_typ.optargs):
|
||||
if args[len(fn_typ.args) + index] is None:
|
||||
args[len(fn_typ.args) + index] = \
|
||||
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))
|
||||
|
||||
if self_arg is not None:
|
||||
assert args[0] is None
|
||||
args[0] = self_arg
|
||||
|
||||
assert None not in args
|
||||
|
||||
if self.unwind_target is None:
|
||||
insn = self.append(ir.Call(func, args, node.arg_exprs))
|
||||
else:
|
||||
after_invoke = self.add_block()
|
||||
insn = self.append(ir.Invoke(func, args, node.arg_exprs,
|
||||
after_invoke, self.unwind_target))
|
||||
self.current_block = after_invoke
|
||||
|
||||
method_key = None
|
||||
if isinstance(node.func, asttyped.AttributeT):
|
||||
attr_node = node.func
|
||||
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn)
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||
# RUN: %python %s
|
||||
|
||||
class contextmgr:
|
||||
def __enter__(self):
|
||||
print(2)
|
||||
|
||||
def __exit__(self, n1, n2, n3):
|
||||
print(4)
|
||||
|
||||
# CHECK-L: a 1
|
||||
# CHECK-L: 2
|
||||
# CHECK-L: a 3
|
||||
# CHECK-L: 4
|
||||
# CHECK-L: a 5
|
||||
print("a", 1)
|
||||
with contextmgr():
|
||||
print("a", 3)
|
||||
print("a", 5)
|
||||
|
||||
# CHECK-L: b 1
|
||||
# CHECK-L: 2
|
||||
# CHECK-L: 4
|
||||
# CHECK-L: b 6
|
||||
try:
|
||||
print("b", 1)
|
||||
with contextmgr():
|
||||
[0][1]
|
||||
print("b", 3)
|
||||
print("b", 5)
|
||||
except:
|
||||
pass
|
||||
print("b", 6)
|
Loading…
Reference in New Issue