From b940b0a3a1cfeb1df26b8bb1681aaeb5ce6a2004 Mon Sep 17 00:00:00 2001 From: lyken Date: Mon, 15 Jul 2024 12:05:10 +0800 Subject: [PATCH] core: irrt refactor print & add print_ndarray --- nac3core/irrt/test/core.hpp | 20 ++------ nac3core/irrt/test/print.hpp | 91 ++++++++++++++++++++++++++++++++---- 2 files changed, 86 insertions(+), 25 deletions(-) diff --git a/nac3core/irrt/test/core.hpp b/nac3core/irrt/test/core.hpp index 1fcb9d55..94e51777 100644 --- a/nac3core/irrt/test/core.hpp +++ b/nac3core/irrt/test/core.hpp @@ -22,16 +22,6 @@ void test_fail() { exit(1); } -template -void debug_print_array(int len, const T* as) { - printf("["); - for (int i = 0; i < len; i++) { - if (i != 0) printf(", "); - print_value(as[i]); - } - printf("]"); -} - void print_assertion_passed(const char* file, int line) { printf("[*] Assertion passed on %s:%d\n", file, line); } @@ -58,10 +48,10 @@ void __assert_arrays_match(const char* file, int line, int len, const T* expecte } else { print_assertion_failed(file, line); printf("Expect = "); - debug_print_array(len, expected); + print_array(len, expected); printf("\n"); printf(" Got = "); - debug_print_array(len, got); + print_array(len, got); printf("\n"); test_fail(); } @@ -102,7 +92,7 @@ ErrorContext create_testing_errctx() { return errctx; } -void debug_print_errctx_content(ErrorContext* errctx) { +void print_errctx_content(ErrorContext* errctx) { if (errctx->has_error()) { printf( "(Error ID %d): %s ... where param1 = %ld, param2 = %ld, param3 = %ld\n", @@ -121,7 +111,7 @@ void __assert_errctx_no_error(const char* file, int line, ErrorContext* errctx) if (errctx->has_error()) { print_assertion_failed(file, line); printf("Expecting no error but caught the following:\n\n"); - debug_print_errctx_content(errctx); + print_errctx_content(errctx); test_fail(); } } @@ -140,7 +130,7 @@ void __assert_errctx_has_error(const char* file, int line, ErrorContext *errctx, expected_error_id, errctx->error_id ); - debug_print_errctx_content(errctx); + print_errctx_content(errctx); test_fail(); } } else { diff --git a/nac3core/irrt/test/print.hpp b/nac3core/irrt/test/print.hpp index cccb2c2b..aed26679 100644 --- a/nac3core/irrt/test/print.hpp +++ b/nac3core/irrt/test/print.hpp @@ -4,7 +4,7 @@ #include template -void print_value(const T& value) {} +void print_value(const T& value); template <> void print_value(const bool& value) { @@ -31,17 +31,88 @@ void print_value(const uint32_t& value) { printf("%u", value); } +template <> +void print_value(const float& value) { + printf("%f", value); +} + template <> void print_value(const double& value) { printf("%f", value); } -// template -// void print_value(const double& value) { -// printf("%f", value); -// } -// -// template -// void print_value(const char*& value) { -// printf("%f", value); -// } \ No newline at end of file +void print_repeated(const char *str, int count) { + for (int i = 0; i < count; i++) { + printf("%s", str); + } +} + +template +void print_array(int len, const T* as) { + printf("["); + for (int i = 0; i < len; i++) { + if (i != 0) printf(", "); + print_value(as[i]); + } + printf("]"); +} + +template +void __print_ndarray_aux(bool first, bool last, SizeT* cursor, SizeT depth, NDArray* ndarray) { + // A really lazy recursive implementation + + // Add left padding unless its the first entry (since there would be "[[[" before it) + if (!first) { + print_repeated(" ", depth); + } + + const SizeT dim = ndarray->shape[depth]; + if (depth + 1 == ndarray->ndims) { + // Recursed down to last dimension, print the values in a nice list + printf("["); + + SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims); + for (SizeT i = 0; i < dim; i++) { + ndarray::basic::util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor); + ElementT* pelement = (ElementT*) ndarray::basic::get_pelement_by_indices(ndarray, indices); + ElementT element = *pelement; + + if (i != 0) printf(", "); // List delimiter + print_value(element); + printf("(@"); + print_array(ndarray->ndims, indices); + printf(")"); + + (*cursor)++; + } + printf("]"); + } else { + printf("["); + for (SizeT i = 0; i < ndarray->shape[depth]; i++) { + __print_ndarray_aux( + i == 0, // first? + i + 1 == dim, // last? + cursor, + depth + 1, + ndarray + ); + } + printf("]"); + } + + // Add newline unless its the last entry (since there will be "]]]" after it) + if (!last) { + print_repeated("\n", depth); + } +} + +template +void print_ndarray(NDArray* ndarray) { + if (ndarray->ndims == 0) { + printf(""); + } else { + SizeT cursor = 0; + __print_ndarray_aux(true, true, &cursor, 0, ndarray); + } + printf("\n"); +}