forked from M-Labs/nac3
1
0
Fork 0
nac3/nac3core/irrt/test/ndarray.hpp

44 lines
1.1 KiB
C++
Raw Normal View History

#pragma once
#include <test/core.hpp>
#include <irrt/numpy/ndarray.hpp>
#include <irrt/numpy/ndarray_util.hpp>
void test_calc_size_from_shape_normal() {
// Test shapes with normal values
BEGIN_TEST();
int32_t shape[4] = { 2, 3, 5, 7 };
assert_values_match(210, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
}
void test_calc_size_from_shape_has_zero() {
// Test shapes with 0 in them
BEGIN_TEST();
int32_t shape[4] = { 2, 0, 5, 7 };
assert_values_match(0, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
}
void test_set_strides_by_shape() {
// Test `set_strides_by_shape()`
BEGIN_TEST();
int32_t shape[4] = { 99, 3, 5, 7 };
int32_t strides[4] = { 0 };
ndarray_util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
int32_t expected_strides[4] = {
105 * sizeof(int32_t),
35 * sizeof(int32_t),
7 * sizeof(int32_t),
1 * sizeof(int32_t)
};
assert_arrays_match(4, expected_strides, strides);
}
void run_all_tests_ndarray() {
test_calc_size_from_shape_normal();
test_calc_size_from_shape_has_zero();
test_set_strides_by_shape();
}