core/ndstrides: implement IRRT slice

Needed by ndarray indexing
This commit is contained in:
lyken 2024-07-26 16:31:28 +08:00
parent 2211c4d852
commit 4b14609342
4 changed files with 255 additions and 0 deletions

View File

@ -0,0 +1,160 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/slice.hpp>
namespace {
/**
* @brief A Python-like slice with resolved indices.
*
* "Resolved indices" means that `start` and `stop` must be positive and are
* bound to a known length.
*/
struct Slice {
SliceIndex start;
SliceIndex stop;
SliceIndex step;
/**
* @brief Calculate and return the length / the number of the slice.
*
* If this were a Python range, this function would be `len(range(start, stop, step))`.
*/
SliceIndex len() {
SliceIndex diff = stop - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
};
namespace slice {
/**
* @brief Resolve a slice index under a given length like Python indexing.
*
* In Python, if you have a `list` of length 100, `list[-1]` resolves to
* `list[99]`, so `resolve_index_in_length_clamped(100, -1)` returns `99`.
*
* If `length` is 0, 0 is returned for any value of `index`.
*
* If `index` is out of bounds, clamps the returned value between `0` and
* `length - 1` (inclusive).
*
*/
SliceIndex resolve_index_in_length_clamped(SliceIndex length,
SliceIndex index) {
if (index < 0) {
return max<SliceIndex>(length + index, 0);
} else {
return min<SliceIndex>(length, index);
}
}
const SliceIndex OUT_OF_BOUNDS = -1;
/**
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS`
* if `index` is out of bounds.
*/
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) {
SliceIndex resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length) {
return resolved;
} else {
return OUT_OF_BOUNDS;
}
}
} // namespace slice
/**
* @brief A Python-like slice with **unresolved** indices.
*/
struct UserSlice {
bool start_defined;
SliceIndex start;
bool stop_defined;
SliceIndex stop;
bool step_defined;
SliceIndex step;
UserSlice() { this->reset(); }
void reset() {
this->start_defined = false;
this->stop_defined = false;
this->step_defined = false;
}
void set_start(SliceIndex start) {
this->start_defined = true;
this->start = start;
}
void set_stop(SliceIndex stop) {
this->stop_defined = true;
this->stop = stop;
}
void set_step(SliceIndex step) {
this->step_defined = true;
this->step = step;
}
/**
* @brief Resolve this slice.
*
* In Python, this would be `slice(start, stop, step).indices(length)`.
*
* @return A `Slice` with the resolved indices.
*/
Slice indices(SliceIndex length) {
Slice result;
result.step = step_defined ? step : 1;
bool step_is_negative = result.step < 0;
if (start_defined) {
result.start =
slice::resolve_index_in_length_clamped(length, start);
} else {
result.start = step_is_negative ? length - 1 : 0;
}
if (stop_defined) {
result.stop = slice::resolve_index_in_length_clamped(length, stop);
} else {
result.stop = step_is_negative ? -1 : length;
}
return result;
}
/**
* @brief Like `.indices()` but with assertions.
*/
void indices_checked(ErrorContext* errctx, SliceIndex length,
Slice* result) {
if (length < 0) {
errctx->set_exception(errctx->exceptions->value_error,
"length should not be negative, got {0}",
length);
return;
}
if (this->step_defined && this->step == 0) {
errctx->set_exception(errctx->exceptions->value_error,
"slice step cannot be zero");
return;
}
*result = this->indices(length);
}
};
} // namespace

View File

@ -6,4 +6,5 @@
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
#include <irrt/utils.hpp>

View File

@ -6,9 +6,11 @@
#include <cstdlib>
#include <test/test_core.hpp>
#include <test/test_ndarray_basic.hpp>
#include <test/test_slice.hpp>
int main() {
test::core::run();
test::slice::run();
test::ndarray_basic::run();
return 0;
}

View File

@ -0,0 +1,92 @@
#pragma once
#include <irrt_everything.hpp>
#include <test/includes.hpp>
namespace test {
namespace slice {
void test_slice_normal() {
// Normal situation
BEGIN_TEST();
UserSlice user_slice;
user_slice.set_stop(5);
Slice slice = user_slice.indices(100);
printf("%d, %d, %d\n", slice.start, slice.stop, slice.step);
assert_values_match(0, slice.start);
assert_values_match(5, slice.stop);
assert_values_match(1, slice.step);
}
void test_slice_start_too_large() {
// Start is too large and should be clamped to length
BEGIN_TEST();
UserSlice user_slice;
user_slice.set_start(400);
Slice slice = user_slice.indices(100);
assert_values_match(100, slice.start);
assert_values_match(100, slice.stop);
assert_values_match(1, slice.step);
}
void test_slice_negative_start_stop() {
// Negative start/stop should be resolved
BEGIN_TEST();
UserSlice user_slice;
user_slice.set_start(-10);
user_slice.set_stop(-5);
Slice slice = user_slice.indices(100);
assert_values_match(90, slice.start);
assert_values_match(95, slice.stop);
assert_values_match(1, slice.step);
}
void test_slice_only_negative_step() {
// Things like `[::-5]` should be handled correctly
BEGIN_TEST();
UserSlice user_slice;
user_slice.set_step(-5);
Slice slice = user_slice.indices(100);
assert_values_match(99, slice.start);
assert_values_match(-1, slice.stop);
assert_values_match(-5, slice.step);
}
void test_slice_step_zero() {
// Step = 0 is a value error
BEGIN_TEST();
ErrorContext errctx = create_testing_errctx();
UserSlice user_slice;
user_slice.set_start(2);
user_slice.set_stop(12);
user_slice.set_step(0);
Slice slice;
user_slice.indices_checked(&errctx, 100, &slice);
assert_errctx_has_exception(&errctx, errctx.exceptions->value_error);
}
void run() {
test_slice_normal();
test_slice_start_too_large();
test_slice_negative_start_stop();
test_slice_only_negative_step();
test_slice_step_zero();
}
} // namespace slice
} // namespace test