forked from M-Labs/nac3
72 lines
3.0 KiB
C++
72 lines
3.0 KiB
C++
#pragma once
|
|
|
|
#include <test/core.hpp>
|
|
#include <irrt_everything.hpp>
|
|
|
|
namespace test { namespace ndarray_broadcast {
|
|
void test_ndarray_broadcast_1() {
|
|
/*
|
|
```python
|
|
array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
|
|
>>> [[19.9 29.9 39.9 49.9]]
|
|
|
|
array = np.broadcast_to(array, (2, 3, 4))
|
|
>>> [[[19.9 29.9 39.9 49.9]
|
|
>>> [19.9 29.9 39.9 49.9]
|
|
>>> [19.9 29.9 39.9 49.9]]
|
|
>>> [[19.9 29.9 39.9 49.9]
|
|
>>> [19.9 29.9 39.9 49.9]
|
|
>>> [19.9 29.9 39.9 49.9]]]
|
|
|
|
assert array.strides == (0, 0, 8)
|
|
# and then pick some values in `array` and check them...
|
|
```
|
|
*/
|
|
BEGIN_TEST();
|
|
|
|
// Prepare src_ndarray
|
|
double src_data[4] = { 19.9, 29.9, 39.9, 49.9 };
|
|
const int32_t src_ndims = 2;
|
|
int32_t src_shape[src_ndims] = {1, 4};
|
|
int32_t src_strides[src_ndims] = {};
|
|
NDArray<int32_t> src_ndarray = {
|
|
.data = (uint8_t*) src_data,
|
|
.itemsize = sizeof(double),
|
|
.ndims = src_ndims,
|
|
.shape = src_shape,
|
|
.strides = src_strides
|
|
};
|
|
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
|
|
|
// Prepare dst_ndarray
|
|
const int32_t dst_ndims = 3;
|
|
int32_t dst_shape[dst_ndims] = {2, 3, 4};
|
|
int32_t dst_strides[dst_ndims] = {};
|
|
NDArray<int32_t> dst_ndarray = {
|
|
.ndims = dst_ndims,
|
|
.shape = dst_shape,
|
|
.strides = dst_strides
|
|
};
|
|
|
|
// Broadcast
|
|
ErrorContext errctx = create_testing_errctx();
|
|
ndarray::broadcast::broadcast_to(&errctx, &src_ndarray, &dst_ndarray);
|
|
assert_errctx_no_error(&errctx);
|
|
|
|
assert_arrays_match(dst_ndims, ((int32_t[]) { 0, 0, 8 }), dst_ndarray.strides);
|
|
|
|
assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 0}))));
|
|
assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 1}))));
|
|
assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 2}))));
|
|
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 3}))));
|
|
assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 0}))));
|
|
assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 1}))));
|
|
assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 2}))));
|
|
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 3}))));
|
|
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {1, 2, 3}))));
|
|
}
|
|
|
|
void run() {
|
|
test_ndarray_broadcast_1();
|
|
}
|
|
}} |