2024-07-10 11:56:31 +08:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "irrt_utils.hpp"
|
|
|
|
#include "irrt_typedefs.hpp"
|
|
|
|
#include "irrt_slice.hpp"
|
|
|
|
|
|
|
|
/*
|
|
|
|
NDArray-related implementations.
|
|
|
|
`*/
|
|
|
|
|
|
|
|
// NDArray indices are always `uint32_t`.
|
|
|
|
using NDIndex = uint32_t;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
namespace ndarray_util {
|
2024-07-10 17:05:01 +08:00
|
|
|
template <typename SizeT>
|
|
|
|
static void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
|
|
|
for (int32_t i = 0; i < ndims; i++) {
|
|
|
|
int32_t dim_i = ndims - i - 1;
|
|
|
|
int32_t dim = shape[dim_i];
|
|
|
|
|
|
|
|
indices[dim_i] = nth % dim;
|
|
|
|
nth /= dim;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-07-10 11:56:31 +08:00
|
|
|
// Compute the strides of an ndarray given an ndarray `shape`
|
|
|
|
// and assuming that the ndarray is *fully C-contagious*.
|
|
|
|
//
|
|
|
|
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
|
|
|
template <typename SizeT>
|
2024-07-10 14:05:08 +08:00
|
|
|
static void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
2024-07-10 11:56:31 +08:00
|
|
|
SizeT stride_product = 1;
|
|
|
|
for (SizeT i = 0; i < ndims; i++) {
|
|
|
|
int dim_i = ndims - i - 1;
|
2024-07-10 14:05:08 +08:00
|
|
|
dst_strides[dim_i] = stride_product * itemsize;
|
2024-07-10 11:56:31 +08:00
|
|
|
stride_product *= shape[dim_i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Compute the size/# of elements of an ndarray given its shape
|
|
|
|
template <typename SizeT>
|
|
|
|
static SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
|
|
|
SizeT size = 1;
|
|
|
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i];
|
|
|
|
return size;
|
|
|
|
}
|
2024-07-10 17:05:01 +08:00
|
|
|
|
|
|
|
template <typename SizeT>
|
|
|
|
static bool can_broadcast_shape_to(
|
|
|
|
const SizeT target_ndims,
|
|
|
|
const SizeT *target_shape,
|
|
|
|
const SizeT src_ndims,
|
|
|
|
const SizeT *src_shape
|
|
|
|
) {
|
|
|
|
/*
|
|
|
|
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
|
|
|
|
|
|
|
This function handles this example:
|
|
|
|
```
|
|
|
|
Image (3d array): 256 x 256 x 3
|
|
|
|
Scale (1d array): 3
|
|
|
|
Result (3d array): 256 x 256 x 3
|
|
|
|
```
|
|
|
|
|
|
|
|
Other interesting examples to consider:
|
|
|
|
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
|
|
|
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
|
|
|
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
|
|
|
|
|
|
|
In cases when the shapes contain zero(es):
|
|
|
|
- `can_broadcast_shape_to([0], [1]) == true`
|
|
|
|
- `can_broadcast_shape_to([0], [2]) == false`
|
|
|
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
|
|
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
|
|
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
|
|
|
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
|
|
|
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
|
|
|
*/
|
|
|
|
|
|
|
|
// This is essentially doing the following in Python:
|
|
|
|
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
|
|
|
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
|
|
|
SizeT target_dim_i = target_ndims - i - 1;
|
|
|
|
SizeT src_dim_i = src_ndims - i - 1;
|
|
|
|
|
|
|
|
bool target_dim_exists = target_dim_i >= 0;
|
|
|
|
bool src_dim_exists = src_dim_i >= 0;
|
|
|
|
|
|
|
|
SizeT target_dim = target_dim_exists ? target_shape[target_dim_i] : 1;
|
|
|
|
SizeT src_dim = src_dim_exists ? src_shape[src_dim_i] : 1;
|
|
|
|
|
|
|
|
bool ok = src_dim == 1 || target_dim == src_dim;
|
|
|
|
if (!ok) return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
2024-07-10 11:56:31 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
typedef uint8_t NDSliceType;
|
|
|
|
extern "C" {
|
2024-07-10 14:05:08 +08:00
|
|
|
const NDSliceType INPUT_SLICE_TYPE_INDEX = 0;
|
2024-07-10 11:56:31 +08:00
|
|
|
const NDSliceType INPUT_SLICE_TYPE_SLICE = 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
struct NDSlice {
|
2024-07-10 14:05:08 +08:00
|
|
|
// A poor-man's `std::variant<int, UserRange>`
|
2024-07-10 11:56:31 +08:00
|
|
|
NDSliceType type;
|
|
|
|
|
|
|
|
/*
|
2024-07-10 14:05:08 +08:00
|
|
|
if type == INPUT_SLICE_TYPE_INDEX => `slice` points to a single `SizeT`
|
|
|
|
if type == INPUT_SLICE_TYPE_SLICE => `slice` points to a single `UserRange`
|
2024-07-10 11:56:31 +08:00
|
|
|
*/
|
|
|
|
uint8_t *slice;
|
|
|
|
};
|
|
|
|
|
2024-07-10 14:05:08 +08:00
|
|
|
namespace ndarray_util {
|
|
|
|
template<typename SizeT>
|
2024-07-10 17:05:01 +08:00
|
|
|
SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_slices, const NDSlice *slices) {
|
2024-07-10 14:05:08 +08:00
|
|
|
irrt_assert(num_slices <= ndims);
|
|
|
|
|
|
|
|
SizeT final_ndims = ndims;
|
|
|
|
for (SizeT i = 0; i < num_slices; i++) {
|
|
|
|
if (slices[i].type == INPUT_SLICE_TYPE_INDEX) {
|
|
|
|
final_ndims--; // An integer slice demotes the rank by 1
|
|
|
|
}
|
2024-07-10 11:56:31 +08:00
|
|
|
}
|
2024-07-10 14:05:08 +08:00
|
|
|
return final_ndims;
|
2024-07-10 11:56:31 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename SizeT>
|
|
|
|
struct NDArrayIndicesIter {
|
|
|
|
SizeT ndims;
|
|
|
|
const SizeT *shape;
|
|
|
|
SizeT *indices;
|
|
|
|
|
|
|
|
void set_indices_zero() {
|
|
|
|
__builtin_memset(indices, 0, sizeof(SizeT) * ndims);
|
|
|
|
}
|
|
|
|
|
|
|
|
void next() {
|
|
|
|
for (SizeT i = 0; i < ndims; i++) {
|
|
|
|
SizeT dim_i = ndims - i - 1;
|
|
|
|
|
|
|
|
indices[dim_i]++;
|
|
|
|
if (indices[dim_i] < shape[dim_i]) {
|
|
|
|
break;
|
|
|
|
} else {
|
|
|
|
indices[dim_i] = 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
|
|
|
|
//
|
|
|
|
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
|
|
|
|
//
|
|
|
|
// Some resources you might find helpful:
|
|
|
|
// - The official numpy implementations:
|
|
|
|
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
|
|
|
// - On strides (about reshaping, slicing, C-contagiousness, etc)
|
|
|
|
// - https://ajcr.net/stride-guide-part-1/.
|
|
|
|
// - https://ajcr.net/stride-guide-part-2/.
|
|
|
|
// - https://ajcr.net/stride-guide-part-3/.
|
|
|
|
template <typename SizeT>
|
|
|
|
struct NDArray {
|
|
|
|
// The underlying data this `ndarray` is pointing to.
|
|
|
|
//
|
|
|
|
// NOTE: Formally this should be of type `void *`, but clang
|
|
|
|
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
|
|
|
// so we will put `uint8_t *` here for clarity.
|
|
|
|
uint8_t *data;
|
|
|
|
|
|
|
|
// The number of bytes of a single element in `data`.
|
|
|
|
//
|
|
|
|
// The `SizeT` is treated as `unsigned`.
|
|
|
|
SizeT itemsize;
|
|
|
|
|
|
|
|
// The number of dimensions of this shape.
|
|
|
|
//
|
|
|
|
// The `SizeT` is treated as `unsigned`.
|
|
|
|
SizeT ndims;
|
|
|
|
|
|
|
|
// Array shape, with length equal to `ndims`.
|
|
|
|
//
|
|
|
|
// The `SizeT` is treated as `unsigned`.
|
|
|
|
//
|
|
|
|
// NOTE: `shape` can contain 0.
|
|
|
|
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
|
|
|
|
SizeT *shape;
|
|
|
|
|
|
|
|
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
|
|
|
|
//
|
|
|
|
// The `SizeT` is treated as `signed`.
|
|
|
|
//
|
|
|
|
// NOTE: `strides` can have negative numbers.
|
|
|
|
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
|
|
|
|
SizeT *strides;
|
|
|
|
|
|
|
|
// Calculate the size/# of elements of an `ndarray`.
|
|
|
|
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
|
|
|
|
SizeT size() {
|
|
|
|
return ndarray_util::calc_size_from_shape(ndims, shape);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
|
|
|
|
// This function corresponds to `ndarray.nbytes`
|
|
|
|
SizeT nbytes() {
|
|
|
|
return this->size() * itemsize;
|
|
|
|
}
|
|
|
|
|
2024-07-10 17:05:01 +08:00
|
|
|
void set_value_at_pelement(uint8_t* pelement, const uint8_t* pvalue) {
|
2024-07-10 11:56:31 +08:00
|
|
|
__builtin_memcpy(pelement, pvalue, itemsize);
|
|
|
|
}
|
|
|
|
|
2024-07-10 17:05:01 +08:00
|
|
|
uint8_t* get_pelement(const SizeT *indices) {
|
2024-07-10 11:56:31 +08:00
|
|
|
uint8_t* element = data;
|
|
|
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
2024-07-10 14:05:08 +08:00
|
|
|
element += indices[dim_i] * strides[dim_i];
|
2024-07-10 11:56:31 +08:00
|
|
|
return element;
|
|
|
|
}
|
|
|
|
|
2024-07-10 17:05:01 +08:00
|
|
|
uint8_t* get_nth_pelement(SizeT nth) {
|
|
|
|
irrt_assert(0 <= nth);
|
|
|
|
irrt_assert(nth < this->size());
|
|
|
|
|
|
|
|
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
|
|
|
|
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
|
|
|
|
return get_pelement(indices);
|
|
|
|
}
|
|
|
|
|
2024-07-10 14:05:08 +08:00
|
|
|
// Get pointer to the first element of this ndarray, assuming
|
|
|
|
// `this->size() > 0`, i.e., not "degenerate" due to zeroes in `this->shape`)
|
|
|
|
//
|
|
|
|
// This is particularly useful for when the ndarray is just containing a single scalar.
|
|
|
|
uint8_t* get_first_pelement() {
|
|
|
|
irrt_assert(this->size() > 0);
|
|
|
|
return this->data; // ...It is simply `this->data`
|
|
|
|
}
|
|
|
|
|
2024-07-10 11:56:31 +08:00
|
|
|
// Is the given `indices` valid/in-bounds?
|
2024-07-10 17:05:01 +08:00
|
|
|
bool in_bounds(const SizeT *indices) {
|
2024-07-10 11:56:31 +08:00
|
|
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) {
|
|
|
|
bool dim_ok = indices[dim_i] < shape[dim_i];
|
|
|
|
if (!dim_ok) return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Fill the ndarray with a value
|
2024-07-10 17:05:01 +08:00
|
|
|
void fill_generic(const uint8_t* pvalue) {
|
2024-07-10 11:56:31 +08:00
|
|
|
NDArrayIndicesIter<SizeT> iter;
|
|
|
|
iter.ndims = this->ndims;
|
|
|
|
iter.shape = this->shape;
|
|
|
|
iter.indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndims);
|
|
|
|
iter.set_indices_zero();
|
|
|
|
|
|
|
|
for (SizeT i = 0; i < this->size(); i++, iter.next()) {
|
|
|
|
uint8_t* pelement = get_pelement(iter.indices);
|
|
|
|
set_value_at_pelement(pelement, pvalue);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
|
|
|
void set_strides_by_shape() {
|
2024-07-10 14:05:08 +08:00
|
|
|
ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape);
|
2024-07-10 11:56:31 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// https://numpy.org/doc/stable/reference/generated/numpy.eye.html
|
2024-07-10 17:05:01 +08:00
|
|
|
void set_to_eye(SizeT k, const uint8_t* zero_pvalue, const uint8_t* one_pvalue) {
|
2024-07-10 11:56:31 +08:00
|
|
|
__builtin_assume(ndims == 2);
|
|
|
|
|
|
|
|
// TODO: Better implementation
|
|
|
|
|
|
|
|
fill_generic(zero_pvalue);
|
|
|
|
for (SizeT i = 0; i < min(shape[0], shape[1]); i++) {
|
|
|
|
SizeT row = i;
|
|
|
|
SizeT col = i + k;
|
|
|
|
SizeT indices[2] = { row, col };
|
|
|
|
|
|
|
|
if (!in_bounds(indices)) continue;
|
|
|
|
|
|
|
|
uint8_t* pelement = get_pelement(indices);
|
|
|
|
set_value_at_pelement(pelement, one_pvalue);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`)
|
2024-07-10 14:05:08 +08:00
|
|
|
//
|
|
|
|
// Things assumed by this function:
|
|
|
|
// - `dst_ndarray` is allocated by the caller
|
|
|
|
// - `dst_ndarray.ndims` has the correct value (according to `ndarray_util::deduce_ndims_after_slicing`).
|
|
|
|
// - ... and `dst_ndarray.shape` and `dst_ndarray.strides` have been allocated by the caller as well
|
|
|
|
//
|
|
|
|
// Other notes:
|
|
|
|
// - `dst_ndarray->data` does not have to be set, it will be derived.
|
|
|
|
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->itemsize`
|
|
|
|
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
|
|
|
|
void slice(SizeT num_ndslices, NDSlice* ndslices, NDArray<SizeT>* dst_ndarray) {
|
|
|
|
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
|
|
|
|
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
|
|
|
|
|
|
|
irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices));
|
|
|
|
|
|
|
|
dst_ndarray->data = this->data;
|
2024-07-10 11:56:31 +08:00
|
|
|
|
|
|
|
SizeT this_axis = 0;
|
2024-07-10 14:05:08 +08:00
|
|
|
SizeT dst_axis = 0;
|
|
|
|
|
|
|
|
for (SizeT i = 0; i < num_ndslices; i++) {
|
|
|
|
NDSlice *ndslice = &ndslices[i];
|
|
|
|
if (ndslice->type == INPUT_SLICE_TYPE_INDEX) {
|
|
|
|
// Handle when the ndslice is just a single (possibly negative) integer
|
|
|
|
// e.g., `my_array[::2, -5, ::-1]`
|
|
|
|
// ^^------ like this
|
|
|
|
SizeT index_user = *((SizeT*) ndslice->slice);
|
|
|
|
SizeT index = resolve_index_in_length(this->shape[this_axis], index_user);
|
|
|
|
dst_ndarray->data += index * this->strides[this_axis]; // Add offset
|
|
|
|
|
|
|
|
// Next
|
|
|
|
this_axis++;
|
|
|
|
} else if (ndslice->type == INPUT_SLICE_TYPE_SLICE) {
|
|
|
|
// Handle when the ndslice is a slice (represented by UserSlice in IRRT)
|
|
|
|
// e.g., `my_array[::2, -5, ::-1]`
|
|
|
|
// ^^^------^^^^----- like these
|
|
|
|
UserSlice<SizeT>* user_slice = (UserSlice<SizeT>*) ndslice->slice;
|
|
|
|
Slice<SizeT> slice = user_slice->indices(this->shape[this_axis]); // To resolve negative indices and other funny stuff written by the user
|
|
|
|
|
|
|
|
// NOTE: There is no need to write special code to handle negative steps/strides.
|
|
|
|
// This simple implementation meticulously handles both positive and negative steps/strides.
|
|
|
|
// Check out the tinynumpy and IRRT's test cases if you are not convinced.
|
|
|
|
dst_ndarray->data += slice.start * this->strides[this_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes)
|
|
|
|
dst_ndarray->strides[dst_axis] = slice.step * this->strides[this_axis]; // Determine stride
|
|
|
|
dst_ndarray->shape[dst_axis] = slice.len(); // Determine shape dimension
|
|
|
|
|
|
|
|
// Next
|
|
|
|
dst_axis++;
|
|
|
|
this_axis++;
|
|
|
|
} else {
|
|
|
|
__builtin_unreachable();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation
|
2024-07-10 11:56:31 +08:00
|
|
|
}
|
2024-07-10 17:05:01 +08:00
|
|
|
|
|
|
|
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
|
|
|
|
// Assumptions:
|
|
|
|
// - `this` has to be fully initialized.
|
|
|
|
// - `dst_ndarray->ndims` has to be set.
|
|
|
|
// - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to.
|
|
|
|
//
|
|
|
|
// Other notes:
|
|
|
|
// - `dst_ndarray->data` does not have to be set, it will be set to `this->data`.
|
|
|
|
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->data`.
|
|
|
|
// - `dst_ndarray->strides` does not have to be set, it will be overwritten.
|
|
|
|
//
|
|
|
|
// Cautions:
|
|
|
|
// ```
|
|
|
|
// xs = np.zeros((4,))
|
|
|
|
// ys = np.zero((4, 1))
|
|
|
|
// ys[:] = xs # ok
|
|
|
|
//
|
|
|
|
// xs = np.zeros((1, 4))
|
|
|
|
// ys = np.zero((4,))
|
|
|
|
// ys[:] = xs # allowed
|
|
|
|
// # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
|
|
|
|
// # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
|
|
|
|
// # This implementation will NOT support this assignment.
|
|
|
|
// ```
|
|
|
|
void broadcast_to(NDArray<SizeT>* dst_ndarray) {
|
|
|
|
dst_ndarray->data = this->data;
|
|
|
|
dst_ndarray->itemsize = this->itemsize;
|
|
|
|
|
|
|
|
irrt_assert(
|
|
|
|
ndarray_util::can_broadcast_shape_to(
|
|
|
|
dst_ndarray->ndims,
|
|
|
|
dst_ndarray->shape,
|
|
|
|
this->ndims,
|
|
|
|
this->shape
|
|
|
|
)
|
|
|
|
);
|
|
|
|
|
|
|
|
SizeT stride_product = 1;
|
|
|
|
for (SizeT i = 0; i < max(this->ndims, dst_ndarray->ndims); i++) {
|
|
|
|
SizeT this_dim_i = this->ndims - i - 1;
|
|
|
|
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
|
|
|
|
|
|
|
|
bool this_dim_exists = this_dim_i >= 0;
|
|
|
|
bool dst_dim_exists = dst_dim_i >= 0;
|
|
|
|
|
|
|
|
// TODO: Explain how this works
|
|
|
|
bool c1 = this_dim_exists && this->shape[this_dim_i] == 1;
|
|
|
|
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
|
|
|
|
if (!this_dim_exists || (c1 && c2)) {
|
|
|
|
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
|
|
|
|
} else {
|
|
|
|
dst_ndarray->strides[dst_dim_i] = stride_product * this->itemsize;
|
|
|
|
stride_product *= this->shape[this_dim_i]; // NOTE: this_dim_exist must be true here.
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2024-07-10 17:27:10 +08:00
|
|
|
|
|
|
|
// Simulates `this_ndarray[:] = src_ndarray`, with automatic broadcasting.
|
|
|
|
// Caution on https://github.com/numpy/numpy/issues/21744
|
|
|
|
// Also see `NDArray::broadcast_to`
|
|
|
|
void assign_with(NDArray<SizeT>* src_ndarray) {
|
|
|
|
irrt_assert(
|
|
|
|
ndarray_util::can_broadcast_shape_to(
|
|
|
|
this->ndims,
|
|
|
|
this->shape,
|
|
|
|
src_ndarray->ndims,
|
|
|
|
src_ndarray->shape
|
|
|
|
)
|
|
|
|
);
|
|
|
|
|
|
|
|
// Broadcast the `src_ndarray` to make the reading process *much* easier
|
|
|
|
SizeT* broadcasted_src_ndarray_strides = __builtin_alloca(sizeof(SizeT) * this->ndims); // Remember to allocate strides beforehand
|
|
|
|
NDArray<SizeT> broadcasted_src_ndarray = {
|
|
|
|
.ndims = this->ndims,
|
|
|
|
.shape = this->shape,
|
|
|
|
.strides = broadcasted_src_ndarray_strides
|
|
|
|
};
|
|
|
|
src_ndarray->broadcast_to(&broadcasted_src_ndarray);
|
|
|
|
|
|
|
|
// Using iter instead of `get_nth_pelement` because it is slightly faster
|
|
|
|
SizeT* indices = __builtin_alloca(sizeof(SizeT) * this->ndims);
|
|
|
|
auto iter = NDArrayIndicesIter<SizeT> {
|
|
|
|
.ndims = this->ndims,
|
|
|
|
.shape = this->shape,
|
|
|
|
.indices = indices
|
|
|
|
};
|
|
|
|
const SizeT this_size = this->size();
|
|
|
|
for (SizeT i = 0; i < this_size; i++, iter.next()) {
|
|
|
|
uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement(indices);
|
|
|
|
uint8_t* this_pelement = this->get_pelement(indices);
|
|
|
|
this->set_value_at_pelement(src_pelement, src_pelement);
|
|
|
|
}
|
|
|
|
}
|
2024-07-10 11:56:31 +08:00
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
extern "C" {
|
|
|
|
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
|
|
|
return ndarray->size();
|
|
|
|
}
|
|
|
|
|
|
|
|
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
|
|
|
return ndarray->size();
|
|
|
|
}
|
|
|
|
|
|
|
|
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
|
|
|
|
ndarray->fill_generic(pvalue);
|
|
|
|
}
|
|
|
|
|
|
|
|
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
|
|
|
|
ndarray->fill_generic(pvalue);
|
|
|
|
}
|
|
|
|
|
|
|
|
// void __nac3_ndarray_slice(NDArray<int32_t>* ndarray, int32_t num_slices, NDSlice<int32_t> *slices, NDArray<int32_t> *dst_ndarray) {
|
|
|
|
// // ndarray->slice(num_slices, slices, dst_ndarray);
|
|
|
|
// }
|
|
|
|
}
|