1
0
forked from M-Labs/nac3

core: make IRRT slice and range work for any int types

This commit is contained in:
lyken 2024-08-23 09:52:08 +08:00
parent ac6c7c5985
commit 3782791323
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
4 changed files with 52 additions and 57 deletions

View File

@ -13,14 +13,14 @@ typedef uint8_t NDIndexType;
/**
* @brief A single element index
*
* `data` points to a `SliceIndex`.
* `data` points to a `int32_t`.
*/
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
/**
* @brief A slice index
*
* `data` points to a `Slice`.
* `data` points to a `Slice<int32_t>`.
*/
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
@ -155,15 +155,15 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
SizeT src_axis = 0;
SizeT dst_axis = 0;
for (SliceIndex i = 0; i < num_indices; i++)
for (int32_t i = 0; i < num_indices; i++)
{
const NDIndex *index = &indices[i];
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT)
{
SliceIndex input = *((SliceIndex *)index->data);
SliceIndex k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
SizeT input = (SizeT) * ((int32_t *)index->data);
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
if (k == slice::OUT_OF_BOUNDS)
if (k == -1)
{
raise_exception(SizeT, EXN_INDEX_ERROR,
"index {0} is out of bounds for axis {1} "
@ -177,9 +177,9 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
}
else if (index->type == ND_INDEX_TYPE_SLICE)
{
Slice *slice = (Slice *)index->data;
Slice<int32_t> *slice = (Slice<int32_t> *)index->data;
Range range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis];
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];

View File

@ -45,7 +45,7 @@ template <typename SizeT> void assert_transpose_axes(SizeT ndims, SizeT num_axes
for (SizeT i = 0; i < ndims; i++)
{
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
if (axis == slice::OUT_OF_BOUNDS)
if (axis == -1)
{
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,

View File

@ -5,26 +5,22 @@
#include <irrt/int_types.hpp>
#include <irrt/math_util.hpp>
// The type of an index or a value describing the length of a
// range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace
{
/**
* @brief A Python range.
*/
struct Range
template <typename T> struct Range
{
SliceIndex start;
SliceIndex stop;
SliceIndex step;
T start;
T stop;
T step;
/**
* @brief Calculate the `len()` of this range.
*/
template <typename SizeT> SliceIndex len()
template <typename SizeT> T len()
{
// Reference: https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
debug_assert(SizeT, step != 0);
@ -52,34 +48,31 @@ namespace slice
* `length - 1` (inclusive).
*
*/
SliceIndex resolve_index_in_length_clamped(SliceIndex length, SliceIndex index)
template <typename T> T resolve_index_in_length_clamped(T length, T index)
{
if (index < 0)
{
return max<SliceIndex>(length + index, 0);
return max<T>(length + index, 0);
}
else
{
return min<SliceIndex>(length, index);
return min<T>(length, index);
}
}
const SliceIndex OUT_OF_BOUNDS = -1;
/**
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS`
* if `index` is out of bounds.
* @brief Like `resolve_index_in_length_clamped`, but returns `-1` if `index` is out of bounds.
*/
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index)
template <typename T> T resolve_index_in_length(T length, T index)
{
SliceIndex resolved = index < 0 ? length + index : index;
T resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length)
{
return resolved;
}
else
{
return OUT_OF_BOUNDS;
return -1;
}
}
} // namespace slice
@ -87,16 +80,16 @@ SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index)
/**
* @brief A Python-like slice with **unresolved** indices.
*/
struct Slice
template <typename T> struct Slice
{
bool start_defined;
SliceIndex start;
T start;
bool stop_defined;
SliceIndex stop;
T stop;
bool step_defined;
SliceIndex step;
T step;
Slice()
{
@ -110,19 +103,19 @@ struct Slice
this->step_defined = false;
}
void set_start(SliceIndex start)
void set_start(T start)
{
this->start_defined = true;
this->start = start;
}
void set_stop(SliceIndex stop)
void set_stop(T stop)
{
this->stop_defined = true;
this->stop = stop;
}
void set_step(SliceIndex step)
void set_step(T step)
{
this->step_defined = true;
this->step = step;
@ -133,17 +126,17 @@ struct Slice
*
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
*/
template <typename SizeT> Range indices(SliceIndex length)
template <typename SizeT> Range<T> indices(T length)
{
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
debug_assert(SizeT, length >= 0);
Range result;
Range<T> result;
result.step = step_defined ? step : 1;
bool step_is_negative = result.step < 0;
SliceIndex lower, upper;
T lower, upper;
if (step_is_negative)
{
lower = -1;
@ -179,7 +172,7 @@ struct Slice
/**
* @brief Like `.indices()` but with assertions.
*/
template <typename SizeT> Range indices_checked(SliceIndex length)
template <typename SizeT> Range<T> indices_checked(T length)
{
// TODO: Switch to `SizeT length`

View File

@ -25,49 +25,51 @@ impl<'ctx> StructKind<'ctx> for NDIndex {
/// Fields of [`Slice`]
#[derive(Debug, Clone)]
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>> {
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> {
pub start_defined: F::Out<Int<Bool>>,
pub start: F::Out<Int<Int32>>,
pub start: F::Out<Int<N>>,
pub stop_defined: F::Out<Int<Bool>>,
pub stop: F::Out<Int<Int32>>,
pub stop: F::Out<Int<N>>,
pub step_defined: F::Out<Int<Bool>>,
pub step: F::Out<Int<Int32>>,
pub step: F::Out<Int<N>>,
}
/// An IRRT representation of an (unresolved) slice.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct Slice;
pub struct Slice<N> {
int_kind: N,
}
impl<'ctx> StructKind<'ctx> for Slice {
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F>;
impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice<N> {
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F, N>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
start_defined: traversal.add_auto("start_defined"),
start: traversal.add_auto("start"),
start: traversal.add("start", Int(self.int_kind)),
stop_defined: traversal.add_auto("stop_defined"),
stop: traversal.add_auto("stop"),
stop: traversal.add("stop", Int(self.int_kind)),
step_defined: traversal.add_auto("step_defined"),
step: traversal.add_auto("step"),
step: traversal.add("step", Int(self.int_kind)),
}
}
}
/// A convenience structure to prepare a [`Slice`].
#[derive(Debug, Clone)]
pub struct RustSlice<'ctx> {
pub start: Option<Instance<'ctx, Int<Int32>>>,
pub stop: Option<Instance<'ctx, Int<Int32>>>,
pub step: Option<Instance<'ctx, Int<Int32>>>,
pub struct RustSlice<'ctx, N: IntKind<'ctx>> {
pub start: Option<Instance<'ctx, Int<N>>>,
pub stop: Option<Instance<'ctx, Int<N>>>,
pub step: Option<Instance<'ctx, Int<N>>>,
}
impl<'ctx> RustSlice<'ctx> {
impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> {
/// Write the contents to an LLVM [`Slice`].
pub fn write_to_slice<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice>>>,
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice<N>>>>,
) {
let false_ = Int(Bool).const_false(generator, ctx.ctx);
let true_ = Int(Bool).const_true(generator, ctx.ctx);
@ -102,7 +104,7 @@ impl<'ctx> RustSlice<'ctx> {
#[derive(Debug, Clone)]
pub enum RustNDIndex<'ctx> {
SingleElement(Instance<'ctx, Int<Int32>>), // TODO: To be SizeT
Slice(RustSlice<'ctx>),
Slice(RustSlice<'ctx, Int32>),
NewAxis,
Ellipsis,
}
@ -143,7 +145,7 @@ impl<'ctx> RustNDIndex<'ctx> {
.store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte)));
}
RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = Struct(Slice).alloca(generator, ctx);
let user_slice_ptr = Struct(Slice { int_kind: Int32 }).alloca(generator, ctx);
in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr);
dst_ndindex_ptr