core: irrt add unchecked ndarray broadcasting

This commit is contained in:
lyken 2024-07-15 00:49:07 +08:00
parent cc8103152f
commit 61dd9762d8

View File

@ -54,5 +54,63 @@ namespace ndarray {
return true;
}
}
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
// Assumptions:
// - `src_ndarray` has to be fully initialized.
// - `dst_ndarray->ndims` has to be set.
// - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to.
//
// Other notes:
// - `dst_ndarray->data` does not have to be set, it will be set to `src_ndarray->data`.
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->data`.
// - `dst_ndarray->strides` does not have to be set, it will be overwritten.
//
// Cautions:
// ```
// xs = np.zeros((4,))
// ys = np.zero((4, 1))
// ys[:] = xs # ok
//
// xs = np.zeros((1, 4))
// ys = np.zero((4,))
// ys[:] = xs # allowed
// # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
// # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
// # This implementation will NOT support this assignment.
// ```
template <typename SizeT>
void broadcast_to(NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// irrt_assert(
// ndarray_util::can_broadcast_shape_to(
// dst_ndarray->ndims,
// dst_ndarray->shape,
// src_ndarray->ndims,
// src_ndarray->shape
// )
// );
SizeT stride_product = 1;
for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) {
SizeT this_dim_i = src_ndarray->ndims - i - 1;
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
bool this_dim_exists = this_dim_i >= 0;
bool dst_dim_exists = dst_dim_i >= 0;
// TODO: Explain how this works
bool c1 = this_dim_exists && src_ndarray->shape[this_dim_i] == 1;
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
if (!this_dim_exists || (c1 && c2)) {
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
} else {
dst_ndarray->strides[dst_dim_i] = stride_product * src_ndarray->itemsize;
stride_product *= src_ndarray->shape[this_dim_i]; // NOTE: this_dim_exist must be true here.
}
}
}
}
}