forked from M-Labs/nac3
1
0
Fork 0

core: irrt refactor print & add print_ndarray

This commit is contained in:
lyken 2024-07-15 12:05:10 +08:00
parent b12d7fcb2d
commit b940b0a3a1
2 changed files with 86 additions and 25 deletions

View File

@ -22,16 +22,6 @@ void test_fail() {
exit(1);
}
template <typename T>
void debug_print_array(int len, const T* as) {
printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
print_value(as[i]);
}
printf("]");
}
void print_assertion_passed(const char* file, int line) {
printf("[*] Assertion passed on %s:%d\n", file, line);
}
@ -58,10 +48,10 @@ void __assert_arrays_match(const char* file, int line, int len, const T* expecte
} else {
print_assertion_failed(file, line);
printf("Expect = ");
debug_print_array(len, expected);
print_array(len, expected);
printf("\n");
printf(" Got = ");
debug_print_array(len, got);
print_array(len, got);
printf("\n");
test_fail();
}
@ -102,7 +92,7 @@ ErrorContext create_testing_errctx() {
return errctx;
}
void debug_print_errctx_content(ErrorContext* errctx) {
void print_errctx_content(ErrorContext* errctx) {
if (errctx->has_error()) {
printf(
"(Error ID %d): %s ... where param1 = %ld, param2 = %ld, param3 = %ld\n",
@ -121,7 +111,7 @@ void __assert_errctx_no_error(const char* file, int line, ErrorContext* errctx)
if (errctx->has_error()) {
print_assertion_failed(file, line);
printf("Expecting no error but caught the following:\n\n");
debug_print_errctx_content(errctx);
print_errctx_content(errctx);
test_fail();
}
}
@ -140,7 +130,7 @@ void __assert_errctx_has_error(const char* file, int line, ErrorContext *errctx,
expected_error_id,
errctx->error_id
);
debug_print_errctx_content(errctx);
print_errctx_content(errctx);
test_fail();
}
} else {

View File

@ -4,7 +4,7 @@
#include <cstdio>
template <class T>
void print_value(const T& value) {}
void print_value(const T& value);
template <>
void print_value(const bool& value) {
@ -31,17 +31,88 @@ void print_value(const uint32_t& value) {
printf("%u", value);
}
template <>
void print_value(const float& value) {
printf("%f", value);
}
template <>
void print_value(const double& value) {
printf("%f", value);
}
// template <double>
// void print_value(const double& value) {
// printf("%f", value);
// }
//
// template <char *>
// void print_value(const char*& value) {
// printf("%f", value);
// }
void print_repeated(const char *str, int count) {
for (int i = 0; i < count; i++) {
printf("%s", str);
}
}
template <typename T>
void print_array(int len, const T* as) {
printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
print_value(as[i]);
}
printf("]");
}
template<typename ElementT, typename SizeT>
void __print_ndarray_aux(bool first, bool last, SizeT* cursor, SizeT depth, NDArray<SizeT>* ndarray) {
// A really lazy recursive implementation
// Add left padding unless its the first entry (since there would be "[[[" before it)
if (!first) {
print_repeated(" ", depth);
}
const SizeT dim = ndarray->shape[depth];
if (depth + 1 == ndarray->ndims) {
// Recursed down to last dimension, print the values in a nice list
printf("[");
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
for (SizeT i = 0; i < dim; i++) {
ndarray::basic::util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor);
ElementT* pelement = (ElementT*) ndarray::basic::get_pelement_by_indices<SizeT>(ndarray, indices);
ElementT element = *pelement;
if (i != 0) printf(", "); // List delimiter
print_value(element);
printf("(@");
print_array(ndarray->ndims, indices);
printf(")");
(*cursor)++;
}
printf("]");
} else {
printf("[");
for (SizeT i = 0; i < ndarray->shape[depth]; i++) {
__print_ndarray_aux<ElementT, SizeT>(
i == 0, // first?
i + 1 == dim, // last?
cursor,
depth + 1,
ndarray
);
}
printf("]");
}
// Add newline unless its the last entry (since there will be "]]]" after it)
if (!last) {
print_repeated("\n", depth);
}
}
template <typename ElementT, typename SizeT>
void print_ndarray(NDArray<SizeT>* ndarray) {
if (ndarray->ndims == 0) {
printf("<empty ndarray>");
} else {
SizeT cursor = 0;
__print_ndarray_aux<ElementT, SizeT>(true, true, &cursor, 0, ndarray);
}
printf("\n");
}