forked from M-Labs/nac3
127 lines
5.8 KiB
C++
127 lines
5.8 KiB
C++
|
#pragma once
|
||
|
|
||
|
#include <test/includes.hpp>
|
||
|
|
||
|
namespace test {
|
||
|
namespace ndarray_broadcast {
|
||
|
void test_can_broadcast_shape() {
|
||
|
BEGIN_TEST();
|
||
|
|
||
|
assert_values_match(true,
|
||
|
ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
1, (int32_t[]){3}, 5, (int32_t[]){1, 1, 1, 1, 3}));
|
||
|
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
1, (int32_t[]){3}, 2, (int32_t[]){3, 1}));
|
||
|
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
1, (int32_t[]){3}, 1, (int32_t[]){3}));
|
||
|
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
1, (int32_t[]){1}, 1, (int32_t[]){3}));
|
||
|
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
1, (int32_t[]){1}, 1, (int32_t[]){1}));
|
||
|
assert_values_match(
|
||
|
true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
3, (int32_t[]){256, 256, 3}, 3, (int32_t[]){256, 1, 3}));
|
||
|
assert_values_match(true,
|
||
|
ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){3}));
|
||
|
assert_values_match(false,
|
||
|
ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){2}));
|
||
|
assert_values_match(true,
|
||
|
ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){1}));
|
||
|
|
||
|
// In cases when the shapes contain zero(es)
|
||
|
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
1, (int32_t[]){0}, 1, (int32_t[]){1}));
|
||
|
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
1, (int32_t[]){0}, 1, (int32_t[]){2}));
|
||
|
assert_values_match(true,
|
||
|
ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
4, (int32_t[]){0, 4, 0, 0}, 1, (int32_t[]){1}));
|
||
|
assert_values_match(
|
||
|
true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 1, 1, 1}));
|
||
|
assert_values_match(
|
||
|
true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 4, 1, 1}));
|
||
|
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 3}));
|
||
|
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||
|
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 0}));
|
||
|
}
|
||
|
|
||
|
void test_ndarray_broadcast() {
|
||
|
/*
|
||
|
# array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
|
||
|
# >>> [[19.9 29.9 39.9 49.9]]
|
||
|
#
|
||
|
# array = np.broadcast_to(array, (2, 3, 4))
|
||
|
# >>> [[[19.9 29.9 39.9 49.9]
|
||
|
# >>> [19.9 29.9 39.9 49.9]
|
||
|
# >>> [19.9 29.9 39.9 49.9]]
|
||
|
# >>> [[19.9 29.9 39.9 49.9]
|
||
|
# >>> [19.9 29.9 39.9 49.9]
|
||
|
# >>> [19.9 29.9 39.9 49.9]]]
|
||
|
#
|
||
|
# assery array.strides == (0, 0, 8)
|
||
|
|
||
|
*/
|
||
|
BEGIN_TEST();
|
||
|
|
||
|
double in_data[4] = {19.9, 29.9, 39.9, 49.9};
|
||
|
const int32_t in_ndims = 2;
|
||
|
int32_t in_shape[in_ndims] = {1, 4};
|
||
|
int32_t in_strides[in_ndims] = {};
|
||
|
NDArray<int32_t> ndarray = {.data = (uint8_t*)in_data,
|
||
|
.itemsize = sizeof(double),
|
||
|
.ndims = in_ndims,
|
||
|
.shape = in_shape,
|
||
|
.strides = in_strides};
|
||
|
ndarray::basic::set_strides_by_shape(&ndarray);
|
||
|
|
||
|
const int32_t dst_ndims = 3;
|
||
|
int32_t dst_shape[dst_ndims] = {2, 3, 4};
|
||
|
int32_t dst_strides[dst_ndims] = {};
|
||
|
NDArray<int32_t> dst_ndarray = {
|
||
|
.ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides};
|
||
|
|
||
|
ndarray::broadcast::broadcast_to(&ndarray, &dst_ndarray);
|
||
|
|
||
|
assert_arrays_match(dst_ndims, ((int32_t[]){0, 0, 8}), dst_ndarray.strides);
|
||
|
|
||
|
assert_values_match(19.9,
|
||
|
*((double*)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, ((int32_t[]){0, 0, 0}))));
|
||
|
assert_values_match(29.9,
|
||
|
*((double*)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, ((int32_t[]){0, 0, 1}))));
|
||
|
assert_values_match(39.9,
|
||
|
*((double*)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, ((int32_t[]){0, 0, 2}))));
|
||
|
assert_values_match(49.9,
|
||
|
*((double*)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, ((int32_t[]){0, 0, 3}))));
|
||
|
assert_values_match(19.9,
|
||
|
*((double*)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, ((int32_t[]){0, 1, 0}))));
|
||
|
assert_values_match(29.9,
|
||
|
*((double*)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, ((int32_t[]){0, 1, 1}))));
|
||
|
assert_values_match(39.9,
|
||
|
*((double*)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, ((int32_t[]){0, 1, 2}))));
|
||
|
assert_values_match(49.9,
|
||
|
*((double*)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, ((int32_t[]){0, 1, 3}))));
|
||
|
assert_values_match(49.9,
|
||
|
*((double*)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, ((int32_t[]){1, 2, 3}))));
|
||
|
}
|
||
|
|
||
|
void run() {
|
||
|
test_can_broadcast_shape();
|
||
|
test_ndarray_broadcast();
|
||
|
}
|
||
|
} // namespace ndarray_broadcast
|
||
|
} // namespace test
|