forked from M-Labs/nac3
650 lines
22 KiB
C++
650 lines
22 KiB
C++
// This file will be compiled like a real C++ program,
|
|
// and we do have the luxury to use the standard libraries.
|
|
// That is if the nix flakes do not have issues... especially on msys2...
|
|
#include <cstdint>
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
|
|
// Set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` defines them
|
|
#define IRRT_DONT_TYPEDEF_INTS
|
|
#include "irrt_everything.hpp"
|
|
|
|
void test_fail() {
|
|
printf("[!] Test failed\n");
|
|
exit(1);
|
|
}
|
|
|
|
void __begin_test(const char* function_name, const char* file, int line) {
|
|
printf("######### Running %s @ %s:%d\n", function_name, file, line);
|
|
}
|
|
|
|
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
|
|
|
|
template <typename T>
|
|
void debug_print_array(const char* format, int len, T* as) {
|
|
printf("[");
|
|
for (int i = 0; i < len; i++) {
|
|
if (i != 0) printf(", ");
|
|
printf(format, as[i]);
|
|
}
|
|
printf("]");
|
|
}
|
|
|
|
template <typename T>
|
|
void assert_arrays_match(const char* label, const char* format, int len, T* expected, T* got) {
|
|
if (!arrays_match(len, expected, got)) {
|
|
printf(">>>>>>> %s\n", label);
|
|
printf(" Expecting = ");
|
|
debug_print_array(format, len, expected);
|
|
printf("\n");
|
|
printf(" Got = ");
|
|
debug_print_array(format, len, got);
|
|
printf("\n");
|
|
test_fail();
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void assert_values_match(const char* label, const char* format, T expected, T got) {
|
|
if (expected != got) {
|
|
printf(">>>>>>> %s\n", label);
|
|
printf(" Expecting = ");
|
|
printf(format, expected);
|
|
printf("\n");
|
|
printf(" Got = ");
|
|
printf(format, got);
|
|
printf("\n");
|
|
test_fail();
|
|
}
|
|
}
|
|
|
|
void print_repeated(const char *str, int count) {
|
|
for (int i = 0; i < count; i++) {
|
|
printf("%s", str);
|
|
}
|
|
}
|
|
|
|
template<typename SizeT, typename ElementT>
|
|
void __print_ndarray_aux(const char *format, bool first, bool last, SizeT* cursor, SizeT depth, NDArray<SizeT>* ndarray) {
|
|
// A really lazy recursive implementation
|
|
|
|
// Add left padding unless its the first entry (since there would be "[[[" before it)
|
|
if (!first) {
|
|
print_repeated(" ", depth);
|
|
}
|
|
|
|
const SizeT dim = ndarray->shape[depth];
|
|
if (depth + 1 == ndarray->ndims) {
|
|
// Recursed down to last dimension, print the values in a nice list
|
|
printf("[");
|
|
|
|
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
|
|
for (SizeT i = 0; i < dim; i++) {
|
|
ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor);
|
|
ElementT* pelement = (ElementT*) ndarray->get_pelement_by_indices(indices);
|
|
ElementT element = *pelement;
|
|
|
|
if (i != 0) printf(", "); // List delimiter
|
|
printf(format, element);
|
|
printf("(@");
|
|
debug_print_array("%d", ndarray->ndims, indices);
|
|
printf(")");
|
|
|
|
(*cursor)++;
|
|
}
|
|
printf("]");
|
|
} else {
|
|
printf("[");
|
|
for (SizeT i = 0; i < ndarray->shape[depth]; i++) {
|
|
__print_ndarray_aux<SizeT, ElementT>(
|
|
format,
|
|
i == 0, // first?
|
|
i + 1 == dim, // last?
|
|
cursor,
|
|
depth + 1,
|
|
ndarray
|
|
);
|
|
}
|
|
printf("]");
|
|
}
|
|
|
|
// Add newline unless its the last entry (since there will be "]]]" after it)
|
|
if (!last) {
|
|
print_repeated("\n", depth);
|
|
}
|
|
}
|
|
|
|
template<typename SizeT, typename ElementT>
|
|
void print_ndarray(const char *format, NDArray<SizeT>* ndarray) {
|
|
if (ndarray->ndims == 0) {
|
|
printf("<empty ndarray>");
|
|
} else {
|
|
SizeT cursor = 0;
|
|
__print_ndarray_aux<SizeT, ElementT>(format, true, true, &cursor, 0, ndarray);
|
|
}
|
|
printf("\n");
|
|
}
|
|
|
|
void test_calc_size_from_shape_normal() {
|
|
// Test shapes with normal values
|
|
BEGIN_TEST();
|
|
|
|
int32_t shape[4] = { 2, 3, 5, 7 };
|
|
assert_values_match("size", "%d", 210, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
|
|
}
|
|
|
|
void test_calc_size_from_shape_has_zero() {
|
|
// Test shapes with 0 in them
|
|
BEGIN_TEST();
|
|
|
|
int32_t shape[4] = { 2, 0, 5, 7 };
|
|
assert_values_match("size", "%d", 0, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
|
|
}
|
|
|
|
void test_set_strides_by_shape() {
|
|
// Test `set_strides_by_shape()`
|
|
BEGIN_TEST();
|
|
|
|
int32_t shape[4] = { 99, 3, 5, 7 };
|
|
int32_t strides[4] = { 0 };
|
|
ndarray_util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
|
|
|
|
int32_t expected_strides[4] = {
|
|
105 * sizeof(int32_t),
|
|
35 * sizeof(int32_t),
|
|
7 * sizeof(int32_t),
|
|
1 * sizeof(int32_t)
|
|
};
|
|
assert_arrays_match("strides", "%u", 4u, expected_strides, strides);
|
|
}
|
|
|
|
// void test_ndarray_indices_iter_normal() {
|
|
// // Test NDArrayIndicesIter normal behavior
|
|
// BEGIN_TEST();
|
|
//
|
|
// int32_t shape[3] = { 1, 2, 3 };
|
|
// int32_t indices[3] = { 0, 0, 0 };
|
|
// auto iter = NDArrayIndicesIter<int32_t> {
|
|
// .ndims = 3,
|
|
// .shape = shape,
|
|
// .indices = indices
|
|
// };
|
|
//
|
|
// assert_arrays_match("indices #0", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 });
|
|
// iter.next();
|
|
// assert_arrays_match("indices #1", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 });
|
|
// iter.next();
|
|
// assert_arrays_match("indices #2", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 2 });
|
|
// iter.next();
|
|
// assert_arrays_match("indices #3", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 0 });
|
|
// iter.next();
|
|
// assert_arrays_match("indices #4", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 1 });
|
|
// iter.next();
|
|
// assert_arrays_match("indices #5", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 2 });
|
|
// iter.next();
|
|
// assert_arrays_match("indices #6", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 }); // Loops back
|
|
// iter.next();
|
|
// assert_arrays_match("indices #7", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 });
|
|
// }
|
|
|
|
void test_ndarray_fill_generic() {
|
|
// Test ndarray fill_generic
|
|
BEGIN_TEST();
|
|
|
|
// Choose a type that's neither int32_t nor uint64_t (candidates of SizeT) to spice it up
|
|
// Also make all the octets non-zero, to see if `memcpy` in `fill_generic` is working perfectly.
|
|
uint16_t fill_value = 0xFACE;
|
|
|
|
uint16_t in_data[6] = { 100, 101, 102, 103, 104, 105 }; // Fill `data` with values that != `999`
|
|
int32_t in_itemsize = sizeof(uint16_t);
|
|
const int32_t in_ndims = 2;
|
|
int32_t in_shape[in_ndims] = { 2, 3 };
|
|
int32_t in_strides[in_ndims] = {};
|
|
NDArray<int32_t> ndarray = {
|
|
.data = (uint8_t*) in_data,
|
|
.itemsize = in_itemsize,
|
|
.ndims = in_ndims,
|
|
.shape = in_shape,
|
|
.strides = in_strides,
|
|
};
|
|
ndarray.set_strides_by_shape();
|
|
ndarray.fill_generic((uint8_t*) &fill_value); // `fill_generic` here
|
|
|
|
uint16_t expected_data[6] = { fill_value, fill_value, fill_value, fill_value, fill_value, fill_value };
|
|
assert_arrays_match("data", "0x%hX", 6, expected_data, in_data);
|
|
}
|
|
|
|
void test_ndarray_set_to_eye() {
|
|
// Test `set_to_eye` behavior (helper function to implement `np.eye()`)
|
|
BEGIN_TEST();
|
|
|
|
double in_data[9] = { 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0 };
|
|
int32_t in_itemsize = sizeof(double);
|
|
const int32_t in_ndims = 2;
|
|
int32_t in_shape[in_ndims] = { 3, 3 };
|
|
int32_t in_strides[in_ndims] = {};
|
|
NDArray<int32_t> ndarray = {
|
|
.data = (uint8_t*) in_data,
|
|
.itemsize = in_itemsize,
|
|
.ndims = in_ndims,
|
|
.shape = in_shape,
|
|
.strides = in_strides,
|
|
};
|
|
ndarray.set_strides_by_shape();
|
|
|
|
double zero = 0.0;
|
|
double one = 1.0;
|
|
ndarray.set_to_eye(1, (uint8_t*) &zero, (uint8_t*) &one);
|
|
|
|
assert_values_match("in_data[0]", "%f", 0.0, in_data[0]);
|
|
assert_values_match("in_data[1]", "%f", 1.0, in_data[1]);
|
|
assert_values_match("in_data[2]", "%f", 0.0, in_data[2]);
|
|
assert_values_match("in_data[3]", "%f", 0.0, in_data[3]);
|
|
assert_values_match("in_data[4]", "%f", 0.0, in_data[4]);
|
|
assert_values_match("in_data[5]", "%f", 1.0, in_data[5]);
|
|
assert_values_match("in_data[6]", "%f", 0.0, in_data[6]);
|
|
assert_values_match("in_data[7]", "%f", 0.0, in_data[7]);
|
|
assert_values_match("in_data[8]", "%f", 0.0, in_data[8]);
|
|
}
|
|
|
|
void test_slice_1() {
|
|
// Test `slice(5, None, None).indices(100) == slice(5, 100, 1)`
|
|
BEGIN_TEST();
|
|
|
|
UserSlice<int> user_slice = {
|
|
.start_defined = 1,
|
|
.start = 5,
|
|
.stop_defined = 0,
|
|
.step_defined = 0,
|
|
};
|
|
|
|
auto slice = user_slice.indices(100);
|
|
assert_values_match("start", "%d", 5, slice.start);
|
|
assert_values_match("stop", "%d", 100, slice.stop);
|
|
assert_values_match("step", "%d", 1, slice.step);
|
|
}
|
|
|
|
void test_slice_2() {
|
|
// Test `slice(400, 999, None).indices(100) == slice(100, 100, 1)`
|
|
BEGIN_TEST();
|
|
|
|
UserSlice<int> user_slice = {
|
|
.start_defined = 1,
|
|
.start = 400,
|
|
.stop_defined = 0,
|
|
.step_defined = 0,
|
|
};
|
|
|
|
auto slice = user_slice.indices(100);
|
|
assert_values_match("start", "%d", 100, slice.start);
|
|
assert_values_match("stop", "%d", 100, slice.stop);
|
|
assert_values_match("step", "%d", 1, slice.step);
|
|
}
|
|
|
|
void test_slice_3() {
|
|
// Test `slice(-10, -5, None).indices(100) == slice(90, 95, 1)`
|
|
BEGIN_TEST();
|
|
|
|
UserSlice<int> user_slice = {
|
|
.start_defined = 1,
|
|
.start = -10,
|
|
.stop_defined = 1,
|
|
.stop = -5,
|
|
.step_defined = 0,
|
|
};
|
|
|
|
auto slice = user_slice.indices(100);
|
|
assert_values_match("start", "%d", 90, slice.start);
|
|
assert_values_match("stop", "%d", 95, slice.stop);
|
|
assert_values_match("step", "%d", 1, slice.step);
|
|
}
|
|
|
|
void test_slice_4() {
|
|
// Test `slice(None, None, -5).indices(100) == (99, -1, -5)`
|
|
BEGIN_TEST();
|
|
|
|
UserSlice<int> user_slice = {
|
|
.start_defined = 0,
|
|
.stop_defined = 0,
|
|
.step_defined = 1,
|
|
.step = -5
|
|
};
|
|
|
|
auto slice = user_slice.indices(100);
|
|
assert_values_match("start", "%d", 99, slice.start);
|
|
assert_values_match("stop", "%d", -1, slice.stop);
|
|
assert_values_match("step", "%d", -5, slice.step);
|
|
}
|
|
|
|
void test_ndslice_1() {
|
|
/*
|
|
Reference Python code:
|
|
```python
|
|
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
|
|
# array([[ 0., 1., 2., 3.],
|
|
# [ 4., 5., 6., 7.],
|
|
# [ 8., 9., 10., 11.]])
|
|
|
|
dst_ndarray = ndarray[-2:, 1::2]
|
|
# array([[ 5., 7.],
|
|
# [ 9., 11.]])
|
|
|
|
assert dst_ndarray.shape == (2, 2)
|
|
assert dst_ndarray.strides == (32, 16)
|
|
assert dst_ndarray[0, 0] == 5.0
|
|
assert dst_ndarray[0, 1] == 7.0
|
|
assert dst_ndarray[1, 0] == 9.0
|
|
assert dst_ndarray[1, 1] == 11.0
|
|
```
|
|
*/
|
|
BEGIN_TEST();
|
|
|
|
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
|
int32_t in_itemsize = sizeof(double);
|
|
const int32_t in_ndims = 2;
|
|
int32_t in_shape[in_ndims] = { 3, 4 };
|
|
int32_t in_strides[in_ndims] = {};
|
|
NDArray<int32_t> ndarray = {
|
|
.data = (uint8_t*) in_data,
|
|
.itemsize = in_itemsize,
|
|
.ndims = in_ndims,
|
|
.shape = in_shape,
|
|
.strides = in_strides
|
|
};
|
|
ndarray.set_strides_by_shape();
|
|
|
|
// Destination ndarray
|
|
// As documented, ndims and shape & strides must be allocated and determined by the caller.
|
|
const int32_t dst_ndims = 2;
|
|
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
|
|
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
|
|
NDArray<int32_t> dst_ndarray = {
|
|
.data = nullptr,
|
|
.ndims = dst_ndims,
|
|
.shape = dst_shape,
|
|
.strides = dst_strides
|
|
};
|
|
|
|
// Create the slice in `ndarray[-2::, 1::2]`
|
|
UserSlice<int32_t> user_slice_1 = {
|
|
.start_defined = 1,
|
|
.start = -2,
|
|
.stop_defined = 0,
|
|
.step_defined = 0
|
|
};
|
|
|
|
UserSlice<int32_t> user_slice_2 = {
|
|
.start_defined = 1,
|
|
.start = 1,
|
|
.stop_defined = 0,
|
|
.step_defined = 1,
|
|
.step = 2
|
|
};
|
|
|
|
const int32_t num_ndslices = 2;
|
|
NDSlice ndslices[num_ndslices] = {
|
|
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_1 },
|
|
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
|
|
};
|
|
|
|
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
|
|
|
|
int32_t expected_shape[dst_ndims] = { 2, 2 };
|
|
int32_t expected_strides[dst_ndims] = { 32, 16 };
|
|
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
|
|
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
|
|
|
|
assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 0 })));
|
|
assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 1 })));
|
|
assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 0 })));
|
|
assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 1 })));
|
|
}
|
|
|
|
void test_ndslice_2() {
|
|
/*
|
|
```python
|
|
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
|
|
# array([[ 0., 1., 2., 3.],
|
|
# [ 4., 5., 6., 7.],
|
|
# [ 8., 9., 10., 11.]])
|
|
|
|
dst_ndarray = ndarray[2, ::-2]
|
|
# array([11., 9.])
|
|
|
|
assert dst_ndarray.shape == (2,)
|
|
assert dst_ndarray.strides == (-16,)
|
|
assert dst_ndarray[0] == 11.0
|
|
assert dst_ndarray[1] == 9.0
|
|
|
|
dst_ndarray[1, 0] == 99 # If you write to `dst_ndarray`
|
|
assert ndarray[1, 3] == 99 # `ndarray` also updates!!
|
|
```
|
|
*/
|
|
BEGIN_TEST();
|
|
|
|
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
|
int32_t in_itemsize = sizeof(double);
|
|
const int32_t in_ndims = 2;
|
|
int32_t in_shape[in_ndims] = { 3, 4 };
|
|
int32_t in_strides[in_ndims] = {};
|
|
NDArray<int32_t> ndarray = {
|
|
.data = (uint8_t*) in_data,
|
|
.itemsize = in_itemsize,
|
|
.ndims = in_ndims,
|
|
.shape = in_shape,
|
|
.strides = in_strides
|
|
};
|
|
ndarray.set_strides_by_shape();
|
|
|
|
// Destination ndarray
|
|
// As documented, ndims and shape & strides must be allocated and determined by the caller.
|
|
const int32_t dst_ndims = 1;
|
|
int32_t dst_shape[dst_ndims] = {999}; // Empty values
|
|
int32_t dst_strides[dst_ndims] = {999}; // Empty values
|
|
NDArray<int32_t> dst_ndarray = {
|
|
.data = nullptr,
|
|
.ndims = dst_ndims,
|
|
.shape = dst_shape,
|
|
.strides = dst_strides
|
|
};
|
|
|
|
// Create the slice in `ndarray[2, ::-2]`
|
|
int32_t user_slice_1 = 2;
|
|
UserSlice<int32_t> user_slice_2 = {
|
|
.start_defined = 0,
|
|
.stop_defined = 0,
|
|
.step_defined = 1,
|
|
.step = -2
|
|
};
|
|
|
|
const int32_t num_ndslices = 2;
|
|
NDSlice ndslices[num_ndslices] = {
|
|
{ .type = INPUT_SLICE_TYPE_INDEX, .slice = (uint8_t*) &user_slice_1 },
|
|
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
|
|
};
|
|
|
|
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
|
|
|
|
int32_t expected_shape[dst_ndims] = { 2 };
|
|
int32_t expected_strides[dst_ndims] = { -16 };
|
|
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
|
|
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
|
|
|
|
// [5.0, 3.0]
|
|
assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0 })));
|
|
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1 })));
|
|
}
|
|
|
|
void test_can_broadcast_shape() {
|
|
BEGIN_TEST();
|
|
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 5, (int32_t[]) { 1, 1, 1, 1, 3 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([3], [3, 1]) == false",
|
|
"%d",
|
|
false,
|
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 2, (int32_t[]) { 3, 1 }));
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([3], [3]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 1, (int32_t[]) { 3 }));
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([1], [3]) == false",
|
|
"%d",
|
|
false,
|
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 3 }));
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([1], [1]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 1 }));
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 3, (int32_t[]) { 256, 1, 3 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([256, 256, 3], [3]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 3 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([256, 256, 3], [2]) == false",
|
|
"%d",
|
|
false,
|
|
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 2 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([256, 256, 3], [1]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 1 })
|
|
);
|
|
|
|
// In cases when the shapes contain zero(es)
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([0], [1]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 1 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([0], [2]) == false",
|
|
"%d",
|
|
false,
|
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 2 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([0, 4, 0, 0], [1]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 1, (int32_t[]) { 1 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 1, 1, 1 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true",
|
|
"%d",
|
|
true,
|
|
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 4, 1, 1 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([4, 3], [0, 3]) == false",
|
|
"%d",
|
|
false,
|
|
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 3 })
|
|
);
|
|
assert_values_match(
|
|
"can_broadcast_shape_to([4, 3], [0, 0]) == false",
|
|
"%d",
|
|
false,
|
|
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 0 })
|
|
);
|
|
}
|
|
|
|
void test_ndarray_broadcast_1() {
|
|
/*
|
|
```python
|
|
array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
|
|
>>> [[19.9 29.9 39.9 49.9]]
|
|
|
|
array = np.broadcast_to(array, (2, 3, 4))
|
|
>>> [[[19.9 29.9 39.9 49.9]
|
|
>>> [19.9 29.9 39.9 49.9]
|
|
>>> [19.9 29.9 39.9 49.9]]
|
|
>>> [[19.9 29.9 39.9 49.9]
|
|
>>> [19.9 29.9 39.9 49.9]
|
|
>>> [19.9 29.9 39.9 49.9]]]
|
|
|
|
assert array.strides == (0, 0, 8)
|
|
# and then pick some values in `array` and check them...
|
|
```
|
|
*/
|
|
BEGIN_TEST();
|
|
|
|
double in_data[4] = { 19.9, 29.9, 39.9, 49.9 };
|
|
const int32_t in_ndims = 2;
|
|
int32_t in_shape[in_ndims] = {1, 4};
|
|
int32_t in_strides[in_ndims] = {};
|
|
NDArray<int32_t> ndarray = {
|
|
.data = (uint8_t*) in_data,
|
|
.itemsize = sizeof(double),
|
|
.ndims = in_ndims,
|
|
.shape = in_shape,
|
|
.strides = in_strides
|
|
};
|
|
ndarray.set_strides_by_shape();
|
|
|
|
const int32_t dst_ndims = 3;
|
|
int32_t dst_shape[dst_ndims] = {2, 3, 4};
|
|
int32_t dst_strides[dst_ndims] = {};
|
|
NDArray<int32_t> dst_ndarray = {
|
|
.ndims = dst_ndims,
|
|
.shape = dst_shape,
|
|
.strides = dst_strides
|
|
};
|
|
|
|
ndarray.broadcast_to(&dst_ndarray);
|
|
|
|
assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides);
|
|
|
|
assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 0})));
|
|
assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 1})));
|
|
assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 2})));
|
|
assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 3})));
|
|
assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 0})));
|
|
assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 1})));
|
|
assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 2})));
|
|
assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 3})));
|
|
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {1, 2, 3})));
|
|
}
|
|
|
|
int main() {
|
|
test_calc_size_from_shape_normal();
|
|
test_calc_size_from_shape_has_zero();
|
|
test_set_strides_by_shape();
|
|
// test_ndarray_indices_iter_normal();
|
|
test_ndarray_fill_generic();
|
|
test_ndarray_set_to_eye();
|
|
test_slice_1();
|
|
test_slice_2();
|
|
test_slice_3();
|
|
test_slice_4();
|
|
test_ndslice_1();
|
|
test_ndslice_2();
|
|
test_can_broadcast_shape();
|
|
test_ndarray_broadcast_1();
|
|
return 0;
|
|
} |