62 lines
1.6 KiB
C++
62 lines
1.6 KiB
C++
|
#include <cstdint>
|
||
|
#include <cstdio>
|
||
|
#include <cstdlib>
|
||
|
|
||
|
#define IRRT_DONT_TYPEDEF_INTS
|
||
|
#include "irrt.hpp"
|
||
|
|
||
|
static void __test_fail(const char *file, int line) {
|
||
|
// NOTE: Try to make the location info follow a format that
|
||
|
// VSCode/other IDEs would recognize as a clickable URL.
|
||
|
printf("[!] test_fail() invoked at %s:%d", file, line);
|
||
|
exit(1);
|
||
|
}
|
||
|
|
||
|
#define test_fail() __test_fail(__FILE__, __LINE__);
|
||
|
|
||
|
template <typename T>
|
||
|
bool arrays_match(int len, T *as, T *bs) {
|
||
|
for (int i = 0; i < len; i++) {
|
||
|
if (as[i] != bs[i]) return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
template <typename T>
|
||
|
void debug_print_array(const char* format, int len, T *as) {
|
||
|
printf("[");
|
||
|
for (int i = 0; i < len; i++) {
|
||
|
if (i != 0) printf(", ");
|
||
|
printf(format, as[i]);
|
||
|
}
|
||
|
printf("]\n");
|
||
|
}
|
||
|
|
||
|
template <typename T>
|
||
|
bool assert_arrays_match(const char *label, const char *format, int len, T *expected, T *got) {
|
||
|
auto match = arrays_match(len, expected, got);
|
||
|
|
||
|
if (!match) {
|
||
|
printf("expected %s: ", label);
|
||
|
debug_print_array(format, len, expected);
|
||
|
printf("got %s: ", label);
|
||
|
debug_print_array(format, len, got);
|
||
|
}
|
||
|
|
||
|
return match;
|
||
|
}
|
||
|
|
||
|
static void test_strides_from_shape() {
|
||
|
const uint64_t ndims = 4;
|
||
|
uint64_t shape[ndims] = { 999, 3, 5, 7 };
|
||
|
uint64_t strides[ndims] = { 0 };
|
||
|
__nac3_ndarray_strides_from_shape64(ndims, shape, strides);
|
||
|
|
||
|
uint64_t expected_strides[ndims] = { 3*5*7, 5*7, 7, 1 };
|
||
|
if (!assert_arrays_match("strides", "%u", ndims, expected_strides, strides)) test_fail();
|
||
|
}
|
||
|
|
||
|
int main() {
|
||
|
test_strides_from_shape();
|
||
|
return 0;
|
||
|
}
|