forked from M-Labs/nac3
1
0
Fork 0

core: make IRRT slice and range work for any int types

This commit is contained in:
lyken 2024-08-23 09:52:08 +08:00
parent ac6c7c5985
commit 3782791323
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
4 changed files with 52 additions and 57 deletions

View File

@ -13,14 +13,14 @@ typedef uint8_t NDIndexType;
/** /**
* @brief A single element index * @brief A single element index
* *
* `data` points to a `SliceIndex`. * `data` points to a `int32_t`.
*/ */
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0; const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
/** /**
* @brief A slice index * @brief A slice index
* *
* `data` points to a `Slice`. * `data` points to a `Slice<int32_t>`.
*/ */
const NDIndexType ND_INDEX_TYPE_SLICE = 1; const NDIndexType ND_INDEX_TYPE_SLICE = 1;
@ -155,15 +155,15 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
SizeT src_axis = 0; SizeT src_axis = 0;
SizeT dst_axis = 0; SizeT dst_axis = 0;
for (SliceIndex i = 0; i < num_indices; i++) for (int32_t i = 0; i < num_indices; i++)
{ {
const NDIndex *index = &indices[i]; const NDIndex *index = &indices[i];
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT)
{ {
SliceIndex input = *((SliceIndex *)index->data); SizeT input = (SizeT) * ((int32_t *)index->data);
SliceIndex k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input); SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
if (k == slice::OUT_OF_BOUNDS) if (k == -1)
{ {
raise_exception(SizeT, EXN_INDEX_ERROR, raise_exception(SizeT, EXN_INDEX_ERROR,
"index {0} is out of bounds for axis {1} " "index {0} is out of bounds for axis {1} "
@ -177,9 +177,9 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
} }
else if (index->type == ND_INDEX_TYPE_SLICE) else if (index->type == ND_INDEX_TYPE_SLICE)
{ {
Slice *slice = (Slice *)index->data; Slice<int32_t> *slice = (Slice<int32_t> *)index->data;
Range range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]); Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis]; dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis];
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis]; dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];

View File

