diff --git a/nac3core/irrt/irrt/numpy/ndarray_subscript.hpp b/nac3core/irrt/irrt/numpy/ndarray_subscript.hpp new file mode 100644 index 00000000..6e9fce9e --- /dev/null +++ b/nac3core/irrt/irrt/numpy/ndarray_subscript.hpp @@ -0,0 +1,137 @@ +#pragma once + +#include +#include +#include + +namespace { +typedef uint8_t NDSubscriptType; + +extern "C" { +const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0; +const NDSubscriptType INPUT_SUBSCRIPT_TYPE_SLICE = 1; +} + +struct NDSubscript { + // A poor-man's enum variant type + NDSubscriptType type; + + /* + if type == INPUT_SUBSCRIPT_TYPE_INDEX => `slice` points to a single `SizeT` + if type == INPUT_SUBSCRIPT_TYPE_SLICE => `slice` points to a single `UserRange` + + `SizeT` is controlled by the caller: `NDSubscript` only cares about where that + slice is (the pointer), `NDSubscript` does not care/know about the actual `sizeof()` + of the slice value. + */ + uint8_t* data; +}; + +namespace ndarray { + namespace util { + template + SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_subscripts, const NDSubscript* subscripts) { + irrt_assert(num_subscripts <= ndims); + + SizeT final_ndims = ndims; + for (SizeT i = 0; i < num_subscripts; i++) { + if (subscripts[i].type == INPUT_SUBSCRIPT_TYPE_INDEX) { + final_ndims--; // An index demotes the rank by 1 + } + } + return final_ndims; + } + } + + // To support numpy "basic indexing" https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing + // "Advanced indexing" https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing is not supported + // + // This function supports: + // - "scalar indexing", + // - "slicing and strides", + // - and "dimensional indexing tools" (TODO, but this is really easy to implement). + // + // Things assumed by this function: + // - `dst_ndarray` is allocated by the caller + // - `dst_ndarray.ndims` has the correct value (according to `ndarray::util::deduce_ndims_after_slicing`). + // - ... and `dst_ndarray.shape` and `dst_ndarray.strides` have been allocated by the caller as well + // + // Other notes: + // - `dst_ndarray->data` does not have to be set, it will be derived. + // - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize` + // - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values + template + void subscript(SizeT num_subscripts, NDSubscript* subscripts, NDArray* src_ndarray, NDArray* dst_ndarray) { + // REFERENCE CODE (check out `_index_helper` in `__getitem__`): + // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 + + // irrt_assert(dst_ndarray->ndims == ndarray::util::deduce_ndims_after_slicing(src_ndarray->ndims, num_subscripts, subscripts)); + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + SizeT src_axis = 0; + SizeT dst_axis = 0; + + for (SizeT i = 0; i < num_subscripts; i++) { + NDSubscript *ndsubscript = &subscripts[i]; + if (ndsubscript->type == INPUT_SUBSCRIPT_TYPE_INDEX) { + // Handle when the ndsubscript is just a single (possibly negative) integer + // e.g., `my_array[::2, -5, ::-1]` + // ^^------ like this + SizeT index_user = *((SizeT*) ndsubscript->data); + SizeT index = slice::resolve_index_in_length(src_ndarray->shape[src_axis], index_user); + dst_ndarray->data += index * src_ndarray->strides[src_axis]; // Add offset + + // Next + src_axis++; + } else if (ndsubscript->type == INPUT_SUBSCRIPT_TYPE_SLICE) { + // Handle when the ndsubscript is a slice (represented by UserSlice in IRRT) + // e.g., `my_array[::2, -5, ::-1]` + // ^^^------^^^^----- like these + UserSlice* user_slice = (UserSlice*) ndsubscript->data; + + // TODO: use checked indices + Slice slice; + user_slice->indices(src_ndarray->shape[src_axis], &slice); // To resolve negative indices and other funny stuff written by the user + + // NOTE: There is no need to write special code to handle negative steps/strides. + // This simple implementation meticulously handles both positive and negative steps/strides. + // Check out the tinynumpy and IRRT's test cases if you are not convinced. + dst_ndarray->data += (SizeT) slice.start * src_ndarray->strides[src_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes) + dst_ndarray->strides[dst_axis] = ((SizeT) slice.step) * src_ndarray->strides[src_axis]; // Determine stride + dst_ndarray->shape[dst_axis] = (SizeT) slice.len(); // Determine shape dimension + + // Next + dst_axis++; + src_axis++; + } else { + __builtin_unreachable(); + } + } + + /* + Reference python code: + ```python + dst_ndarray.shape.extend(src_ndarray.shape[src_axis:]) + dst_ndarray.strides.extend(src_ndarray.strides[src_axis:]) + ``` + */ + + for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) { + dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + } +} +} + +extern "C" { +void __nac3_ndarray_subscript(int32_t num_subscripts, NDSubscript* subscripts, NDArray* src_ndarray, NDArray *dst_ndarray) { + ndarray::subscript(num_subscripts, subscripts, src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_subscript64(int64_t num_subscripts, NDSubscript* subscripts, NDArray* src_ndarray, NDArray *dst_ndarray) { + ndarray::subscript(num_subscripts, subscripts, src_ndarray, dst_ndarray); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/slice.hpp b/nac3core/irrt/irrt/slice.hpp new file mode 100644 index 00000000..826c9f7e --- /dev/null +++ b/nac3core/irrt/irrt/slice.hpp @@ -0,0 +1,132 @@ +#pragma once + +#include +#include + +namespace { +struct Slice { + SliceIndex start; + SliceIndex stop; + SliceIndex step; + + // The length/The number of elements of the slice if it were a range, + // i.e., the value of `len(range(this->start, this->stop, this->end))` + SliceIndex len() { + SliceIndex diff = stop - start; + if (diff > 0 && step > 0) { + return ((diff - 1) / step) + 1; + } else if (diff < 0 && step < 0) { + return ((diff + 1) / step) + 1; + } else { + return 0; + } + } +}; + +namespace slice { + // "Resolve" an index value under a length in Python lists. + // If you have a `list` of length 100, `list[-1]` would resolve to `list[100-1] == list[99]`. + // + // If length == 0, this function returns 0 + // + // If index is out of bounds, this function clamps the value + // (to `list[0]` or `list[-1]` in the context of a list and depending on if index is + or -) + SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) { + if (index < 0) { + // Remember that index is negative, so do a plus here + return max(length + index, 0); + } else { + return min(length, index); + } + } +} + +// A user-written Python-like slice. +// +// i.e., this slice is a triple of either an int or nothing. (e.g., `my_array[:10:2]`, `start` is None) +// +// You can "resolve" a `UserSlice` by using `UserSlice::indices()` +struct UserSlice { + // Did the user specify `start`? If 0, `start` is undefined (and contains an empty value) + bool start_defined; + SliceIndex start; + + // Similar to `start_defined` + bool stop_defined; + SliceIndex stop; + + // Similar to `start_defined` + bool step_defined; + SliceIndex step; + + // Constructor faithfully follows Python's `slice()`. + explicit UserSlice(SliceIndex stop) { + start_defined = false; + stop_defined = true; + step_defined = false; + + this->stop = stop; + } + + explicit UserSlice(SliceIndex start, SliceIndex stop) { + start_defined = true; + stop_defined = true; + step_defined = false; + + this->start = start; + this->stop = stop; + } + + explicit UserSlice(SliceIndex start, SliceIndex stop, SliceIndex step) { + start_defined = true; + stop_defined = true; + step_defined = true; + + this->start = start; + this->stop = stop; + this->step = step; + } + + // Like Python's `slice(start, stop, step).indices(length)` + void indices(SliceIndex length, Slice* result) { + // NOTE: This function implements Python's `slice.indices` *FAITHFULLY*. + // SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546 + result->step = step_defined ? step : 1; + bool step_is_negative = result->step < 0; + + if (start_defined) { + result->start = slice::resolve_index_in_length(length, start); + } else { + result->start = step_is_negative ? length - 1 : 0; + } + + if (stop_defined) { + result->stop = slice::resolve_index_in_length(length, stop); + } else { + result->stop = step_is_negative ? -1 : length; + } + } + + // `indices()` but asserts `this->step != 0` and `this->length >= 0` + void checked_indices(ErrorContext* errctx, SliceIndex length, Slice* result) { + if (!(length >= 0)) { + errctx->set_error( + errctx->error_ids->value_error, + "length should not be negative, got {0}", // Edited. Error message copied from python by doing `slice(0, 0, 0).indices(100)` + length + ); + return; + } + + if (!(this->step_defined && this->step != 0)) { + // Error message + errctx->set_error( + errctx->error_ids->value_error, + "slice step cannot be zero" // Error message copied from python by doing `slice(0, 0, 0).indices(100)` + ); + return; + } + this->indices(length, result); + } +}; +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index 1be4f6e8..a555d3d3 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -3,8 +3,10 @@ #include #include #include -#include #include #include +#include #include +#include +#include #include \ No newline at end of file diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index 0332ef66..30ee1c65 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -8,11 +8,13 @@ #include #include -#include #include +#include +#include int main() { test_int_exp(); run_all_tests_ndarray(); + run_all_tests_ndarray_slice(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/test/ndarray.hpp b/nac3core/irrt/test/test_ndarray.hpp similarity index 72% rename from nac3core/irrt/test/ndarray.hpp rename to nac3core/irrt/test/test_ndarray.hpp index 3554f613..d1073fe3 100644 --- a/nac3core/irrt/test/ndarray.hpp +++ b/nac3core/irrt/test/test_ndarray.hpp @@ -1,15 +1,14 @@ #pragma once #include -#include -#include +#include void test_calc_size_from_shape_normal() { // Test shapes with normal values BEGIN_TEST(); int32_t shape[4] = { 2, 3, 5, 7 }; - assert_values_match(210, ndarray_util::calc_size_from_shape(4, shape)); + assert_values_match(210, ndarray::util::calc_size_from_shape(4, shape)); } void test_calc_size_from_shape_has_zero() { @@ -17,7 +16,7 @@ void test_calc_size_from_shape_has_zero() { BEGIN_TEST(); int32_t shape[4] = { 2, 0, 5, 7 }; - assert_values_match(0, ndarray_util::calc_size_from_shape(4, shape)); + assert_values_match(0, ndarray::util::calc_size_from_shape(4, shape)); } void test_set_strides_by_shape() { @@ -26,7 +25,7 @@ void test_set_strides_by_shape() { int32_t shape[4] = { 99, 3, 5, 7 }; int32_t strides[4] = { 0 }; - ndarray_util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape); + ndarray::util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape); int32_t expected_strides[4] = { 105 * sizeof(int32_t), diff --git a/nac3core/irrt/test/test_slice.hpp b/nac3core/irrt/test/test_slice.hpp new file mode 100644 index 00000000..4ba72019 --- /dev/null +++ b/nac3core/irrt/test/test_slice.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +void test_slice_1() { + BEGIN_TEST(); + + UserSlice user_slice(5); + Slice slice; + user_slice.indices(100, &slice); + + assert_values_match(0, slice.start); + assert_values_match(5, slice.stop); + assert_values_match(1, slice.step); +} + +void run_all_tests_ndarray_slice() { + test_slice_1(); +} \ No newline at end of file