forked from M-Labs/nac3
196 lines
6.0 KiB
C++
196 lines
6.0 KiB
C++
|
#pragma once
|
||
|
|
||
|
#include "irrt_utils.hpp"
|
||
|
#include "irrt_typedefs.hpp"
|
||
|
|
||
|
/*
|
||
|
NDArray-related implementations.
|
||
|
`*/
|
||
|
|
||
|
// NDArray indices are always `uint32_t`.
|
||
|
using NDIndex = uint32_t;
|
||
|
|
||
|
namespace {
|
||
|
namespace ndarray_util {
|
||
|
// Compute the strides of an ndarray given an ndarray `shape`
|
||
|
// and assuming that the ndarray is *fully C-contagious*.
|
||
|
//
|
||
|
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
||
|
template <typename SizeT>
|
||
|
static void set_strides_by_shape(SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
||
|
SizeT stride_product = 1;
|
||
|
for (SizeT i = 0; i < ndims; i++) {
|
||
|
int dim_i = ndims - i - 1;
|
||
|
dst_strides[dim_i] = stride_product;
|
||
|
stride_product *= shape[dim_i];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Compute the size/# of elements of an ndarray given its shape
|
||
|
template <typename SizeT>
|
||
|
static SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
||
|
SizeT size = 1;
|
||
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i];
|
||
|
return size;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
template <typename SizeT>
|
||
|
struct NDArrayIndicesIter {
|
||
|
SizeT ndims;
|
||
|
const SizeT *shape;
|
||
|
SizeT *indices;
|
||
|
|
||
|
void set_indices_zero() {
|
||
|
__builtin_memset(indices, 0, sizeof(SizeT) * ndims);
|
||
|
}
|
||
|
|
||
|
void next() {
|
||
|
for (SizeT i = 0; i < ndims; i++) {
|
||
|
SizeT dim_i = ndims - i - 1;
|
||
|
|
||
|
indices[dim_i]++;
|
||
|
if (indices[dim_i] < shape[dim_i]) {
|
||
|
break;
|
||
|
} else {
|
||
|
indices[dim_i] = 0;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
|
||
|
//
|
||
|
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
|
||
|
//
|
||
|
// Some resources you might find helpful:
|
||
|
// - The official numpy implementations:
|
||
|
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
||
|
// - On strides (about reshaping, slicing, C-contagiousness, etc)
|
||
|
// - https://ajcr.net/stride-guide-part-1/.
|
||
|
// - https://ajcr.net/stride-guide-part-2/.
|
||
|
// - https://ajcr.net/stride-guide-part-3/.
|
||
|
template <typename SizeT>
|
||
|
struct NDArray {
|
||
|
// The underlying data this `ndarray` is pointing to.
|
||
|
//
|
||
|
// NOTE: Formally this should be of type `void *`, but clang
|
||
|
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
||
|
// so we will put `uint8_t *` here for clarity.
|
||
|
uint8_t *data;
|
||
|
|
||
|
// The number of bytes of a single element in `data`.
|
||
|
//
|
||
|
// The `SizeT` is treated as `unsigned`.
|
||
|
SizeT itemsize;
|
||
|
|
||
|
// The number of dimensions of this shape.
|
||
|
//
|
||
|
// The `SizeT` is treated as `unsigned`.
|
||
|
SizeT ndims;
|
||
|
|
||
|
// Array shape, with length equal to `ndims`.
|
||
|
//
|
||
|
// The `SizeT` is treated as `unsigned`.
|
||
|
//
|
||
|
// NOTE: `shape` can contain 0.
|
||
|
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
|
||
|
SizeT *shape;
|
||
|
|
||
|
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
|
||
|
//
|
||
|
// The `SizeT` is treated as `signed`.
|
||
|
//
|
||
|
// NOTE: `strides` can have negative numbers.
|
||
|
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
|
||
|
SizeT *strides;
|
||
|
|
||
|
// Calculate the size/# of elements of an `ndarray`.
|
||
|
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
|
||
|
SizeT size() {
|
||
|
return ndarray_util::calc_size_from_shape(ndims, shape);
|
||
|
}
|
||
|
|
||
|
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
|
||
|
// This function corresponds to `ndarray.nbytes`
|
||
|
SizeT nbytes() {
|
||
|
return this->size() * itemsize;
|
||
|
}
|
||
|
|
||
|
void set_value_at_pelement(uint8_t* pelement, uint8_t* pvalue) {
|
||
|
__builtin_memcpy(pelement, pvalue, itemsize);
|
||
|
}
|
||
|
|
||
|
uint8_t* get_pelement(SizeT *indices) {
|
||
|
uint8_t* element = data;
|
||
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
||
|
element += indices[dim_i] * strides[dim_i] * itemsize;
|
||
|
return element;
|
||
|
}
|
||
|
|
||
|
// Is the given `indices` valid/in-bounds?
|
||
|
bool in_bounds(SizeT *indices) {
|
||
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) {
|
||
|
bool dim_ok = indices[dim_i] < shape[dim_i];
|
||
|
if (!dim_ok) return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
// Fill the ndarray with a value
|
||
|
void fill_generic(uint8_t* pvalue) {
|
||
|
NDArrayIndicesIter<SizeT> iter;
|
||
|
iter.ndims = this->ndims;
|
||
|
iter.shape = this->shape;
|
||
|
iter.indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndims);
|
||
|
iter.set_indices_zero();
|
||
|
|
||
|
for (SizeT i = 0; i < this->size(); i++, iter.next()) {
|
||
|
uint8_t* pelement = get_pelement(iter.indices);
|
||
|
set_value_at_pelement(pelement, pvalue);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
||
|
void set_strides_by_shape() {
|
||
|
ndarray_util::set_strides_by_shape(ndims, strides, shape);
|
||
|
}
|
||
|
|
||
|
// https://numpy.org/doc/stable/reference/generated/numpy.eye.html
|
||
|
void set_to_eye(SizeT k, uint8_t* zero_pvalue, uint8_t* one_pvalue) {
|
||
|
__builtin_assume(ndims == 2);
|
||
|
|
||
|
// TODO: Better implementation
|
||
|
|
||
|
fill_generic(zero_pvalue);
|
||
|
for (SizeT i = 0; i < min(shape[0], shape[1]); i++) {
|
||
|
SizeT row = i;
|
||
|
SizeT col = i + k;
|
||
|
SizeT indices[2] = { row, col };
|
||
|
|
||
|
if (!in_bounds(indices)) continue;
|
||
|
|
||
|
uint8_t* pelement = get_pelement(indices);
|
||
|
set_value_at_pelement(pelement, one_pvalue);
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
}
|
||
|
|
||
|
extern "C" {
|
||
|
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||
|
return ndarray->size();
|
||
|
}
|
||
|
|
||
|
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||
|
return ndarray->size();
|
||
|
}
|
||
|
|
||
|
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
|
||
|
ndarray->fill_generic(pvalue);
|
||
|
}
|
||
|
|
||
|
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
|
||
|
ndarray->fill_generic(pvalue);
|
||
|
}
|
||
|
}
|