#pragma once #include #include namespace test { namespace ndarray_subscript { void test_ndsubscript_normal_1() { /* Reference Python code: ```python ndarray = np.arange(12, dtype=np.float64).reshape((3, 4)); # array([[ 0., 1., 2., 3.], # [ 4., 5., 6., 7.], # [ 8., 9., 10., 11.]]) dst_ndarray = ndarray[-2:, 1::2] # array([[ 5., 7.], # [ 9., 11.]]) assert dst_ndarray.shape == (2, 2) assert dst_ndarray.strides == (32, 16) assert dst_ndarray[0, 0] == 5.0 assert dst_ndarray[0, 1] == 7.0 assert dst_ndarray[1, 0] == 9.0 assert dst_ndarray[1, 1] == 11.0 ``` */ BEGIN_TEST(); // Prepare src_ndarray double src_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 }; int32_t src_itemsize = sizeof(double); const int32_t src_ndims = 2; int32_t src_shape[src_ndims] = { 3, 4 }; int32_t src_strides[src_ndims] = {}; NDArray src_ndarray = { .data = (uint8_t*) src_data, .itemsize = src_itemsize, .ndims = src_ndims, .shape = src_shape, .strides = src_strides }; ndarray::basic::set_strides_by_shape(&src_ndarray); // Prepare dst_ndarray const int32_t dst_ndims = 2; int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values NDArray dst_ndarray = { .data = nullptr, .ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides }; // Create the subscripts in `ndarray[-2::, 1::2]` UserSlice subscript_1; subscript_1.set_start(-2); UserSlice subscript_2; subscript_2.set_start(1); subscript_2.set_step(2); const int32_t num_ndsubscripts = 2; NDSubscript ndsubscripts[num_ndsubscripts] = { { .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_1 }, { .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_2 } }; ErrorContext errctx = create_testing_errctx(); ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray); assert_errctx_no_error(&errctx); int32_t expected_shape[dst_ndims] = { 2, 2 }; int32_t expected_strides[dst_ndims] = { 32, 16 }; assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape); assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides); // dst_ndarray[0, 0] assert_values_match( 5.0, *((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0, 0 })) ); // dst_ndarray[0, 1] assert_values_match( 7.0, *((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0, 1 })) ); // dst_ndarray[1, 0] assert_values_match( 9.0, *((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1, 0 })) ); // dst_ndarray[1, 1] assert_values_match( 11.0, *((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1, 1 })) ); } void test_ndsubscript_normal_2() { /* ```python ndarray = np.arange(12, dtype=np.float64).reshape((3, 4)) # array([[ 0., 1., 2., 3.], # [ 4., 5., 6., 7.], # [ 8., 9., 10., 11.]]) dst_ndarray = ndarray[2, ::-2] # array([11., 9.]) assert dst_ndarray.shape == (2,) assert dst_ndarray.strides == (-16,) assert dst_ndarray[0] == 11.0 assert dst_ndarray[1] == 9.0 ``` */ BEGIN_TEST(); // Prepare src_ndarray double src_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 }; int32_t src_itemsize = sizeof(double); const int32_t src_ndims = 2; int32_t src_shape[src_ndims] = { 3, 4 }; int32_t src_strides[src_ndims] = {}; NDArray src_ndarray = { .data = (uint8_t*) src_data, .itemsize = src_itemsize, .ndims = src_ndims, .shape = src_shape, .strides = src_strides }; ndarray::basic::set_strides_by_shape(&src_ndarray); // Prepare dst_ndarray const int32_t dst_ndims = 1; int32_t dst_shape[dst_ndims] = {999}; // Empty values int32_t dst_strides[dst_ndims] = {999}; // Empty values NDArray dst_ndarray = { .data = nullptr, .ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides }; // Create the subscripts in `ndarray[2, ::-2]` int32_t subscript_1 = 2; UserSlice subscript_2; subscript_2.set_step(-2); const int32_t num_ndsubscripts = 2; NDSubscript ndsubscripts[num_ndsubscripts] = { { .type = INPUT_SUBSCRIPT_TYPE_INDEX, .data = (uint8_t*) &subscript_1 }, { .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_2 } }; ErrorContext errctx = create_testing_errctx(); ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray); assert_errctx_no_error(&errctx); int32_t expected_shape[dst_ndims] = { 2 }; int32_t expected_strides[dst_ndims] = { -16 }; assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape); assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides); assert_values_match( 11.0, *((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0 })) ); assert_values_match( 9.0, *((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1 })) ); } void test_ndsubscript_index_subscript_out_of_bounds() { /* # Consider `my_array` print(my_array.shape) # (4, 5, 6) my_array[2, 100] # error, index subscript at axis 1 is out of bounds */ BEGIN_TEST(); // Prepare src_ndarray const int32_t src_ndims = 2; int32_t src_shape[src_ndims] = { 3, 4 }; int32_t src_strides[src_ndims] = {}; NDArray src_ndarray = { .data = (uint8_t*) nullptr, // placeholder, we wouldn't access it .itemsize = sizeof(double), // placeholder .ndims = src_ndims, .shape = src_shape, .strides = src_strides }; ndarray::basic::set_strides_by_shape(&src_ndarray); // Create the subscripts in `my_array[2, 100]` int32_t subscript_1 = 2; int32_t subscript_2 = 100; const int32_t num_ndsubscripts = 2; NDSubscript ndsubscripts[num_ndsubscripts] = { { .type = INPUT_SUBSCRIPT_TYPE_INDEX, .data = (uint8_t*) &subscript_1 }, { .type = INPUT_SUBSCRIPT_TYPE_INDEX, .data = (uint8_t*) &subscript_2 } }; // Prepare dst_ndarray const int32_t dst_ndims = 0; int32_t dst_shape[dst_ndims] = {}; int32_t dst_strides[dst_ndims] = {}; NDArray dst_ndarray = { .data = nullptr, // placehloder .ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides }; ErrorContext errctx = create_testing_errctx(); ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray); assert_errctx_has_error(&errctx, errctx.error_ids->index_error); } void run() { test_ndsubscript_normal_1(); test_ndsubscript_normal_2(); test_ndsubscript_index_subscript_out_of_bounds(); } } }