forked from M-Labs/nac3
core: irrt general numpy slicing
This commit is contained in:
parent
94c547ee22
commit
c192256b78
|
@ -12,7 +12,6 @@
|
||||||
|
|
||||||
// The type of an index or a value describing the length of a range/slice is
|
// The type of an index or a value describing the length of a range/slice is
|
||||||
// always `int32_t`.
|
// always `int32_t`.
|
||||||
typedef int32_t SliceIndex;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||||
|
|
|
@ -18,11 +18,11 @@ namespace {
|
||||||
//
|
//
|
||||||
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
||||||
template <typename SizeT>
|
template <typename SizeT>
|
||||||
static void set_strides_by_shape(SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
static void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
||||||
SizeT stride_product = 1;
|
SizeT stride_product = 1;
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
for (SizeT i = 0; i < ndims; i++) {
|
||||||
int dim_i = ndims - i - 1;
|
int dim_i = ndims - i - 1;
|
||||||
dst_strides[dim_i] = stride_product;
|
dst_strides[dim_i] = stride_product * itemsize;
|
||||||
stride_product *= shape[dim_i];
|
stride_product *= shape[dim_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,32 +38,35 @@ namespace {
|
||||||
|
|
||||||
typedef uint8_t NDSliceType;
|
typedef uint8_t NDSliceType;
|
||||||
extern "C" {
|
extern "C" {
|
||||||
const NDSliceType INPUT_SLICE_TYPE_INTEGER = 0;
|
const NDSliceType INPUT_SLICE_TYPE_INDEX = 0;
|
||||||
const NDSliceType INPUT_SLICE_TYPE_SLICE = 1;
|
const NDSliceType INPUT_SLICE_TYPE_SLICE = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NDSlice {
|
struct NDSlice {
|
||||||
|
// A poor-man's `std::variant<int, UserRange>`
|
||||||
NDSliceType type;
|
NDSliceType type;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
type = INPUT_SLICE_TYPE_INTEGER => `slice` points to a single `SizeT`
|
if type == INPUT_SLICE_TYPE_INDEX => `slice` points to a single `SizeT`
|
||||||
type = INPUT_SLICE_TYPE_SLICE => `slice` points to a single `NDSliceRange`
|
if type == INPUT_SLICE_TYPE_SLICE => `slice` points to a single `UserRange`
|
||||||
*/
|
*/
|
||||||
uint8_t *slice;
|
uint8_t *slice;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace ndarray_util {
|
||||||
template<typename SizeT>
|
template<typename SizeT>
|
||||||
SizeT deduce_ndims_after_slicing(SizeT ndims, const SizeT num_slices, const NDSlice *slices) {
|
SizeT deduce_ndims_after_slicing(SizeT ndims, const SizeT num_slices, const NDSlice *slices) {
|
||||||
nac3_assert(num_slices <= ndims);
|
irrt_assert(num_slices <= ndims);
|
||||||
|
|
||||||
SizeT final_ndims = ndims;
|
SizeT final_ndims = ndims;
|
||||||
for (SizeT i = 0; i < num_slices; i++) {
|
for (SizeT i = 0; i < num_slices; i++) {
|
||||||
if (slices[i].type == INPUT_SLICE_TYPE_INTEGER) {
|
if (slices[i].type == INPUT_SLICE_TYPE_INDEX) {
|
||||||
final_ndims--; // An integer slice demotes the rank by 1
|
final_ndims--; // An integer slice demotes the rank by 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return final_ndims;
|
return final_ndims;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename SizeT>
|
template <typename SizeT>
|
||||||
struct NDArrayIndicesIter {
|
struct NDArrayIndicesIter {
|
||||||
|
@ -154,10 +157,19 @@ namespace {
|
||||||
uint8_t* get_pelement(SizeT *indices) {
|
uint8_t* get_pelement(SizeT *indices) {
|
||||||
uint8_t* element = data;
|
uint8_t* element = data;
|
||||||
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
||||||
element += indices[dim_i] * strides[dim_i] * itemsize;
|
element += indices[dim_i] * strides[dim_i];
|
||||||
return element;
|
return element;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get pointer to the first element of this ndarray, assuming
|
||||||
|
// `this->size() > 0`, i.e., not "degenerate" due to zeroes in `this->shape`)
|
||||||
|
//
|
||||||
|
// This is particularly useful for when the ndarray is just containing a single scalar.
|
||||||
|
uint8_t* get_first_pelement() {
|
||||||
|
irrt_assert(this->size() > 0);
|
||||||
|
return this->data; // ...It is simply `this->data`
|
||||||
|
}
|
||||||
|
|
||||||
// Is the given `indices` valid/in-bounds?
|
// Is the given `indices` valid/in-bounds?
|
||||||
bool in_bounds(SizeT *indices) {
|
bool in_bounds(SizeT *indices) {
|
||||||
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) {
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) {
|
||||||
|
@ -183,7 +195,7 @@ namespace {
|
||||||
|
|
||||||
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
||||||
void set_strides_by_shape() {
|
void set_strides_by_shape() {
|
||||||
ndarray_util::set_strides_by_shape(ndims, strides, shape);
|
ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://numpy.org/doc/stable/reference/generated/numpy.eye.html
|
// https://numpy.org/doc/stable/reference/generated/numpy.eye.html
|
||||||
|
@ -206,15 +218,62 @@ namespace {
|
||||||
}
|
}
|
||||||
|
|
||||||
// To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`)
|
// To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`)
|
||||||
void slice(SizeT num_slices, NDSlice* slices, NDArray<SizeT>*dst_ndarray) {
|
//
|
||||||
// It is assumed that `dst_ndarray` is allocated by the caller and
|
// Things assumed by this function:
|
||||||
// has the correct `ndims`.
|
// - `dst_ndarray` is allocated by the caller
|
||||||
nac3_assert(dst_ndarray->ndims == deduce_ndims_after_slicing(this->ndims, num_slices, slices));
|
// - `dst_ndarray.ndims` has the correct value (according to `ndarray_util::deduce_ndims_after_slicing`).
|
||||||
|
// - ... and `dst_ndarray.shape` and `dst_ndarray.strides` have been allocated by the caller as well
|
||||||
|
//
|
||||||
|
// Other notes:
|
||||||
|
// - `dst_ndarray->data` does not have to be set, it will be derived.
|
||||||
|
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->itemsize`
|
||||||
|
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
|
||||||
|
void slice(SizeT num_ndslices, NDSlice* ndslices, NDArray<SizeT>* dst_ndarray) {
|
||||||
|
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
|
||||||
|
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
||||||
|
|
||||||
|
irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices));
|
||||||
|
|
||||||
|
dst_ndarray->data = this->data;
|
||||||
|
|
||||||
SizeT this_axis = 0;
|
SizeT this_axis = 0;
|
||||||
SizeT guest_axis = 0;
|
SizeT dst_axis = 0;
|
||||||
// for () {
|
|
||||||
// }
|
for (SizeT i = 0; i < num_ndslices; i++) {
|
||||||
|
NDSlice *ndslice = &ndslices[i];
|
||||||
|
if (ndslice->type == INPUT_SLICE_TYPE_INDEX) {
|
||||||
|
// Handle when the ndslice is just a single (possibly negative) integer
|
||||||
|
// e.g., `my_array[::2, -5, ::-1]`
|
||||||
|
// ^^------ like this
|
||||||
|
SizeT index_user = *((SizeT*) ndslice->slice);
|
||||||
|
SizeT index = resolve_index_in_length(this->shape[this_axis], index_user);
|
||||||
|
dst_ndarray->data += index * this->strides[this_axis]; // Add offset
|
||||||
|
|
||||||
|
// Next
|
||||||
|
this_axis++;
|
||||||
|
} else if (ndslice->type == INPUT_SLICE_TYPE_SLICE) {
|
||||||
|
// Handle when the ndslice is a slice (represented by UserSlice in IRRT)
|
||||||
|
// e.g., `my_array[::2, -5, ::-1]`
|
||||||
|
// ^^^------^^^^----- like these
|
||||||
|
UserSlice<SizeT>* user_slice = (UserSlice<SizeT>*) ndslice->slice;
|
||||||
|
Slice<SizeT> slice = user_slice->indices(this->shape[this_axis]); // To resolve negative indices and other funny stuff written by the user
|
||||||
|
|
||||||
|
// NOTE: There is no need to write special code to handle negative steps/strides.
|
||||||
|
// This simple implementation meticulously handles both positive and negative steps/strides.
|
||||||
|
// Check out the tinynumpy and IRRT's test cases if you are not convinced.
|
||||||
|
dst_ndarray->data += slice.start * this->strides[this_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes)
|
||||||
|
dst_ndarray->strides[dst_axis] = slice.step * this->strides[this_axis]; // Determine stride
|
||||||
|
dst_ndarray->shape[dst_axis] = slice.len(); // Determine shape dimension
|
||||||
|
|
||||||
|
// Next
|
||||||
|
dst_axis++;
|
||||||
|
this_axis++;
|
||||||
|
} else {
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,16 +5,31 @@
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// A proper slice in IRRT, all negative indices have be resolved to absolute values.
|
// A proper slice in IRRT, all negative indices have be resolved to absolute values.
|
||||||
|
// Even though nac3core's slices are always `int32_t`, we will template slice anyway
|
||||||
|
// since this struct is used as a general utility.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct Slice {
|
struct Slice {
|
||||||
T start;
|
T start;
|
||||||
T stop;
|
T stop;
|
||||||
T step;
|
T step;
|
||||||
|
|
||||||
|
// The length/The number of elements of the slice if it were a range,
|
||||||
|
// i.e., the value of `len(range(this->start, this->stop, this->end))`
|
||||||
|
T len() {
|
||||||
|
T diff = stop - start;
|
||||||
|
if (diff > 0 && step > 0) {
|
||||||
|
return ((diff - 1) / step) + 1;
|
||||||
|
} else if (diff < 0 && step < 0) {
|
||||||
|
return ((diff + 1) / step) + 1;
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
T resolve_index_in_length(T length, T index) {
|
T resolve_index_in_length(T length, T index) {
|
||||||
nac3_assert(length >= 0);
|
irrt_assert(length >= 0);
|
||||||
if (index < 0) {
|
if (index < 0) {
|
||||||
// Remember that index is negative, so do a plus here
|
// Remember that index is negative, so do a plus here
|
||||||
return max(length + index, 0);
|
return max(length + index, 0);
|
||||||
|
@ -40,8 +55,8 @@ namespace {
|
||||||
Slice<T> indices(T length) {
|
Slice<T> indices(T length) {
|
||||||
// NOTE: This function implements Python's `slice.indices` *FAITHFULLY*.
|
// NOTE: This function implements Python's `slice.indices` *FAITHFULLY*.
|
||||||
// SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546
|
// SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546
|
||||||
nac3_assert(length >= 0);
|
irrt_assert(length >= 0);
|
||||||
nac3_assert(!step_defined || step != 0); // step_defined -> step != 0; step cannot be zero if specified by user
|
irrt_assert(!step_defined || step != 0); // step_defined -> step != 0; step cannot be zero if specified by user
|
||||||
|
|
||||||
Slice<T> result;
|
Slice<T> result;
|
||||||
result.step = step_defined ? step : 1;
|
result.step = step_defined ? step : 1;
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
|
// This file will be compiled like a real C++ program,
|
||||||
|
// and we do have the luxury to use the standard libraries.
|
||||||
|
// That is if the nix flakes do not have issues... especially on msys2...
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
// set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` has it all
|
// Set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` defines them
|
||||||
#define IRRT_DONT_TYPEDEF_INTS
|
#define IRRT_DONT_TYPEDEF_INTS
|
||||||
#include "irrt_everything.hpp"
|
#include "irrt_everything.hpp"
|
||||||
|
|
||||||
|
@ -17,14 +20,6 @@ void __begin_test(const char* function_name, const char* file, int line) {
|
||||||
|
|
||||||
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
|
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
bool arrays_match(int len, T *as, T *bs) {
|
|
||||||
for (int i = 0; i < len; i++) {
|
|
||||||
if (as[i] != bs[i]) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void debug_print_array(const char* format, int len, T* as) {
|
void debug_print_array(const char* format, int len, T* as) {
|
||||||
printf("[");
|
printf("[");
|
||||||
|
@ -84,9 +79,14 @@ void test_set_strides_by_shape() {
|
||||||
|
|
||||||
int32_t shape[4] = { 99, 3, 5, 7 };
|
int32_t shape[4] = { 99, 3, 5, 7 };
|
||||||
int32_t strides[4] = { 0 };
|
int32_t strides[4] = { 0 };
|
||||||
ndarray_util::set_strides_by_shape(4, strides, shape);
|
ndarray_util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
|
||||||
|
|
||||||
int32_t expected_strides[4] = { 105, 35, 7, 1 };
|
int32_t expected_strides[4] = {
|
||||||
|
105 * sizeof(int32_t),
|
||||||
|
35 * sizeof(int32_t),
|
||||||
|
7 * sizeof(int32_t),
|
||||||
|
1 * sizeof(int32_t)
|
||||||
|
};
|
||||||
assert_arrays_match("strides", "%u", 4u, expected_strides, strides);
|
assert_arrays_match("strides", "%u", 4u, expected_strides, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,6 +248,168 @@ void test_slice_4() {
|
||||||
assert_values_match("step", "%d", -5, slice.step);
|
assert_values_match("step", "%d", -5, slice.step);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void test_ndslice_1() {
|
||||||
|
/*
|
||||||
|
Reference Python code:
|
||||||
|
```python
|
||||||
|
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
|
||||||
|
# array([[ 0., 1., 2., 3.],
|
||||||
|
# [ 4., 5., 6., 7.],
|
||||||
|
# [ 8., 9., 10., 11.]])
|
||||||
|
|
||||||
|
dst_ndarray = ndarray[-2:, 1::2]
|
||||||
|
# array([[ 5., 7.],
|
||||||
|
# [ 9., 11.]])
|
||||||
|
|
||||||
|
assert dst_ndarray.shape == (2, 2)
|
||||||
|
assert dst_ndarray.strides == (32, 16)
|
||||||
|
assert dst_ndarray[0, 0] == 5.0
|
||||||
|
assert dst_ndarray[0, 1] == 7.0
|
||||||
|
assert dst_ndarray[1, 0] == 9.0
|
||||||
|
assert dst_ndarray[1, 1] == 11.0
|
||||||
|
|
||||||
|
dst_ndarray[1, 0] == 99 # Write to `dst_ndarray`
|
||||||
|
assert ndarray[1, 3] == 99 # `ndarray` also updates!!
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
||||||
|
int32_t in_itemsize = sizeof(double);
|
||||||
|
const int32_t in_ndims = 2;
|
||||||
|
int32_t in_shape[in_ndims] = { 3, 4 };
|
||||||
|
int32_t in_strides[in_ndims] = {};
|
||||||
|
NDArray<int32_t> ndarray = {
|
||||||
|
.data = (uint8_t*) in_data,
|
||||||
|
.itemsize = in_itemsize,
|
||||||
|
.ndims = in_ndims,
|
||||||
|
.shape = in_shape,
|
||||||
|
.strides = in_strides
|
||||||
|
};
|
||||||
|
ndarray.set_strides_by_shape();
|
||||||
|
|
||||||
|
// Destination ndarray
|
||||||
|
// As documented, ndims and shape & strides must be allocated and determined by the caller.
|
||||||
|
const int32_t dst_ndims = 2;
|
||||||
|
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
|
||||||
|
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
|
||||||
|
NDArray<int32_t> dst_ndarray = {
|
||||||
|
.data = nullptr,
|
||||||
|
.ndims = dst_ndims,
|
||||||
|
.shape = dst_shape,
|
||||||
|
.strides = dst_strides
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create the slice in `ndarray[-2::, 1::2]`
|
||||||
|
UserSlice<int32_t> user_slice_1 = {
|
||||||
|
.start_defined = 1,
|
||||||
|
.start = -2,
|
||||||
|
.stop_defined = 0,
|
||||||
|
.step_defined = 0
|
||||||
|
};
|
||||||
|
|
||||||
|
UserSlice<int32_t> user_slice_2 = {
|
||||||
|
.start_defined = 1,
|
||||||
|
.start = 1,
|
||||||
|
.stop_defined = 0,
|
||||||
|
.step_defined = 1,
|
||||||
|
.step = 2
|
||||||
|
};
|
||||||
|
|
||||||
|
const int32_t num_ndslices = 2;
|
||||||
|
NDSlice ndslices[num_ndslices] = {
|
||||||
|
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_1 },
|
||||||
|
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
|
||||||
|
};
|
||||||
|
|
||||||
|
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
|
||||||
|
|
||||||
|
int32_t expected_shape[dst_ndims] = { 2, 2 };
|
||||||
|
int32_t expected_strides[dst_ndims] = { 32, 16 };
|
||||||
|
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
|
||||||
|
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
|
||||||
|
|
||||||
|
assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 0 })));
|
||||||
|
assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 1 })));
|
||||||
|
assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 0 })));
|
||||||
|
assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 1 })));
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ndslice_2() {
|
||||||
|
/*
|
||||||
|
```python
|
||||||
|
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
|
||||||
|
# array([[ 0., 1., 2., 3.],
|
||||||
|
# [ 4., 5., 6., 7.],
|
||||||
|
# [ 8., 9., 10., 11.]])
|
||||||
|
|
||||||
|
dst_ndarray = ndarray[2, ::-2]
|
||||||
|
# array([11., 9.])
|
||||||
|
|
||||||
|
assert dst_ndarray.shape == (2,)
|
||||||
|
assert dst_ndarray.strides == (-16,)
|
||||||
|
assert dst_ndarray[0] == 11.0
|
||||||
|
assert dst_ndarray[1] == 9.0
|
||||||
|
|
||||||
|
dst_ndarray[1, 0] == 99 # If you write to `dst_ndarray`
|
||||||
|
assert ndarray[1, 3] == 99 # `ndarray` also updates!!
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
||||||
|
int32_t in_itemsize = sizeof(double);
|
||||||
|
const int32_t in_ndims = 2;
|
||||||
|
int32_t in_shape[in_ndims] = { 3, 4 };
|
||||||
|
int32_t in_strides[in_ndims] = {};
|
||||||
|
NDArray<int32_t> ndarray = {
|
||||||
|
.data = (uint8_t*) in_data,
|
||||||
|
.itemsize = in_itemsize,
|
||||||
|
.ndims = in_ndims,
|
||||||
|
.shape = in_shape,
|
||||||
|
.strides = in_strides
|
||||||
|
};
|
||||||
|
ndarray.set_strides_by_shape();
|
||||||
|
|
||||||
|
// Destination ndarray
|
||||||
|
// As documented, ndims and shape & strides must be allocated and determined by the caller.
|
||||||
|
const int32_t dst_ndims = 1;
|
||||||
|
int32_t dst_shape[dst_ndims] = {999}; // Empty values
|
||||||
|
int32_t dst_strides[dst_ndims] = {999}; // Empty values
|
||||||
|
NDArray<int32_t> dst_ndarray = {
|
||||||
|
.data = nullptr,
|
||||||
|
.ndims = dst_ndims,
|
||||||
|
.shape = dst_shape,
|
||||||
|
.strides = dst_strides
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create the slice in `ndarray[2, ::-2]`
|
||||||
|
int32_t user_slice_1 = 2;
|
||||||
|
UserSlice<int32_t> user_slice_2 = {
|
||||||
|
.start_defined = 0,
|
||||||
|
.stop_defined = 0,
|
||||||
|
.step_defined = 1,
|
||||||
|
.step = -2
|
||||||
|
};
|
||||||
|
|
||||||
|
const int32_t num_ndslices = 2;
|
||||||
|
NDSlice ndslices[num_ndslices] = {
|
||||||
|
{ .type = INPUT_SLICE_TYPE_INDEX, .slice = (uint8_t*) &user_slice_1 },
|
||||||
|
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
|
||||||
|
};
|
||||||
|
|
||||||
|
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
|
||||||
|
|
||||||
|
int32_t expected_shape[dst_ndims] = { 2 };
|
||||||
|
int32_t expected_strides[dst_ndims] = { -16 };
|
||||||
|
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
|
||||||
|
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
|
||||||
|
|
||||||
|
// [5.0, 3.0]
|
||||||
|
assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0 })));
|
||||||
|
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 })));
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
test_calc_size_from_shape_normal();
|
test_calc_size_from_shape_normal();
|
||||||
test_calc_size_from_shape_has_zero();
|
test_calc_size_from_shape_has_zero();
|
||||||
|
@ -259,5 +421,7 @@ int main() {
|
||||||
test_slice_2();
|
test_slice_2();
|
||||||
test_slice_3();
|
test_slice_3();
|
||||||
test_slice_4();
|
test_slice_4();
|
||||||
|
test_ndslice_1();
|
||||||
|
test_ndslice_2();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
|
@ -10,3 +10,5 @@ typedef unsigned _BitInt(32) uint32_t;
|
||||||
typedef _BitInt(64) int64_t;
|
typedef _BitInt(64) int64_t;
|
||||||
typedef unsigned _BitInt(64) uint64_t;
|
typedef unsigned _BitInt(64) uint64_t;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
typedef int32_t SliceIndex;
|
|
@ -13,15 +13,24 @@ namespace {
|
||||||
return a > b ? b : a;
|
return a > b ? b : a;
|
||||||
}
|
}
|
||||||
|
|
||||||
void nac3_assert(bool condition) {
|
template <typename T>
|
||||||
// Doesn't do anything (for now (?))
|
bool arrays_match(int len, T *as, T *bs) {
|
||||||
// Helps to make code self-documenting
|
for (int i = 0; i < len; i++) {
|
||||||
|
if (as[i] != bs[i]) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void irrt_panic() {
|
||||||
|
// Crash the program for now.
|
||||||
|
// TODO: Don't crash the program
|
||||||
|
// ... or at least produce a good message when doing testing IRRT
|
||||||
|
|
||||||
if (!condition) {
|
|
||||||
// TODO: don't crash the program
|
|
||||||
// TODO: address 0 on hardware might be writable?
|
|
||||||
uint8_t* death = nullptr;
|
uint8_t* death = nullptr;
|
||||||
*death = 0;
|
*death = 0; // TODO: address 0 on hardware might be writable?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void irrt_assert(bool condition) {
|
||||||
|
if (!condition) irrt_panic();
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue