forked from M-Labs/artiq
transforms.artiq_ir_generator: add support for user-defined context managers.
This commit is contained in:
parent
d633c8e1f8
commit
6a6d7dab19
|
@ -555,7 +555,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
def visit_Continue(self, node):
|
def visit_Continue(self, node):
|
||||||
self.append(ir.Branch(self.continue_target))
|
self.append(ir.Branch(self.continue_target))
|
||||||
|
|
||||||
def raise_exn(self, exn, loc=None):
|
def raise_exn(self, exn=None, loc=None):
|
||||||
if self.final_branch is not None:
|
if self.final_branch is not None:
|
||||||
raise_proxy = self.add_block("try.raise")
|
raise_proxy = self.add_block("try.raise")
|
||||||
self.final_branch(raise_proxy, self.current_block)
|
self.final_branch(raise_proxy, self.current_block)
|
||||||
|
@ -718,19 +718,34 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
if not post_handler.is_terminated():
|
if not post_handler.is_terminated():
|
||||||
post_handler.append(ir.Branch(tail))
|
post_handler.append(ir.Branch(tail))
|
||||||
|
|
||||||
def visit_With(self, node):
|
def _try_finally(self, body_gen, finally_gen, name):
|
||||||
if len(node.items) != 1:
|
dispatcher = self.add_block("{}.dispatch".format(name))
|
||||||
diag = diagnostic.Diagnostic("fatal",
|
|
||||||
"only one expression per 'with' statement is supported",
|
|
||||||
{"type": types.TypePrinter().name(typ)},
|
|
||||||
node.context_expr.loc)
|
|
||||||
self.engine.process(diag)
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
old_unwind, self.unwind_target = self.unwind_target, dispatcher
|
||||||
|
body_gen()
|
||||||
|
finally:
|
||||||
|
self.unwind_target = old_unwind
|
||||||
|
|
||||||
|
if not self.current_block.is_terminated():
|
||||||
|
finally_gen()
|
||||||
|
|
||||||
|
self.post_body = self.current_block
|
||||||
|
|
||||||
|
self.current_block = self.add_block("{}.cleanup".format(name))
|
||||||
|
dispatcher.append(ir.LandingPad(self.current_block))
|
||||||
|
finally_gen()
|
||||||
|
self.raise_exn()
|
||||||
|
|
||||||
|
self.current_block = self.post_body
|
||||||
|
|
||||||
|
def visit_With(self, node):
|
||||||
context_expr_node = node.items[0].context_expr
|
context_expr_node = node.items[0].context_expr
|
||||||
optional_vars_node = node.items[0].optional_vars
|
optional_vars_node = node.items[0].optional_vars
|
||||||
|
|
||||||
if types.is_builtin(context_expr_node.type, "sequential"):
|
if types.is_builtin(context_expr_node.type, "sequential"):
|
||||||
self.visit(node.body)
|
self.visit(node.body)
|
||||||
|
return
|
||||||
elif types.is_builtin(context_expr_node.type, "parallel"):
|
elif types.is_builtin(context_expr_node.type, "parallel"):
|
||||||
parallel = self.append(ir.Parallel([]))
|
parallel = self.append(ir.Parallel([]))
|
||||||
|
|
||||||
|
@ -748,32 +763,44 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
for tail in tails:
|
for tail in tails:
|
||||||
if not tail.is_terminated():
|
if not tail.is_terminated():
|
||||||
tail.append(ir.Branch(self.current_block))
|
tail.append(ir.Branch(self.current_block))
|
||||||
elif isinstance(context_expr_node, asttyped.CallT) and \
|
return
|
||||||
|
|
||||||
|
cleanup = []
|
||||||
|
for item_node in node.items:
|
||||||
|
context_expr_node = item_node.context_expr
|
||||||
|
optional_vars_node = item_node.optional_vars
|
||||||
|
|
||||||
|
if isinstance(context_expr_node, asttyped.CallT) and \
|
||||||
types.is_builtin(context_expr_node.func.type, "watchdog"):
|
types.is_builtin(context_expr_node.func.type, "watchdog"):
|
||||||
timeout = self.visit(context_expr_node.args[0])
|
timeout = self.visit(context_expr_node.args[0])
|
||||||
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
|
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
|
||||||
ir.Constant(1000, builtins.TFloat())))
|
ir.Constant(1000, builtins.TFloat())))
|
||||||
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
|
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
|
||||||
|
|
||||||
watchdog = self.append(ir.Builtin("watchdog_set", [timeout_ms_int], builtins.TInt32()))
|
watchdog_id = self.append(ir.Builtin("watchdog_set", [timeout_ms_int],
|
||||||
|
builtins.TInt32()))
|
||||||
dispatcher = self.add_block("watchdog.dispatch")
|
cleanup.append(lambda:
|
||||||
|
self.append(ir.Builtin("watchdog_clear", [watchdog_id], builtins.TNone())))
|
||||||
|
else: # user-defined context manager
|
||||||
|
context_mgr = self.visit(context_expr_node)
|
||||||
|
enter_fn = self._get_attribute(context_mgr, '__enter__')
|
||||||
|
exit_fn = self._get_attribute(context_mgr, '__exit__')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
old_unwind, self.unwind_target = self.unwind_target, dispatcher
|
self.current_assign = self._user_call(enter_fn, [], {})
|
||||||
self.visit(node.body)
|
if optional_vars_node is not None:
|
||||||
|
self.visit(optional_vars_node)
|
||||||
finally:
|
finally:
|
||||||
self.unwind_target = old_unwind
|
self.current_assign = None
|
||||||
|
|
||||||
cleanup = self.add_block('watchdog.cleanup')
|
none = self.append(ir.Alloc([], builtins.TNone()))
|
||||||
landingpad = dispatcher.append(ir.LandingPad(cleanup))
|
cleanup.append(lambda:
|
||||||
cleanup.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
|
self._user_call(exit_fn, [none, none, none], {}))
|
||||||
cleanup.append(ir.Reraise(self.unwind_target))
|
|
||||||
|
|
||||||
if not self.current_block.is_terminated():
|
self._try_finally(
|
||||||
self.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
|
body_gen=lambda: self.visit(node.body),
|
||||||
else:
|
finally_gen=lambda: [thunk() for thunk in cleanup],
|
||||||
assert False
|
name="with")
|
||||||
|
|
||||||
# Expression visitors
|
# Expression visitors
|
||||||
# These visitors return a node in addition to mutating
|
# These visitors return a node in addition to mutating
|
||||||
|
@ -850,14 +877,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
else:
|
else:
|
||||||
return self._set_local(node.id, self.current_assign)
|
return self._set_local(node.id, self.current_assign)
|
||||||
|
|
||||||
def visit_AttributeT(self, node):
|
def _get_attribute(self, obj, attr_name):
|
||||||
try:
|
if attr_name not in obj.type.find().attributes:
|
||||||
old_assign, self.current_assign = self.current_assign, None
|
|
||||||
obj = self.visit(node.value)
|
|
||||||
finally:
|
|
||||||
self.current_assign = old_assign
|
|
||||||
|
|
||||||
if node.attr not in obj.type.find().attributes:
|
|
||||||
# A class attribute. Get the constructor (class object) and
|
# A class attribute. Get the constructor (class object) and
|
||||||
# extract the attribute from it.
|
# extract the attribute from it.
|
||||||
constr_type = obj.type.constructor
|
constr_type = obj.type.constructor
|
||||||
|
@ -865,16 +886,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
constr_type.name, constr_type,
|
constr_type.name, constr_type,
|
||||||
name="constructor." + constr_type.name))
|
name="constructor." + constr_type.name))
|
||||||
|
|
||||||
if types.is_function(constr.type.attributes[node.attr]):
|
if types.is_function(constr.type.attributes[attr_name]):
|
||||||
# A method. Construct a method object instead.
|
# A method. Construct a method object instead.
|
||||||
func = self.append(ir.GetAttr(constr, node.attr))
|
func = self.append(ir.GetAttr(constr, attr_name))
|
||||||
return self.append(ir.Alloc([func, obj], node.type))
|
return self.append(ir.Alloc([func, obj],
|
||||||
|
types.TMethod(obj.type, func.type)))
|
||||||
else:
|
else:
|
||||||
obj = constr
|
obj = constr
|
||||||
|
|
||||||
|
return self.append(ir.GetAttr(obj, attr_name,
|
||||||
|
name="{}.{}".format(_readable_name(obj), attr_name)))
|
||||||
|
|
||||||
|
def visit_AttributeT(self, node):
|
||||||
|
try:
|
||||||
|
old_assign, self.current_assign = self.current_assign, None
|
||||||
|
obj = self.visit(node.value)
|
||||||
|
finally:
|
||||||
|
self.current_assign = old_assign
|
||||||
|
|
||||||
if self.current_assign is None:
|
if self.current_assign is None:
|
||||||
return self.append(ir.GetAttr(obj, node.attr,
|
return self._get_attribute(obj, node.attr)
|
||||||
name="{}.{}".format(_readable_name(obj), node.attr)))
|
|
||||||
elif types.is_rpc_function(self.current_assign.type):
|
elif types.is_rpc_function(self.current_assign.type):
|
||||||
# RPC functions are just type-level markers
|
# RPC functions are just type-level markers
|
||||||
return self.append(ir.Builtin("nop", [], builtins.TNone()))
|
return self.append(ir.Builtin("nop", [], builtins.TNone()))
|
||||||
|
@ -1624,52 +1655,39 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
node.loc)
|
node.loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
def visit_CallT(self, node):
|
def _user_call(self, callee, positional, keywords, arg_exprs={}):
|
||||||
typ = node.func.type.find()
|
if types.is_function(callee.type):
|
||||||
|
func = callee
|
||||||
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
|
|
||||||
before_delay = self.current_block
|
|
||||||
during_delay = self.add_block()
|
|
||||||
before_delay.append(ir.Branch(during_delay))
|
|
||||||
self.current_block = during_delay
|
|
||||||
|
|
||||||
if types.is_builtin(typ):
|
|
||||||
insn = self.visit_builtin_call(node)
|
|
||||||
else:
|
|
||||||
if types.is_function(typ):
|
|
||||||
func = self.visit(node.func)
|
|
||||||
self_arg = None
|
self_arg = None
|
||||||
fn_typ = typ
|
fn_typ = callee.type
|
||||||
offset = 0
|
offset = 0
|
||||||
elif types.is_method(typ):
|
elif types.is_method(callee.type):
|
||||||
method = self.visit(node.func)
|
func = self.append(ir.GetAttr(callee, "__func__"))
|
||||||
func = self.append(ir.GetAttr(method, "__func__"))
|
self_arg = self.append(ir.GetAttr(callee, "__self__"))
|
||||||
self_arg = self.append(ir.GetAttr(method, "__self__"))
|
fn_typ = types.get_method_function(callee.type)
|
||||||
fn_typ = types.get_method_function(typ)
|
|
||||||
offset = 1
|
offset = 1
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
|
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
|
||||||
|
|
||||||
for index, arg_node in enumerate(node.args):
|
for index, arg in enumerate(positional):
|
||||||
arg = self.visit(arg_node)
|
|
||||||
if index + offset < len(fn_typ.args):
|
if index + offset < len(fn_typ.args):
|
||||||
args[index + offset] = arg
|
args[index + offset] = arg
|
||||||
else:
|
else:
|
||||||
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
|
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
|
||||||
|
|
||||||
for keyword in node.keywords:
|
for keyword in keywords:
|
||||||
arg = self.visit(keyword.value)
|
arg = keywords[keyword]
|
||||||
if keyword.arg in fn_typ.args:
|
if keyword in fn_typ.args:
|
||||||
for index, arg_name in enumerate(fn_typ.args):
|
for index, arg_name in enumerate(fn_typ.args):
|
||||||
if keyword.arg == arg_name:
|
if keyword == arg_name:
|
||||||
assert args[index] is None
|
assert args[index] is None
|
||||||
args[index] = arg
|
args[index] = arg
|
||||||
break
|
break
|
||||||
elif keyword.arg in fn_typ.optargs:
|
elif keyword in fn_typ.optargs:
|
||||||
for index, optarg_name in enumerate(fn_typ.optargs):
|
for index, optarg_name in enumerate(fn_typ.optargs):
|
||||||
if keyword.arg == optarg_name:
|
if keyword == optarg_name:
|
||||||
assert args[len(fn_typ.args) + index] is None
|
assert args[len(fn_typ.args) + index] is None
|
||||||
args[len(fn_typ.args) + index] = \
|
args[len(fn_typ.args) + index] = \
|
||||||
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
|
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
|
||||||
|
@ -1687,14 +1705,32 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
assert None not in args
|
assert None not in args
|
||||||
|
|
||||||
if self.unwind_target is None:
|
if self.unwind_target is None:
|
||||||
insn = self.append(ir.Call(func, args, node.arg_exprs))
|
insn = self.append(ir.Call(func, args, arg_exprs))
|
||||||
else:
|
else:
|
||||||
after_invoke = self.add_block()
|
after_invoke = self.add_block()
|
||||||
insn = self.append(ir.Invoke(func, args, node.arg_exprs,
|
insn = self.append(ir.Invoke(func, args, arg_exprs,
|
||||||
after_invoke, self.unwind_target))
|
after_invoke, self.unwind_target))
|
||||||
self.current_block = after_invoke
|
self.current_block = after_invoke
|
||||||
|
|
||||||
method_key = None
|
return insn
|
||||||
|
|
||||||
|
def visit_CallT(self, node):
|
||||||
|
if not types.is_builtin(node.func.type):
|
||||||
|
callee = self.visit(node.func)
|
||||||
|
args = [self.visit(arg_node) for arg_node in node.args]
|
||||||
|
keywords = {kw_node.arg: self.visit(kw_node.value) for kw_node in node.keywords}
|
||||||
|
|
||||||
|
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
|
||||||
|
before_delay = self.current_block
|
||||||
|
during_delay = self.add_block()
|
||||||
|
before_delay.append(ir.Branch(during_delay))
|
||||||
|
self.current_block = during_delay
|
||||||
|
|
||||||
|
if types.is_builtin(node.func.type):
|
||||||
|
insn = self.visit_builtin_call(node)
|
||||||
|
else:
|
||||||
|
insn = self._user_call(callee, args, keywords, node.arg_exprs)
|
||||||
|
|
||||||
if isinstance(node.func, asttyped.AttributeT):
|
if isinstance(node.func, asttyped.AttributeT):
|
||||||
attr_node = node.func
|
attr_node = node.func
|
||||||
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn)
|
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn)
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||||
|
# RUN: %python %s
|
||||||
|
|
||||||
|
class contextmgr:
|
||||||
|
def __enter__(self):
|
||||||
|
print(2)
|
||||||
|
|
||||||
|
def __exit__(self, n1, n2, n3):
|
||||||
|
print(4)
|
||||||
|
|
||||||
|
# CHECK-L: a 1
|
||||||
|
# CHECK-L: 2
|
||||||
|
# CHECK-L: a 3
|
||||||
|
# CHECK-L: 4
|
||||||
|
# CHECK-L: a 5
|
||||||
|
print("a", 1)
|
||||||
|
with contextmgr():
|
||||||
|
print("a", 3)
|
||||||
|
print("a", 5)
|
||||||
|
|
||||||
|
# CHECK-L: b 1
|
||||||
|
# CHECK-L: 2
|
||||||
|
# CHECK-L: 4
|
||||||
|
# CHECK-L: b 6
|
||||||
|
try:
|
||||||
|
print("b", 1)
|
||||||
|
with contextmgr():
|
||||||
|
[0][1]
|
||||||
|
print("b", 3)
|
||||||
|
print("b", 5)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
print("b", 6)
|
Loading…
Reference in New Issue