@ -45,7 +45,7 @@ template <typename SizeT> void assert_transpose_axes(SizeT ndims, SizeT num_axes
for (SizeT i = 0; i < ndims; i++) for (SizeT i = 0; i < ndims; i++)
{ {
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
if (axis == slice::OUT_OF_BOUNDS) if (axis == -1)
{ {
// TODO: numpy actually throws a `numpy.exceptions.AxisError` // TODO: numpy actually throws a `numpy.exceptions.AxisError`
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims, raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,

View File

@ -5,26 +5,22 @@
#include <irrt/int_types.hpp> #include <irrt/int_types.hpp>
#include <irrt/math_util.hpp> #include <irrt/math_util.hpp>
// The type of an index or a value describing the length of a
// range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace namespace
{ {
/** /**
* @brief A Python range. * @brief A Python range.
*/ */
struct Range template <typename T> struct Range
{ {
SliceIndex start; T start;
SliceIndex stop; T stop;
SliceIndex step; T step;
/** /**
* @brief Calculate the `len()` of this range. * @brief Calculate the `len()` of this range.
*/ */
template <typename SizeT> SliceIndex len() template <typename SizeT> T len()
{ {
// Reference: https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933 // Reference: https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
debug_assert(SizeT, step != 0); debug_assert(SizeT, step != 0);
@ -52,34 +48,31 @@ namespace slice
* `length - 1` (inclusive). * `length - 1` (inclusive).
* *
*/ */
SliceIndex resolve_index_in_length_clamped(SliceIndex length, SliceIndex index) template <typename T> T resolve_index_in_length_clamped(T length, T index)
{ {
if (index < 0) if (index < 0)
{ {
return max<SliceIndex>(length + index, 0); return max<T>(length + index, 0);
} }
else else
{ {
return min<SliceIndex>(length, index); return min<T>(length, index);
} }
} }
const SliceIndex OUT_OF_BOUNDS = -1;
/** /**
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS` * @brief Like `resolve_index_in_length_clamped`, but returns `-1` if `index` is out of bounds.
* if `index` is out of bounds.
*/ */
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) template <typename T> T resolve_index_in_length(T length, T index)
{ {
SliceIndex resolved = index < 0 ? length + index : index; T resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length) if (0 <= resolved && resolved < length)
{ {
return resolved; return resolved;
} }
else else
{ {
return OUT_OF_BOUNDS; return -1;
} }
} }
} // namespace slice } // namespace slice
@ -87,16 +80,16 @@ SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index)
/** /**
* @brief A Python-like slice with **unresolved** indices. * @brief A Python-like slice with **unresolved** indices.
*/ */
struct Slice template <typename T> struct Slice
{ {
bool start_defined; bool start_defined;
SliceIndex start; T start;
bool stop_defined; bool stop_defined;
SliceIndex stop; T stop;
bool step_defined; bool step_defined;
SliceIndex step; T step;
Slice() Slice()
{ {
@ -110,19 +103,19 @@ struct Slice
this->step_defined = false; this->step_defined = false;
} }
void set_start(SliceIndex start) void set_start(T start)
{ {
this->start_defined = true; this->start_defined = true;
this->start = start; this->start = start;
} }
void set_stop(SliceIndex stop) void set_stop(T stop)
{ {
this->stop_defined = true; this->stop_defined = true;
this->stop = stop; this->stop = stop;
} }
void set_step(SliceIndex step) void set_step(T step)
{ {
this->step_defined = true; this->step_defined = true;
this->step = step; this->step = step;
@ -133,17 +126,17 @@ struct Slice
* *
* In Python, this would be `range(*slice(start, stop, step).indices(length))`. * In Python, this would be `range(*slice(start, stop, step).indices(length))`.
*/ */
template <typename SizeT> Range indices(SliceIndex length) template <typename SizeT> Range<T> indices(T length)
{ {
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388 // Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
debug_assert(SizeT, length >= 0); debug_assert(SizeT, length >= 0);
Range result; Range<T> result;
result.step = step_defined ? step : 1; result.step = step_defined ? step : 1;
bool step_is_negative = result.step < 0; bool step_is_negative = result.step < 0;
SliceIndex lower, upper; T lower, upper;
if (step_is_negative) if (step_is_negative)
{ {
lower = -1; lower = -1;
@ -179,7 +172,7 @@ struct Slice
/** /**
* @brief Like `.indices()` but with assertions. * @brief Like `.indices()` but with assertions.
*/ */
template <typename SizeT> Range indices_checked(SliceIndex length) template <typename SizeT> Range<T> indices_checked(T length)
{ {
// TODO: Switch to `SizeT length` // TODO: Switch to `SizeT length`

View File

@ -25,49 +25,51 @@ impl<'ctx> StructKind<'ctx> for NDIndex {
/// Fields of [`Slice`] /// Fields of [`Slice`]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>> { pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> {
pub start_defined: F::Out<Int<Bool>>, pub start_defined: F::Out<Int<Bool>>,
pub start: F::Out<Int<Int32>>, pub start: F::Out<Int<N>>,
pub stop_defined: F::Out<Int<Bool>>, pub stop_defined: F::Out<Int<Bool>>,
pub stop: F::Out<Int<Int32>>, pub stop: F::Out<Int<N>>,
pub step_defined: F::Out<Int<Bool>>, pub step_defined: F::Out<Int<Bool>>,
pub step: F::Out<Int<Int32>>, pub step: F::Out<Int<N>>,
} }
/// An IRRT representation of an (unresolved) slice. /// An IRRT representation of an (unresolved) slice.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct Slice; pub struct Slice<N> {
int_kind: N,
}
impl<'ctx> StructKind<'ctx> for Slice { impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice<N> {
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F>; type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F, N>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> { fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { Self::Fields {
start_defined: traversal.add_auto("start_defined"), start_defined: traversal.add_auto("start_defined"),
start: traversal.add_auto("start"), start: traversal.add("start", Int(self.int_kind)),
stop_defined: traversal.add_auto("stop_defined"), stop_defined: traversal.add_auto("stop_defined"),
stop: traversal.add_auto("stop"), stop: traversal.add("stop", Int(self.int_kind)),
step_defined: traversal.add_auto("step_defined"), step_defined: traversal.add_auto("step_defined"),
step: traversal.add_auto("step"), step: traversal.add("step", Int(self.int_kind)),
} }
} }
} }
/// A convenience structure to prepare a [`Slice`]. /// A convenience structure to prepare a [`Slice`].
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RustSlice<'ctx> { pub struct RustSlice<'ctx, N: IntKind<'ctx>> {
pub start: Option<Instance<'ctx, Int<Int32>>>, pub start: Option<Instance<'ctx, Int<N>>>,
pub stop: Option<Instance<'ctx, Int<Int32>>>, pub stop: Option<Instance<'ctx, Int<N>>>,
pub step: Option<Instance<'ctx, Int<Int32>>>, pub step: Option<Instance<'ctx, Int<N>>>,
} }
impl<'ctx> RustSlice<'ctx> { impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> {
/// Write the contents to an LLVM [`Slice`]. /// Write the contents to an LLVM [`Slice`].
pub fn write_to_slice<G: CodeGenerator + ?Sized>( pub fn write_to_slice<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice>>>, dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice<N>>>>,
) { ) {
let false_ = Int(Bool).const_false(generator, ctx.ctx); let false_ = Int(Bool).const_false(generator, ctx.ctx);
let true_ = Int(Bool).const_true(generator, ctx.ctx); let true_ = Int(Bool).const_true(generator, ctx.ctx);
@ -102,7 +104,7 @@ impl<'ctx> RustSlice<'ctx> {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum RustNDIndex<'ctx> { pub enum RustNDIndex<'ctx> {
SingleElement(Instance<'ctx, Int<Int32>>), // TODO: To be SizeT SingleElement(Instance<'ctx, Int<Int32>>), // TODO: To be SizeT
Slice(RustSlice<'ctx>), Slice(RustSlice<'ctx, Int32>),
NewAxis, NewAxis,
Ellipsis, Ellipsis,
} }
@ -143,7 +145,7 @@ impl<'ctx> RustNDIndex<'ctx> {
.store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte))); .store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte)));
} }
RustNDIndex::Slice(in_rust_slice) => { RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = Struct(Slice).alloca(generator, ctx); let user_slice_ptr = Struct(Slice { int_kind: Int32 }).alloca(generator, ctx);
in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr); in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr);
dst_ndindex_ptr dst_ndindex_ptr