forked from M-Labs/nac3
core: irrt add unchecked ndarray slicing
This commit is contained in:
parent
b8c0d5836f
commit
cc8103152f
|
@ -0,0 +1,137 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/slice.hpp>
|
||||||
|
#include <irrt/numpy/ndarray_def.hpp>
|
||||||
|
#include <irrt/numpy/ndarray_basic.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
typedef uint8_t NDSubscriptType;
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0;
|
||||||
|
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_SLICE = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NDSubscript {
|
||||||
|
// A poor-man's enum variant type
|
||||||
|
NDSubscriptType type;
|
||||||
|
|
||||||
|
/*
|
||||||
|
if type == INPUT_SUBSCRIPT_TYPE_INDEX => `slice` points to a single `SizeT`
|
||||||
|
if type == INPUT_SUBSCRIPT_TYPE_SLICE => `slice` points to a single `UserRange<SizeT>`
|
||||||
|
|
||||||
|
`SizeT` is controlled by the caller: `NDSubscript` only cares about where that
|
||||||
|
slice is (the pointer), `NDSubscript` does not care/know about the actual `sizeof()`
|
||||||
|
of the slice value.
|
||||||
|
*/
|
||||||
|
uint8_t* data;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace ndarray {
|
||||||
|
namespace util {
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_subscripts, const NDSubscript* subscripts) {
|
||||||
|
irrt_assert(num_subscripts <= ndims);
|
||||||
|
|
||||||
|
SizeT final_ndims = ndims;
|
||||||
|
for (SizeT i = 0; i < num_subscripts; i++) {
|
||||||
|
if (subscripts[i].type == INPUT_SUBSCRIPT_TYPE_INDEX) {
|
||||||
|
final_ndims--; // An index demotes the rank by 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return final_ndims;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// To support numpy "basic indexing" https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing
|
||||||
|
// "Advanced indexing" https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing is not supported
|
||||||
|
//
|
||||||
|
// This function supports:
|
||||||
|
// - "scalar indexing",
|
||||||
|
// - "slicing and strides",
|
||||||
|
// - and "dimensional indexing tools" (TODO, but this is really easy to implement).
|
||||||
|
//
|
||||||
|
// Things assumed by this function:
|
||||||
|
// - `dst_ndarray` is allocated by the caller
|
||||||
|
// - `dst_ndarray.ndims` has the correct value (according to `ndarray::util::deduce_ndims_after_slicing`).
|
||||||
|
// - ... and `dst_ndarray.shape` and `dst_ndarray.strides` have been allocated by the caller as well
|
||||||
|
//
|
||||||
|
// Other notes:
|
||||||
|
// - `dst_ndarray->data` does not have to be set, it will be derived.
|
||||||
|
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize`
|
||||||
|
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
|
||||||
|
template <typename SizeT>
|
||||||
|
void subscript(SizeT num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
|
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
|
||||||
|
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
||||||
|
|
||||||
|
// irrt_assert(dst_ndarray->ndims == ndarray::util::deduce_ndims_after_slicing(src_ndarray->ndims, num_subscripts, subscripts));
|
||||||
|
|
||||||
|
dst_ndarray->data = src_ndarray->data;
|
||||||
|
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||||
|
|
||||||
|
SizeT src_axis = 0;
|
||||||
|
SizeT dst_axis = 0;
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < num_subscripts; i++) {
|
||||||
|
NDSubscript *ndsubscript = &subscripts[i];
|
||||||
|
if (ndsubscript->type == INPUT_SUBSCRIPT_TYPE_INDEX) {
|
||||||
|
// Handle when the ndsubscript is just a single (possibly negative) integer
|
||||||
|
// e.g., `my_array[::2, -5, ::-1]`
|
||||||
|
// ^^------ like this
|
||||||
|
SizeT index_user = *((SizeT*) ndsubscript->data);
|
||||||
|
SizeT index = slice::resolve_index_in_length(src_ndarray->shape[src_axis], index_user);
|
||||||
|
dst_ndarray->data += index * src_ndarray->strides[src_axis]; // Add offset
|
||||||
|
|
||||||
|
// Next
|
||||||
|
src_axis++;
|
||||||
|
} else if (ndsubscript->type == INPUT_SUBSCRIPT_TYPE_SLICE) {
|
||||||
|
// Handle when the ndsubscript is a slice (represented by UserSlice in IRRT)
|
||||||
|
// e.g., `my_array[::2, -5, ::-1]`
|
||||||
|
// ^^^------^^^^----- like these
|
||||||
|
UserSlice* user_slice = (UserSlice*) ndsubscript->data;
|
||||||
|
|
||||||
|
// TODO: use checked indices
|
||||||
|
Slice slice;
|
||||||
|
user_slice->indices(src_ndarray->shape[src_axis], &slice); // To resolve negative indices and other funny stuff written by the user
|
||||||
|
|
||||||
|
// NOTE: There is no need to write special code to handle negative steps/strides.
|
||||||
|
// This simple implementation meticulously handles both positive and negative steps/strides.
|
||||||
|
// Check out the tinynumpy and IRRT's test cases if you are not convinced.
|
||||||
|
dst_ndarray->data += (SizeT) slice.start * src_ndarray->strides[src_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes)
|
||||||
|
dst_ndarray->strides[dst_axis] = ((SizeT) slice.step) * src_ndarray->strides[src_axis]; // Determine stride
|
||||||
|
dst_ndarray->shape[dst_axis] = (SizeT) slice.len(); // Determine shape dimension
|
||||||
|
|
||||||
|
// Next
|
||||||
|
dst_axis++;
|
||||||
|
src_axis++;
|
||||||
|
} else {
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Reference python code:
|
||||||
|
```python
|
||||||
|
dst_ndarray.shape.extend(src_ndarray.shape[src_axis:])
|
||||||
|
dst_ndarray.strides.extend(src_ndarray.strides[src_axis:])
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
|
||||||
|
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
||||||
|
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
void __nac3_ndarray_subscript(int32_t num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
|
||||||
|
ndarray::subscript(num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_subscript64(int64_t num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
|
||||||
|
ndarray::subscript(num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,132 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/slice.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct Slice {
|
||||||
|
SliceIndex start;
|
||||||
|
SliceIndex stop;
|
||||||
|
SliceIndex step;
|
||||||
|
|
||||||
|
// The length/The number of elements of the slice if it were a range,
|
||||||
|
// i.e., the value of `len(range(this->start, this->stop, this->end))`
|
||||||
|
SliceIndex len() {
|
||||||
|
SliceIndex diff = stop - start;
|
||||||
|
if (diff > 0 && step > 0) {
|
||||||
|
return ((diff - 1) / step) + 1;
|
||||||
|
} else if (diff < 0 && step < 0) {
|
||||||
|
return ((diff + 1) / step) + 1;
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace slice {
|
||||||
|
// "Resolve" an index value under a length in Python lists.
|
||||||
|
// If you have a `list` of length 100, `list[-1]` would resolve to `list[100-1] == list[99]`.
|
||||||
|
//
|
||||||
|
// If length == 0, this function returns 0
|
||||||
|
//
|
||||||
|
// If index is out of bounds, this function clamps the value
|
||||||
|
// (to `list[0]` or `list[-1]` in the context of a list and depending on if index is + or -)
|
||||||
|
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) {
|
||||||
|
if (index < 0) {
|
||||||
|
// Remember that index is negative, so do a plus here
|
||||||
|
return max<SliceIndex>(length + index, 0);
|
||||||
|
} else {
|
||||||
|
return min<SliceIndex>(length, index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A user-written Python-like slice.
|
||||||
|
//
|
||||||
|
// i.e., this slice is a triple of either an int or nothing. (e.g., `my_array[:10:2]`, `start` is None)
|
||||||
|
//
|
||||||
|
// You can "resolve" a `UserSlice` by using `UserSlice::indices(<length>)`
|
||||||
|
struct UserSlice {
|
||||||
|
// Did the user specify `start`? If 0, `start` is undefined (and contains an empty value)
|
||||||
|
bool start_defined;
|
||||||
|
SliceIndex start;
|
||||||
|
|
||||||
|
// Similar to `start_defined`
|
||||||
|
bool stop_defined;
|
||||||
|
SliceIndex stop;
|
||||||
|
|
||||||
|
// Similar to `start_defined`
|
||||||
|
bool step_defined;
|
||||||
|
SliceIndex step;
|
||||||
|
|
||||||
|
// Constructor faithfully follows Python's `slice()`.
|
||||||
|
explicit UserSlice(SliceIndex stop) {
|
||||||
|
start_defined = false;
|
||||||
|
stop_defined = true;
|
||||||
|
step_defined = false;
|
||||||
|
|
||||||
|
this->stop = stop;
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit UserSlice(SliceIndex start, SliceIndex stop) {
|
||||||
|
start_defined = true;
|
||||||
|
stop_defined = true;
|
||||||
|
step_defined = false;
|
||||||
|
|
||||||
|
this->start = start;
|
||||||
|
this->stop = stop;
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit UserSlice(SliceIndex start, SliceIndex stop, SliceIndex step) {
|
||||||
|
start_defined = true;
|
||||||
|
stop_defined = true;
|
||||||
|
step_defined = true;
|
||||||
|
|
||||||
|
this->start = start;
|
||||||
|
this->stop = stop;
|
||||||
|
this->step = step;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Like Python's `slice(start, stop, step).indices(length)`
|
||||||
|
void indices(SliceIndex length, Slice* result) {
|
||||||
|
// NOTE: This function implements Python's `slice.indices` *FAITHFULLY*.
|
||||||
|
// SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546
|
||||||
|
result->step = step_defined ? step : 1;
|
||||||
|
bool step_is_negative = result->step < 0;
|
||||||
|
|
||||||
|
if (start_defined) {
|
||||||
|
result->start = slice::resolve_index_in_length(length, start);
|
||||||
|
} else {
|
||||||
|
result->start = step_is_negative ? length - 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stop_defined) {
|
||||||
|
result->stop = slice::resolve_index_in_length(length, stop);
|
||||||
|
} else {
|
||||||
|
result->stop = step_is_negative ? -1 : length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `indices()` but asserts `this->step != 0` and `this->length >= 0`
|
||||||
|
void checked_indices(ErrorContext* errctx, SliceIndex length, Slice* result) {
|
||||||
|
if (!(length >= 0)) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->value_error,
|
||||||
|
"length should not be negative, got {0}", // Edited. Error message copied from python by doing `slice(0, 0, 0).indices(100)`
|
||||||
|
length
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(this->step_defined && this->step != 0)) {
|
||||||
|
// Error message
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->value_error,
|
||||||
|
"slice step cannot be zero" // Error message copied from python by doing `slice(0, 0, 0).indices(100)`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this->indices(length, result);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
|
@ -3,8 +3,10 @@
|
||||||
#include <irrt/core.hpp>
|
#include <irrt/core.hpp>
|
||||||
#include <irrt/error_context.hpp>
|
#include <irrt/error_context.hpp>
|
||||||
#include <irrt/int_defs.hpp>
|
#include <irrt/int_defs.hpp>
|
||||||
#include <irrt/numpy/ndarray_def.hpp>
|
|
||||||
#include <irrt/numpy/ndarray_basic.hpp>
|
#include <irrt/numpy/ndarray_basic.hpp>
|
||||||
#include <irrt/numpy/ndarray_broadcast.hpp>
|
#include <irrt/numpy/ndarray_broadcast.hpp>
|
||||||
|
#include <irrt/numpy/ndarray_def.hpp>
|
||||||
#include <irrt/numpy/ndarray_fill.hpp>
|
#include <irrt/numpy/ndarray_fill.hpp>
|
||||||
|
#include <irrt/numpy/ndarray_subscript.hpp>
|
||||||
|
#include <irrt/slice.hpp>
|
||||||
#include <irrt/utils.hpp>
|
#include <irrt/utils.hpp>
|
|
@ -8,11 +8,13 @@
|
||||||
#include <irrt_everything.hpp>
|
#include <irrt_everything.hpp>
|
||||||
|
|
||||||
#include <test/core.hpp>
|
#include <test/core.hpp>
|
||||||
#include <test/ndarray.hpp>
|
|
||||||
#include <test/test_core.hpp>
|
#include <test/test_core.hpp>
|
||||||
|
#include <test/test_ndarray.hpp>
|
||||||
|
#include <test/test_slice.hpp>
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
test_int_exp();
|
test_int_exp();
|
||||||
run_all_tests_ndarray();
|
run_all_tests_ndarray();
|
||||||
|
run_all_tests_ndarray_slice();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
|
@ -1,15 +1,14 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <test/core.hpp>
|
#include <test/core.hpp>
|
||||||
#include <irrt/numpy/ndarray.hpp>
|
#include <irrt_everything.hpp>
|
||||||
#include <irrt/numpy/ndarray_util.hpp>
|
|
||||||
|
|
||||||
void test_calc_size_from_shape_normal() {
|
void test_calc_size_from_shape_normal() {
|
||||||
// Test shapes with normal values
|
// Test shapes with normal values
|
||||||
BEGIN_TEST();
|
BEGIN_TEST();
|
||||||
|
|
||||||
int32_t shape[4] = { 2, 3, 5, 7 };
|
int32_t shape[4] = { 2, 3, 5, 7 };
|
||||||
assert_values_match(210, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
|
assert_values_match(210, ndarray::util::calc_size_from_shape<int32_t>(4, shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_calc_size_from_shape_has_zero() {
|
void test_calc_size_from_shape_has_zero() {
|
||||||
|
@ -17,7 +16,7 @@ void test_calc_size_from_shape_has_zero() {
|
||||||
BEGIN_TEST();
|
BEGIN_TEST();
|
||||||
|
|
||||||
int32_t shape[4] = { 2, 0, 5, 7 };
|
int32_t shape[4] = { 2, 0, 5, 7 };
|
||||||
assert_values_match(0, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
|
assert_values_match(0, ndarray::util::calc_size_from_shape<int32_t>(4, shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_set_strides_by_shape() {
|
void test_set_strides_by_shape() {
|
||||||
|
@ -26,7 +25,7 @@ void test_set_strides_by_shape() {
|
||||||
|
|
||||||
int32_t shape[4] = { 99, 3, 5, 7 };
|
int32_t shape[4] = { 99, 3, 5, 7 };
|
||||||
int32_t strides[4] = { 0 };
|
int32_t strides[4] = { 0 };
|
||||||
ndarray_util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
|
ndarray::util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
|
||||||
|
|
||||||
int32_t expected_strides[4] = {
|
int32_t expected_strides[4] = {
|
||||||
105 * sizeof(int32_t),
|
105 * sizeof(int32_t),
|
|
@ -0,0 +1,20 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <test/core.hpp>
|
||||||
|
#include <irrt_everything.hpp>
|
||||||
|
|
||||||
|
void test_slice_1() {
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
UserSlice user_slice(5);
|
||||||
|
Slice slice;
|
||||||
|
user_slice.indices(100, &slice);
|
||||||
|
|
||||||
|
assert_values_match(0, slice.start);
|
||||||
|
assert_values_match(5, slice.stop);
|
||||||
|
assert_values_match(1, slice.step);
|
||||||
|
}
|
||||||
|
|
||||||
|
void run_all_tests_ndarray_slice() {
|
||||||
|
test_slice_1();
|
||||||
|
}
|
Loading…
Reference in New Issue