#pragma once #include #include namespace test { namespace ndarray_broadcast { void test_ndarray_broadcast_1() { /* ```python array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64) >>> [[19.9 29.9 39.9 49.9]] array = np.broadcast_to(array, (2, 3, 4)) >>> [[[19.9 29.9 39.9 49.9] >>> [19.9 29.9 39.9 49.9] >>> [19.9 29.9 39.9 49.9]] >>> [[19.9 29.9 39.9 49.9] >>> [19.9 29.9 39.9 49.9] >>> [19.9 29.9 39.9 49.9]]] assert array.strides == (0, 0, 8) # and then pick some values in `array` and check them... ``` */ BEGIN_TEST(); // Prepare src_ndarray double src_data[4] = { 19.9, 29.9, 39.9, 49.9 }; const int32_t src_ndims = 2; int32_t src_shape[src_ndims] = {1, 4}; int32_t src_strides[src_ndims] = {}; NDArray src_ndarray = { .data = (uint8_t*) src_data, .itemsize = sizeof(double), .ndims = src_ndims, .shape = src_shape, .strides = src_strides }; ndarray::basic::set_strides_by_shape(&src_ndarray); // Prepare dst_ndarray const int32_t dst_ndims = 3; int32_t dst_shape[dst_ndims] = {2, 3, 4}; int32_t dst_strides[dst_ndims] = {}; NDArray dst_ndarray = { .ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides }; // Broadcast ErrorContext errctx = create_testing_errctx(); ndarray::broadcast::broadcast_to(&errctx, &src_ndarray, &dst_ndarray); assert_errctx_no_error(&errctx); assert_arrays_match(dst_ndims, ((int32_t[]) { 0, 0, 8 }), dst_ndarray.strides); assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 0})))); assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 1})))); assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 2})))); assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 3})))); assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 0})))); assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 1})))); assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 2})))); assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 3})))); assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {1, 2, 3})))); } void run() { test_ndarray_broadcast_1(); } }}