165 lines
6.0 KiB
C++
165 lines
6.0 KiB
C++
|
#pragma once
|
||
|
|
||
|
#include <test/includes.hpp>
|
||
|
|
||
|
namespace test {
|
||
|
namespace ndarray_indexing {
|
||
|
void test_normal_1() {
|
||
|
/*
|
||
|
Reference Python code:
|
||
|
```python
|
||
|
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
|
||
|
# array([[ 0., 1., 2., 3.],
|
||
|
# [ 4., 5., 6., 7.],
|
||
|
# [ 8., 9., 10., 11.]])
|
||
|
|
||
|
dst_ndarray = ndarray[-2:, 1::2]
|
||
|
# array([[ 5., 7.],
|
||
|
# [ 9., 11.]])
|
||
|
|
||
|
assert dst_ndarray.shape == (2, 2)
|
||
|
assert dst_ndarray.strides == (32, 16)
|
||
|
assert dst_ndarray[0, 0] == 5.0
|
||
|
assert dst_ndarray[0, 1] == 7.0
|
||
|
assert dst_ndarray[1, 0] == 9.0
|
||
|
assert dst_ndarray[1, 1] == 11.0
|
||
|
```
|
||
|
*/
|
||
|
BEGIN_TEST();
|
||
|
|
||
|
// Prepare src_ndarray
|
||
|
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
|
||
|
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
|
||
|
int64_t src_itemsize = sizeof(double);
|
||
|
const int64_t src_ndims = 2;
|
||
|
int64_t src_shape[src_ndims] = {3, 4};
|
||
|
int64_t src_strides[src_ndims] = {};
|
||
|
NDArray<int64_t> src_ndarray = {.data = (uint8_t *)src_data,
|
||
|
.itemsize = src_itemsize,
|
||
|
.ndims = src_ndims,
|
||
|
.shape = src_shape,
|
||
|
.strides = src_strides};
|
||
|
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
||
|
|
||
|
// Prepare dst_ndarray
|
||
|
const int64_t dst_ndims = 2;
|
||
|
int64_t dst_shape[dst_ndims] = {999, 999}; // Empty values
|
||
|
int64_t dst_strides[dst_ndims] = {999, 999}; // Empty values
|
||
|
NDArray<int64_t> dst_ndarray = {.data = nullptr,
|
||
|
.ndims = dst_ndims,
|
||
|
.shape = dst_shape,
|
||
|
.strides = dst_strides};
|
||
|
|
||
|
// Create the subscripts in `ndarray[-2::, 1::2]`
|
||
|
UserSlice subscript_1;
|
||
|
subscript_1.set_start(-2);
|
||
|
|
||
|
UserSlice subscript_2;
|
||
|
subscript_2.set_start(1);
|
||
|
subscript_2.set_step(2);
|
||
|
|
||
|
const int64_t num_indexes = 2;
|
||
|
NDIndex indexes[num_indexes] = {
|
||
|
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_1},
|
||
|
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
|
||
|
|
||
|
ndarray::indexing::index(num_indexes, indexes, &src_ndarray, &dst_ndarray);
|
||
|
|
||
|
int64_t expected_shape[dst_ndims] = {2, 2};
|
||
|
int64_t expected_strides[dst_ndims] = {32, 16};
|
||
|
|
||
|
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
|
||
|
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
|
||
|
|
||
|
// dst_ndarray[0, 0]
|
||
|
assert_values_match(5.0,
|
||
|
*((double *)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, (int64_t[dst_ndims]){0, 0})));
|
||
|
// dst_ndarray[0, 1]
|
||
|
assert_values_match(7.0,
|
||
|
*((double *)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, (int64_t[dst_ndims]){0, 1})));
|
||
|
// dst_ndarray[1, 0]
|
||
|
assert_values_match(9.0,
|
||
|
*((double *)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, (int64_t[dst_ndims]){1, 0})));
|
||
|
// dst_ndarray[1, 1]
|
||
|
assert_values_match(11.0,
|
||
|
*((double *)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, (int64_t[dst_ndims]){1, 1})));
|
||
|
}
|
||
|
|
||
|
void test_normal_2() {
|
||
|
/*
|
||
|
```python
|
||
|
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
|
||
|
# array([[ 0., 1., 2., 3.],
|
||
|
# [ 4., 5., 6., 7.],
|
||
|
# [ 8., 9., 10., 11.]])
|
||
|
|
||
|
dst_ndarray = ndarray[2, ::-2]
|
||
|
# array([11., 9.])
|
||
|
|
||
|
assert dst_ndarray.shape == (2,)
|
||
|
assert dst_ndarray.strides == (-16,)
|
||
|
assert dst_ndarray[0] == 11.0
|
||
|
assert dst_ndarray[1] == 9.0
|
||
|
```
|
||
|
*/
|
||
|
BEGIN_TEST();
|
||
|
|
||
|
// Prepare src_ndarray
|
||
|
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
|
||
|
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
|
||
|
int64_t src_itemsize = sizeof(double);
|
||
|
const int64_t src_ndims = 2;
|
||
|
int64_t src_shape[src_ndims] = {3, 4};
|
||
|
int64_t src_strides[src_ndims] = {};
|
||
|
NDArray<int64_t> src_ndarray = {.data = (uint8_t *)src_data,
|
||
|
.itemsize = src_itemsize,
|
||
|
.ndims = src_ndims,
|
||
|
.shape = src_shape,
|
||
|
.strides = src_strides};
|
||
|
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
||
|
|
||
|
// Prepare dst_ndarray
|
||
|
const int64_t dst_ndims = 1;
|
||
|
int64_t dst_shape[dst_ndims] = {999}; // Empty values
|
||
|
int64_t dst_strides[dst_ndims] = {999}; // Empty values
|
||
|
NDArray<int64_t> dst_ndarray = {.data = nullptr,
|
||
|
.ndims = dst_ndims,
|
||
|
.shape = dst_shape,
|
||
|
.strides = dst_strides};
|
||
|
|
||
|
// Create the subscripts in `ndarray[2, ::-2]`
|
||
|
int64_t subscript_1 = 2;
|
||
|
|
||
|
UserSlice subscript_2;
|
||
|
subscript_2.set_step(-2);
|
||
|
|
||
|
const int64_t num_indexes = 2;
|
||
|
NDIndex indexes[num_indexes] = {
|
||
|
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1},
|
||
|
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
|
||
|
|
||
|
ndarray::indexing::index(num_indexes, indexes, &src_ndarray, &dst_ndarray);
|
||
|
|
||
|
int64_t expected_shape[dst_ndims] = {2};
|
||
|
int64_t expected_strides[dst_ndims] = {-16};
|
||
|
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
|
||
|
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
|
||
|
|
||
|
assert_values_match(11.0,
|
||
|
*((double *)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, (int64_t[dst_ndims]){0})));
|
||
|
assert_values_match(9.0,
|
||
|
*((double *)ndarray::basic::get_pelement_by_indices(
|
||
|
&dst_ndarray, (int64_t[dst_ndims]){1})));
|
||
|
}
|
||
|
|
||
|
void run() {
|
||
|
test_normal_1();
|
||
|
test_normal_2();
|
||
|
}
|
||
|
} // namespace ndarray_indexing
|
||
|
} // namespace test
|