#pragma once #include namespace test { namespace ndarray_indexing { void test_normal_1() { /* Reference Python code: ```python ndarray = np.arange(12, dtype=np.float64).reshape((3, 4)); # array([[ 0., 1., 2., 3.], # [ 4., 5., 6., 7.], # [ 8., 9., 10., 11.]]) dst_ndarray = ndarray[-2:, 1::2] # array([[ 5., 7.], # [ 9., 11.]]) assert dst_ndarray.shape == (2, 2) assert dst_ndarray.strides == (32, 16) assert dst_ndarray[0, 0] == 5.0 assert dst_ndarray[0, 1] == 7.0 assert dst_ndarray[1, 0] == 9.0 assert dst_ndarray[1, 1] == 11.0 ``` */ BEGIN_TEST(); // Prepare src_ndarray double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0}; int64_t src_itemsize = sizeof(double); const int64_t src_ndims = 2; int64_t src_shape[src_ndims] = {3, 4}; int64_t src_strides[src_ndims] = {}; NDArray src_ndarray = {.data = (uint8_t *)src_data, .itemsize = src_itemsize, .ndims = src_ndims, .shape = src_shape, .strides = src_strides}; ndarray::basic::set_strides_by_shape(&src_ndarray); // Prepare dst_ndarray const int64_t dst_ndims = 2; int64_t dst_shape[dst_ndims] = {999, 999}; // Empty values int64_t dst_strides[dst_ndims] = {999, 999}; // Empty values NDArray dst_ndarray = {.data = nullptr, .ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides}; // Create the subscripts in `ndarray[-2::, 1::2]` UserSlice subscript_1; subscript_1.set_start(-2); UserSlice subscript_2; subscript_2.set_start(1); subscript_2.set_step(2); const int64_t num_indexes = 2; NDIndex indexes[num_indexes] = { {.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_1}, {.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}}; ndarray::indexing::index(num_indexes, indexes, &src_ndarray, &dst_ndarray); int64_t expected_shape[dst_ndims] = {2, 2}; int64_t expected_strides[dst_ndims] = {32, 16}; assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape); assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides); // dst_ndarray[0, 0] assert_values_match(5.0, *((double *)ndarray::basic::get_pelement_by_indices( &dst_ndarray, (int64_t[dst_ndims]){0, 0}))); // dst_ndarray[0, 1] assert_values_match(7.0, *((double *)ndarray::basic::get_pelement_by_indices( &dst_ndarray, (int64_t[dst_ndims]){0, 1}))); // dst_ndarray[1, 0] assert_values_match(9.0, *((double *)ndarray::basic::get_pelement_by_indices( &dst_ndarray, (int64_t[dst_ndims]){1, 0}))); // dst_ndarray[1, 1] assert_values_match(11.0, *((double *)ndarray::basic::get_pelement_by_indices( &dst_ndarray, (int64_t[dst_ndims]){1, 1}))); } void test_normal_2() { /* ```python ndarray = np.arange(12, dtype=np.float64).reshape((3, 4)) # array([[ 0., 1., 2., 3.], # [ 4., 5., 6., 7.], # [ 8., 9., 10., 11.]]) dst_ndarray = ndarray[2, ::-2] # array([11., 9.]) assert dst_ndarray.shape == (2,) assert dst_ndarray.strides == (-16,) assert dst_ndarray[0] == 11.0 assert dst_ndarray[1] == 9.0 ``` */ BEGIN_TEST(); // Prepare src_ndarray double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0}; int64_t src_itemsize = sizeof(double); const int64_t src_ndims = 2; int64_t src_shape[src_ndims] = {3, 4}; int64_t src_strides[src_ndims] = {}; NDArray src_ndarray = {.data = (uint8_t *)src_data, .itemsize = src_itemsize, .ndims = src_ndims, .shape = src_shape, .strides = src_strides}; ndarray::basic::set_strides_by_shape(&src_ndarray); // Prepare dst_ndarray const int64_t dst_ndims = 1; int64_t dst_shape[dst_ndims] = {999}; // Empty values int64_t dst_strides[dst_ndims] = {999}; // Empty values NDArray dst_ndarray = {.data = nullptr, .ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides}; // Create the subscripts in `ndarray[2, ::-2]` int64_t subscript_1 = 2; UserSlice subscript_2; subscript_2.set_step(-2); const int64_t num_indexes = 2; NDIndex indexes[num_indexes] = { {.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1}, {.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}}; ndarray::indexing::index(num_indexes, indexes, &src_ndarray, &dst_ndarray); int64_t expected_shape[dst_ndims] = {2}; int64_t expected_strides[dst_ndims] = {-16}; assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape); assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides); assert_values_match(11.0, *((double *)ndarray::basic::get_pelement_by_indices( &dst_ndarray, (int64_t[dst_ndims]){0}))); assert_values_match(9.0, *((double *)ndarray::basic::get_pelement_by_indices( &dst_ndarray, (int64_t[dst_ndims]){1}))); } void run() { test_normal_1(); test_normal_2(); } } // namespace ndarray_indexing } // namespace test