From 9aae29072795e5ac86a8bdbfbd5d819947c4f38f Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 10 Jul 2024 17:05:01 +0800 Subject: [PATCH] core: irrt general numpy broadcasting --- nac3core/irrt/irrt_numpy_ndarray.hpp | 140 +++++++++++++++- nac3core/irrt/irrt_test.cpp | 237 ++++++++++++++++++++++++++- nac3core/irrt/irrt_utils.hpp | 1 + 3 files changed, 364 insertions(+), 14 deletions(-) diff --git a/nac3core/irrt/irrt_numpy_ndarray.hpp b/nac3core/irrt/irrt_numpy_ndarray.hpp index 38f62754..8e1f1d50 100644 --- a/nac3core/irrt/irrt_numpy_ndarray.hpp +++ b/nac3core/irrt/irrt_numpy_ndarray.hpp @@ -13,6 +13,17 @@ using NDIndex = uint32_t; namespace { namespace ndarray_util { + template + static void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) { + for (int32_t i = 0; i < ndims; i++) { + int32_t dim_i = ndims - i - 1; + int32_t dim = shape[dim_i]; + + indices[dim_i] = nth % dim; + nth /= dim; + } + } + // Compute the strides of an ndarray given an ndarray `shape` // and assuming that the ndarray is *fully C-contagious*. // @@ -34,6 +45,57 @@ namespace { for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i]; return size; } + + template + static bool can_broadcast_shape_to( + const SizeT target_ndims, + const SizeT *target_shape, + const SizeT src_ndims, + const SizeT *src_shape + ) { + /* + // See https://numpy.org/doc/stable/user/basics.broadcasting.html + + This function handles this example: + ``` + Image (3d array): 256 x 256 x 3 + Scale (1d array): 3 + Result (3d array): 256 x 256 x 3 + ``` + + Other interesting examples to consider: + - `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true` + - `can_broadcast_shape_to([3], [3, 1]) == false` + - `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true` + + In cases when the shapes contain zero(es): + - `can_broadcast_shape_to([0], [1]) == true` + - `can_broadcast_shape_to([0], [2]) == false` + - `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true` + - `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true` + - `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true` + - `can_broadcast_shape_to([4, 3], [0, 3]) == false` + - `can_broadcast_shape_to([4, 3], [0, 0]) == false` + */ + + // This is essentially doing the following in Python: + // `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)` + for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) { + SizeT target_dim_i = target_ndims - i - 1; + SizeT src_dim_i = src_ndims - i - 1; + + bool target_dim_exists = target_dim_i >= 0; + bool src_dim_exists = src_dim_i >= 0; + + SizeT target_dim = target_dim_exists ? target_shape[target_dim_i] : 1; + SizeT src_dim = src_dim_exists ? src_shape[src_dim_i] : 1; + + bool ok = src_dim == 1 || target_dim == src_dim; + if (!ok) return false; + } + + return true; + } } typedef uint8_t NDSliceType; @@ -55,7 +117,7 @@ namespace { namespace ndarray_util { template - SizeT deduce_ndims_after_slicing(SizeT ndims, const SizeT num_slices, const NDSlice *slices) { + SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_slices, const NDSlice *slices) { irrt_assert(num_slices <= ndims); SizeT final_ndims = ndims; @@ -150,17 +212,26 @@ namespace { return this->size() * itemsize; } - void set_value_at_pelement(uint8_t* pelement, uint8_t* pvalue) { + void set_value_at_pelement(uint8_t* pelement, const uint8_t* pvalue) { __builtin_memcpy(pelement, pvalue, itemsize); } - uint8_t* get_pelement(SizeT *indices) { + uint8_t* get_pelement(const SizeT *indices) { uint8_t* element = data; for (SizeT dim_i = 0; dim_i < ndims; dim_i++) element += indices[dim_i] * strides[dim_i]; return element; } + uint8_t* get_nth_pelement(SizeT nth) { + irrt_assert(0 <= nth); + irrt_assert(nth < this->size()); + + SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims); + ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth); + return get_pelement(indices); + } + // Get pointer to the first element of this ndarray, assuming // `this->size() > 0`, i.e., not "degenerate" due to zeroes in `this->shape`) // @@ -171,7 +242,7 @@ namespace { } // Is the given `indices` valid/in-bounds? - bool in_bounds(SizeT *indices) { + bool in_bounds(const SizeT *indices) { for (SizeT dim_i = 0; dim_i < ndims; dim_i++) { bool dim_ok = indices[dim_i] < shape[dim_i]; if (!dim_ok) return false; @@ -180,7 +251,7 @@ namespace { } // Fill the ndarray with a value - void fill_generic(uint8_t* pvalue) { + void fill_generic(const uint8_t* pvalue) { NDArrayIndicesIter iter; iter.ndims = this->ndims; iter.shape = this->shape; @@ -199,7 +270,7 @@ namespace { } // https://numpy.org/doc/stable/reference/generated/numpy.eye.html - void set_to_eye(SizeT k, uint8_t* zero_pvalue, uint8_t* one_pvalue) { + void set_to_eye(SizeT k, const uint8_t* zero_pvalue, const uint8_t* one_pvalue) { __builtin_assume(ndims == 2); // TODO: Better implementation @@ -275,6 +346,63 @@ namespace { irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation } + + // Similar to `np.broadcast_to(, )` + // Assumptions: + // - `this` has to be fully initialized. + // - `dst_ndarray->ndims` has to be set. + // - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to. + // + // Other notes: + // - `dst_ndarray->data` does not have to be set, it will be set to `this->data`. + // - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->data`. + // - `dst_ndarray->strides` does not have to be set, it will be overwritten. + // + // Cautions: + // ``` + // xs = np.zeros((4,)) + // ys = np.zero((4, 1)) + // ys[:] = xs # ok + // + // xs = np.zeros((1, 4)) + // ys = np.zero((4,)) + // ys[:] = xs # allowed + // # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule. + // # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744 + // # This implementation will NOT support this assignment. + // ``` + void broadcast_to(NDArray* dst_ndarray) { + dst_ndarray->data = this->data; + dst_ndarray->itemsize = this->itemsize; + + irrt_assert( + ndarray_util::can_broadcast_shape_to( + dst_ndarray->ndims, + dst_ndarray->shape, + this->ndims, + this->shape + ) + ); + + SizeT stride_product = 1; + for (SizeT i = 0; i < max(this->ndims, dst_ndarray->ndims); i++) { + SizeT this_dim_i = this->ndims - i - 1; + SizeT dst_dim_i = dst_ndarray->ndims - i - 1; + + bool this_dim_exists = this_dim_i >= 0; + bool dst_dim_exists = dst_dim_i >= 0; + + // TODO: Explain how this works + bool c1 = this_dim_exists && this->shape[this_dim_i] == 1; + bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1; + if (!this_dim_exists || (c1 && c2)) { + dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place + } else { + dst_ndarray->strides[dst_dim_i] = stride_product * this->itemsize; + stride_product *= this->shape[this_dim_i]; // NOTE: this_dim_exist must be true here. + } + } + } }; } diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index 1dda0b0d..f6e67ff3 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -33,10 +33,11 @@ void debug_print_array(const char* format, int len, T* as) { template void assert_arrays_match(const char* label, const char* format, int len, T* expected, T* got) { if (!arrays_match(len, expected, got)) { - printf("expected %s: ", label); + printf(">>>>>>> %s\n", label); + printf(" Expecting = "); debug_print_array(format, len, expected); printf("\n"); - printf("got %s: ", label); + printf(" Got = "); debug_print_array(format, len, got); printf("\n"); test_fail(); @@ -46,22 +47,89 @@ void assert_arrays_match(const char* label, const char* format, int len, T* expe template void assert_values_match(const char* label, const char* format, T expected, T got) { if (expected != got) { - printf("expected %s: ", label); + printf(">>>>>>> %s\n", label); + printf(" Expecting = "); printf(format, expected); printf("\n"); - printf("got %s: ", label); + printf(" Got = "); printf(format, got); printf("\n"); test_fail(); } } +void print_repeated(const char *str, int count) { + for (int i = 0; i < count; i++) { + printf("%s", str); + } +} + +template +void __print_ndarray_aux(const char *format, bool first, bool last, SizeT* cursor, SizeT depth, NDArray* ndarray) { + // A really lazy recursive implementation + + // Add left padding unless its the first entry (since there would be "[[[" before it) + if (!first) { + print_repeated(" ", depth); + } + + const SizeT dim = ndarray->shape[depth]; + if (depth + 1 == ndarray->ndims) { + // Recursed down to last dimension, print the values in a nice list + printf("["); + + SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims); + for (SizeT i = 0; i < dim; i++) { + ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor); + ElementT* pelement = (ElementT*) ndarray->get_pelement(indices); + ElementT element = *pelement; + + if (i != 0) printf(", "); // List delimiter + printf(format, element); + printf("(@"); + debug_print_array("%d", ndarray->ndims, indices); + printf(")"); + + (*cursor)++; + } + printf("]"); + } else { + printf("["); + for (SizeT i = 0; i < ndarray->shape[depth]; i++) { + __print_ndarray_aux( + format, + i == 0, // first? + i + 1 == dim, // last? + cursor, + depth + 1, + ndarray + ); + } + printf("]"); + } + + // Add newline unless its the last entry (since there will be "]]]" after it) + if (!last) { + print_repeated("\n", depth); + } +} + +template +void print_ndarray(const char *format, NDArray* ndarray) { + if (ndarray->ndims == 0) { + printf(""); + } else { + SizeT cursor = 0; + __print_ndarray_aux(format, true, true, &cursor, 0, ndarray); + } + printf("\n"); +} + void test_calc_size_from_shape_normal() { // Test shapes with normal values BEGIN_TEST(); int32_t shape[4] = { 2, 3, 5, 7 }; - debug_print_array("%d", 4, shape); assert_values_match("size", "%d", 210, ndarray_util::calc_size_from_shape(4, shape)); } @@ -267,9 +335,6 @@ void test_ndslice_1() { assert dst_ndarray[0, 1] == 7.0 assert dst_ndarray[1, 0] == 9.0 assert dst_ndarray[1, 1] == 11.0 - - dst_ndarray[1, 0] == 99 # Write to `dst_ndarray` - assert ndarray[1, 3] == 99 # `ndarray` also updates!! ``` */ BEGIN_TEST(); @@ -410,6 +475,160 @@ void test_ndslice_2() { assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 }))); } +void test_can_broadcast_shape() { + BEGIN_TEST(); + + assert_values_match( + "can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 5, (int32_t[]) { 1, 1, 1, 1, 3 }) + ); + assert_values_match( + "can_broadcast_shape_to([3], [3, 1]) == false", + "%d", + false, + ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 2, (int32_t[]) { 3, 1 })); + assert_values_match( + "can_broadcast_shape_to([3], [3]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 1, (int32_t[]) { 3 })); + assert_values_match( + "can_broadcast_shape_to([1], [3]) == false", + "%d", + false, + ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 3 })); + assert_values_match( + "can_broadcast_shape_to([1], [1]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 1 })); + assert_values_match( + "can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 3, (int32_t[]) { 256, 1, 3 }) + ); + assert_values_match( + "can_broadcast_shape_to([256, 256, 3], [3]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 3 }) + ); + assert_values_match( + "can_broadcast_shape_to([256, 256, 3], [2]) == false", + "%d", + false, + ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 2 }) + ); + assert_values_match( + "can_broadcast_shape_to([256, 256, 3], [1]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 1 }) + ); + + // In cases when the shapes contain zero(es) + assert_values_match( + "can_broadcast_shape_to([0], [1]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 1 }) + ); + assert_values_match( + "can_broadcast_shape_to([0], [2]) == false", + "%d", + false, + ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 2 }) + ); + assert_values_match( + "can_broadcast_shape_to([0, 4, 0, 0], [1]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 1, (int32_t[]) { 1 }) + ); + assert_values_match( + "can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 1, 1, 1 }) + ); + assert_values_match( + "can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true", + "%d", + true, + ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 4, 1, 1 }) + ); + assert_values_match( + "can_broadcast_shape_to([4, 3], [0, 3]) == false", + "%d", + false, + ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 3 }) + ); + assert_values_match( + "can_broadcast_shape_to([4, 3], [0, 0]) == false", + "%d", + false, + ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 0 }) + ); +} + +void test_ndarray_broadcast_1() { + /* + # array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64) + # >>> [[19.9 29.9 39.9 49.9]] + # + # array = np.broadcast_to(array, (2, 3, 4)) + # >>> [[[19.9 29.9 39.9 49.9] + # >>> [19.9 29.9 39.9 49.9] + # >>> [19.9 29.9 39.9 49.9]] + # >>> [[19.9 29.9 39.9 49.9] + # >>> [19.9 29.9 39.9 49.9] + # >>> [19.9 29.9 39.9 49.9]]] + # + # assery array.strides == (0, 0, 8) + + */ + BEGIN_TEST(); + + double in_data[4] = { 19.9, 29.9, 39.9, 49.9 }; + const int32_t in_ndims = 2; + int32_t in_shape[in_ndims] = {1, 4}; + int32_t in_strides[in_ndims] = {}; + NDArray ndarray = { + .data = (uint8_t*) in_data, + .itemsize = sizeof(double), + .ndims = in_ndims, + .shape = in_shape, + .strides = in_strides + }; + ndarray.set_strides_by_shape(); + + const int32_t dst_ndims = 3; + int32_t dst_shape[dst_ndims] = {2, 3, 4}; + int32_t dst_strides[dst_ndims] = {}; + NDArray dst_ndarray = { + .ndims = dst_ndims, + .shape = dst_shape, + .strides = dst_strides + }; + + ndarray.broadcast_to(&dst_ndarray); + + assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides); + + assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 0}))); + assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 1}))); + assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 2}))); + assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 3}))); + assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 0}))); + assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 1}))); + assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 2}))); + assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 3}))); + assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {1, 2, 3}))); +} + int main() { test_calc_size_from_shape_normal(); test_calc_size_from_shape_has_zero(); @@ -423,5 +642,7 @@ int main() { test_slice_4(); test_ndslice_1(); test_ndslice_2(); + test_can_broadcast_shape(); + test_ndarray_broadcast_1(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/irrt_utils.hpp b/nac3core/irrt/irrt_utils.hpp index 033b995e..8d69b6a1 100644 --- a/nac3core/irrt/irrt_utils.hpp +++ b/nac3core/irrt/irrt_utils.hpp @@ -30,6 +30,7 @@ namespace { *death = 0; // TODO: address 0 on hardware might be writable? } + // TODO: Make this a macro and allow it to be toggled on/off (e.g., debug vs release) void irrt_assert(bool condition) { if (!condition) irrt_panic(); }