core/irrt: add Slice and ResolvedSlice

Needed for implementing general ndarray indexing
This commit is contained in:
lyken 2024-08-20 12:33:31 +08:00
parent 9f94e613f1
commit 49b5c92f15
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
2 changed files with 197 additions and 1 deletions

View File

@ -5,3 +5,4 @@
#include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/iter.hpp>
#include <irrt/original.hpp>
#include <irrt/slice.hpp>

View File

@ -0,0 +1,195 @@
#pragma once
#include <irrt/exception.hpp>
#include <irrt/int_types.hpp>
#include <irrt/math_util.hpp>
// The type of an index or a value describing the length of a
// range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace
{
/**
* @brief A Python-like slice with resolved indices.
*
* "Resolved indices" means that `start` and `stop` must be positive and are
* bound to a known length.
*/
struct ResolvedSlice
{
SliceIndex start;
SliceIndex stop;
SliceIndex step;
/**
* @brief Calculate and return the length / the number of the slice.
*
* If this were a Python range, this function would be `len(range(start, stop, step))`.
*/
SliceIndex len()
{
SliceIndex diff = stop - start;
if (diff > 0 && step > 0)
{
return ((diff - 1) / step) + 1;
}
else if (diff < 0 && step < 0)
{
return ((diff + 1) / step) + 1;
}
else
{
return 0;
}
}
};
namespace slice
{
/**
* @brief Resolve a slice index under a given length like Python indexing.
*
* In Python, if you have a `list` of length 100, `list[-1]` resolves to
* `list[99]`, so `resolve_index_in_length_clamped(100, -1)` returns `99`.
*
* If `length` is 0, 0 is returned for any value of `index`.
*
* If `index` is out of bounds, clamps the returned value between `0` and
* `length - 1` (inclusive).
*
*/
SliceIndex resolve_index_in_length_clamped(SliceIndex length, SliceIndex index)
{
if (index < 0)
{
return max<SliceIndex>(length + index, 0);
}
else
{
return min<SliceIndex>(length, index);
}
}
const SliceIndex OUT_OF_BOUNDS = -1;
/**
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS`
* if `index` is out of bounds.
*/
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index)
{
SliceIndex resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length)
{
return resolved;
}
else
{
return OUT_OF_BOUNDS;
}
}
} // namespace slice
/**
* @brief A Python-like slice with **unresolved** indices.
*/
struct Slice
{
bool start_defined;
SliceIndex start;
bool stop_defined;
SliceIndex stop;
bool step_defined;
SliceIndex step;
Slice()
{
this->reset();
}
void reset()
{
this->start_defined = false;
this->stop_defined = false;
this->step_defined = false;
}
void set_start(SliceIndex start)
{
this->start_defined = true;
this->start = start;
}
void set_stop(SliceIndex stop)
{
this->stop_defined = true;
this->stop = stop;
}
void set_step(SliceIndex step)
{
this->step_defined = true;
this->step = step;
}
/**
* @brief Resolve this slice.
*
* In Python, this would be `slice(start, stop, step).indices(length)`.
*
* @return A `ResolvedSlice` with the resolved indices.
*/
ResolvedSlice indices(SliceIndex length)
{
ResolvedSlice result;
result.step = step_defined ? step : 1;
bool step_is_negative = result.step < 0;
if (start_defined)
{
result.start = slice::resolve_index_in_length_clamped(length, start);
}
else
{
result.start = step_is_negative ? length - 1 : 0;
}
if (stop_defined)
{
result.stop = slice::resolve_index_in_length_clamped(length, stop);
}
else
{
result.stop = step_is_negative ? -1 : length;
}
return result;
}
/**
* @brief Like `.indices()` but with assertions.
*/
template <typename SizeT> ResolvedSlice indices_checked(SliceIndex length)
{
// TODO: Switch to `SizeT length`
if (length < 0)
{
raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM,
NO_PARAM);
}
if (this->step_defined && this->step == 0)
{
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM);
}
return this->indices(length);
}
};
} // namespace