forked from M-Labs/nac3
1
0
Fork 0

core: irrt refactor print & add print_ndarray

This commit is contained in:
lyken 2024-07-15 12:05:10 +08:00
parent b12d7fcb2d
commit b940b0a3a1
2 changed files with 86 additions and 25 deletions

View File

@ -22,16 +22,6 @@ void test_fail() {
exit(1); exit(1);
} }
template <typename T>
void debug_print_array(int len, const T* as) {
printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
print_value(as[i]);
}
printf("]");
}
void print_assertion_passed(const char* file, int line) { void print_assertion_passed(const char* file, int line) {
printf("[*] Assertion passed on %s:%d\n", file, line); printf("[*] Assertion passed on %s:%d\n", file, line);
} }
@ -58,10 +48,10 @@ void __assert_arrays_match(const char* file, int line, int len, const T* expecte
} else { } else {
print_assertion_failed(file, line); print_assertion_failed(file, line);
printf("Expect = "); printf("Expect = ");
debug_print_array(len, expected); print_array(len, expected);
printf("\n"); printf("\n");
printf(" Got = "); printf(" Got = ");
debug_print_array(len, got); print_array(len, got);
printf("\n"); printf("\n");
test_fail(); test_fail();
} }
@ -102,7 +92,7 @@ ErrorContext create_testing_errctx() {
return errctx; return errctx;
} }
void debug_print_errctx_content(ErrorContext* errctx) { void print_errctx_content(ErrorContext* errctx) {
if (errctx->has_error()) { if (errctx->has_error()) {
printf( printf(
"(Error ID %d): %s ... where param1 = %ld, param2 = %ld, param3 = %ld\n", "(Error ID %d): %s ... where param1 = %ld, param2 = %ld, param3 = %ld\n",
@ -121,7 +111,7 @@ void __assert_errctx_no_error(const char* file, int line, ErrorContext* errctx)
if (errctx->has_error()) { if (errctx->has_error()) {
print_assertion_failed(file, line); print_assertion_failed(file, line);
printf("Expecting no error but caught the following:\n\n"); printf("Expecting no error but caught the following:\n\n");
debug_print_errctx_content(errctx); print_errctx_content(errctx);
test_fail(); test_fail();
} }
} }
@ -140,7 +130,7 @@ void __assert_errctx_has_error(const char* file, int line, ErrorContext *errctx,
expected_error_id, expected_error_id,
errctx->error_id errctx->error_id
); );
debug_print_errctx_content(errctx); print_errctx_content(errctx);
test_fail(); test_fail();
} }
} else { } else {

View File

@ -4,7 +4,7 @@
#include <cstdio> #include <cstdio>
template <class T> template <class T>
void print_value(const T& value) {} void print_value(const T& value);
template <> template <>
void print_value(const bool& value) { void print_value(const bool& value) {
@ -31,17 +31,88 @@ void print_value(const uint32_t& value) {
printf("%u", value); printf("%u", value);
} }
template <>
void print_value(const float& value) {
printf("%f", value);
}
template <> template <>
void print_value(const double& value) { void print_value(const double& value) {
printf("%f", value); printf("%f", value);
} }
// template <double> void print_repeated(const char *str, int count) {
// void print_value(const double& value) { for (int i = 0; i < count; i++) {
// printf("%f", value); printf("%s", str);
// } }
// }
// template <char *>
// void print_value(const char*& value) { template <typename T>
// printf("%f", value); void print_array(int len, const T* as) {
// } printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
print_value(as[i]);
}
printf("]");
}
template<typename ElementT, typename SizeT>
void __print_ndarray_aux(bool first, bool last, SizeT* cursor, SizeT depth, NDArray<SizeT>* ndarray) {
// A really lazy recursive implementation
// Add left padding unless its the first entry (since there would be "[[[" before it)
if (!first) {
print_repeated(" ", depth);
}
const SizeT dim = ndarray->shape[depth];
if (depth + 1 == ndarray->ndims) {
// Recursed down to last dimension, print the values in a nice list
printf("[");
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
for (SizeT i = 0; i < dim; i++) {
ndarray::basic::util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor);
ElementT* pelement = (ElementT*) ndarray::basic::get_pelement_by_indices<SizeT>(ndarray, indices);
ElementT element = *pelement;
if (i != 0) printf(", "); // List delimiter
print_value(element);
printf("(@");
print_array(ndarray->ndims, indices);
printf(")");
(*cursor)++;
}
printf("]");
} else {
printf("[");
for (SizeT i = 0; i < ndarray->shape[depth]; i++) {
__print_ndarray_aux<ElementT, SizeT>(
i == 0, // first?
i + 1 == dim, // last?
cursor,
depth + 1,
ndarray
);
}
printf("]");
}
// Add newline unless its the last entry (since there will be "]]]" after it)
if (!last) {
print_repeated("\n", depth);
}
}
template <typename ElementT, typename SizeT>
void print_ndarray(NDArray<SizeT>* ndarray) {
if (ndarray->ndims == 0) {
printf("<empty ndarray>");
} else {
SizeT cursor = 0;
__print_ndarray_aux<ElementT, SizeT>(true, true, &cursor, 0, ndarray);
}
printf("\n");
}