1
0
forked from M-Labs/artiq

compiler: Raise exception on failed assert()s rather than panic

This allows assert() to be used on Zynq, where abort() is not
currently implemented for kernels. Furthermore, this is arguably
the more natural implementation of assertions on all kernel targets
(i.e. where embedding into host Python is used), as it matches host
Python behavior, and the exception information actually makes it to
the user rather than leading to a ConnectionClosed error.

Since this does not implement printing of the subexpressions, I
left the old print+abort implementation as default for the time
being.

The lit/integration/instance.py diff isn't just a spurious change;
the exception-based assert implementation exposes a limitation in
the existing closure lifetime tracking algorithm (which is not
supposed to be what is tested there).

GitHub: Fixes #1539.
This commit is contained in:
David Nadlinger 2020-11-04 19:45:10 +01:00
parent f0ec987d23
commit 4f311e7448
7 changed files with 146 additions and 16 deletions

View File

@ -40,7 +40,8 @@ class Source:
return cls(source.Buffer(f.read(), filename, 1), engine=engine)
class Module:
def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False):
def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False,
raise_assertion_errors=False):
self.attribute_writeback = attribute_writeback
self.engine = src.engine
self.embedding_map = src.embedding_map
@ -55,9 +56,11 @@ class Module:
iodelay_estimator = transforms.IODelayEstimator(engine=self.engine,
ref_period=ref_period)
constness_validator = validators.ConstnessValidator(engine=self.engine)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine,
module_name=src.name,
ref_period=ref_period)
artiq_ir_generator = transforms.ARTIQIRGenerator(
engine=self.engine,
module_name=src.name,
ref_period=ref_period,
raise_assertion_errors=raise_assertion_errors)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine)
local_access_validator = validators.LocalAccessValidator(engine=self.engine)
local_demoter = transforms.LocalDemoter()

View File

