1
0
forked from M-Labs/nac3

core/irrt: add Slice and Range

Needed for implementing general ndarray indexing.

Currently the IRRT slice and range have nothing to do with NAC3's slice
and range.
This commit is contained in:
lyken 2024-08-24 15:37:45 +08:00
parent ad5afb52c4
commit 5537645395
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
3 changed files with 202 additions and 1 deletions

View File

@ -4,4 +4,6 @@
#include <irrt/ndarray/basic.hpp> #include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp> #include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/iter.hpp> #include <irrt/ndarray/iter.hpp>
#include <irrt/original.hpp> #include <irrt/original.hpp>
#include <irrt/range.hpp>
#include <irrt/slice.hpp>

View File

@ -0,0 +1,41 @@
#pragma once
#include <irrt/debug.hpp>
#include <irrt/int_types.hpp>
namespace
{
namespace range
{
template <typename T> T len(T start, T stop, T step)
{
// Reference:
// https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
if (step > 0 && start < stop)
return 1 + (stop - 1 - start) / step;
else if (step < 0 && start > stop)
return 1 + (start - 1 - stop) / (-step);
else
return 0;
}
} // namespace range
/**
* @brief A Python range.
*/
template <typename T> struct Range
{
T start;
T stop;
T step;
/**
* @brief Calculate the `len()` of this range.
*/
template <typename SizeT> T len()
{
debug_assert(SizeT, step != 0);
return range::len(start, stop, step);
}
};
} // namespace

View File

@ -0,0 +1,158 @@
#pragma once
#include <irrt/debug.hpp>
#include <irrt/exception.hpp>
#include <irrt/int_types.hpp>
#include <irrt/math_util.hpp>
#include <irrt/range.hpp>
namespace
{
namespace slice
{
/**
* @brief Resolve a possibly negative index in a list of a known length.
*
* Returns -1 if the resolved index is out of the list's bounds.
*/
template <typename T> T resolve_index_in_length(T length, T index)
{
T resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length)
{
return resolved;
}
else
{
return -1;
}
}
/**
* @brief Resolve a slice as a range.
*
* This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python.
*/
template <typename T>
void indices(bool start_defined, T start, bool stop_defined, T stop, bool step_defined, T step, T length,
T *range_start, T *range_stop, T *range_step)
{
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
*range_step = step_defined ? step : 1;
bool step_is_negative = *range_step < 0;
T lower, upper;
if (step_is_negative)
{
lower = -1;
upper = length - 1;
}
else
{
lower = 0;
upper = length;
}
if (start_defined)
{
*range_start = start < 0 ? max(lower, start + length) : min(upper, start);
}
else
{
*range_start = step_is_negative ? upper : lower;
}
if (stop_defined)
{
*range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop);
}
else
{
*range_stop = step_is_negative ? lower : upper;
}
}
} // namespace slice
/**
* @brief A Python-like slice with **unresolved** indices.
*/
template <typename T> struct Slice
{
bool start_defined;
T start;
bool stop_defined;
T stop;
bool step_defined;
T step;
Slice()
{
this->reset();
}
void reset()
{
this->start_defined = false;
this->stop_defined = false;
this->step_defined = false;
}
void set_start(T start)
{
this->start_defined = true;
this->start = start;
}
void set_stop(T stop)
{
this->stop_defined = true;
this->stop = stop;
}
void set_step(T step)
{
this->step_defined = true;
this->step = step;
}
/**
* @brief Resolve this slice as a range.
*
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
*/
template <typename SizeT> Range<T> indices(T length)
{
// Reference:
// https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
debug_assert(SizeT, length >= 0);
Range<T> result;
slice::indices(start_defined, start, stop_defined, stop, step_defined, step, length, &result.start,
&result.stop, &result.step);
return result;
}
/**
* @brief Like `.indices()` but with assertions.
*/
template <typename SizeT> Range<T> indices_checked(T length)
{
// TODO: Switch to `SizeT length`
if (length < 0)
{
raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM,
NO_PARAM);
}
if (this->step_defined && this->step == 0)
{
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM);
}
return this->indices<SizeT>(length);
}
};
} // namespace