forked from M-Labs/nac3
core: irrt add unchecked ndarray broadcasting
This commit is contained in:
parent
cc8103152f
commit
61dd9762d8
|
@ -54,5 +54,63 @@ namespace ndarray {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
|
||||||
|
// Assumptions:
|
||||||
|
// - `src_ndarray` has to be fully initialized.
|
||||||
|
// - `dst_ndarray->ndims` has to be set.
|
||||||
|
// - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to.
|
||||||
|
//
|
||||||
|
// Other notes:
|
||||||
|
// - `dst_ndarray->data` does not have to be set, it will be set to `src_ndarray->data`.
|
||||||
|
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->data`.
|
||||||
|
// - `dst_ndarray->strides` does not have to be set, it will be overwritten.
|
||||||
|
//
|
||||||
|
// Cautions:
|
||||||
|
// ```
|
||||||
|
// xs = np.zeros((4,))
|
||||||
|
// ys = np.zero((4, 1))
|
||||||
|
// ys[:] = xs # ok
|
||||||
|
//
|
||||||
|
// xs = np.zeros((1, 4))
|
||||||
|
// ys = np.zero((4,))
|
||||||
|
// ys[:] = xs # allowed
|
||||||
|
// # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
|
||||||
|
// # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
|
||||||
|
// # This implementation will NOT support this assignment.
|
||||||
|
// ```
|
||||||
|
template <typename SizeT>
|
||||||
|
void broadcast_to(NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
|
dst_ndarray->data = src_ndarray->data;
|
||||||
|
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||||
|
|
||||||
|
// irrt_assert(
|
||||||
|
// ndarray_util::can_broadcast_shape_to(
|
||||||
|
// dst_ndarray->ndims,
|
||||||
|
// dst_ndarray->shape,
|
||||||
|
// src_ndarray->ndims,
|
||||||
|
// src_ndarray->shape
|
||||||
|
// )
|
||||||
|
// );
|
||||||
|
|
||||||
|
SizeT stride_product = 1;
|
||||||
|
for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) {
|
||||||
|
SizeT this_dim_i = src_ndarray->ndims - i - 1;
|
||||||
|
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
|
||||||
|
|
||||||
|
bool this_dim_exists = this_dim_i >= 0;
|
||||||
|
bool dst_dim_exists = dst_dim_i >= 0;
|
||||||
|
|
||||||
|
// TODO: Explain how this works
|
||||||
|
bool c1 = this_dim_exists && src_ndarray->shape[this_dim_i] == 1;
|
||||||
|
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
|
||||||
|
if (!this_dim_exists || (c1 && c2)) {
|
||||||
|
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
|
||||||
|
} else {
|
||||||
|
dst_ndarray->strides[dst_dim_i] = stride_product * src_ndarray->itemsize;
|
||||||
|
stride_product *= src_ndarray->shape[this_dim_i]; // NOTE: this_dim_exist must be true here.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue