forked from M-Labs/nac3
core: irrt split ndarray.hpp
This commit is contained in:
parent
3344a2bcd3
commit
d92cccb85e
|
@ -1,155 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <irrt/int_defs.hpp>
|
|
||||||
#include <irrt/numpy/ndarray_util.hpp>
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
|
|
||||||
//
|
|
||||||
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
|
|
||||||
//
|
|
||||||
// Some resources you might find helpful:
|
|
||||||
// - The official numpy implementations:
|
|
||||||
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
|
||||||
// - On strides (about reshaping, slicing, C-contagiousness, etc)
|
|
||||||
// - https://ajcr.net/stride-guide-part-1/.
|
|
||||||
// - https://ajcr.net/stride-guide-part-2/.
|
|
||||||
// - https://ajcr.net/stride-guide-part-3/.
|
|
||||||
template <typename SizeT>
|
|
||||||
struct NDArray {
|
|
||||||
// The underlying data this `ndarray` is pointing to.
|
|
||||||
//
|
|
||||||
// NOTE: Formally this should be of type `void *`, but clang
|
|
||||||
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
|
||||||
// so we will put `uint8_t *` here for clarity.
|
|
||||||
//
|
|
||||||
// This pointer should point to the first element of the ndarray directly
|
|
||||||
uint8_t *data;
|
|
||||||
|
|
||||||
// The number of bytes of a single element in `data`.
|
|
||||||
//
|
|
||||||
// The `SizeT` is treated as `unsigned`.
|
|
||||||
SizeT itemsize;
|
|
||||||
|
|
||||||
// The number of dimensions of this shape.
|
|
||||||
//
|
|
||||||
// The `SizeT` is treated as `unsigned`.
|
|
||||||
SizeT ndims;
|
|
||||||
|
|
||||||
// Array shape, with length equal to `ndims`.
|
|
||||||
//
|
|
||||||
// The `SizeT` is treated as `unsigned`.
|
|
||||||
//
|
|
||||||
// NOTE: `shape` can contain 0.
|
|
||||||
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
|
|
||||||
SizeT *shape;
|
|
||||||
|
|
||||||
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
|
|
||||||
//
|
|
||||||
// The `SizeT` is treated as `signed`.
|
|
||||||
//
|
|
||||||
// NOTE: `strides` can have negative numbers.
|
|
||||||
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
|
|
||||||
SizeT *strides;
|
|
||||||
|
|
||||||
// Calculate the size/# of elements of an `ndarray`.
|
|
||||||
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
|
|
||||||
SizeT size() {
|
|
||||||
return ndarray_util::calc_size_from_shape(ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
|
|
||||||
// This function corresponds to `ndarray.nbytes`
|
|
||||||
SizeT nbytes() {
|
|
||||||
return this->size() * itemsize;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
|
||||||
void set_strides_by_shape() {
|
|
||||||
ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* get_pelement_by_indices(const SizeT *indices) {
|
|
||||||
uint8_t* element = data;
|
|
||||||
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
|
||||||
element += indices[dim_i] * strides[dim_i];
|
|
||||||
return element;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* get_nth_pelement(SizeT nth) {
|
|
||||||
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
|
|
||||||
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
|
|
||||||
return get_pelement_by_indices(indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the pointer to the nth element of the ndarray as if it were flattened.
|
|
||||||
uint8_t* checked_get_nth_pelement(ErrorContext* errctx, SizeT nth) {
|
|
||||||
SizeT arr_size = this->size();
|
|
||||||
if (!(0 <= nth && nth < arr_size)) {
|
|
||||||
errctx->set_error(
|
|
||||||
errctx->error_ids->index_error,
|
|
||||||
"index {0} is out of bounds, valid range is {1} <= index < {2}",
|
|
||||||
nth, 0, arr_size
|
|
||||||
);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return get_nth_pelement(nth);
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_pelement_value(uint8_t* pelement, const uint8_t* pvalue) {
|
|
||||||
__builtin_memcpy(pelement, pvalue, itemsize);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill the ndarray with a value
|
|
||||||
void fill_generic(const uint8_t* pvalue) {
|
|
||||||
const SizeT size = this->size();
|
|
||||||
for (SizeT i = 0; i < size; i++) {
|
|
||||||
uint8_t* pelement = get_nth_pelement(i); // No need for checked_get_nth_pelement
|
|
||||||
set_pelement_value(pelement, pvalue);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
|
||||||
return ndarray->size();
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
|
||||||
return ndarray->size();
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
|
||||||
return ndarray->nbytes();
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
|
||||||
return ndarray->nbytes();
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) {
|
|
||||||
ndarray_util::assert_shape_no_negative(errctx, ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_shape_no_negative64(ErrorContext* errctx, int64_t ndims, int64_t* shape) {
|
|
||||||
ndarray_util::assert_shape_no_negative(errctx, ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
|
||||||
ndarray->set_strides_by_shape();
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
|
||||||
ndarray->set_strides_by_shape();
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
|
|
||||||
ndarray->fill_generic(pvalue);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
|
|
||||||
ndarray->fill_generic(pvalue);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,151 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/error_context.hpp>
|
||||||
|
#include <irrt/numpy/ndarray_def.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
namespace util {
|
||||||
|
// Throw an error if there is an axis with negative dimension
|
||||||
|
template <typename SizeT>
|
||||||
|
void assert_shape_no_negative(ErrorContext* errctx, SizeT ndims, const SizeT* shape) {
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||||
|
if (shape[axis] < 0) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->value_error,
|
||||||
|
"negative dimensions are not allowed; axis {0} has dimension {1}",
|
||||||
|
axis, shape[axis]
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the size/# of elements of an ndarray given its shape
|
||||||
|
template <typename SizeT>
|
||||||
|
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
||||||
|
SizeT size = 1;
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis];
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the strides of an ndarray given an ndarray `shape`
|
||||||
|
// and assuming that the ndarray is *fully C-contagious*.
|
||||||
|
//
|
||||||
|
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
||||||
|
//
|
||||||
|
// This function might be used in isolation without an ndarray. That's
|
||||||
|
// why it separated out into its own util function.
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
||||||
|
SizeT stride_product = 1;
|
||||||
|
for (SizeT i = 0; i < ndims; i++) {
|
||||||
|
int axis = ndims - i - 1;
|
||||||
|
dst_strides[axis] = stride_product * itemsize;
|
||||||
|
stride_product *= shape[axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
||||||
|
for (int32_t i = 0; i < ndims; i++) {
|
||||||
|
int32_t axis = ndims - i - 1;
|
||||||
|
int32_t dim = shape[axis];
|
||||||
|
|
||||||
|
indices[axis] = nth % dim;
|
||||||
|
nth /= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the size/# of elements of an `ndarray`.
|
||||||
|
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
|
||||||
|
template <typename SizeT>
|
||||||
|
SizeT size(NDArray<SizeT>* ndarray) {
|
||||||
|
return ndarray::util::calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
|
||||||
|
// This function corresponds to `ndarray.nbytes`
|
||||||
|
template <typename SizeT>
|
||||||
|
SizeT nbytes(NDArray<SizeT>* ndarray) {
|
||||||
|
return ndarray::size(ndarray) * ndarray->itemsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
|
||||||
|
ndarray::util::set_strides_by_shape(ndarray->itemsize, ndarray->ndims, ndarray->strides, ndarray->shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
uint8_t* get_pelement_by_indices(NDArray<SizeT>* ndarray, const SizeT *indices) {
|
||||||
|
uint8_t* element = ndarray->data;
|
||||||
|
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
|
||||||
|
element += indices[dim_i] * ndarray->strides[dim_i];
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
uint8_t* get_nth_pelement(NDArray<SizeT>* ndarray, SizeT nth) {
|
||||||
|
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
|
||||||
|
ndarray::util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, nth);
|
||||||
|
return ndarray::get_pelement_by_indices(ndarray, indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the pointer to the nth element of the ndarray as if it were flattened.
|
||||||
|
template <typename SizeT>
|
||||||
|
uint8_t* checked_get_nth_pelement(NDArray<SizeT>* ndarray, ErrorContext* errctx, SizeT nth) {
|
||||||
|
SizeT arr_size = ndarray->size();
|
||||||
|
if (!(0 <= nth && nth < arr_size)) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->index_error,
|
||||||
|
"index {0} is out of bounds, valid range is {1} <= index < {2}",
|
||||||
|
nth, 0, arr_size
|
||||||
|
);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return ndarray::get_nth_pelement(ndarray, nth);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement, const uint8_t* pvalue) {
|
||||||
|
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||||||
|
return ndarray::size(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||||||
|
return ndarray::size(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
||||||
|
return ndarray::nbytes(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
||||||
|
return ndarray::nbytes(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) {
|
||||||
|
ndarray::util::assert_shape_no_negative(errctx, ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative64(ErrorContext* errctx, int64_t ndims, int64_t* shape) {
|
||||||
|
ndarray::util::assert_shape_no_negative(errctx, ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
||||||
|
ndarray::set_strides_by_shape(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
||||||
|
ndarray::set_strides_by_shape(ndarray);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
#include <irrt/numpy/ndarray_def.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
namespace util {
|
||||||
|
template <typename SizeT>
|
||||||
|
bool can_broadcast_shape_to(
|
||||||
|
const SizeT target_ndims,
|
||||||
|
const SizeT *target_shape,
|
||||||
|
const SizeT src_ndims,
|
||||||
|
const SizeT *src_shape
|
||||||
|
) {
|
||||||
|
/*
|
||||||
|
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||||
|
|
||||||
|
This function handles this example:
|
||||||
|
```
|
||||||
|
Image (3d array): 256 x 256 x 3
|
||||||
|
Scale (1d array): 3
|
||||||
|
Result (3d array): 256 x 256 x 3
|
||||||
|
```
|
||||||
|
|
||||||
|
Other interesting examples to consider:
|
||||||
|
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
||||||
|
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
||||||
|
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
||||||
|
|
||||||
|
In cases when the shapes contain zero(es):
|
||||||
|
- `can_broadcast_shape_to([0], [1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0], [2]) == false`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
||||||
|
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
||||||
|
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
||||||
|
*/
|
||||||
|
|
||||||
|
// This is essentially doing the following in Python:
|
||||||
|
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
||||||
|
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
||||||
|
SizeT target_axis = target_ndims - i - 1;
|
||||||
|
SizeT src_axis = src_ndims - i - 1;
|
||||||
|
|
||||||
|
bool target_dim_exists = target_axis >= 0;
|
||||||
|
bool src_dim_exists = src_axis >= 0;
|
||||||
|
|
||||||
|
SizeT target_dim = target_dim_exists ? target_shape[target_axis] : 1;
|
||||||
|
SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1;
|
||||||
|
|
||||||
|
bool ok = src_dim == 1 || target_dim == src_dim;
|
||||||
|
if (!ok) return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
|
||||||
|
//
|
||||||
|
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
|
||||||
|
//
|
||||||
|
// Some resources you might find helpful:
|
||||||
|
// - The official numpy implementations:
|
||||||
|
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
||||||
|
// - On strides (about reshaping, slicing, C-contagiousness, etc)
|
||||||
|
// - https://ajcr.net/stride-guide-part-1/.
|
||||||
|
// - https://ajcr.net/stride-guide-part-2/.
|
||||||
|
// - https://ajcr.net/stride-guide-part-3/.
|
||||||
|
template <typename SizeT>
|
||||||
|
struct NDArray {
|
||||||
|
// The underlying data this `ndarray` is pointing to.
|
||||||
|
//
|
||||||
|
// NOTE: Formally this should be of type `void *`, but clang
|
||||||
|
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
||||||
|
// so we will put `uint8_t *` here for clarity.
|
||||||
|
//
|
||||||
|
// This pointer should point to the first element of the ndarray directly
|
||||||
|
uint8_t *data;
|
||||||
|
|
||||||
|
// The number of bytes of a single element in `data`.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `unsigned`.
|
||||||
|
SizeT itemsize;
|
||||||
|
|
||||||
|
// The number of dimensions of this shape.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `unsigned`.
|
||||||
|
SizeT ndims;
|
||||||
|
|
||||||
|
// Array shape, with length equal to `ndims`.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `unsigned`.
|
||||||
|
//
|
||||||
|
// NOTE: `shape` can contain 0.
|
||||||
|
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
|
||||||
|
SizeT *shape;
|
||||||
|
|
||||||
|
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `signed`.
|
||||||
|
//
|
||||||
|
// NOTE: `strides` can have negative numbers.
|
||||||
|
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
|
||||||
|
SizeT *strides;
|
||||||
|
};
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/numpy/ndarray_def.hpp>
|
||||||
|
#include <irrt/numpy/ndarray_basic.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
// Fill the ndarray with a value
|
||||||
|
template <typename SizeT>
|
||||||
|
void fill_generic(NDArray<SizeT>* ndarray, const uint8_t* pvalue) {
|
||||||
|
const SizeT size = ndarray::size(ndarray);
|
||||||
|
for (SizeT i = 0; i < size; i++) {
|
||||||
|
uint8_t* pelement = ndarray::get_nth_pelement(ndarray, i); // No need for checked_get_nth_pelement
|
||||||
|
ndarray::set_pelement_value(ndarray, pelement, pvalue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
|
||||||
|
ndarray::fill_generic(ndarray, pvalue);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
|
||||||
|
ndarray::fill_generic(ndarray, pvalue);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,107 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <irrt/int_defs.hpp>
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray_util {
|
|
||||||
|
|
||||||
// Throw an error if there is an axis with negative dimension
|
|
||||||
template <typename SizeT>
|
|
||||||
void assert_shape_no_negative(ErrorContext* errctx, SizeT ndims, const SizeT* shape) {
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
if (shape[axis] < 0) {
|
|
||||||
errctx->set_error(
|
|
||||||
errctx->error_ids->value_error,
|
|
||||||
"negative dimensions are not allowed; axis {0} has dimension {1}",
|
|
||||||
axis, shape[axis]
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the size/# of elements of an ndarray given its shape
|
|
||||||
template <typename SizeT>
|
|
||||||
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
|
||||||
SizeT size = 1;
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis];
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the strides of an ndarray given an ndarray `shape`
|
|
||||||
// and assuming that the ndarray is *fully C-contagious*.
|
|
||||||
//
|
|
||||||
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
|
||||||
template <typename SizeT>
|
|
||||||
void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
|
||||||
SizeT stride_product = 1;
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
int axis = ndims - i - 1;
|
|
||||||
dst_strides[axis] = stride_product * itemsize;
|
|
||||||
stride_product *= shape[axis];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
|
||||||
for (int32_t i = 0; i < ndims; i++) {
|
|
||||||
int32_t axis = ndims - i - 1;
|
|
||||||
int32_t dim = shape[axis];
|
|
||||||
|
|
||||||
indices[axis] = nth % dim;
|
|
||||||
nth /= dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
bool can_broadcast_shape_to(
|
|
||||||
const SizeT target_ndims,
|
|
||||||
const SizeT *target_shape,
|
|
||||||
const SizeT src_ndims,
|
|
||||||
const SizeT *src_shape
|
|
||||||
) {
|
|
||||||
/*
|
|
||||||
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
|
||||||
|
|
||||||
This function handles this example:
|
|
||||||
```
|
|
||||||
Image (3d array): 256 x 256 x 3
|
|
||||||
Scale (1d array): 3
|
|
||||||
Result (3d array): 256 x 256 x 3
|
|
||||||
```
|
|
||||||
|
|
||||||
Other interesting examples to consider:
|
|
||||||
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
|
||||||
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
|
||||||
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
|
||||||
|
|
||||||
In cases when the shapes contain zero(es):
|
|
||||||
- `can_broadcast_shape_to([0], [1]) == true`
|
|
||||||
- `can_broadcast_shape_to([0], [2]) == false`
|
|
||||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
|
||||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
|
||||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
|
||||||
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
|
||||||
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
|
||||||
*/
|
|
||||||
|
|
||||||
// This is essentially doing the following in Python:
|
|
||||||
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
|
||||||
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
|
||||||
SizeT target_axis = target_ndims - i - 1;
|
|
||||||
SizeT src_axis = src_ndims - i - 1;
|
|
||||||
|
|
||||||
bool target_dim_exists = target_axis >= 0;
|
|
||||||
bool src_dim_exists = src_axis >= 0;
|
|
||||||
|
|
||||||
SizeT target_dim = target_dim_exists ? target_shape[target_axis] : 1;
|
|
||||||
SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1;
|
|
||||||
|
|
||||||
bool ok = src_dim == 1 || target_dim == src_dim;
|
|
||||||
if (!ok) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,6 +3,8 @@
|
||||||
#include <irrt/core.hpp>
|
#include <irrt/core.hpp>
|
||||||
#include <irrt/error_context.hpp>
|
#include <irrt/error_context.hpp>
|
||||||
#include <irrt/int_defs.hpp>
|
#include <irrt/int_defs.hpp>
|
||||||
#include <irrt/numpy/ndarray.hpp>
|
#include <irrt/numpy/ndarray_def.hpp>
|
||||||
#include <irrt/numpy/ndarray_util.hpp>
|
#include <irrt/numpy/ndarray_basic.hpp>
|
||||||
|
#include <irrt/numpy/ndarray_broadcast.hpp>
|
||||||
|
#include <irrt/numpy/ndarray_fill.hpp>
|
||||||
#include <irrt/utils.hpp>
|
#include <irrt/utils.hpp>
|
Loading…
Reference in New Issue