1
0
forked from M-Labs/nac3
nac3/nac3core/irrt/test/test_ndarray_indexing.hpp

220 lines
7.9 KiB
C++
Raw Permalink Normal View History

#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_indexing {
void test_normal_1() {
/*
Reference Python code:
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[-2:, 1::2]
# array([[ 5., 7.],
# [ 9., 11.]])
assert dst_ndarray.shape == (2, 2)
assert dst_ndarray.strides == (32, 16)
assert dst_ndarray[0, 0] == 5.0
assert dst_ndarray[0, 1] == 7.0
assert dst_ndarray[1, 0] == 9.0
assert dst_ndarray[1, 1] == 11.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int32_t src_itemsize = sizeof(double);
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {3, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int32_t dst_ndims = 2;
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
NDArray<int32_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[-2::, 1::2]`
UserSlice subscript_1;
subscript_1.set_start(-2);
UserSlice subscript_2;
subscript_2.set_start(1);
subscript_2.set_step(2);
const int32_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ErrorContext errctx = create_testing_errctx();
ndarray::indexing::index(&errctx, num_indexes, indexes, &src_ndarray,
&dst_ndarray);
assert_errctx_no_error(&errctx);
int32_t expected_shape[dst_ndims] = {2, 2};
int32_t expected_strides[dst_ndims] = {32, 16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
// dst_ndarray[0, 0]
assert_values_match(5.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){0, 0})));
// dst_ndarray[0, 1]
assert_values_match(7.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){0, 1})));
// dst_ndarray[1, 0]
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){1, 0})));
// dst_ndarray[1, 1]
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){1, 1})));
}
void test_normal_2() {
/*
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[2, ::-2]
# array([11., 9.])
assert dst_ndarray.shape == (2,)
assert dst_ndarray.strides == (-16,)
assert dst_ndarray[0] == 11.0
assert dst_ndarray[1] == 9.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int32_t src_itemsize = sizeof(double);
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {3, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int32_t dst_ndims = 1;
int32_t dst_shape[dst_ndims] = {999}; // Empty values
int32_t dst_strides[dst_ndims] = {999}; // Empty values
NDArray<int32_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[2, ::-2]`
int32_t subscript_1 = 2;
UserSlice subscript_2;
subscript_2.set_step(-2);
const int32_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ErrorContext errctx = create_testing_errctx();
ndarray::indexing::index(&errctx, num_indexes, indexes, &src_ndarray,
&dst_ndarray);
assert_errctx_no_error(&errctx);
int32_t expected_shape[dst_ndims] = {2};
int32_t expected_strides[dst_ndims] = {-16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){0})));
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){1})));
}
void test_index_subscript_out_of_bounds() {
/*
# Consider `my_array`
print(my_array.shape)
# (4, 5, 6)
my_array[2, 100] # error, index subscript at axis 1 is out of bounds
*/
BEGIN_TEST();
// Prepare src_ndarray
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {3, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {
.data = (uint8_t *)nullptr, // placeholder, we wouldn't access it
.itemsize = sizeof(double), // placeholder
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Create the subscripts in `my_array[2, 100]`
int32_t subscript_1 = 2;
int32_t subscript_2 = 100;
const int32_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT,
.data = (uint8_t *)&subscript_2}};
// Prepare dst_ndarray
const int32_t dst_ndims = 0;
int32_t dst_shape[dst_ndims] = {};
int32_t dst_strides[dst_ndims] = {};
NDArray<int32_t> dst_ndarray = {.data = nullptr, // placehloder
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
ErrorContext errctx = create_testing_errctx();
ndarray::indexing::index(&errctx, num_indexes, indexes, &src_ndarray,
&dst_ndarray);
assert_errctx_has_error(&errctx, errctx.error_ids->index_error);
}
void run() {
test_normal_1();
test_normal_2();
test_index_subscript_out_of_bounds();
}
} // namespace ndarray_indexing
} // namespace test