forked from M-Labs/nac3
core: make IRRT slice and range work for any int types
This commit is contained in:
parent
ac6c7c5985
commit
3782791323
|
@ -13,14 +13,14 @@ typedef uint8_t NDIndexType;
|
|||
/**
|
||||
* @brief A single element index
|
||||
*
|
||||
* `data` points to a `SliceIndex`.
|
||||
* `data` points to a `int32_t`.
|
||||
*/
|
||||
|
||||
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
|
||||
/**
|
||||
* @brief A slice index
|
||||
*
|
||||
* `data` points to a `Slice`.
|
||||
* `data` points to a `Slice<int32_t>`.
|
||||
*/
|
||||
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
|
||||
|
||||
|
@ -155,15 +155,15 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
|
|||
SizeT src_axis = 0;
|
||||
SizeT dst_axis = 0;
|
||||
|
||||
for (SliceIndex i = 0; i < num_indices; i++)
|
||||
for (int32_t i = 0; i < num_indices; i++)
|
||||
{
|
||||
const NDIndex *index = &indices[i];
|
||||
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT)
|
||||
{
|
||||
SliceIndex input = *((SliceIndex *)index->data);
|
||||
SliceIndex k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
|
||||
SizeT input = (SizeT) * ((int32_t *)index->data);
|
||||
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
|
||||
|
||||
if (k == slice::OUT_OF_BOUNDS)
|
||||
if (k == -1)
|
||||
{
|
||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
||||
"index {0} is out of bounds for axis {1} "
|
||||
|
@ -177,9 +177,9 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
|
|||
}
|
||||
else if (index->type == ND_INDEX_TYPE_SLICE)
|
||||
{
|
||||
Slice *slice = (Slice *)index->data;
|
||||
Slice<int32_t> *slice = (Slice<int32_t> *)index->data;
|
||||
|
||||
Range range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
|
||||
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
|
||||
|
||||
dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis];
|
||||
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];
|
||||
|
|
|
@ -45,7 +45,7 @@ template <typename SizeT> void assert_transpose_axes(SizeT ndims, SizeT num_axes
|
|||
for (SizeT i = 0; i < ndims; i++)
|
||||
{
|
||||
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
|
||||
if (axis == slice::OUT_OF_BOUNDS)
|
||||
if (axis == -1)
|
||||
{
|
||||
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
|
||||
|
|
|
@ -5,26 +5,22 @@
|
|||
#include <irrt/int_types.hpp>
|
||||
#include <irrt/math_util.hpp>
|
||||
|
||||
// The type of an index or a value describing the length of a
|
||||
// range/slice is always `int32_t`.
|
||||
using SliceIndex = int32_t;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
/**
|
||||
* @brief A Python range.
|
||||
*/
|
||||
struct Range
|
||||
template <typename T> struct Range
|
||||
{
|
||||
SliceIndex start;
|
||||
SliceIndex stop;
|
||||
SliceIndex step;
|
||||
T start;
|
||||
T stop;
|
||||
T step;
|
||||
|
||||
/**
|
||||
* @brief Calculate the `len()` of this range.
|
||||
*/
|
||||
template <typename SizeT> SliceIndex len()
|
||||
template <typename SizeT> T len()
|
||||
{
|
||||
// Reference: https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
|
||||
debug_assert(SizeT, step != 0);
|
||||
|
@ -52,34 +48,31 @@ namespace slice
|
|||
* `length - 1` (inclusive).
|
||||
*
|
||||
*/
|
||||
SliceIndex resolve_index_in_length_clamped(SliceIndex length, SliceIndex index)
|
||||
template <typename T> T resolve_index_in_length_clamped(T length, T index)
|
||||
{
|
||||
if (index < 0)
|
||||
{
|
||||
return max<SliceIndex>(length + index, 0);
|
||||
return max<T>(length + index, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return min<SliceIndex>(length, index);
|
||||
return min<T>(length, index);
|
||||
}
|
||||
}
|
||||
|
||||
const SliceIndex OUT_OF_BOUNDS = -1;
|
||||
|
||||
/**
|
||||
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS`
|
||||
* if `index` is out of bounds.
|
||||
* @brief Like `resolve_index_in_length_clamped`, but returns `-1` if `index` is out of bounds.
|
||||
*/
|
||||
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index)
|
||||
template <typename T> T resolve_index_in_length(T length, T index)
|
||||
{
|
||||
SliceIndex resolved = index < 0 ? length + index : index;
|
||||
T resolved = index < 0 ? length + index : index;
|
||||
if (0 <= resolved && resolved < length)
|
||||
{
|
||||
return resolved;
|
||||
}
|
||||
else
|
||||
{
|
||||
return OUT_OF_BOUNDS;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
} // namespace slice
|
||||
|
@ -87,16 +80,16 @@ SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index)
|
|||
/**
|
||||
* @brief A Python-like slice with **unresolved** indices.
|
||||
*/
|
||||
struct Slice
|
||||
template <typename T> struct Slice
|
||||
{
|
||||
bool start_defined;
|
||||
SliceIndex start;
|
||||
T start;
|
||||
|
||||
bool stop_defined;
|
||||
SliceIndex stop;
|
||||
T stop;
|
||||
|
||||
bool step_defined;
|
||||
SliceIndex step;
|
||||
T step;
|
||||
|
||||
Slice()
|
||||
{
|
||||
|
@ -110,19 +103,19 @@ struct Slice
|
|||
this->step_defined = false;
|
||||
}
|
||||
|
||||
void set_start(SliceIndex start)
|
||||
void set_start(T start)
|
||||
{
|
||||
this->start_defined = true;
|
||||
this->start = start;
|
||||
}
|
||||
|
||||
void set_stop(SliceIndex stop)
|
||||
void set_stop(T stop)
|
||||
{
|
||||
this->stop_defined = true;
|
||||
this->stop = stop;
|
||||
}
|
||||
|
||||
void set_step(SliceIndex step)
|
||||
void set_step(T step)
|
||||
{
|
||||
this->step_defined = true;
|
||||
this->step = step;
|
||||
|
@ -133,17 +126,17 @@ struct Slice
|
|||
*
|
||||
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
|
||||
*/
|
||||
template <typename SizeT> Range indices(SliceIndex length)
|
||||
template <typename SizeT> Range<T> indices(T length)
|
||||
{
|
||||
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
||||
debug_assert(SizeT, length >= 0);
|
||||
|
||||
Range result;
|
||||
Range<T> result;
|
||||
|
||||
result.step = step_defined ? step : 1;
|
||||
bool step_is_negative = result.step < 0;
|
||||
|
||||
SliceIndex lower, upper;
|
||||
T lower, upper;
|
||||
if (step_is_negative)
|
||||
{
|
||||
lower = -1;
|
||||
|
@ -179,7 +172,7 @@ struct Slice
|
|||
/**
|
||||
* @brief Like `.indices()` but with assertions.
|
||||
*/
|
||||
template <typename SizeT> Range indices_checked(SliceIndex length)
|
||||
template <typename SizeT> Range<T> indices_checked(T length)
|
||||
{
|
||||
// TODO: Switch to `SizeT length`
|
||||
|
||||
|
|
|
@ -25,49 +25,51 @@ impl<'ctx> StructKind<'ctx> for NDIndex {
|
|||
|
||||
/// Fields of [`Slice`]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>> {
|
||||
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> {
|
||||
pub start_defined: F::Out<Int<Bool>>,
|
||||
pub start: F::Out<Int<Int32>>,
|
||||
pub start: F::Out<Int<N>>,
|
||||
pub stop_defined: F::Out<Int<Bool>>,
|
||||
pub stop: F::Out<Int<Int32>>,
|
||||
pub stop: F::Out<Int<N>>,
|
||||
pub step_defined: F::Out<Int<Bool>>,
|
||||
pub step: F::Out<Int<Int32>>,
|
||||
pub step: F::Out<Int<N>>,
|
||||
}
|
||||
|
||||
/// An IRRT representation of an (unresolved) slice.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub struct Slice;
|
||||
pub struct Slice<N> {
|
||||
int_kind: N,
|
||||
}
|
||||
|
||||
impl<'ctx> StructKind<'ctx> for Slice {
|
||||
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F>;
|
||||
impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice<N> {
|
||||
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F, N>;
|
||||
|
||||
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
||||
Self::Fields {
|
||||
start_defined: traversal.add_auto("start_defined"),
|
||||
start: traversal.add_auto("start"),
|
||||
start: traversal.add("start", Int(self.int_kind)),
|
||||
stop_defined: traversal.add_auto("stop_defined"),
|
||||
stop: traversal.add_auto("stop"),
|
||||
stop: traversal.add("stop", Int(self.int_kind)),
|
||||
step_defined: traversal.add_auto("step_defined"),
|
||||
step: traversal.add_auto("step"),
|
||||
step: traversal.add("step", Int(self.int_kind)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A convenience structure to prepare a [`Slice`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RustSlice<'ctx> {
|
||||
pub start: Option<Instance<'ctx, Int<Int32>>>,
|
||||
pub stop: Option<Instance<'ctx, Int<Int32>>>,
|
||||
pub step: Option<Instance<'ctx, Int<Int32>>>,
|
||||
pub struct RustSlice<'ctx, N: IntKind<'ctx>> {
|
||||
pub start: Option<Instance<'ctx, Int<N>>>,
|
||||
pub stop: Option<Instance<'ctx, Int<N>>>,
|
||||
pub step: Option<Instance<'ctx, Int<N>>>,
|
||||
}
|
||||
|
||||
impl<'ctx> RustSlice<'ctx> {
|
||||
impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> {
|
||||
/// Write the contents to an LLVM [`Slice`].
|
||||
pub fn write_to_slice<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice>>>,
|
||||
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice<N>>>>,
|
||||
) {
|
||||
let false_ = Int(Bool).const_false(generator, ctx.ctx);
|
||||
let true_ = Int(Bool).const_true(generator, ctx.ctx);
|
||||
|
@ -102,7 +104,7 @@ impl<'ctx> RustSlice<'ctx> {
|
|||
#[derive(Debug, Clone)]
|
||||
pub enum RustNDIndex<'ctx> {
|
||||
SingleElement(Instance<'ctx, Int<Int32>>), // TODO: To be SizeT
|
||||
Slice(RustSlice<'ctx>),
|
||||
Slice(RustSlice<'ctx, Int32>),
|
||||
NewAxis,
|
||||
Ellipsis,
|
||||
}
|
||||
|
@ -143,7 +145,7 @@ impl<'ctx> RustNDIndex<'ctx> {
|
|||
.store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte)));
|
||||
}
|
||||
RustNDIndex::Slice(in_rust_slice) => {
|
||||
let user_slice_ptr = Struct(Slice).alloca(generator, ctx);
|
||||
let user_slice_ptr = Struct(Slice { int_kind: Int32 }).alloca(generator, ctx);
|
||||
in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr);
|
||||
|
||||
dst_ndindex_ptr
|
||||
|
|
Loading…
Reference in New Issue