diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index d43404b20..3cce61105 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -40,7 +40,8 @@ class Source: return cls(source.Buffer(f.read(), filename, 1), engine=engine) class Module: - def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False): + def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False, + raise_assertion_errors=False): self.attribute_writeback = attribute_writeback self.engine = src.engine self.embedding_map = src.embedding_map @@ -55,9 +56,11 @@ class Module: iodelay_estimator = transforms.IODelayEstimator(engine=self.engine, ref_period=ref_period) constness_validator = validators.ConstnessValidator(engine=self.engine) - artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine, - module_name=src.name, - ref_period=ref_period) + artiq_ir_generator = transforms.ARTIQIRGenerator( + engine=self.engine, + module_name=src.name, + ref_period=ref_period, + raise_assertion_errors=raise_assertion_errors) dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine) local_access_validator = validators.LocalAccessValidator(engine=self.engine) local_demoter = transforms.LocalDemoter() diff --git a/artiq/compiler/targets.py b/artiq/compiler/targets.py index 9ebc7907d..427d9f07c 100644 --- a/artiq/compiler/targets.py +++ b/artiq/compiler/targets.py @@ -80,6 +80,9 @@ class Target: determined from data_layout due to JIT. :var now_pinning: (boolean) Whether the target implements the now-pinning RTIO optimization. + :var raise_assertion_errors: (bool) + Whether to raise an AssertionError on failed assertions or abort/panic + instead. """ triple = "unknown" data_layout = "" @@ -87,6 +90,7 @@ class Target: print_function = "printf" little_endian = False now_pinning = True + raise_assertion_errors = False tool_ld = "ld.lld" tool_strip = "llvm-strip" @@ -277,6 +281,10 @@ class CortexA9Target(Target): little_endian = True now_pinning = False + # Support for marshalling kernel CPU panics as RunAborted errors is not + # implemented in the ARTIQ Zynq runtime. + raise_assertion_errors = True + tool_ld = "armv7-unknown-linux-gnueabihf-ld" tool_strip = "armv7-unknown-linux-gnueabihf-strip" tool_addr2line = "armv7-unknown-linux-gnueabihf-addr2line" diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index f5d66aa24..075bf581d 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -94,11 +94,12 @@ class ARTIQIRGenerator(algorithm.Visitor): _size_type = builtins.TInt32() - def __init__(self, module_name, engine, ref_period): + def __init__(self, module_name, engine, ref_period, raise_assertion_errors): self.engine = engine self.functions = [] self.name = [module_name] if module_name != "" else [] self.ref_period = ir.Constant(ref_period, builtins.TFloat()) + self.raise_assertion_errors = raise_assertion_errors self.current_loc = None self.current_function = None self.current_class = None @@ -119,6 +120,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.variable_map = dict() self.method_map = defaultdict(lambda: []) self.array_op_funcs = dict() + self.raise_assert_func = None def annotate_calls(self, devirtualization): for var_node in devirtualization.variable_map: @@ -1017,19 +1019,30 @@ class ARTIQIRGenerator(algorithm.Visitor): cond_block = self.current_block self.current_block = body_block = self.add_block("check.body") - closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("check", {})))) + self._invoke_raising_func(func, params, "check") + + self.current_block = tail_block = self.add_block("check.tail") + cond_block.append(ir.BranchIf(cond, tail_block, body_block)) + + def _invoke_raising_func(self, func, params, block_name): + """Emit a call/invoke instruction as appropriate to terminte the current + basic block with a call to a helper function that always raises an + exception. + + (This is done for compiler-inserted checks and assertions to keep the + generated code tight for the normal case.) + """ + closure = self.append(ir.Closure(func, + ir.Constant(None, ir.TEnvironment("raise", {})))) if self.unwind_target is None: insn = self.append(ir.Call(closure, params, {})) else: - after_invoke = self.add_block("check.invoke") + after_invoke = self.add_block(block_name + ".invoke") insn = self.append(ir.Invoke(closure, params, {}, after_invoke, self.unwind_target)) self.current_block = after_invoke insn.is_cold = True self.append(ir.Unreachable()) - self.current_block = tail_block = self.add_block("check.tail") - cond_block.append(ir.BranchIf(cond, tail_block, body_block)) - def _map_index(self, length, index, one_past_the_end=False, loc=None): lt_0 = self.append(ir.Compare(ast.Lt(loc=None), index, ir.Constant(0, index.type))) @@ -2478,6 +2491,56 @@ class ARTIQIRGenerator(algorithm.Visitor): loc=node.loc) self.current_assert_subexprs.append((node, name)) + def _get_raise_assert_func(self): + """Emit the helper function that constructs AssertionErrors and raises + them, if it does not already exist in the current module. + + A separate function is used for code size reasons. (This could also be + compiled into a stand-alone support library instead.) + """ + if self.raise_assert_func: + return self.raise_assert_func + try: + msg = ir.Argument(builtins.TStr(), "msg") + file = ir.Argument(builtins.TStr(), "file") + line = ir.Argument(builtins.TInt32(), "line") + col = ir.Argument(builtins.TInt32(), "col") + function = ir.Argument(builtins.TStr(), "function") + + args = [msg, file, line, col, function] + typ = types.TFunction(args=OrderedDict([(arg.name, arg.type) + for arg in args]), + optargs=OrderedDict(), + ret=builtins.TNone()) + env = ir.TEnvironment(name="raise", vars={}) + env_arg = ir.EnvironmentArgument(env, "ARG.ENV") + func = ir.Function(typ, "_artiq_raise_assert", [env_arg] + args) + func.is_internal = True + func.is_cold = True + func.is_generated = True + self.functions.append(func) + old_func, self.current_function = self.current_function, func + + entry = self.add_block("entry") + old_block, self.current_block = self.current_block, entry + old_final_branch, self.final_branch = self.final_branch, None + old_unwind, self.unwind_target = self.unwind_target, None + + exn = self.alloc_exn(builtins.TException("AssertionError"), message=msg) + self.append(ir.SetAttr(exn, "__file__", file)) + self.append(ir.SetAttr(exn, "__line__", line)) + self.append(ir.SetAttr(exn, "__col__", col)) + self.append(ir.SetAttr(exn, "__func__", function)) + self.append(ir.Raise(exn)) + finally: + self.current_function = old_func + self.current_block = old_block + self.final_branch = old_final_branch + self.unwind_target = old_unwind + + self.raise_assert_func = func + return self.raise_assert_func + def visit_Assert(self, node): try: assert_suffix = ".assert@{}:{}".format(node.loc.line(), node.loc.column()) @@ -2502,6 +2565,26 @@ class ARTIQIRGenerator(algorithm.Visitor): if_failed = self.current_block = self.add_block("assert.fail") + if self.raise_assertion_errors: + self._raise_assertion_error(node) + else: + self._abort_after_assertion(node, assert_subexprs, assert_env) + + tail = self.current_block = self.add_block("assert.tail") + self.append(ir.BranchIf(cond, tail, if_failed), block=head) + + def _raise_assertion_error(self, node): + text = str(node.msg.s) if node.msg else "AssertionError" + msg = ir.Constant(text, builtins.TStr()) + loc_file = ir.Constant(node.loc.source_buffer.name, builtins.TStr()) + loc_line = ir.Constant(node.loc.line(), builtins.TInt32()) + loc_column = ir.Constant(node.loc.column(), builtins.TInt32()) + loc_function = ir.Constant(".".join(self.name), builtins.TStr()) + self._invoke_raising_func(self._get_raise_assert_func(), [ + msg, loc_file, loc_line, loc_column, loc_function + ], "assert.fail") + + def _abort_after_assertion(self, node, assert_subexprs, assert_env): if node.msg: explanation = node.msg.s else: @@ -2535,9 +2618,6 @@ class ARTIQIRGenerator(algorithm.Visitor): self.append(ir.Builtin("abort", [], builtins.TNone())) self.append(ir.Unreachable()) - tail = self.current_block = self.add_block("assert.tail") - self.append(ir.BranchIf(cond, tail, if_failed), block=head) - def polymorphic_print(self, values, separator, suffix="", as_repr=False, as_rtio=False): def printf(format_string, *args): format = ir.Constant(format_string, builtins.TStr()) diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index d150df596..87d6a854f 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -106,7 +106,8 @@ class Core: module = Module(stitcher, ref_period=self.ref_period, - attribute_writeback=attribute_writeback) + attribute_writeback=attribute_writeback, + raise_assertion_errors=self.target_cls.raise_assertion_errors) target = self.target_cls() library = target.compile_and_link([module]) diff --git a/artiq/coredevice/exceptions.py b/artiq/coredevice/exceptions.py index 0d84d49c0..8260739d2 100644 --- a/artiq/coredevice/exceptions.py +++ b/artiq/coredevice/exceptions.py @@ -11,6 +11,7 @@ ZeroDivisionError = builtins.ZeroDivisionError ValueError = builtins.ValueError IndexError = builtins.IndexError RuntimeError = builtins.RuntimeError +AssertionError = builtins.AssertionError class CoreException: diff --git a/artiq/test/coredevice/test_embedding.py b/artiq/test/coredevice/test_embedding.py index 57c2a8fab..0d6da1cd4 100644 --- a/artiq/test/coredevice/test_embedding.py +++ b/artiq/test/coredevice/test_embedding.py @@ -460,3 +460,40 @@ class _ArrayQuoting(EnvExperiment): class ArrayQuotingTest(ExperimentCase): def test_quoting(self): self.create(_ArrayQuoting).run() + + +class _Assert(EnvExperiment): + def build(self): + self.setattr_device("core") + + def raises_assertion_errors(self): + return self.core.target_cls.raises_assertion_errors + + @kernel + def check(self, value): + assert value + + @kernel + def check_msg(self, value): + assert value, "foo" + + +class AssertTest(ExperimentCase): + def test_assert(self): + exp = self.create(_Assert) + + def check_fail(fn, msg): + if exp.raises_assertion_errors: + with self.assertRaises(AssertionError) as ctx: + fn() + self.assertEqual(str(ctx.exception), msg) + else: + # Without assertion exceptions, core device panics should still lead + # to a cleanly dropped connectionr rather than a hang/… + with self.assertRaises(ConnectionResetError): + fn() + + exp.check(True) + check_fail(lambda: exp.check(False), "AssertionError") + exp.check_msg(True) + check_fail(lambda: exp.check_msg(False), "foo") diff --git a/artiq/test/lit/integration/instance.py b/artiq/test/lit/integration/instance.py index bf255d88f..5acea8721 100644 --- a/artiq/test/lit/integration/instance.py +++ b/artiq/test/lit/integration/instance.py @@ -6,9 +6,9 @@ class c: i = c() -assert i.a == 1 - def f(): c = None assert i.a == 1 + +assert i.a == 1 f()