diff --git a/artiq/devices/core.py b/artiq/devices/core.py index 56a3e1e51..8630ef4d0 100644 --- a/artiq/devices/core.py +++ b/artiq/devices/core.py @@ -44,7 +44,8 @@ class Core: # transform/simplify AST _debug_unparse = _make_debug_unparse("fold_constants_2") - func_def, rpc_map = inline(self, k_function, k_args, k_kwargs) + func_def, rpc_map, exception_map = inline( + self, k_function, k_args, k_kwargs) _debug_unparse("inline", func_def) lower_units(func_def, self.runtime_env.ref_period) diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index 630341c43..0a9fc6b56 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -27,13 +27,14 @@ class _ReferenceManager: # inlined_name -> use_count self.use_count = dict() self.rpc_map = defaultdict(lambda: len(self.rpc_map)) + self.exception_map = defaultdict(lambda: len(self.exception_map)) self.kernel_attr_init = [] # reserved names for kg in core_language.kernel_globals: self.use_count[kg] = 1 for name in ("int", "round", "int64", "round64", "float", "array", - "range", "Fraction", "Quantity", + "range", "Fraction", "Quantity", "EncodedException", "s_unit", "Hz_unit", "microcycle_unit"): self.use_count[name] = 1 @@ -93,7 +94,7 @@ _embeddable_calls = { core_language.syscall, range, int, float, round, core_language.int64, core_language.round64, core_language.array, - Fraction, units.Quantity + Fraction, units.Quantity, core_language.EncodedException } @@ -173,8 +174,8 @@ class _ReferenceReplacer(ast.NodeVisitor): retval_name = self.rm.new_name( func.k_function_info.k_function.__name__ + "_return") args = [func.__self__] + new_args - inlined, _ = inline(self.core, func.k_function_info.k_function, - args, dict(), self.rm, retval_name) + inlined, _, _ = inline(self.core, func.k_function_info.k_function, + args, dict(), self.rm, retval_name) self._insertion_point.append(ast.With( items=[ast.withitem(context_expr=ast.Name(id="sequential", ctx=ast.Load()), @@ -214,6 +215,35 @@ class _ReferenceReplacer(ast.NodeVisitor): self.generic_visit(node) return node + def visit_Raise(self, node): + if node.cause is not None: + raise NotImplementedError("Exception causes are not supported") + exception_class = self.rm.get(self.obj, self.func_name, node.exc) + if not inspect.isclass(exception_class): + raise NotImplementedError("Exception must be a class") + exception_id = self.rm.exception_map[exception_class] + node.exc = ast.copy_location( + ast.Call(func=ast.Name("EncodedException", ast.Load()), + args=[value_to_ast(exception_id)], + keywords=[], starargs=None, kwargs=None), + node.exc) + return node + + def visit_ExceptHandler(self, node): + if node.name is not None: + raise NotImplementedError("'as target' is not supported") + exception_class = self.rm.get(self.obj, self.func_name, node.type) + if not inspect.isclass(exception_class): + raise NotImplementedError("Exception type must be a class") + exception_id = self.rm.exception_map[exception_class] + node.type = ast.copy_location( + ast.Call(func=ast.Name("EncodedException", ast.Load()), + args=[value_to_ast(exception_id)], + keywords=[], starargs=None, kwargs=None), + node.type) + self.generic_visit(node) + return node + class _ListReadOnlyParams(ast.NodeVisitor): def visit_FunctionDef(self, node): @@ -272,4 +302,7 @@ def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None): r_rpc_map = dict((rpc_num, rpc_fun) for rpc_fun, rpc_num in rm.rpc_map.items()) - return func_def, r_rpc_map + r_exception_map = dict((exception_num, exception_class) + for exception_class, exception_num + in rm.exception_map.items()) + return func_def, r_rpc_map, r_exception_map