forked from M-Labs/nac3
core/irrt: add Slice and Range
Needed for implementing general ndarray indexing
This commit is contained in:
parent
5411ac5c88
commit
bda003989e
|
@ -4,4 +4,5 @@
|
||||||
#include <irrt/ndarray/basic.hpp>
|
#include <irrt/ndarray/basic.hpp>
|
||||||
#include <irrt/ndarray/def.hpp>
|
#include <irrt/ndarray/def.hpp>
|
||||||
#include <irrt/ndarray/iter.hpp>
|
#include <irrt/ndarray/iter.hpp>
|
||||||
#include <irrt/original.hpp>
|
#include <irrt/original.hpp>
|
||||||
|
#include <irrt/slice.hpp>
|
|
@ -0,0 +1,200 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/debug.hpp>
|
||||||
|
#include <irrt/exception.hpp>
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
#include <irrt/math_util.hpp>
|
||||||
|
|
||||||
|
// The type of an index or a value describing the length of a
|
||||||
|
// range/slice is always `int32_t`.
|
||||||
|
using SliceIndex = int32_t;
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A Python range.
|
||||||
|
*/
|
||||||
|
struct Range
|
||||||
|
{
|
||||||
|
SliceIndex start;
|
||||||
|
SliceIndex stop;
|
||||||
|
SliceIndex step;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculate the `len()` of this range.
|
||||||
|
*/
|
||||||
|
template <typename SizeT> SliceIndex len()
|
||||||
|
{
|
||||||
|
// Reference: https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
|
||||||
|
debug_assert(SizeT, step != 0);
|
||||||
|
|
||||||
|
if (step > 0 && start < stop)
|
||||||
|
return 1 + (stop - 1 - start) / step;
|
||||||
|
else if (step < 0 && start > stop)
|
||||||
|
return 1 + (start - 1 - stop) / (-step);
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace slice
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @brief Resolve a slice index under a given length like Python indexing.
|
||||||
|
*
|
||||||
|
* In Python, if you have a `list` of length 100, `list[-1]` resolves to
|
||||||
|
* `list[99]`, so `resolve_index_in_length_clamped(100, -1)` returns `99`.
|
||||||
|
*
|
||||||
|
* If `length` is 0, 0 is returned for any value of `index`.
|
||||||
|
*
|
||||||
|
* If `index` is out of bounds, clamps the returned value between `0` and
|
||||||
|
* `length - 1` (inclusive).
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
SliceIndex resolve_index_in_length_clamped(SliceIndex length, SliceIndex index)
|
||||||
|
{
|
||||||
|
if (index < 0)
|
||||||
|
{
|
||||||
|
return max<SliceIndex>(length + index, 0);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return min<SliceIndex>(length, index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const SliceIndex OUT_OF_BOUNDS = -1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS`
|
||||||
|
* if `index` is out of bounds.
|
||||||
|
*/
|
||||||
|
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index)
|
||||||
|
{
|
||||||
|
SliceIndex resolved = index < 0 ? length + index : index;
|
||||||
|
if (0 <= resolved && resolved < length)
|
||||||
|
{
|
||||||
|
return resolved;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return OUT_OF_BOUNDS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace slice
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A Python-like slice with **unresolved** indices.
|
||||||
|
*/
|
||||||
|
struct Slice
|
||||||
|
{
|
||||||
|
bool start_defined;
|
||||||
|
SliceIndex start;
|
||||||
|
|
||||||
|
bool stop_defined;
|
||||||
|
SliceIndex stop;
|
||||||
|
|
||||||
|
bool step_defined;
|
||||||
|
SliceIndex step;
|
||||||
|
|
||||||
|
Slice()
|
||||||
|
{
|
||||||
|
this->reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset()
|
||||||
|
{
|
||||||
|
this->start_defined = false;
|
||||||
|
this->stop_defined = false;
|
||||||
|
this->step_defined = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_start(SliceIndex start)
|
||||||
|
{
|
||||||
|
this->start_defined = true;
|
||||||
|
this->start = start;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_stop(SliceIndex stop)
|
||||||
|
{
|
||||||
|
this->stop_defined = true;
|
||||||
|
this->stop = stop;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_step(SliceIndex step)
|
||||||
|
{
|
||||||
|
this->step_defined = true;
|
||||||
|
this->step = step;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Resolve this slice as a range.
|
||||||
|
*
|
||||||
|
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
|
||||||
|
*/
|
||||||
|
template <typename SizeT> Range indices(SliceIndex length)
|
||||||
|
{
|
||||||
|
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
||||||
|
debug_assert(SizeT, length >= 0);
|
||||||
|
|
||||||
|
Range result;
|
||||||
|
|
||||||
|
result.step = step_defined ? step : 1;
|
||||||
|
bool step_is_negative = result.step < 0;
|
||||||
|
|
||||||
|
SliceIndex lower, upper;
|
||||||
|
if (step_is_negative)
|
||||||
|
{
|
||||||
|
lower = -1;
|
||||||
|
upper = length - 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
lower = 0;
|
||||||
|
upper = length;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (start_defined)
|
||||||
|
{
|
||||||
|
result.start = start < 0 ? max(lower, start + length) : min(upper, start);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
result.start = step_is_negative ? upper : lower;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stop_defined)
|
||||||
|
{
|
||||||
|
result.stop = stop < 0 ? max(lower, stop + length) : min(upper, stop);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
result.stop = step_is_negative ? lower : upper;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Like `.indices()` but with assertions.
|
||||||
|
*/
|
||||||
|
template <typename SizeT> Range indices_checked(SliceIndex length)
|
||||||
|
{
|
||||||
|
// TODO: Switch to `SizeT length`
|
||||||
|
|
||||||
|
if (length < 0)
|
||||||
|
{
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM,
|
||||||
|
NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this->step_defined && this->step == 0)
|
||||||
|
{
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
return this->indices<SizeT>(length);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
Loading…
Reference in New Issue