From 4f311e74485d5c77d01395db63b3ea00eee12faf Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Wed, 4 Nov 2020 19:45:10 +0100 Subject: [PATCH] compiler: Raise exception on failed assert()s rather than panic This allows assert() to be used on Zynq, where abort() is not currently implemented for kernels. Furthermore, this is arguably the more natural implementation of assertions on all kernel targets (i.e. where embedding into host Python is used), as it matches host Python behavior, and the exception information actually makes it to the user rather than leading to a ConnectionClosed error. Since this does not implement printing of the subexpressions, I left the old print+abort implementation as default for the time being. The lit/integration/instance.py diff isn't just a spurious change; the exception-based assert implementation exposes a limitation in the existing closure lifetime tracking algorithm (which is not supposed to be what is tested there). GitHub: Fixes #1539. --- artiq/compiler/module.py | 11 ++- artiq/compiler/targets.py | 8 ++ .../compiler/transforms/artiq_ir_generator.py | 98 +++++++++++++++++-- artiq/coredevice/core.py | 3 +- artiq/coredevice/exceptions.py | 1 + artiq/test/coredevice/test_embedding.py | 37 +++++++ artiq/test/lit/integration/instance.py | 4 +- 7 files changed, 146 insertions(+), 16 deletions(-) diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index d43404b20..3cce61105 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -40,7 +40,8 @@ class Source: return cls(source.Buffer(f.read(), filename, 1), engine=engine) class Module: - def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False): + def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False, + raise_assertion_errors=False): self.attribute_writeback = attribute_writeback self.engine = src.engine self.embedding_map = src.embedding_map @@ -55,9 +56,11 @@ class Module: iodelay_estimator = transforms.IODelayEstimator(engine=self.engine, ref_period=ref_period) constness_validator = validators.ConstnessValidator(engine=self.engine) - artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine, - module_name=src.name, - ref_period=ref_period) + artiq_ir_generator = transforms.ARTIQIRGenerator( + engine=self.engine, + module_name=src.name, + ref_period=ref_period, + raise_assertion_errors=raise_assertion_errors) dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine) local_access_validator = validators.LocalAccessValidator(engine=self.engine) local_demoter = transforms.LocalDemoter() diff --git a/artiq/compiler/targets.py b/artiq/compiler/targets.py index 9ebc7907d..427d9f07c 100644 --- a/artiq/compiler/targets.py +++ b/artiq/compiler/targets.py @@ -80,6 +80,9 @@ class Target: determined from data_layout due to JIT. :var now_pinning: (boolean) Whether the target implements the now-pinning RTIO optimization. + :var raise_assertion_errors: (bool) + Whether to raise an AssertionError on failed assertions or abort/panic + instead. """ triple = "unknown" data_layout = "" @@ -87,6 +90,7 @@ class Target: print_function = "printf" little_endian = False now_pinning = True + raise_assertion_errors = False tool_ld = "ld.lld" tool_strip = "llvm-strip" @@ -277,6 +281,10 @@ class CortexA9Target(Target): little_endian = True now_pinning = False + # Support for marshalling kernel CPU panics as RunAborted errors is not + # implemented in the ARTIQ Zynq runtime. + raise_assertion_errors = True + tool_ld = "armv7-unknown-linux-gnueabihf-ld" tool_strip = "armv7-unknown-linux-gnueabihf-strip" tool_addr2line = "armv7-unknown-linux-gnueabihf-addr2line" diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index f5d66aa24..075bf581d 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -94,11 +94,12 @@ class ARTIQIRGenerator(algorithm.Visitor): _size_type = builtins.TInt32() - def __init__(self, module_name, engine, ref_period): + def __init__(self, module_name, engine, ref_period, raise_assertion_errors): self.engine = engine self.functions = [] self.name = [module_name] if module_name != "" else [] self.ref_period = ir.Constant(ref_period, builtins.TFloat()) + self.raise_assertion_errors = raise_assertion_errors self.current_loc = None self.current_function = None self.current_class = None @@ -119,6 +120,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.variable_map = dict() self.method_map = defaultdict(lambda: []) self.array_op_funcs = dict() + self.raise_assert_func = None def annotate_calls(self, devirtualization): for var_node in devirtualization.variable_map: @@ -1017,19 +1019,30 @@ class ARTIQIRGenerator(algorithm.Visitor): cond_block = self.current_block self.current_block = body_block = self.add_block("check.body") - closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("check", {})))) + self._invoke_raising_func(func, params, "check") + + self.current_block = tail_block = self.add_block("check.tail") + cond_block.append(ir.BranchIf(cond, tail_block, body_block)) + + def _invoke_raising_func(self, func, params, block_name): + """Emit a call/invoke instruction as appropriate to terminte the current + basic block with a call to a helper function that always raises an + exception. + + (This is done for compiler-inserted checks and assertions to keep the + generated code tight for the normal case.) + """ + closure = self.append(ir.Closure(func, + ir.Constant(None, ir.TEnvironment("raise", {})))) if self.unwind_target is None: insn = self.append(ir.Call(closure, params, {})) else: - after_invoke = self.add_block("check.invoke") + after_invoke = self.add_block(block_name + ".invoke") insn = self.append(ir.Invoke(closure, params, {}, after_invoke, self.unwind_target)) self.current_block = after_invoke insn.is_cold = True self.append(ir.Unreachable()) - self.current_block = tail_block = self.add_block("check.tail") - cond_block.append(ir.BranchIf(cond, tail_block, body_block)) - def _map_index(self, length, index, one_past_the_end=False, loc=None): lt_0 = self.append(ir.Compare(ast.Lt(loc=None), index, ir.Constant(0, index.type))) @@ -2478,6 +2491,56 @@ class ARTIQIRGenerator(algorithm.Visitor): loc=node.loc) self.current_assert_subexprs.append((node, name)) + def _get_raise_assert_func(self): + """Emit the helper function that constructs AssertionErrors and raises + them, if it does not already exist in the current module. + + A separate function is used for code size reasons. (This could also be + compiled into a stand-alone support library instead.) + """ + if self.raise_assert_func: + return self.raise_assert_func + try: + msg = ir.Argument(builtins.TStr(), "msg") + file = ir.Argument(builtins.TStr(), "file") + line = ir.Argument(builtins.TInt32(), "line") + col = ir.Argument(builtins.TInt32(), "col") + function = ir.Argument(builtins.TStr(), "function") + + args = [msg, file, line, col, function] + typ = types.TFunction(args=OrderedDict([(arg.name, arg.type) + for arg in args]), + optargs=OrderedDict(), + ret=builtins.TNone()) + env = ir.TEnvironment(name="raise", vars={}) + env_arg = ir.EnvironmentArgument(env, "ARG.ENV") + func = ir.Function(typ, "_artiq_raise_assert", [env_arg] + args) + func.is_internal = True + func.is_cold = True + func.is_generated = True + self.functions.append(func) + old_func, self.current_function = self.current_function, func + + entry = self.add_block("entry") + old_block, self.current_block = self.current_block, entry + old_final_branch, self.final_branch = self.final_branch, None + old_unwind, self.unwind_target = self.unwind_target, None + + exn = self.alloc_exn(builtins.TException("AssertionError"), message=msg) + self.append(ir.SetAttr(exn, "__file__", file)) + self.append(ir.SetAttr(exn, "__line__", line)) + self.append(ir.SetAttr(exn, "__col__", col)) + self.append(ir.SetAttr(exn, "__func__", function)) + self.append(ir.Raise(exn)) + finally: + self.current_function = old_func + self.current_block = old_block + self.final_branch = old_final_branch + self.unwind_target = old_unwind + + self.raise_assert_func = func + return self.raise_assert_func + def visit_Assert(self, node): try: assert_suffix = ".assert@{}:{}".format(node.loc.line(), node.loc.column()) @@ -2502,6 +2565,26 @@ class ARTIQIRGenerator(algorithm.Visitor): if_failed = self.current_block = self.add_block("assert.fail") + if self.raise_assertion_errors: + self._raise_assertion_error(node) + else: + self._abort_after_assertion(node, assert_subexprs, assert_env) + + tail = self.current_block = self.add_block("assert.tail") + self.append(ir.BranchIf(cond, tail, if_failed), block=head) + + def _raise_assertion_error(self, node): + text = str(node.msg.s) if node.msg else "AssertionError" + msg = ir.Constant(text, builtins.TStr()) + loc_file = ir.Constant(node.loc.source_buffer.name, builtins.TStr()) + loc_line = ir.Constant(node.loc.line(), builtins.TInt32()) + loc_column = ir.Constant(node.loc.column(), builtins.TInt32()) + loc_function = ir.Constant(".".join(self.name), builtins.TStr()) + self._invoke_raising_func(self._get_raise_assert_func(), [ + msg, loc_file, loc_line, loc_column, loc_function + ], "assert.fail") + + def _abort_after_assertion(self, node, assert_subexprs, assert_env): if node.msg: explanation = node.msg.s else: @@ -2535,9 +2618,6 @@ class ARTIQIRGenerator(algorithm.Visitor): self.append(ir.Builtin("abort", [], builtins.TNone())) self.append(ir.Unreachable()) - tail = self.current_block = self.add_block("assert.tail") - self.append(ir.BranchIf(cond, tail, if_failed), block=head) - def polymorphic_print(self, values, separator, suffix="", as_repr=False, as_rtio=False): def printf(format_string, *args): format = ir.Constant(format_string, builtins.TStr()) diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index d150df596..87d6a854f 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -106,7 +106,8 @@ class Core: module = Module(stitcher, ref_period=self.ref_period, - attribute_writeback=attribute_writeback) + attribute_writeback=attribute_writeback, + raise_assertion_errors=self.target_cls.raise_assertion_errors) target = self.target_cls() library = target.compile_and_link([module]) diff --git a/artiq/coredevice/exceptions.py b/artiq/coredevice/exceptions.py index 0d84d49c0..8260739d2 100644 --- a/artiq/coredevice/exceptions.py +++ b/artiq/coredevice/exceptions.py @@ -11,6 +11,7 @@ ZeroDivisionError = builtins.ZeroDivisionError ValueError = builtins.ValueError IndexError = builtins.IndexError RuntimeError = builtins.RuntimeError +AssertionError = builtins.AssertionError class CoreException: diff --git a/artiq/test/coredevice/test_embedding.py b/artiq/test/coredevice/test_embedding.py index 57c2a8fab..0d6da1cd4 100644 --- a/artiq/test/coredevice/test_embedding.py +++ b/artiq/test/coredevice/test_embedding.py @@ -460,3 +460,40 @@ class _ArrayQuoting(EnvExperiment): class ArrayQuotingTest(ExperimentCase): def test_quoting(self): self.create(_ArrayQuoting).run() + + +class _Assert(EnvExperiment): + def build(self): + self.setattr_device("core") + + def raises_assertion_errors(self): + return self.core.target_cls.raises_assertion_errors + + @kernel + def check(self, value): + assert value + + @kernel + def check_msg(self, value): + assert value, "foo" + + +class AssertTest(ExperimentCase): + def test_assert(self): + exp = self.create(_Assert) + + def check_fail(fn, msg): + if exp.raises_assertion_errors: + with self.assertRaises(AssertionError) as ctx: + fn() + self.assertEqual(str(ctx.exception), msg) + else: + # Without assertion exceptions, core device panics should still lead + # to a cleanly dropped connectionr rather than a hang/… + with self.assertRaises(ConnectionResetError): + fn() + + exp.check(True) + check_fail(lambda: exp.check(False), "AssertionError") + exp.check_msg(True) + check_fail(lambda: exp.check_msg(False), "foo") diff --git a/artiq/test/lit/integration/instance.py b/artiq/test/lit/integration/instance.py index bf255d88f..5acea8721 100644 --- a/artiq/test/lit/integration/instance.py +++ b/artiq/test/lit/integration/instance.py @@ -6,9 +6,9 @@ class c: i = c() -assert i.a == 1 - def f(): c = None assert i.a == 1 + +assert i.a == 1 f()