@ -80,6 +80,9 @@ class Target:
determined from data_layout due to JIT.
:var now_pinning: (boolean)
Whether the target implements the now-pinning RTIO optimization.
:var raise_assertion_errors: (bool)
Whether to raise an AssertionError on failed assertions or abort/panic
instead.
"""
triple = "unknown"
data_layout = ""
@ -87,6 +90,7 @@ class Target:
print_function = "printf"
little_endian = False
now_pinning = True
raise_assertion_errors = False
tool_ld = "ld.lld"
tool_strip = "llvm-strip"
@ -277,6 +281,10 @@ class CortexA9Target(Target):
little_endian = True
now_pinning = False
# Support for marshalling kernel CPU panics as RunAborted errors is not
# implemented in the ARTIQ Zynq runtime.
raise_assertion_errors = True
tool_ld = "armv7-unknown-linux-gnueabihf-ld"
tool_strip = "armv7-unknown-linux-gnueabihf-strip"
tool_addr2line = "armv7-unknown-linux-gnueabihf-addr2line"

View File

@ -94,11 +94,12 @@ class ARTIQIRGenerator(algorithm.Visitor):
_size_type = builtins.TInt32()
def __init__(self, module_name, engine, ref_period):
def __init__(self, module_name, engine, ref_period, raise_assertion_errors):
self.engine = engine
self.functions = []
self.name = [module_name] if module_name != "" else []
self.ref_period = ir.Constant(ref_period, builtins.TFloat())
self.raise_assertion_errors = raise_assertion_errors
self.current_loc = None
self.current_function = None
self.current_class = None
@ -119,6 +120,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.variable_map = dict()
self.method_map = defaultdict(lambda: [])
self.array_op_funcs = dict()
self.raise_assert_func = None
def annotate_calls(self, devirtualization):
for var_node in devirtualization.variable_map:
@ -1017,19 +1019,30 @@ class ARTIQIRGenerator(algorithm.Visitor):
cond_block = self.current_block
self.current_block = body_block = self.add_block("check.body")
closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("check", {}))))
self._invoke_raising_func(func, params, "check")
self.current_block = tail_block = self.add_block("check.tail")
cond_block.append(ir.BranchIf(cond, tail_block, body_block))
def _invoke_raising_func(self, func, params, block_name):
"""Emit a call/invoke instruction as appropriate to terminte the current
basic block with a call to a helper function that always raises an
exception.
(This is done for compiler-inserted checks and assertions to keep the
generated code tight for the normal case.)
"""
closure = self.append(ir.Closure(func,
ir.Constant(None, ir.TEnvironment("raise", {}))))
if self.unwind_target is None:
insn = self.append(ir.Call(closure, params, {}))
else:
after_invoke = self.add_block("check.invoke")
after_invoke = self.add_block(block_name + ".invoke")
insn = self.append(ir.Invoke(closure, params, {}, after_invoke, self.unwind_target))
self.current_block = after_invoke
insn.is_cold = True
self.append(ir.Unreachable())
self.current_block = tail_block = self.add_block("check.tail")
cond_block.append(ir.BranchIf(cond, tail_block, body_block))
def _map_index(self, length, index, one_past_the_end=False, loc=None):
lt_0 = self.append(ir.Compare(ast.Lt(loc=None),
index, ir.Constant(0, index.type)))
@ -2478,6 +2491,56 @@ class ARTIQIRGenerator(algorithm.Visitor):
loc=node.loc)
self.current_assert_subexprs.append((node, name))
def _get_raise_assert_func(self):
"""Emit the helper function that constructs AssertionErrors and raises
them, if it does not already exist in the current module.
A separate function is used for code size reasons. (This could also be
compiled into a stand-alone support library instead.)
"""
if self.raise_assert_func:
return self.raise_assert_func
try:
msg = ir.Argument(builtins.TStr(), "msg")
file = ir.Argument(builtins.TStr(), "file")
line = ir.Argument(builtins.TInt32(), "line")
col = ir.Argument(builtins.TInt32(), "col")
function = ir.Argument(builtins.TStr(), "function")
args = [msg, file, line, col, function]
typ = types.TFunction(args=OrderedDict([(arg.name, arg.type)
for arg in args]),
optargs=OrderedDict(),
ret=builtins.TNone())
env = ir.TEnvironment(name="raise", vars={})
env_arg = ir.EnvironmentArgument(env, "ARG.ENV")
func = ir.Function(typ, "_artiq_raise_assert", [env_arg] + args)
func.is_internal = True
func.is_cold = True
func.is_generated = True
self.functions.append(func)
old_func, self.current_function = self.current_function, func
entry = self.add_block("entry")
old_block, self.current_block = self.current_block, entry
old_final_branch, self.final_branch = self.final_branch, None
old_unwind, self.unwind_target = self.unwind_target, None
exn = self.alloc_exn(builtins.TException("AssertionError"), message=msg)
self.append(ir.SetAttr(exn, "__file__", file))
self.append(ir.SetAttr(exn, "__line__", line))
self.append(ir.SetAttr(exn, "__col__", col))
self.append(ir.SetAttr(exn, "__func__", function))
self.append(ir.Raise(exn))
finally:
self.current_function = old_func
self.current_block = old_block
self.final_branch = old_final_branch
self.unwind_target = old_unwind
self.raise_assert_func = func
return self.raise_assert_func
def visit_Assert(self, node):
try:
assert_suffix = ".assert@{}:{}".format(node.loc.line(), node.loc.column())
@ -2502,6 +2565,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
if_failed = self.current_block = self.add_block("assert.fail")
if self.raise_assertion_errors:
self._raise_assertion_error(node)
else:
self._abort_after_assertion(node, assert_subexprs, assert_env)
tail = self.current_block = self.add_block("assert.tail")
self.append(ir.BranchIf(cond, tail, if_failed), block=head)
def _raise_assertion_error(self, node):
text = str(node.msg.s) if node.msg else "AssertionError"
msg = ir.Constant(text, builtins.TStr())
loc_file = ir.Constant(node.loc.source_buffer.name, builtins.TStr())
loc_line = ir.Constant(node.loc.line(), builtins.TInt32())
loc_column = ir.Constant(node.loc.column(), builtins.TInt32())
loc_function = ir.Constant(".".join(self.name), builtins.TStr())
self._invoke_raising_func(self._get_raise_assert_func(), [
msg, loc_file, loc_line, loc_column, loc_function
], "assert.fail")
def _abort_after_assertion(self, node, assert_subexprs, assert_env):
if node.msg:
explanation = node.msg.s
else:
@ -2535,9 +2618,6 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.append(ir.Builtin("abort", [], builtins.TNone()))
self.append(ir.Unreachable())
tail = self.current_block = self.add_block("assert.tail")
self.append(ir.BranchIf(cond, tail, if_failed), block=head)
def polymorphic_print(self, values, separator, suffix="", as_repr=False, as_rtio=False):
def printf(format_string, *args):
format = ir.Constant(format_string, builtins.TStr())

View File

@ -106,7 +106,8 @@ class Core:
module = Module(stitcher,
ref_period=self.ref_period,
attribute_writeback=attribute_writeback)
attribute_writeback=attribute_writeback,
raise_assertion_errors=self.target_cls.raise_assertion_errors)
target = self.target_cls()
library = target.compile_and_link([module])

View File

@ -11,6 +11,7 @@ ZeroDivisionError = builtins.ZeroDivisionError
ValueError = builtins.ValueError
IndexError = builtins.IndexError
RuntimeError = builtins.RuntimeError
AssertionError = builtins.AssertionError
class CoreException:

View File

@ -460,3 +460,40 @@ class _ArrayQuoting(EnvExperiment):
class ArrayQuotingTest(ExperimentCase):
def test_quoting(self):
self.create(_ArrayQuoting).run()
class _Assert(EnvExperiment):
def build(self):
self.setattr_device("core")
def raises_assertion_errors(self):
return self.core.target_cls.raises_assertion_errors
@kernel
def check(self, value):
assert value
@kernel
def check_msg(self, value):
assert value, "foo"
class AssertTest(ExperimentCase):
def test_assert(self):
exp = self.create(_Assert)
def check_fail(fn, msg):
if exp.raises_assertion_errors:
with self.assertRaises(AssertionError) as ctx:
fn()
self.assertEqual(str(ctx.exception), msg)
else:
# Without assertion exceptions, core device panics should still lead
# to a cleanly dropped connectionr rather than a hang/…
with self.assertRaises(ConnectionResetError):
fn()
exp.check(True)
check_fail(lambda: exp.check(False), "AssertionError")
exp.check_msg(True)
check_fail(lambda: exp.check_msg(False), "foo")

View File

@ -6,9 +6,9 @@ class c:
i = c()
assert i.a == 1
def f():
c = None
assert i.a == 1
assert i.a == 1
f()