From 4b146093428633eb1bef2bd3cdcff7d19d50931c Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 26 Jul 2024 16:31:28 +0800 Subject: [PATCH] core/ndstrides: implement IRRT slice Needed by ndarray indexing --- nac3core/irrt/irrt/slice.hpp | 160 ++++++++++++++++++++++++++++++ nac3core/irrt/irrt_everything.hpp | 1 + nac3core/irrt/irrt_test.cpp | 2 + nac3core/irrt/test/test_slice.hpp | 92 +++++++++++++++++ 4 files changed, 255 insertions(+) create mode 100644 nac3core/irrt/irrt/slice.hpp create mode 100644 nac3core/irrt/test/test_slice.hpp diff --git a/nac3core/irrt/irrt/slice.hpp b/nac3core/irrt/irrt/slice.hpp new file mode 100644 index 00000000..711918f5 --- /dev/null +++ b/nac3core/irrt/irrt/slice.hpp @@ -0,0 +1,160 @@ +#pragma once + +#include +#include + +namespace { + +/** + * @brief A Python-like slice with resolved indices. + * + * "Resolved indices" means that `start` and `stop` must be positive and are + * bound to a known length. + */ +struct Slice { + SliceIndex start; + SliceIndex stop; + SliceIndex step; + + /** + * @brief Calculate and return the length / the number of the slice. + * + * If this were a Python range, this function would be `len(range(start, stop, step))`. + */ + SliceIndex len() { + SliceIndex diff = stop - start; + if (diff > 0 && step > 0) { + return ((diff - 1) / step) + 1; + } else if (diff < 0 && step < 0) { + return ((diff + 1) / step) + 1; + } else { + return 0; + } + } +}; + +namespace slice { +/** + * @brief Resolve a slice index under a given length like Python indexing. + * + * In Python, if you have a `list` of length 100, `list[-1]` resolves to + * `list[99]`, so `resolve_index_in_length_clamped(100, -1)` returns `99`. + * + * If `length` is 0, 0 is returned for any value of `index`. + * + * If `index` is out of bounds, clamps the returned value between `0` and + * `length - 1` (inclusive). + * + */ +SliceIndex resolve_index_in_length_clamped(SliceIndex length, + SliceIndex index) { + if (index < 0) { + return max(length + index, 0); + } else { + return min(length, index); + } +} + +const SliceIndex OUT_OF_BOUNDS = -1; + +/** + * @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS` + * if `index` is out of bounds. + */ +SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) { + SliceIndex resolved = index < 0 ? length + index : index; + if (0 <= resolved && resolved < length) { + return resolved; + } else { + return OUT_OF_BOUNDS; + } +} +} // namespace slice + +/** + * @brief A Python-like slice with **unresolved** indices. + */ +struct UserSlice { + bool start_defined; + SliceIndex start; + + bool stop_defined; + SliceIndex stop; + + bool step_defined; + SliceIndex step; + + UserSlice() { this->reset(); } + + void reset() { + this->start_defined = false; + this->stop_defined = false; + this->step_defined = false; + } + + void set_start(SliceIndex start) { + this->start_defined = true; + this->start = start; + } + + void set_stop(SliceIndex stop) { + this->stop_defined = true; + this->stop = stop; + } + + void set_step(SliceIndex step) { + this->step_defined = true; + this->step = step; + } + + /** + * @brief Resolve this slice. + * + * In Python, this would be `slice(start, stop, step).indices(length)`. + * + * @return A `Slice` with the resolved indices. + */ + Slice indices(SliceIndex length) { + Slice result; + + result.step = step_defined ? step : 1; + bool step_is_negative = result.step < 0; + + if (start_defined) { + result.start = + slice::resolve_index_in_length_clamped(length, start); + } else { + result.start = step_is_negative ? length - 1 : 0; + } + + if (stop_defined) { + result.stop = slice::resolve_index_in_length_clamped(length, stop); + } else { + result.stop = step_is_negative ? -1 : length; + } + + return result; + } + + /** + * @brief Like `.indices()` but with assertions. + */ + void indices_checked(ErrorContext* errctx, SliceIndex length, + Slice* result) { + if (length < 0) { + errctx->set_exception(errctx->exceptions->value_error, + "length should not be negative, got {0}", + length); + return; + } + + if (this->step_defined && this->step == 0) { + errctx->set_exception(errctx->exceptions->value_error, + "slice step cannot be zero"); + return; + } + + *result = this->indices(length); + } +}; +} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index f6558051..0471789b 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -6,4 +6,5 @@ #include #include #include +#include #include \ No newline at end of file diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index c8622565..4ba28d8e 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -6,9 +6,11 @@ #include #include #include +#include int main() { test::core::run(); + test::slice::run(); test::ndarray_basic::run(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/test/test_slice.hpp b/nac3core/irrt/test/test_slice.hpp new file mode 100644 index 00000000..8923cee5 --- /dev/null +++ b/nac3core/irrt/test/test_slice.hpp @@ -0,0 +1,92 @@ +#pragma once + +#include +#include + +namespace test { +namespace slice { +void test_slice_normal() { + // Normal situation + BEGIN_TEST(); + + UserSlice user_slice; + user_slice.set_stop(5); + + Slice slice = user_slice.indices(100); + + printf("%d, %d, %d\n", slice.start, slice.stop, slice.step); + + assert_values_match(0, slice.start); + assert_values_match(5, slice.stop); + assert_values_match(1, slice.step); +} + +void test_slice_start_too_large() { + // Start is too large and should be clamped to length + BEGIN_TEST(); + + UserSlice user_slice; + user_slice.set_start(400); + + Slice slice = user_slice.indices(100); + + assert_values_match(100, slice.start); + assert_values_match(100, slice.stop); + assert_values_match(1, slice.step); +} + +void test_slice_negative_start_stop() { + // Negative start/stop should be resolved + BEGIN_TEST(); + + UserSlice user_slice; + user_slice.set_start(-10); + user_slice.set_stop(-5); + + Slice slice = user_slice.indices(100); + + assert_values_match(90, slice.start); + assert_values_match(95, slice.stop); + assert_values_match(1, slice.step); +} + +void test_slice_only_negative_step() { + // Things like `[::-5]` should be handled correctly + BEGIN_TEST(); + + UserSlice user_slice; + user_slice.set_step(-5); + + Slice slice = user_slice.indices(100); + + assert_values_match(99, slice.start); + assert_values_match(-1, slice.stop); + assert_values_match(-5, slice.step); +} + +void test_slice_step_zero() { + // Step = 0 is a value error + BEGIN_TEST(); + + ErrorContext errctx = create_testing_errctx(); + + UserSlice user_slice; + user_slice.set_start(2); + user_slice.set_stop(12); + user_slice.set_step(0); + + Slice slice; + user_slice.indices_checked(&errctx, 100, &slice); + + assert_errctx_has_exception(&errctx, errctx.exceptions->value_error); +} + +void run() { + test_slice_normal(); + test_slice_start_too_large(); + test_slice_negative_start_stop(); + test_slice_only_negative_step(); + test_slice_step_zero(); +} +} // namespace slice +} // namespace test \ No newline at end of file