forked from M-Labs/artiq
1
0
Fork 0

transforms.artiq_ir_generator: add support for user-defined context managers.

This commit is contained in:
whitequark 2016-01-05 04:10:31 +00:00
parent d633c8e1f8
commit 6a6d7dab19
2 changed files with 171 additions and 102 deletions

View File

@ -555,7 +555,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
def visit_Continue(self, node): def visit_Continue(self, node):
self.append(ir.Branch(self.continue_target)) self.append(ir.Branch(self.continue_target))
def raise_exn(self, exn, loc=None): def raise_exn(self, exn=None, loc=None):
if self.final_branch is not None: if self.final_branch is not None:
raise_proxy = self.add_block("try.raise") raise_proxy = self.add_block("try.raise")
self.final_branch(raise_proxy, self.current_block) self.final_branch(raise_proxy, self.current_block)
@ -718,19 +718,34 @@ class ARTIQIRGenerator(algorithm.Visitor):
if not post_handler.is_terminated(): if not post_handler.is_terminated():
post_handler.append(ir.Branch(tail)) post_handler.append(ir.Branch(tail))
def visit_With(self, node): def _try_finally(self, body_gen, finally_gen, name):
if len(node.items) != 1: dispatcher = self.add_block("{}.dispatch".format(name))
diag = diagnostic.Diagnostic("fatal",
"only one expression per 'with' statement is supported",
{"type": types.TypePrinter().name(typ)},
node.context_expr.loc)
self.engine.process(diag)
try:
old_unwind, self.unwind_target = self.unwind_target, dispatcher
body_gen()
finally:
self.unwind_target = old_unwind
if not self.current_block.is_terminated():
finally_gen()
self.post_body = self.current_block
self.current_block = self.add_block("{}.cleanup".format(name))
dispatcher.append(ir.LandingPad(self.current_block))
finally_gen()
self.raise_exn()
self.current_block = self.post_body
def visit_With(self, node):
context_expr_node = node.items[0].context_expr context_expr_node = node.items[0].context_expr
optional_vars_node = node.items[0].optional_vars optional_vars_node = node.items[0].optional_vars
if types.is_builtin(context_expr_node.type, "sequential"): if types.is_builtin(context_expr_node.type, "sequential"):
self.visit(node.body) self.visit(node.body)
return
elif types.is_builtin(context_expr_node.type, "parallel"): elif types.is_builtin(context_expr_node.type, "parallel"):
parallel = self.append(ir.Parallel([])) parallel = self.append(ir.Parallel([]))
@ -748,32 +763,44 @@ class ARTIQIRGenerator(algorithm.Visitor):
for tail in tails: for tail in tails:
if not tail.is_terminated(): if not tail.is_terminated():
tail.append(ir.Branch(self.current_block)) tail.append(ir.Branch(self.current_block))
elif isinstance(context_expr_node, asttyped.CallT) and \ return
types.is_builtin(context_expr_node.func.type, "watchdog"):
timeout = self.visit(context_expr_node.args[0])
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
ir.Constant(1000, builtins.TFloat())))
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
watchdog = self.append(ir.Builtin("watchdog_set", [timeout_ms_int], builtins.TInt32())) cleanup = []
for item_node in node.items:
context_expr_node = item_node.context_expr
optional_vars_node = item_node.optional_vars
dispatcher = self.add_block("watchdog.dispatch") if isinstance(context_expr_node, asttyped.CallT) and \
types.is_builtin(context_expr_node.func.type, "watchdog"):
timeout = self.visit(context_expr_node.args[0])
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
ir.Constant(1000, builtins.TFloat())))
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
try: watchdog_id = self.append(ir.Builtin("watchdog_set", [timeout_ms_int],
old_unwind, self.unwind_target = self.unwind_target, dispatcher builtins.TInt32()))
self.visit(node.body) cleanup.append(lambda:
finally: self.append(ir.Builtin("watchdog_clear", [watchdog_id], builtins.TNone())))
self.unwind_target = old_unwind else: # user-defined context manager
context_mgr = self.visit(context_expr_node)
enter_fn = self._get_attribute(context_mgr, '__enter__')
exit_fn = self._get_attribute(context_mgr, '__exit__')
cleanup = self.add_block('watchdog.cleanup') try:
landingpad = dispatcher.append(ir.LandingPad(cleanup)) self.current_assign = self._user_call(enter_fn, [], {})
cleanup.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone())) if optional_vars_node is not None:
cleanup.append(ir.Reraise(self.unwind_target)) self.visit(optional_vars_node)
finally:
self.current_assign = None
if not self.current_block.is_terminated(): none = self.append(ir.Alloc([], builtins.TNone()))
self.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone())) cleanup.append(lambda:
else: self._user_call(exit_fn, [none, none, none], {}))
assert False
self._try_finally(
body_gen=lambda: self.visit(node.body),
finally_gen=lambda: [thunk() for thunk in cleanup],
name="with")
# Expression visitors # Expression visitors
# These visitors return a node in addition to mutating # These visitors return a node in addition to mutating
@ -850,14 +877,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
else: else:
return self._set_local(node.id, self.current_assign) return self._set_local(node.id, self.current_assign)
def visit_AttributeT(self, node): def _get_attribute(self, obj, attr_name):
try: if attr_name not in obj.type.find().attributes:
old_assign, self.current_assign = self.current_assign, None
obj = self.visit(node.value)
finally:
self.current_assign = old_assign
if node.attr not in obj.type.find().attributes:
# A class attribute. Get the constructor (class object) and # A class attribute. Get the constructor (class object) and
# extract the attribute from it. # extract the attribute from it.
constr_type = obj.type.constructor constr_type = obj.type.constructor
@ -865,16 +886,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
constr_type.name, constr_type, constr_type.name, constr_type,
name="constructor." + constr_type.name)) name="constructor." + constr_type.name))
if types.is_function(constr.type.attributes[node.attr]): if types.is_function(constr.type.attributes[attr_name]):
# A method. Construct a method object instead. # A method. Construct a method object instead.
func = self.append(ir.GetAttr(constr, node.attr)) func = self.append(ir.GetAttr(constr, attr_name))
return self.append(ir.Alloc([func, obj], node.type)) return self.append(ir.Alloc([func, obj],
types.TMethod(obj.type, func.type)))
else: else:
obj = constr obj = constr
return self.append(ir.GetAttr(obj, attr_name,
name="{}.{}".format(_readable_name(obj), attr_name)))
def visit_AttributeT(self, node):
try:
old_assign, self.current_assign = self.current_assign, None
obj = self.visit(node.value)
finally:
self.current_assign = old_assign
if self.current_assign is None: if self.current_assign is None:
return self.append(ir.GetAttr(obj, node.attr, return self._get_attribute(obj, node.attr)
name="{}.{}".format(_readable_name(obj), node.attr)))
elif types.is_rpc_function(self.current_assign.type): elif types.is_rpc_function(self.current_assign.type):
# RPC functions are just type-level markers # RPC functions are just type-level markers
return self.append(ir.Builtin("nop", [], builtins.TNone())) return self.append(ir.Builtin("nop", [], builtins.TNone()))
@ -1624,8 +1655,70 @@ class ARTIQIRGenerator(algorithm.Visitor):
node.loc) node.loc)
self.engine.process(diag) self.engine.process(diag)
def _user_call(self, callee, positional, keywords, arg_exprs={}):
if types.is_function(callee.type):
func = callee
self_arg = None
fn_typ = callee.type
offset = 0
elif types.is_method(callee.type):
func = self.append(ir.GetAttr(callee, "__func__"))
self_arg = self.append(ir.GetAttr(callee, "__self__"))
fn_typ = types.get_method_function(callee.type)
offset = 1
else:
assert False
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
for index, arg in enumerate(positional):
if index + offset < len(fn_typ.args):
args[index + offset] = arg
else:
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
for keyword in keywords:
arg = keywords[keyword]
if keyword in fn_typ.args:
for index, arg_name in enumerate(fn_typ.args):
if keyword == arg_name:
assert args[index] is None
args[index] = arg
break
elif keyword in fn_typ.optargs:
for index, optarg_name in enumerate(fn_typ.optargs):
if keyword == optarg_name:
assert args[len(fn_typ.args) + index] is None
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
break
for index, optarg_name in enumerate(fn_typ.optargs):
if args[len(fn_typ.args) + index] is None:
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))
if self_arg is not None:
assert args[0] is None
args[0] = self_arg
assert None not in args
if self.unwind_target is None:
insn = self.append(ir.Call(func, args, arg_exprs))
else:
after_invoke = self.add_block()
insn = self.append(ir.Invoke(func, args, arg_exprs,
after_invoke, self.unwind_target))
self.current_block = after_invoke
return insn
def visit_CallT(self, node): def visit_CallT(self, node):
typ = node.func.type.find() if not types.is_builtin(node.func.type):
callee = self.visit(node.func)
args = [self.visit(arg_node) for arg_node in node.args]
keywords = {kw_node.arg: self.visit(kw_node.value) for kw_node in node.keywords}
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0): if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
before_delay = self.current_block before_delay = self.current_block
@ -1633,68 +1726,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
before_delay.append(ir.Branch(during_delay)) before_delay.append(ir.Branch(during_delay))
self.current_block = during_delay self.current_block = during_delay
if types.is_builtin(typ): if types.is_builtin(node.func.type):
insn = self.visit_builtin_call(node) insn = self.visit_builtin_call(node)
else: else:
if types.is_function(typ): insn = self._user_call(callee, args, keywords, node.arg_exprs)
func = self.visit(node.func)
self_arg = None
fn_typ = typ
offset = 0
elif types.is_method(typ):
method = self.visit(node.func)
func = self.append(ir.GetAttr(method, "__func__"))
self_arg = self.append(ir.GetAttr(method, "__self__"))
fn_typ = types.get_method_function(typ)
offset = 1
else:
assert False
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
for index, arg_node in enumerate(node.args):
arg = self.visit(arg_node)
if index + offset < len(fn_typ.args):
args[index + offset] = arg
else:
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
for keyword in node.keywords:
arg = self.visit(keyword.value)
if keyword.arg in fn_typ.args:
for index, arg_name in enumerate(fn_typ.args):
if keyword.arg == arg_name:
assert args[index] is None
args[index] = arg
break
elif keyword.arg in fn_typ.optargs:
for index, optarg_name in enumerate(fn_typ.optargs):
if keyword.arg == optarg_name:
assert args[len(fn_typ.args) + index] is None
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
break
for index, optarg_name in enumerate(fn_typ.optargs):
if args[len(fn_typ.args) + index] is None:
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))
if self_arg is not None:
assert args[0] is None
args[0] = self_arg
assert None not in args
if self.unwind_target is None:
insn = self.append(ir.Call(func, args, node.arg_exprs))
else:
after_invoke = self.add_block()
insn = self.append(ir.Invoke(func, args, node.arg_exprs,
after_invoke, self.unwind_target))
self.current_block = after_invoke
method_key = None
if isinstance(node.func, asttyped.AttributeT): if isinstance(node.func, asttyped.AttributeT):
attr_node = node.func attr_node = node.func
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn) self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn)

View File

@ -0,0 +1,33 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
# RUN: %python %s
class contextmgr:
def __enter__(self):
print(2)
def __exit__(self, n1, n2, n3):
print(4)
# CHECK-L: a 1
# CHECK-L: 2
# CHECK-L: a 3
# CHECK-L: 4
# CHECK-L: a 5
print("a", 1)
with contextmgr():
print("a", 3)
print("a", 5)
# CHECK-L: b 1
# CHECK-L: 2
# CHECK-L: 4
# CHECK-L: b 6
try:
print("b", 1)
with contextmgr():
[0][1]
print("b", 3)
print("b", 5)
except:
pass
print("b", 6)