#pragma once #include namespace test { namespace ndarray_broadcast { void test_can_broadcast_shape() { BEGIN_TEST(); assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( 1, (int32_t[]){3}, 5, (int32_t[]){1, 1, 1, 1, 3})); assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( 1, (int32_t[]){3}, 2, (int32_t[]){3, 1})); assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( 1, (int32_t[]){3}, 1, (int32_t[]){3})); assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( 1, (int32_t[]){1}, 1, (int32_t[]){3})); assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( 1, (int32_t[]){1}, 1, (int32_t[]){1})); assert_values_match( true, ndarray::broadcast::util::can_broadcast_shape_to( 3, (int32_t[]){256, 256, 3}, 3, (int32_t[]){256, 1, 3})); assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( 3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){3})); assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( 3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){2})); assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( 3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){1})); // In cases when the shapes contain zero(es) assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( 1, (int32_t[]){0}, 1, (int32_t[]){1})); assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( 1, (int32_t[]){0}, 1, (int32_t[]){2})); assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( 4, (int32_t[]){0, 4, 0, 0}, 1, (int32_t[]){1})); assert_values_match( true, ndarray::broadcast::util::can_broadcast_shape_to( 4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 1, 1, 1})); assert_values_match( true, ndarray::broadcast::util::can_broadcast_shape_to( 4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 4, 1, 1})); assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( 2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 3})); assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( 2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 0})); } void test_ndarray_broadcast() { /* # array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64) # >>> [[19.9 29.9 39.9 49.9]] # # array = np.broadcast_to(array, (2, 3, 4)) # >>> [[[19.9 29.9 39.9 49.9] # >>> [19.9 29.9 39.9 49.9] # >>> [19.9 29.9 39.9 49.9]] # >>> [[19.9 29.9 39.9 49.9] # >>> [19.9 29.9 39.9 49.9] # >>> [19.9 29.9 39.9 49.9]]] # # assery array.strides == (0, 0, 8) */ BEGIN_TEST(); double in_data[4] = {19.9, 29.9, 39.9, 49.9}; const int32_t in_ndims = 2; int32_t in_shape[in_ndims] = {1, 4}; int32_t in_strides[in_ndims] = {}; NDArray ndarray = {.data = (uint8_t*)in_data, .itemsize = sizeof(double), .ndims = in_ndims, .shape = in_shape, .strides = in_strides}; ndarray::basic::set_strides_by_shape(&ndarray); const int32_t dst_ndims = 3; int32_t dst_shape[dst_ndims] = {2, 3, 4}; int32_t dst_strides[dst_ndims] = {}; NDArray dst_ndarray = { .ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides}; ndarray::broadcast::broadcast_to(&ndarray, &dst_ndarray); assert_arrays_match(dst_ndims, ((int32_t[]){0, 0, 8}), dst_ndarray.strides); assert_values_match(19.9, *((double*)ndarray::basic::get_pelement_by_indices( &dst_ndarray, ((int32_t[]){0, 0, 0})))); assert_values_match(29.9, *((double*)ndarray::basic::get_pelement_by_indices( &dst_ndarray, ((int32_t[]){0, 0, 1})))); assert_values_match(39.9, *((double*)ndarray::basic::get_pelement_by_indices( &dst_ndarray, ((int32_t[]){0, 0, 2})))); assert_values_match(49.9, *((double*)ndarray::basic::get_pelement_by_indices( &dst_ndarray, ((int32_t[]){0, 0, 3})))); assert_values_match(19.9, *((double*)ndarray::basic::get_pelement_by_indices( &dst_ndarray, ((int32_t[]){0, 1, 0})))); assert_values_match(29.9, *((double*)ndarray::basic::get_pelement_by_indices( &dst_ndarray, ((int32_t[]){0, 1, 1})))); assert_values_match(39.9, *((double*)ndarray::basic::get_pelement_by_indices( &dst_ndarray, ((int32_t[]){0, 1, 2})))); assert_values_match(49.9, *((double*)ndarray::basic::get_pelement_by_indices( &dst_ndarray, ((int32_t[]){0, 1, 3})))); assert_values_match(49.9, *((double*)ndarray::basic::get_pelement_by_indices( &dst_ndarray, ((int32_t[]){1, 2, 3})))); } void run() { test_can_broadcast_shape(); test_ndarray_broadcast(); } } // namespace ndarray_broadcast } // namespace test