From 61dd9762d82d63b6815901b17ecb9a984c1f5567 Mon Sep 17 00:00:00 2001 From: lyken Date: Mon, 15 Jul 2024 00:49:07 +0800 Subject: [PATCH] core: irrt add unchecked ndarray broadcasting --- .../irrt/irrt/numpy/ndarray_broadcast.hpp | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/nac3core/irrt/irrt/numpy/ndarray_broadcast.hpp b/nac3core/irrt/irrt/numpy/ndarray_broadcast.hpp index 1fae3939..4de4976b 100644 --- a/nac3core/irrt/irrt/numpy/ndarray_broadcast.hpp +++ b/nac3core/irrt/irrt/numpy/ndarray_broadcast.hpp @@ -54,5 +54,63 @@ namespace ndarray { return true; } } + + // Similar to `np.broadcast_to(, )` + // Assumptions: + // - `src_ndarray` has to be fully initialized. + // - `dst_ndarray->ndims` has to be set. + // - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to. + // + // Other notes: + // - `dst_ndarray->data` does not have to be set, it will be set to `src_ndarray->data`. + // - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->data`. + // - `dst_ndarray->strides` does not have to be set, it will be overwritten. + // + // Cautions: + // ``` + // xs = np.zeros((4,)) + // ys = np.zero((4, 1)) + // ys[:] = xs # ok + // + // xs = np.zeros((1, 4)) + // ys = np.zero((4,)) + // ys[:] = xs # allowed + // # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule. + // # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744 + // # This implementation will NOT support this assignment. + // ``` + template + void broadcast_to(NDArray* src_ndarray, NDArray* dst_ndarray) { + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // irrt_assert( + // ndarray_util::can_broadcast_shape_to( + // dst_ndarray->ndims, + // dst_ndarray->shape, + // src_ndarray->ndims, + // src_ndarray->shape + // ) + // ); + + SizeT stride_product = 1; + for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) { + SizeT this_dim_i = src_ndarray->ndims - i - 1; + SizeT dst_dim_i = dst_ndarray->ndims - i - 1; + + bool this_dim_exists = this_dim_i >= 0; + bool dst_dim_exists = dst_dim_i >= 0; + + // TODO: Explain how this works + bool c1 = this_dim_exists && src_ndarray->shape[this_dim_i] == 1; + bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1; + if (!this_dim_exists || (c1 && c2)) { + dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place + } else { + dst_ndarray->strides[dst_dim_i] = stride_product * src_ndarray->itemsize; + stride_product *= src_ndarray->shape[this_dim_i]; // NOTE: this_dim_exist must be true here. + } + } + } } } \ No newline at end of file