forked from M-Labs/nac3
core: irrt add unchecked ndarray broadcasting
This commit is contained in:
parent
cc8103152f
commit
61dd9762d8
@ -54,5 +54,63 @@ namespace ndarray {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
|
||||
// Assumptions:
|
||||
// - `src_ndarray` has to be fully initialized.
|
||||
// - `dst_ndarray->ndims` has to be set.
|
||||
// - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to.
|
||||
//
|
||||
// Other notes:
|
||||
// - `dst_ndarray->data` does not have to be set, it will be set to `src_ndarray->data`.
|
||||
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->data`.
|
||||
// - `dst_ndarray->strides` does not have to be set, it will be overwritten.
|
||||
//
|
||||
// Cautions:
|
||||
// ```
|
||||
// xs = np.zeros((4,))
|
||||
// ys = np.zero((4, 1))
|
||||
// ys[:] = xs # ok
|
||||
//
|
||||
// xs = np.zeros((1, 4))
|
||||
// ys = np.zero((4,))
|
||||
// ys[:] = xs # allowed
|
||||
// # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
|
||||
// # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
|
||||
// # This implementation will NOT support this assignment.
|
||||
// ```
|
||||
template <typename SizeT>
|
||||
void broadcast_to(NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
|
||||
// irrt_assert(
|
||||
// ndarray_util::can_broadcast_shape_to(
|
||||
// dst_ndarray->ndims,
|
||||
// dst_ndarray->shape,
|
||||
// src_ndarray->ndims,
|
||||
// src_ndarray->shape
|
||||
// )
|
||||
// );
|
||||
|
||||
SizeT stride_product = 1;
|
||||
for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) {
|
||||
SizeT this_dim_i = src_ndarray->ndims - i - 1;
|
||||
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
|
||||
|
||||
bool this_dim_exists = this_dim_i >= 0;
|
||||
bool dst_dim_exists = dst_dim_i >= 0;
|
||||
|
||||
// TODO: Explain how this works
|
||||
bool c1 = this_dim_exists && src_ndarray->shape[this_dim_i] == 1;
|
||||
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
|
||||
if (!this_dim_exists || (c1 && c2)) {
|
||||
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
|
||||
} else {
|
||||
dst_ndarray->strides[dst_dim_i] = stride_product * src_ndarray->itemsize;
|
||||
stride_product *= src_ndarray->shape[this_dim_i]; // NOTE: this_dim_exist must be true here.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user