forked from M-Labs/nac3
233 lines
8.2 KiB
C++
233 lines
8.2 KiB
C++
#pragma once
|
|
|
|
#include <test/core.hpp>
|
|
#include <irrt_everything.hpp>
|
|
|
|
namespace test { namespace ndarray_subscript {
|
|
void test_ndsubscript_normal_1() {
|
|
/*
|
|
Reference Python code:
|
|
```python
|
|
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
|
|
# array([[ 0., 1., 2., 3.],
|
|
# [ 4., 5., 6., 7.],
|
|
# [ 8., 9., 10., 11.]])
|
|
|
|
dst_ndarray = ndarray[-2:, 1::2]
|
|
# array([[ 5., 7.],
|
|
# [ 9., 11.]])
|
|
|
|
assert dst_ndarray.shape == (2, 2)
|
|
assert dst_ndarray.strides == (32, 16)
|
|
assert dst_ndarray[0, 0] == 5.0
|
|
assert dst_ndarray[0, 1] == 7.0
|
|
assert dst_ndarray[1, 0] == 9.0
|
|
assert dst_ndarray[1, 1] == 11.0
|
|
```
|
|
*/
|
|
BEGIN_TEST();
|
|
|
|
// Prepare src_ndarray
|
|
double src_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
|
int32_t src_itemsize = sizeof(double);
|
|
const int32_t src_ndims = 2;
|
|
int32_t src_shape[src_ndims] = { 3, 4 };
|
|
int32_t src_strides[src_ndims] = {};
|
|
NDArray<int32_t> src_ndarray = {
|
|
.data = (uint8_t*) src_data,
|
|
.itemsize = src_itemsize,
|
|
.ndims = src_ndims,
|
|
.shape = src_shape,
|
|
.strides = src_strides
|
|
};
|
|
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
|
|
|
// Prepare dst_ndarray
|
|
const int32_t dst_ndims = 2;
|
|
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
|
|
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
|
|
NDArray<int32_t> dst_ndarray = {
|
|
.data = nullptr,
|
|
.ndims = dst_ndims,
|
|
.shape = dst_shape,
|
|
.strides = dst_strides
|
|
};
|
|
|
|
// Create the subscripts in `ndarray[-2::, 1::2]`
|
|
UserSlice subscript_1;
|
|
subscript_1.set_start(-2);
|
|
|
|
UserSlice subscript_2;
|
|
subscript_2.set_start(1);
|
|
subscript_2.set_step(2);
|
|
|
|
const int32_t num_ndsubscripts = 2;
|
|
NDSubscript ndsubscripts[num_ndsubscripts] = {
|
|
{ .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_1 },
|
|
{ .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_2 }
|
|
};
|
|
|
|
ErrorContext errctx = create_testing_errctx();
|
|
ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray);
|
|
assert_errctx_no_error(&errctx);
|
|
|
|
int32_t expected_shape[dst_ndims] = { 2, 2 };
|
|
int32_t expected_strides[dst_ndims] = { 32, 16 };
|
|
|
|
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
|
|
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
|
|
|
|
// dst_ndarray[0, 0]
|
|
assert_values_match(
|
|
5.0,
|
|
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0, 0 }))
|
|
);
|
|
// dst_ndarray[0, 1]
|
|
assert_values_match(
|
|
7.0,
|
|
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0, 1 }))
|
|
);
|
|
// dst_ndarray[1, 0]
|
|
assert_values_match(
|
|
9.0,
|
|
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1, 0 }))
|
|
);
|
|
// dst_ndarray[1, 1]
|
|
assert_values_match(
|
|
11.0,
|
|
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1, 1 }))
|
|
);
|
|
}
|
|
|
|
void test_ndsubscript_normal_2() {
|
|
/*
|
|
```python
|
|
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
|
|
# array([[ 0., 1., 2., 3.],
|
|
# [ 4., 5., 6., 7.],
|
|
# [ 8., 9., 10., 11.]])
|
|
|
|
dst_ndarray = ndarray[2, ::-2]
|
|
# array([11., 9.])
|
|
|
|
assert dst_ndarray.shape == (2,)
|
|
assert dst_ndarray.strides == (-16,)
|
|
assert dst_ndarray[0] == 11.0
|
|
assert dst_ndarray[1] == 9.0
|
|
```
|
|
*/
|
|
BEGIN_TEST();
|
|
|
|
// Prepare src_ndarray
|
|
double src_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
|
int32_t src_itemsize = sizeof(double);
|
|
const int32_t src_ndims = 2;
|
|
int32_t src_shape[src_ndims] = { 3, 4 };
|
|
int32_t src_strides[src_ndims] = {};
|
|
NDArray<int32_t> src_ndarray = {
|
|
.data = (uint8_t*) src_data,
|
|
.itemsize = src_itemsize,
|
|
.ndims = src_ndims,
|
|
.shape = src_shape,
|
|
.strides = src_strides
|
|
};
|
|
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
|
|
|
// Prepare dst_ndarray
|
|
const int32_t dst_ndims = 1;
|
|
int32_t dst_shape[dst_ndims] = {999}; // Empty values
|
|
int32_t dst_strides[dst_ndims] = {999}; // Empty values
|
|
NDArray<int32_t> dst_ndarray = {
|
|
.data = nullptr,
|
|
.ndims = dst_ndims,
|
|
.shape = dst_shape,
|
|
.strides = dst_strides
|
|
};
|
|
|
|
// Create the subscripts in `ndarray[2, ::-2]`
|
|
int32_t subscript_1 = 2;
|
|
|
|
UserSlice subscript_2;
|
|
subscript_2.set_step(-2);
|
|
|
|
const int32_t num_ndsubscripts = 2;
|
|
NDSubscript ndsubscripts[num_ndsubscripts] = {
|
|
{ .type = INPUT_SUBSCRIPT_TYPE_INDEX, .data = (uint8_t*) &subscript_1 },
|
|
{ .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_2 }
|
|
};
|
|
|
|
ErrorContext errctx = create_testing_errctx();
|
|
ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray);
|
|
assert_errctx_no_error(&errctx);
|
|
|
|
int32_t expected_shape[dst_ndims] = { 2 };
|
|
int32_t expected_strides[dst_ndims] = { -16 };
|
|
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
|
|
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
|
|
|
|
assert_values_match(
|
|
11.0,
|
|
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0 }))
|
|
);
|
|
assert_values_match(
|
|
9.0,
|
|
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1 }))
|
|
);
|
|
}
|
|
|
|
void test_ndsubscript_index_subscript_out_of_bounds() {
|
|
/*
|
|
# Consider `my_array`
|
|
|
|
print(my_array.shape)
|
|
# (4, 5, 6)
|
|
|
|
my_array[2, 100] # error, index subscript at axis 1 is out of bounds
|
|
*/
|
|
BEGIN_TEST();
|
|
|
|
// Prepare src_ndarray
|
|
const int32_t src_ndims = 2;
|
|
int32_t src_shape[src_ndims] = { 3, 4 };
|
|
int32_t src_strides[src_ndims] = {};
|
|
NDArray<int32_t> src_ndarray = {
|
|
.data = (uint8_t*) nullptr, // placeholder, we wouldn't access it
|
|
.itemsize = sizeof(double), // placeholder
|
|
.ndims = src_ndims,
|
|
.shape = src_shape,
|
|
.strides = src_strides
|
|
};
|
|
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
|
|
|
// Create the subscripts in `my_array[2, 100]`
|
|
int32_t subscript_1 = 2;
|
|
int32_t subscript_2 = 100;
|
|
|
|
const int32_t num_ndsubscripts = 2;
|
|
NDSubscript ndsubscripts[num_ndsubscripts] = {
|
|
{ .type = INPUT_SUBSCRIPT_TYPE_INDEX, .data = (uint8_t*) &subscript_1 },
|
|
{ .type = INPUT_SUBSCRIPT_TYPE_INDEX, .data = (uint8_t*) &subscript_2 }
|
|
};
|
|
|
|
// Prepare dst_ndarray
|
|
const int32_t dst_ndims = 0;
|
|
int32_t dst_shape[dst_ndims] = {};
|
|
int32_t dst_strides[dst_ndims] = {};
|
|
NDArray<int32_t> dst_ndarray = {
|
|
.data = nullptr, // placehloder
|
|
.ndims = dst_ndims,
|
|
.shape = dst_shape,
|
|
.strides = dst_strides
|
|
};
|
|
|
|
ErrorContext errctx = create_testing_errctx();
|
|
ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray);
|
|
assert_errctx_has_error(&errctx, errctx.error_ids->index_error);
|
|
}
|
|
|
|
void run() {
|
|
test_ndsubscript_normal_1();
|
|
test_ndsubscript_normal_2();
|
|
test_ndsubscript_index_subscript_out_of_bounds();
|
|
}
|
|
} } |