From d18c769cdc491c75f1f8a8acc63e09506d807da2 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 10 Jul 2024 14:05:08 +0800 Subject: [PATCH] core: irrt general numpy slicing --- nac3core/irrt/irrt_basic.hpp | 1 - nac3core/irrt/irrt_numpy_ndarray.hpp | 103 +++++++++++---- nac3core/irrt/irrt_slice.hpp | 21 ++- nac3core/irrt/irrt_test.cpp | 186 +++++++++++++++++++++++++-- nac3core/irrt/irrt_typedefs.hpp | 4 +- nac3core/irrt/irrt_utils.hpp | 27 ++-- 6 files changed, 295 insertions(+), 47 deletions(-) diff --git a/nac3core/irrt/irrt_basic.hpp b/nac3core/irrt/irrt_basic.hpp index 08214927..4f6a9b4c 100644 --- a/nac3core/irrt/irrt_basic.hpp +++ b/nac3core/irrt/irrt_basic.hpp @@ -12,7 +12,6 @@ // The type of an index or a value describing the length of a range/slice is // always `int32_t`. -typedef int32_t SliceIndex; namespace { // adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c diff --git a/nac3core/irrt/irrt_numpy_ndarray.hpp b/nac3core/irrt/irrt_numpy_ndarray.hpp index 8a4784be..38f62754 100644 --- a/nac3core/irrt/irrt_numpy_ndarray.hpp +++ b/nac3core/irrt/irrt_numpy_ndarray.hpp @@ -18,11 +18,11 @@ namespace { // // You might want to read up on https://ajcr.net/stride-guide-part-1/. template - static void set_strides_by_shape(SizeT ndims, SizeT* dst_strides, const SizeT* shape) { + static void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) { SizeT stride_product = 1; for (SizeT i = 0; i < ndims; i++) { int dim_i = ndims - i - 1; - dst_strides[dim_i] = stride_product; + dst_strides[dim_i] = stride_product * itemsize; stride_product *= shape[dim_i]; } } @@ -38,31 +38,34 @@ namespace { typedef uint8_t NDSliceType; extern "C" { - const NDSliceType INPUT_SLICE_TYPE_INTEGER = 0; + const NDSliceType INPUT_SLICE_TYPE_INDEX = 0; const NDSliceType INPUT_SLICE_TYPE_SLICE = 1; } struct NDSlice { + // A poor-man's `std::variant` NDSliceType type; /* - type = INPUT_SLICE_TYPE_INTEGER => `slice` points to a single `SizeT` - type = INPUT_SLICE_TYPE_SLICE => `slice` points to a single `NDSliceRange` + if type == INPUT_SLICE_TYPE_INDEX => `slice` points to a single `SizeT` + if type == INPUT_SLICE_TYPE_SLICE => `slice` points to a single `UserRange` */ uint8_t *slice; }; - template - SizeT deduce_ndims_after_slicing(SizeT ndims, const SizeT num_slices, const NDSlice *slices) { - nac3_assert(num_slices <= ndims); + namespace ndarray_util { + template + SizeT deduce_ndims_after_slicing(SizeT ndims, const SizeT num_slices, const NDSlice *slices) { + irrt_assert(num_slices <= ndims); - SizeT final_ndims = ndims; - for (SizeT i = 0; i < num_slices; i++) { - if (slices[i].type == INPUT_SLICE_TYPE_INTEGER) { - final_ndims--; // An integer slice demotes the rank by 1 + SizeT final_ndims = ndims; + for (SizeT i = 0; i < num_slices; i++) { + if (slices[i].type == INPUT_SLICE_TYPE_INDEX) { + final_ndims--; // An integer slice demotes the rank by 1 + } } + return final_ndims; } - return final_ndims; } template @@ -154,10 +157,19 @@ namespace { uint8_t* get_pelement(SizeT *indices) { uint8_t* element = data; for (SizeT dim_i = 0; dim_i < ndims; dim_i++) - element += indices[dim_i] * strides[dim_i] * itemsize; + element += indices[dim_i] * strides[dim_i]; return element; } + // Get pointer to the first element of this ndarray, assuming + // `this->size() > 0`, i.e., not "degenerate" due to zeroes in `this->shape`) + // + // This is particularly useful for when the ndarray is just containing a single scalar. + uint8_t* get_first_pelement() { + irrt_assert(this->size() > 0); + return this->data; // ...It is simply `this->data` + } + // Is the given `indices` valid/in-bounds? bool in_bounds(SizeT *indices) { for (SizeT dim_i = 0; dim_i < ndims; dim_i++) { @@ -183,7 +195,7 @@ namespace { // Set the strides of the ndarray with `ndarray_util::set_strides_by_shape` void set_strides_by_shape() { - ndarray_util::set_strides_by_shape(ndims, strides, shape); + ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape); } // https://numpy.org/doc/stable/reference/generated/numpy.eye.html @@ -206,15 +218,62 @@ namespace { } // To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`) - void slice(SizeT num_slices, NDSlice* slices, NDArray*dst_ndarray) { - // It is assumed that `dst_ndarray` is allocated by the caller and - // has the correct `ndims`. - nac3_assert(dst_ndarray->ndims == deduce_ndims_after_slicing(this->ndims, num_slices, slices)); + // + // Things assumed by this function: + // - `dst_ndarray` is allocated by the caller + // - `dst_ndarray.ndims` has the correct value (according to `ndarray_util::deduce_ndims_after_slicing`). + // - ... and `dst_ndarray.shape` and `dst_ndarray.strides` have been allocated by the caller as well + // + // Other notes: + // - `dst_ndarray->data` does not have to be set, it will be derived. + // - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->itemsize` + // - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values + void slice(SizeT num_ndslices, NDSlice* ndslices, NDArray* dst_ndarray) { + // REFERENCE CODE (check out `_index_helper` in `__getitem__`): + // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 + + irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices)); + + dst_ndarray->data = this->data; SizeT this_axis = 0; - SizeT guest_axis = 0; - // for () { - // } + SizeT dst_axis = 0; + + for (SizeT i = 0; i < num_ndslices; i++) { + NDSlice *ndslice = &ndslices[i]; + if (ndslice->type == INPUT_SLICE_TYPE_INDEX) { + // Handle when the ndslice is just a single (possibly negative) integer + // e.g., `my_array[::2, -5, ::-1]` + // ^^------ like this + SizeT index_user = *((SizeT*) ndslice->slice); + SizeT index = resolve_index_in_length(this->shape[this_axis], index_user); + dst_ndarray->data += index * this->strides[this_axis]; // Add offset + + // Next + this_axis++; + } else if (ndslice->type == INPUT_SLICE_TYPE_SLICE) { + // Handle when the ndslice is a slice (represented by UserSlice in IRRT) + // e.g., `my_array[::2, -5, ::-1]` + // ^^^------^^^^----- like these + UserSlice* user_slice = (UserSlice*) ndslice->slice; + Slice slice = user_slice->indices(this->shape[this_axis]); // To resolve negative indices and other funny stuff written by the user + + // NOTE: There is no need to write special code to handle negative steps/strides. + // This simple implementation meticulously handles both positive and negative steps/strides. + // Check out the tinynumpy and IRRT's test cases if you are not convinced. + dst_ndarray->data += slice.start * this->strides[this_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes) + dst_ndarray->strides[dst_axis] = slice.step * this->strides[this_axis]; // Determine stride + dst_ndarray->shape[dst_axis] = slice.len(); // Determine shape dimension + + // Next + dst_axis++; + this_axis++; + } else { + __builtin_unreachable(); + } + } + + irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation } }; } diff --git a/nac3core/irrt/irrt_slice.hpp b/nac3core/irrt/irrt_slice.hpp index 02802d44..4a565245 100644 --- a/nac3core/irrt/irrt_slice.hpp +++ b/nac3core/irrt/irrt_slice.hpp @@ -5,16 +5,31 @@ namespace { // A proper slice in IRRT, all negative indices have be resolved to absolute values. + // Even though nac3core's slices are always `int32_t`, we will template slice anyway + // since this struct is used as a general utility. template struct Slice { T start; T stop; T step; + + // The length/The number of elements of the slice if it were a range, + // i.e., the value of `len(range(this->start, this->stop, this->end))` + T len() { + T diff = stop - start; + if (diff > 0 && step > 0) { + return ((diff - 1) / step) + 1; + } else if (diff < 0 && step < 0) { + return ((diff + 1) / step) + 1; + } else { + return 0; + } + } }; template T resolve_index_in_length(T length, T index) { - nac3_assert(length >= 0); + irrt_assert(length >= 0); if (index < 0) { // Remember that index is negative, so do a plus here return max(length + index, 0); @@ -40,8 +55,8 @@ namespace { Slice indices(T length) { // NOTE: This function implements Python's `slice.indices` *FAITHFULLY*. // SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546 - nac3_assert(length >= 0); - nac3_assert(!step_defined || step != 0); // step_defined -> step != 0; step cannot be zero if specified by user + irrt_assert(length >= 0); + irrt_assert(!step_defined || step != 0); // step_defined -> step != 0; step cannot be zero if specified by user Slice result; result.step = step_defined ? step : 1; diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index e541865a..1dda0b0d 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -1,8 +1,11 @@ +// This file will be compiled like a real C++ program, +// and we do have the luxury to use the standard libraries. +// That is if the nix flakes do not have issues... especially on msys2... #include #include #include -// set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` has it all +// Set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` defines them #define IRRT_DONT_TYPEDEF_INTS #include "irrt_everything.hpp" @@ -17,14 +20,6 @@ void __begin_test(const char* function_name, const char* file, int line) { #define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__) -template -bool arrays_match(int len, T *as, T *bs) { - for (int i = 0; i < len; i++) { - if (as[i] != bs[i]) return false; - } - return true; -} - template void debug_print_array(const char* format, int len, T* as) { printf("["); @@ -84,9 +79,14 @@ void test_set_strides_by_shape() { int32_t shape[4] = { 99, 3, 5, 7 }; int32_t strides[4] = { 0 }; - ndarray_util::set_strides_by_shape(4, strides, shape); + ndarray_util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape); - int32_t expected_strides[4] = { 105, 35, 7, 1 }; + int32_t expected_strides[4] = { + 105 * sizeof(int32_t), + 35 * sizeof(int32_t), + 7 * sizeof(int32_t), + 1 * sizeof(int32_t) + }; assert_arrays_match("strides", "%u", 4u, expected_strides, strides); } @@ -248,6 +248,168 @@ void test_slice_4() { assert_values_match("step", "%d", -5, slice.step); } +void test_ndslice_1() { + /* + Reference Python code: + ```python + ndarray = np.arange(12, dtype=np.float64).reshape((3, 4)); + # array([[ 0., 1., 2., 3.], + # [ 4., 5., 6., 7.], + # [ 8., 9., 10., 11.]]) + + dst_ndarray = ndarray[-2:, 1::2] + # array([[ 5., 7.], + # [ 9., 11.]]) + + assert dst_ndarray.shape == (2, 2) + assert dst_ndarray.strides == (32, 16) + assert dst_ndarray[0, 0] == 5.0 + assert dst_ndarray[0, 1] == 7.0 + assert dst_ndarray[1, 0] == 9.0 + assert dst_ndarray[1, 1] == 11.0 + + dst_ndarray[1, 0] == 99 # Write to `dst_ndarray` + assert ndarray[1, 3] == 99 # `ndarray` also updates!! + ``` + */ + BEGIN_TEST(); + + double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 }; + int32_t in_itemsize = sizeof(double); + const int32_t in_ndims = 2; + int32_t in_shape[in_ndims] = { 3, 4 }; + int32_t in_strides[in_ndims] = {}; + NDArray ndarray = { + .data = (uint8_t*) in_data, + .itemsize = in_itemsize, + .ndims = in_ndims, + .shape = in_shape, + .strides = in_strides + }; + ndarray.set_strides_by_shape(); + + // Destination ndarray + // As documented, ndims and shape & strides must be allocated and determined by the caller. + const int32_t dst_ndims = 2; + int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values + int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values + NDArray dst_ndarray = { + .data = nullptr, + .ndims = dst_ndims, + .shape = dst_shape, + .strides = dst_strides + }; + + // Create the slice in `ndarray[-2::, 1::2]` + UserSlice user_slice_1 = { + .start_defined = 1, + .start = -2, + .stop_defined = 0, + .step_defined = 0 + }; + + UserSlice user_slice_2 = { + .start_defined = 1, + .start = 1, + .stop_defined = 0, + .step_defined = 1, + .step = 2 + }; + + const int32_t num_ndslices = 2; + NDSlice ndslices[num_ndslices] = { + { .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_1 }, + { .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 } + }; + + ndarray.slice(num_ndslices, ndslices, &dst_ndarray); + + int32_t expected_shape[dst_ndims] = { 2, 2 }; + int32_t expected_strides[dst_ndims] = { 32, 16 }; + assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape); + assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides); + + assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 0 }))); + assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 1 }))); + assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 0 }))); + assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 1 }))); +} + +void test_ndslice_2() { + /* + ```python + ndarray = np.arange(12, dtype=np.float64).reshape((3, 4)) + # array([[ 0., 1., 2., 3.], + # [ 4., 5., 6., 7.], + # [ 8., 9., 10., 11.]]) + + dst_ndarray = ndarray[2, ::-2] + # array([11., 9.]) + + assert dst_ndarray.shape == (2,) + assert dst_ndarray.strides == (-16,) + assert dst_ndarray[0] == 11.0 + assert dst_ndarray[1] == 9.0 + + dst_ndarray[1, 0] == 99 # If you write to `dst_ndarray` + assert ndarray[1, 3] == 99 # `ndarray` also updates!! + ``` + */ + BEGIN_TEST(); + + double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 }; + int32_t in_itemsize = sizeof(double); + const int32_t in_ndims = 2; + int32_t in_shape[in_ndims] = { 3, 4 }; + int32_t in_strides[in_ndims] = {}; + NDArray ndarray = { + .data = (uint8_t*) in_data, + .itemsize = in_itemsize, + .ndims = in_ndims, + .shape = in_shape, + .strides = in_strides + }; + ndarray.set_strides_by_shape(); + + // Destination ndarray + // As documented, ndims and shape & strides must be allocated and determined by the caller. + const int32_t dst_ndims = 1; + int32_t dst_shape[dst_ndims] = {999}; // Empty values + int32_t dst_strides[dst_ndims] = {999}; // Empty values + NDArray dst_ndarray = { + .data = nullptr, + .ndims = dst_ndims, + .shape = dst_shape, + .strides = dst_strides + }; + + // Create the slice in `ndarray[2, ::-2]` + int32_t user_slice_1 = 2; + UserSlice user_slice_2 = { + .start_defined = 0, + .stop_defined = 0, + .step_defined = 1, + .step = -2 + }; + + const int32_t num_ndslices = 2; + NDSlice ndslices[num_ndslices] = { + { .type = INPUT_SLICE_TYPE_INDEX, .slice = (uint8_t*) &user_slice_1 }, + { .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 } + }; + + ndarray.slice(num_ndslices, ndslices, &dst_ndarray); + + int32_t expected_shape[dst_ndims] = { 2 }; + int32_t expected_strides[dst_ndims] = { -16 }; + assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape); + assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides); + + // [5.0, 3.0] + assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0 }))); + assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 }))); +} + int main() { test_calc_size_from_shape_normal(); test_calc_size_from_shape_has_zero(); @@ -259,5 +421,7 @@ int main() { test_slice_2(); test_slice_3(); test_slice_4(); + test_ndslice_1(); + test_ndslice_2(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/irrt_typedefs.hpp b/nac3core/irrt/irrt_typedefs.hpp index acd75da8..7a10a03d 100644 --- a/nac3core/irrt/irrt_typedefs.hpp +++ b/nac3core/irrt/irrt_typedefs.hpp @@ -9,4 +9,6 @@ typedef _BitInt(32) int32_t; typedef unsigned _BitInt(32) uint32_t; typedef _BitInt(64) int64_t; typedef unsigned _BitInt(64) uint64_t; -#endif \ No newline at end of file +#endif + +typedef int32_t SliceIndex; \ No newline at end of file diff --git a/nac3core/irrt/irrt_utils.hpp b/nac3core/irrt/irrt_utils.hpp index 7ddc9ac0..033b995e 100644 --- a/nac3core/irrt/irrt_utils.hpp +++ b/nac3core/irrt/irrt_utils.hpp @@ -13,15 +13,24 @@ namespace { return a > b ? b : a; } - void nac3_assert(bool condition) { - // Doesn't do anything (for now (?)) - // Helps to make code self-documenting - - if (!condition) { - // TODO: don't crash the program - // TODO: address 0 on hardware might be writable? - uint8_t* death = nullptr; - *death = 0; + template + bool arrays_match(int len, T *as, T *bs) { + for (int i = 0; i < len; i++) { + if (as[i] != bs[i]) return false; } + return true; + } + + void irrt_panic() { + // Crash the program for now. + // TODO: Don't crash the program + // ... or at least produce a good message when doing testing IRRT + + uint8_t* death = nullptr; + *death = 0; // TODO: address 0 on hardware might be writable? + } + + void irrt_assert(bool condition) { + if (!condition) irrt_panic(); } } \ No newline at end of file