forked from M-Labs/nac3
1
0
Fork 0
nac3/nac3core/irrt/irrt_test.cpp

62 lines
1.6 KiB
C++
Raw Normal View History

2024-07-08 14:16:35 +08:00
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#define IRRT_DONT_TYPEDEF_INTS
#include "irrt.hpp"
static void __test_fail(const char *file, int line) {
// NOTE: Try to make the location info follow a format that
// VSCode/other IDEs would recognize as a clickable URL.
printf("[!] test_fail() invoked at %s:%d", file, line);
exit(1);
}
#define test_fail() __test_fail(__FILE__, __LINE__);
template <typename T>
bool arrays_match(int len, T *as, T *bs) {
for (int i = 0; i < len; i++) {
if (as[i] != bs[i]) return false;
}
return true;
}
template <typename T>
void debug_print_array(const char* format, int len, T *as) {
printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
printf(format, as[i]);
}
printf("]\n");
}
template <typename T>
bool assert_arrays_match(const char *label, const char *format, int len, T *expected, T *got) {
auto match = arrays_match(len, expected, got);
if (!match) {
printf("expected %s: ", label);
debug_print_array(format, len, expected);
printf("got %s: ", label);
debug_print_array(format, len, got);
}
return match;
}
static void test_strides_from_shape() {
const uint64_t ndims = 4;
uint64_t shape[ndims] = { 999, 3, 5, 7 };
uint64_t strides[ndims] = { 0 };
__nac3_ndarray_strides_from_shape64(ndims, shape, strides);
uint64_t expected_strides[ndims] = { 3*5*7, 5*7, 7, 1 };
if (!assert_arrays_match("strides", "%u", ndims, expected_strides, strides)) test_fail();
}
int main() {
test_strides_from_shape();
return 0;
}