forked from M-Labs/artiq
compiler: Raise exception on failed assert()s rather than panic
This allows assert() to be used on Zynq, where abort() is not currently implemented for kernels. Furthermore, this is arguably the more natural implementation of assertions on all kernel targets (i.e. where embedding into host Python is used), as it matches host Python behavior, and the exception information actually makes it to the user rather than leading to a ConnectionClosed error. Since this does not implement printing of the subexpressions, I left the old print+abort implementation as default for the time being. The lit/integration/instance.py diff isn't just a spurious change; the exception-based assert implementation exposes a limitation in the existing closure lifetime tracking algorithm (which is not supposed to be what is tested there). GitHub: Fixes #1539.
This commit is contained in:
parent
f0ec987d23
commit
4f311e7448
|
@ -40,7 +40,8 @@ class Source:
|
||||||
return cls(source.Buffer(f.read(), filename, 1), engine=engine)
|
return cls(source.Buffer(f.read(), filename, 1), engine=engine)
|
||||||
|
|
||||||
class Module:
|
class Module:
|
||||||
def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False):
|
def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False,
|
||||||
|
raise_assertion_errors=False):
|
||||||
self.attribute_writeback = attribute_writeback
|
self.attribute_writeback = attribute_writeback
|
||||||
self.engine = src.engine
|
self.engine = src.engine
|
||||||
self.embedding_map = src.embedding_map
|
self.embedding_map = src.embedding_map
|
||||||
|
@ -55,9 +56,11 @@ class Module:
|
||||||
iodelay_estimator = transforms.IODelayEstimator(engine=self.engine,
|
iodelay_estimator = transforms.IODelayEstimator(engine=self.engine,
|
||||||
ref_period=ref_period)
|
ref_period=ref_period)
|
||||||
constness_validator = validators.ConstnessValidator(engine=self.engine)
|
constness_validator = validators.ConstnessValidator(engine=self.engine)
|
||||||
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine,
|
artiq_ir_generator = transforms.ARTIQIRGenerator(
|
||||||
module_name=src.name,
|
engine=self.engine,
|
||||||
ref_period=ref_period)
|
module_name=src.name,
|
||||||
|
ref_period=ref_period,
|
||||||
|
raise_assertion_errors=raise_assertion_errors)
|
||||||
dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine)
|
dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine)
|
||||||
local_access_validator = validators.LocalAccessValidator(engine=self.engine)
|
local_access_validator = validators.LocalAccessValidator(engine=self.engine)
|
||||||
local_demoter = transforms.LocalDemoter()
|
local_demoter = transforms.LocalDemoter()
|
||||||
|
|
|
@ -80,6 +80,9 @@ class Target:
|
||||||
determined from data_layout due to JIT.
|
determined from data_layout due to JIT.
|
||||||
:var now_pinning: (boolean)
|
:var now_pinning: (boolean)
|
||||||
Whether the target implements the now-pinning RTIO optimization.
|
Whether the target implements the now-pinning RTIO optimization.
|
||||||
|
:var raise_assertion_errors: (bool)
|
||||||
|
Whether to raise an AssertionError on failed assertions or abort/panic
|
||||||
|
instead.
|
||||||
"""
|
"""
|
||||||
triple = "unknown"
|
triple = "unknown"
|
||||||
data_layout = ""
|
data_layout = ""
|
||||||
|
@ -87,6 +90,7 @@ class Target:
|
||||||
print_function = "printf"
|
print_function = "printf"
|
||||||
little_endian = False
|
little_endian = False
|
||||||
now_pinning = True
|
now_pinning = True
|
||||||
|
raise_assertion_errors = False
|
||||||
|
|
||||||
tool_ld = "ld.lld"
|
tool_ld = "ld.lld"
|
||||||
tool_strip = "llvm-strip"
|
tool_strip = "llvm-strip"
|
||||||
|
@ -277,6 +281,10 @@ class CortexA9Target(Target):
|
||||||
little_endian = True
|
little_endian = True
|
||||||
now_pinning = False
|
now_pinning = False
|
||||||
|
|
||||||
|
# Support for marshalling kernel CPU panics as RunAborted errors is not
|
||||||
|
# implemented in the ARTIQ Zynq runtime.
|
||||||
|
raise_assertion_errors = True
|
||||||
|
|
||||||
tool_ld = "armv7-unknown-linux-gnueabihf-ld"
|
tool_ld = "armv7-unknown-linux-gnueabihf-ld"
|
||||||
tool_strip = "armv7-unknown-linux-gnueabihf-strip"
|
tool_strip = "armv7-unknown-linux-gnueabihf-strip"
|
||||||
tool_addr2line = "armv7-unknown-linux-gnueabihf-addr2line"
|
tool_addr2line = "armv7-unknown-linux-gnueabihf-addr2line"
|
||||||
|
|
|
@ -94,11 +94,12 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
|
|
||||||
_size_type = builtins.TInt32()
|
_size_type = builtins.TInt32()
|
||||||
|
|
||||||
def __init__(self, module_name, engine, ref_period):
|
def __init__(self, module_name, engine, ref_period, raise_assertion_errors):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.functions = []
|
self.functions = []
|
||||||
self.name = [module_name] if module_name != "" else []
|
self.name = [module_name] if module_name != "" else []
|
||||||
self.ref_period = ir.Constant(ref_period, builtins.TFloat())
|
self.ref_period = ir.Constant(ref_period, builtins.TFloat())
|
||||||
|
self.raise_assertion_errors = raise_assertion_errors
|
||||||
self.current_loc = None
|
self.current_loc = None
|
||||||
self.current_function = None
|
self.current_function = None
|
||||||
self.current_class = None
|
self.current_class = None
|
||||||
|
@ -119,6 +120,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
self.variable_map = dict()
|
self.variable_map = dict()
|
||||||
self.method_map = defaultdict(lambda: [])
|
self.method_map = defaultdict(lambda: [])
|
||||||
self.array_op_funcs = dict()
|
self.array_op_funcs = dict()
|
||||||
|
self.raise_assert_func = None
|
||||||
|
|
||||||
def annotate_calls(self, devirtualization):
|
def annotate_calls(self, devirtualization):
|
||||||
for var_node in devirtualization.variable_map:
|
for var_node in devirtualization.variable_map:
|
||||||
|
@ -1017,19 +1019,30 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
cond_block = self.current_block
|
cond_block = self.current_block
|
||||||
|
|
||||||
self.current_block = body_block = self.add_block("check.body")
|
self.current_block = body_block = self.add_block("check.body")
|
||||||
closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("check", {}))))
|
self._invoke_raising_func(func, params, "check")
|
||||||
|
|
||||||
|
self.current_block = tail_block = self.add_block("check.tail")
|
||||||
|
cond_block.append(ir.BranchIf(cond, tail_block, body_block))
|
||||||
|
|
||||||
|
def _invoke_raising_func(self, func, params, block_name):
|
||||||
|
"""Emit a call/invoke instruction as appropriate to terminte the current
|
||||||
|
basic block with a call to a helper function that always raises an
|
||||||
|
exception.
|
||||||
|
|
||||||
|
(This is done for compiler-inserted checks and assertions to keep the
|
||||||
|
generated code tight for the normal case.)
|
||||||
|
"""
|
||||||
|
closure = self.append(ir.Closure(func,
|
||||||
|
ir.Constant(None, ir.TEnvironment("raise", {}))))
|
||||||
if self.unwind_target is None:
|
if self.unwind_target is None:
|
||||||
insn = self.append(ir.Call(closure, params, {}))
|
insn = self.append(ir.Call(closure, params, {}))
|
||||||
else:
|
else:
|
||||||
after_invoke = self.add_block("check.invoke")
|
after_invoke = self.add_block(block_name + ".invoke")
|
||||||
insn = self.append(ir.Invoke(closure, params, {}, after_invoke, self.unwind_target))
|
insn = self.append(ir.Invoke(closure, params, {}, after_invoke, self.unwind_target))
|
||||||
self.current_block = after_invoke
|
self.current_block = after_invoke
|
||||||
insn.is_cold = True
|
insn.is_cold = True
|
||||||
self.append(ir.Unreachable())
|
self.append(ir.Unreachable())
|
||||||
|
|
||||||
self.current_block = tail_block = self.add_block("check.tail")
|
|
||||||
cond_block.append(ir.BranchIf(cond, tail_block, body_block))
|
|
||||||
|
|
||||||
def _map_index(self, length, index, one_past_the_end=False, loc=None):
|
def _map_index(self, length, index, one_past_the_end=False, loc=None):
|
||||||
lt_0 = self.append(ir.Compare(ast.Lt(loc=None),
|
lt_0 = self.append(ir.Compare(ast.Lt(loc=None),
|
||||||
index, ir.Constant(0, index.type)))
|
index, ir.Constant(0, index.type)))
|
||||||
|
@ -2478,6 +2491,56 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
loc=node.loc)
|
loc=node.loc)
|
||||||
self.current_assert_subexprs.append((node, name))
|
self.current_assert_subexprs.append((node, name))
|
||||||
|
|
||||||
|
def _get_raise_assert_func(self):
|
||||||
|
"""Emit the helper function that constructs AssertionErrors and raises
|
||||||
|
them, if it does not already exist in the current module.
|
||||||
|
|
||||||
|
A separate function is used for code size reasons. (This could also be
|
||||||
|
compiled into a stand-alone support library instead.)
|
||||||
|
"""
|
||||||
|
if self.raise_assert_func:
|
||||||
|
return self.raise_assert_func
|
||||||
|
try:
|
||||||
|
msg = ir.Argument(builtins.TStr(), "msg")
|
||||||
|
file = ir.Argument(builtins.TStr(), "file")
|
||||||
|
line = ir.Argument(builtins.TInt32(), "line")
|
||||||
|
col = ir.Argument(builtins.TInt32(), "col")
|
||||||
|
function = ir.Argument(builtins.TStr(), "function")
|
||||||
|
|
||||||
|
args = [msg, file, line, col, function]
|
||||||
|
typ = types.TFunction(args=OrderedDict([(arg.name, arg.type)
|
||||||
|
for arg in args]),
|
||||||
|
optargs=OrderedDict(),
|
||||||
|
ret=builtins.TNone())
|
||||||
|
env = ir.TEnvironment(name="raise", vars={})
|
||||||
|
env_arg = ir.EnvironmentArgument(env, "ARG.ENV")
|
||||||
|
func = ir.Function(typ, "_artiq_raise_assert", [env_arg] + args)
|
||||||
|
func.is_internal = True
|
||||||
|
func.is_cold = True
|
||||||
|
func.is_generated = True
|
||||||
|
self.functions.append(func)
|
||||||
|
old_func, self.current_function = self.current_function, func
|
||||||
|
|
||||||
|
entry = self.add_block("entry")
|
||||||
|
old_block, self.current_block = self.current_block, entry
|
||||||
|
old_final_branch, self.final_branch = self.final_branch, None
|
||||||
|
old_unwind, self.unwind_target = self.unwind_target, None
|
||||||
|
|
||||||
|
exn = self.alloc_exn(builtins.TException("AssertionError"), message=msg)
|
||||||
|
self.append(ir.SetAttr(exn, "__file__", file))
|
||||||
|
self.append(ir.SetAttr(exn, "__line__", line))
|
||||||
|
self.append(ir.SetAttr(exn, "__col__", col))
|
||||||
|
self.append(ir.SetAttr(exn, "__func__", function))
|
||||||
|
self.append(ir.Raise(exn))
|
||||||
|
finally:
|
||||||
|
self.current_function = old_func
|
||||||
|
self.current_block = old_block
|
||||||
|
self.final_branch = old_final_branch
|
||||||
|
self.unwind_target = old_unwind
|
||||||
|
|
||||||
|
self.raise_assert_func = func
|
||||||
|
return self.raise_assert_func
|
||||||
|
|
||||||
def visit_Assert(self, node):
|
def visit_Assert(self, node):
|
||||||
try:
|
try:
|
||||||
assert_suffix = ".assert@{}:{}".format(node.loc.line(), node.loc.column())
|
assert_suffix = ".assert@{}:{}".format(node.loc.line(), node.loc.column())
|
||||||
|
@ -2502,6 +2565,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
|
|
||||||
if_failed = self.current_block = self.add_block("assert.fail")
|
if_failed = self.current_block = self.add_block("assert.fail")
|
||||||
|
|
||||||
|
if self.raise_assertion_errors:
|
||||||
|
self._raise_assertion_error(node)
|
||||||
|
else:
|
||||||
|
self._abort_after_assertion(node, assert_subexprs, assert_env)
|
||||||
|
|
||||||
|
tail = self.current_block = self.add_block("assert.tail")
|
||||||
|
self.append(ir.BranchIf(cond, tail, if_failed), block=head)
|
||||||
|
|
||||||
|
def _raise_assertion_error(self, node):
|
||||||
|
text = str(node.msg.s) if node.msg else "AssertionError"
|
||||||
|
msg = ir.Constant(text, builtins.TStr())
|
||||||
|
loc_file = ir.Constant(node.loc.source_buffer.name, builtins.TStr())
|
||||||
|
loc_line = ir.Constant(node.loc.line(), builtins.TInt32())
|
||||||
|
loc_column = ir.Constant(node.loc.column(), builtins.TInt32())
|
||||||
|
loc_function = ir.Constant(".".join(self.name), builtins.TStr())
|
||||||
|
self._invoke_raising_func(self._get_raise_assert_func(), [
|
||||||
|
msg, loc_file, loc_line, loc_column, loc_function
|
||||||
|
], "assert.fail")
|
||||||
|
|
||||||
|
def _abort_after_assertion(self, node, assert_subexprs, assert_env):
|
||||||
if node.msg:
|
if node.msg:
|
||||||
explanation = node.msg.s
|
explanation = node.msg.s
|
||||||
else:
|
else:
|
||||||
|
@ -2535,9 +2618,6 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
self.append(ir.Builtin("abort", [], builtins.TNone()))
|
self.append(ir.Builtin("abort", [], builtins.TNone()))
|
||||||
self.append(ir.Unreachable())
|
self.append(ir.Unreachable())
|
||||||
|
|
||||||
tail = self.current_block = self.add_block("assert.tail")
|
|
||||||
self.append(ir.BranchIf(cond, tail, if_failed), block=head)
|
|
||||||
|
|
||||||
def polymorphic_print(self, values, separator, suffix="", as_repr=False, as_rtio=False):
|
def polymorphic_print(self, values, separator, suffix="", as_repr=False, as_rtio=False):
|
||||||
def printf(format_string, *args):
|
def printf(format_string, *args):
|
||||||
format = ir.Constant(format_string, builtins.TStr())
|
format = ir.Constant(format_string, builtins.TStr())
|
||||||
|
|
|
@ -106,7 +106,8 @@ class Core:
|
||||||
|
|
||||||
module = Module(stitcher,
|
module = Module(stitcher,
|
||||||
ref_period=self.ref_period,
|
ref_period=self.ref_period,
|
||||||
attribute_writeback=attribute_writeback)
|
attribute_writeback=attribute_writeback,
|
||||||
|
raise_assertion_errors=self.target_cls.raise_assertion_errors)
|
||||||
target = self.target_cls()
|
target = self.target_cls()
|
||||||
|
|
||||||
library = target.compile_and_link([module])
|
library = target.compile_and_link([module])
|
||||||
|
|
|
@ -11,6 +11,7 @@ ZeroDivisionError = builtins.ZeroDivisionError
|
||||||
ValueError = builtins.ValueError
|
ValueError = builtins.ValueError
|
||||||
IndexError = builtins.IndexError
|
IndexError = builtins.IndexError
|
||||||
RuntimeError = builtins.RuntimeError
|
RuntimeError = builtins.RuntimeError
|
||||||
|
AssertionError = builtins.AssertionError
|
||||||
|
|
||||||
|
|
||||||
class CoreException:
|
class CoreException:
|
||||||
|
|
|
@ -460,3 +460,40 @@ class _ArrayQuoting(EnvExperiment):
|
||||||
class ArrayQuotingTest(ExperimentCase):
|
class ArrayQuotingTest(ExperimentCase):
|
||||||
def test_quoting(self):
|
def test_quoting(self):
|
||||||
self.create(_ArrayQuoting).run()
|
self.create(_ArrayQuoting).run()
|
||||||
|
|
||||||
|
|
||||||
|
class _Assert(EnvExperiment):
|
||||||
|
def build(self):
|
||||||
|
self.setattr_device("core")
|
||||||
|
|
||||||
|
def raises_assertion_errors(self):
|
||||||
|
return self.core.target_cls.raises_assertion_errors
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def check(self, value):
|
||||||
|
assert value
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def check_msg(self, value):
|
||||||
|
assert value, "foo"
|
||||||
|
|
||||||
|
|
||||||
|
class AssertTest(ExperimentCase):
|
||||||
|
def test_assert(self):
|
||||||
|
exp = self.create(_Assert)
|
||||||
|
|
||||||
|
def check_fail(fn, msg):
|
||||||
|
if exp.raises_assertion_errors:
|
||||||
|
with self.assertRaises(AssertionError) as ctx:
|
||||||
|
fn()
|
||||||
|
self.assertEqual(str(ctx.exception), msg)
|
||||||
|
else:
|
||||||
|
# Without assertion exceptions, core device panics should still lead
|
||||||
|
# to a cleanly dropped connectionr rather than a hang/…
|
||||||
|
with self.assertRaises(ConnectionResetError):
|
||||||
|
fn()
|
||||||
|
|
||||||
|
exp.check(True)
|
||||||
|
check_fail(lambda: exp.check(False), "AssertionError")
|
||||||
|
exp.check_msg(True)
|
||||||
|
check_fail(lambda: exp.check_msg(False), "foo")
|
||||||
|
|
|
@ -6,9 +6,9 @@ class c:
|
||||||
|
|
||||||
i = c()
|
i = c()
|
||||||
|
|
||||||
assert i.a == 1
|
|
||||||
|
|
||||||
def f():
|
def f():
|
||||||
c = None
|
c = None
|
||||||
assert i.a == 1
|
assert i.a == 1
|
||||||
|
|
||||||
|
assert i.a == 1
|
||||||
f()
|
f()
|
||||||
|
|
Loading…
Reference in New Issue