nac3/nac3core/irrt/test/test_ndarray_subscript.hpp
lyken 628965e519 core: irrt reformat & more progress
progress details:
- organize test suite sources with namespaces
- organize ndarray implementations with namespaces
- remove extraneous code/comment
- add more tests
- some renaming
- fix pre-existing bugs
- ndarray::subscript now throw errors
2024-07-15 11:50:45 +08:00

182 lines
6.5 KiB
C++

#pragma once
#include <test/core.hpp>
#include <irrt_everything.hpp>
namespace test { namespace ndarray_subscript {
void test_ndsubscript_normal_1() {
/*
Reference Python code:
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[-2:, 1::2]
# array([[ 5., 7.],
# [ 9., 11.]])
assert dst_ndarray.shape == (2, 2)
assert dst_ndarray.strides == (32, 16)
assert dst_ndarray[0, 0] == 5.0
assert dst_ndarray[0, 1] == 7.0
assert dst_ndarray[1, 0] == 9.0
assert dst_ndarray[1, 1] == 11.0
```
*/
BEGIN_TEST();
double src_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
int32_t src_itemsize = sizeof(double);
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = { 3, 4 };
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {
.data = (uint8_t*) src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides
};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Destination ndarray
// As documented, ndims and shape & strides must be allocated and determined by the caller.
const int32_t dst_ndims = 2;
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
NDArray<int32_t> dst_ndarray = {
.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
// Create the slice in `ndarray[-2::, 1::2]`
UserSlice subscript_1;
subscript_1.set_start(-2);
UserSlice subscript_2;
subscript_2.set_start(1);
subscript_2.set_step(2);
const int32_t num_ndsubscripts = 2;
NDSubscript ndsubscripts[num_ndsubscripts] = {
{ .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_1 },
{ .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_2 }
};
ErrorContext errctx = create_testing_errctx();
ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray);
assert_errctx_no_error(&errctx);
int32_t expected_shape[dst_ndims] = { 2, 2 };
int32_t expected_strides[dst_ndims] = { 32, 16 };
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
// dst_ndarray[0, 0]
assert_values_match(
5.0,
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0, 0 }))
);
// dst_ndarray[0, 1]
assert_values_match(
7.0,
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0, 1 }))
);
// dst_ndarray[1, 0]
assert_values_match(
9.0,
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1, 0 }))
);
// dst_ndarray[1, 1]
assert_values_match(
11.0,
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1, 1 }))
);
}
void test_ndsubscript_normal_2() {
/*
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[2, ::-2]
# array([11., 9.])
assert dst_ndarray.shape == (2,)
assert dst_ndarray.strides == (-16,)
assert dst_ndarray[0] == 11.0
assert dst_ndarray[1] == 9.0
```
*/
BEGIN_TEST();
double src_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
int32_t src_itemsize = sizeof(double);
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = { 3, 4 };
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {
.data = (uint8_t*) src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides
};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Destination ndarray
// As documented, ndims and shape & strides must be allocated and determined by the caller.
const int32_t dst_ndims = 1;
int32_t dst_shape[dst_ndims] = {999}; // Empty values
int32_t dst_strides[dst_ndims] = {999}; // Empty values
NDArray<int32_t> dst_ndarray = {
.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
// Create the slice in `ndarray[2, ::-2]`
int32_t subscript_1 = 2;
UserSlice subscript_2;
subscript_2.set_step(-2);
const int32_t num_ndsubscripts = 2;
NDSubscript ndsubscripts[num_ndsubscripts] = {
{ .type = INPUT_SUBSCRIPT_TYPE_INDEX, .data = (uint8_t*) &subscript_1 },
{ .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_2 }
};
ErrorContext errctx = create_testing_errctx();
ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray);
assert_errctx_no_error(&errctx);
int32_t expected_shape[dst_ndims] = { 2 };
int32_t expected_strides[dst_ndims] = { -16 };
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match(
11.0,
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0 }))
);
assert_values_match(
9.0,
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1 }))
);
}
void run() {
test_ndsubscript_normal_1();
test_ndsubscript_normal_2();
}
} }