forked from M-Labs/nac3
core: make IRRT slice and range work for any int types
This commit is contained in:
parent
ac6c7c5985
commit
3782791323
|
@ -13,14 +13,14 @@ typedef uint8_t NDIndexType;
|
||||||
/**
|
/**
|
||||||
* @brief A single element index
|
* @brief A single element index
|
||||||
*
|
*
|
||||||
* `data` points to a `SliceIndex`.
|
* `data` points to a `int32_t`.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
|
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
|
||||||
/**
|
/**
|
||||||
* @brief A slice index
|
* @brief A slice index
|
||||||
*
|
*
|
||||||
* `data` points to a `Slice`.
|
* `data` points to a `Slice<int32_t>`.
|
||||||
*/
|
*/
|
||||||
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
|
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
|
||||||
|
|
||||||
|
@ -155,15 +155,15 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
|
||||||
SizeT src_axis = 0;
|
SizeT src_axis = 0;
|
||||||
SizeT dst_axis = 0;
|
SizeT dst_axis = 0;
|
||||||
|
|
||||||
for (SliceIndex i = 0; i < num_indices; i++)
|
for (int32_t i = 0; i < num_indices; i++)
|
||||||
{
|
{
|
||||||
const NDIndex *index = &indices[i];
|
const NDIndex *index = &indices[i];
|
||||||
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT)
|
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT)
|
||||||
{
|
{
|
||||||
SliceIndex input = *((SliceIndex *)index->data);
|
SizeT input = (SizeT) * ((int32_t *)index->data);
|
||||||
SliceIndex k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
|
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
|
||||||
|
|
||||||
if (k == slice::OUT_OF_BOUNDS)
|
if (k == -1)
|
||||||
{
|
{
|
||||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
raise_exception(SizeT, EXN_INDEX_ERROR,
|
||||||
"index {0} is out of bounds for axis {1} "
|
"index {0} is out of bounds for axis {1} "
|
||||||
|
@ -177,9 +177,9 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
|
||||||
}
|
}
|
||||||
else if (index->type == ND_INDEX_TYPE_SLICE)
|
else if (index->type == ND_INDEX_TYPE_SLICE)
|
||||||
{
|
{
|
||||||
Slice *slice = (Slice *)index->data;
|
Slice<int32_t> *slice = (Slice<int32_t> *)index->data;
|
||||||
|
|
||||||
Range range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
|
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
|
||||||
|
|
||||||
dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis];
|
dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis];
|
||||||
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];
|
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];
|
||||||
|
|
|
@ -45,7 +45,7 @@ template <typename SizeT> void assert_transpose_axes(SizeT ndims, SizeT num_axes
|
||||||
for (SizeT i = 0; i < ndims; i++)
|
for (SizeT i = 0; i < ndims; i++)
|
||||||
{
|
{
|
||||||
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
|
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
|
||||||
if (axis == slice::OUT_OF_BOUNDS)
|
if (axis == -1)
|
||||||
{
|
{
|
||||||
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
|
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
|
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
|
||||||
|
|
|
@ -5,26 +5,22 @@
|
||||||
#include <irrt/int_types.hpp>
|
#include <irrt/int_types.hpp>
|
||||||
#include <irrt/math_util.hpp>
|
#include <irrt/math_util.hpp>
|
||||||
|
|
||||||
// The type of an index or a value describing the length of a
|
|
||||||
// range/slice is always `int32_t`.
|
|
||||||
using SliceIndex = int32_t;
|
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A Python range.
|
* @brief A Python range.
|
||||||
*/
|
*/
|
||||||
struct Range
|
template <typename T> struct Range
|
||||||
{
|
{
|
||||||
SliceIndex start;
|
T start;
|
||||||
SliceIndex stop;
|
T stop;
|
||||||
SliceIndex step;
|
T step;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Calculate the `len()` of this range.
|
* @brief Calculate the `len()` of this range.
|
||||||
*/
|
*/
|
||||||
template <typename SizeT> SliceIndex len()
|
template <typename SizeT> T len()
|
||||||
{
|
{
|
||||||
// Reference: https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
|
// Reference: https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
|
||||||
debug_assert(SizeT, step != 0);
|
debug_assert(SizeT, step != 0);
|
||||||
|
@ -52,34 +48,31 @@ namespace slice
|
||||||
* `length - 1` (inclusive).
|
* `length - 1` (inclusive).
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
SliceIndex resolve_index_in_length_clamped(SliceIndex length, SliceIndex index)
|
template <typename T> T resolve_index_in_length_clamped(T length, T index)
|
||||||
{
|
{
|
||||||
if (index < 0)
|
if (index < 0)
|
||||||
{
|
{
|
||||||
return max<SliceIndex>(length + index, 0);
|
return max<T>(length + index, 0);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
return min<SliceIndex>(length, index);
|
return min<T>(length, index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const SliceIndex OUT_OF_BOUNDS = -1;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS`
|
* @brief Like `resolve_index_in_length_clamped`, but returns `-1` if `index` is out of bounds.
|
||||||
* if `index` is out of bounds.
|
|
||||||
*/
|
*/
|
||||||
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index)
|
template <typename T> T resolve_index_in_length(T length, T index)
|
||||||
{
|
{
|
||||||
SliceIndex resolved = index < 0 ? length + index : index;
|
T resolved = index < 0 ? length + index : index;
|
||||||
if (0 <= resolved && resolved < length)
|
if (0 <= resolved && resolved < length)
|
||||||
{
|
{
|
||||||
return resolved;
|
return resolved;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
return OUT_OF_BOUNDS;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace slice
|
} // namespace slice
|
||||||
|
@ -87,16 +80,16 @@ SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index)
|
||||||
/**
|
/**
|
||||||
* @brief A Python-like slice with **unresolved** indices.
|
* @brief A Python-like slice with **unresolved** indices.
|
||||||
*/
|
*/
|
||||||
struct Slice
|
template <typename T> struct Slice
|
||||||
{
|
{
|
||||||
bool start_defined;
|
bool start_defined;
|
||||||
SliceIndex start;
|
T start;
|
||||||
|
|
||||||
bool stop_defined;
|
bool stop_defined;
|
||||||
SliceIndex stop;
|
T stop;
|
||||||
|
|
||||||
bool step_defined;
|
bool step_defined;
|
||||||
SliceIndex step;
|
T step;
|
||||||
|
|
||||||
Slice()
|
Slice()
|
||||||
{
|
{
|
||||||
|
@ -110,19 +103,19 @@ struct Slice
|
||||||
this->step_defined = false;
|
this->step_defined = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_start(SliceIndex start)
|
void set_start(T start)
|
||||||
{
|
{
|
||||||
this->start_defined = true;
|
this->start_defined = true;
|
||||||
this->start = start;
|
this->start = start;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_stop(SliceIndex stop)
|
void set_stop(T stop)
|
||||||
{
|
{
|
||||||
this->stop_defined = true;
|
this->stop_defined = true;
|
||||||
this->stop = stop;
|
this->stop = stop;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_step(SliceIndex step)
|
void set_step(T step)
|
||||||
{
|
{
|
||||||
this->step_defined = true;
|
this->step_defined = true;
|
||||||
this->step = step;
|
this->step = step;
|
||||||
|
@ -133,17 +126,17 @@ struct Slice
|
||||||
*
|
*
|
||||||
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
|
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
|
||||||
*/
|
*/
|
||||||
template <typename SizeT> Range indices(SliceIndex length)
|
template <typename SizeT> Range<T> indices(T length)
|
||||||
{
|
{
|
||||||
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
||||||
debug_assert(SizeT, length >= 0);
|
debug_assert(SizeT, length >= 0);
|
||||||
|
|
||||||
Range result;
|
Range<T> result;
|
||||||
|
|
||||||
result.step = step_defined ? step : 1;
|
result.step = step_defined ? step : 1;
|
||||||
bool step_is_negative = result.step < 0;
|
bool step_is_negative = result.step < 0;
|
||||||
|
|
||||||
SliceIndex lower, upper;
|
T lower, upper;
|
||||||
if (step_is_negative)
|
if (step_is_negative)
|
||||||
{
|
{
|
||||||
lower = -1;
|
lower = -1;
|
||||||
|
@ -179,7 +172,7 @@ struct Slice
|
||||||
/**
|
/**
|
||||||
* @brief Like `.indices()` but with assertions.
|
* @brief Like `.indices()` but with assertions.
|
||||||
*/
|
*/
|
||||||
template <typename SizeT> Range indices_checked(SliceIndex length)
|
template <typename SizeT> Range<T> indices_checked(T length)
|
||||||
{
|
{
|
||||||
// TODO: Switch to `SizeT length`
|
// TODO: Switch to `SizeT length`
|
||||||
|
|
||||||
|
|
|
@ -25,49 +25,51 @@ impl<'ctx> StructKind<'ctx> for NDIndex {
|
||||||
|
|
||||||
/// Fields of [`Slice`]
|
/// Fields of [`Slice`]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>> {
|
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> {
|
||||||
pub start_defined: F::Out<Int<Bool>>,
|
pub start_defined: F::Out<Int<Bool>>,
|
||||||
pub start: F::Out<Int<Int32>>,
|
pub start: F::Out<Int<N>>,
|
||||||
pub stop_defined: F::Out<Int<Bool>>,
|
pub stop_defined: F::Out<Int<Bool>>,
|
||||||
pub stop: F::Out<Int<Int32>>,
|
pub stop: F::Out<Int<N>>,
|
||||||
pub step_defined: F::Out<Int<Bool>>,
|
pub step_defined: F::Out<Int<Bool>>,
|
||||||
pub step: F::Out<Int<Int32>>,
|
pub step: F::Out<Int<N>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An IRRT representation of an (unresolved) slice.
|
/// An IRRT representation of an (unresolved) slice.
|
||||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||||
pub struct Slice;
|
pub struct Slice<N> {
|
||||||
|
int_kind: N,
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx> StructKind<'ctx> for Slice {
|
impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice<N> {
|
||||||
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F>;
|
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F, N>;
|
||||||
|
|
||||||
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
||||||
Self::Fields {
|
Self::Fields {
|
||||||
start_defined: traversal.add_auto("start_defined"),
|
start_defined: traversal.add_auto("start_defined"),
|
||||||
start: traversal.add_auto("start"),
|
start: traversal.add("start", Int(self.int_kind)),
|
||||||
stop_defined: traversal.add_auto("stop_defined"),
|
stop_defined: traversal.add_auto("stop_defined"),
|
||||||
stop: traversal.add_auto("stop"),
|
stop: traversal.add("stop", Int(self.int_kind)),
|
||||||
step_defined: traversal.add_auto("step_defined"),
|
step_defined: traversal.add_auto("step_defined"),
|
||||||
step: traversal.add_auto("step"),
|
step: traversal.add("step", Int(self.int_kind)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A convenience structure to prepare a [`Slice`].
|
/// A convenience structure to prepare a [`Slice`].
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RustSlice<'ctx> {
|
pub struct RustSlice<'ctx, N: IntKind<'ctx>> {
|
||||||
pub start: Option<Instance<'ctx, Int<Int32>>>,
|
pub start: Option<Instance<'ctx, Int<N>>>,
|
||||||
pub stop: Option<Instance<'ctx, Int<Int32>>>,
|
pub stop: Option<Instance<'ctx, Int<N>>>,
|
||||||
pub step: Option<Instance<'ctx, Int<Int32>>>,
|
pub step: Option<Instance<'ctx, Int<N>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> RustSlice<'ctx> {
|
impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> {
|
||||||
/// Write the contents to an LLVM [`Slice`].
|
/// Write the contents to an LLVM [`Slice`].
|
||||||
pub fn write_to_slice<G: CodeGenerator + ?Sized>(
|
pub fn write_to_slice<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice>>>,
|
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice<N>>>>,
|
||||||
) {
|
) {
|
||||||
let false_ = Int(Bool).const_false(generator, ctx.ctx);
|
let false_ = Int(Bool).const_false(generator, ctx.ctx);
|
||||||
let true_ = Int(Bool).const_true(generator, ctx.ctx);
|
let true_ = Int(Bool).const_true(generator, ctx.ctx);
|
||||||
|
@ -102,7 +104,7 @@ impl<'ctx> RustSlice<'ctx> {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum RustNDIndex<'ctx> {
|
pub enum RustNDIndex<'ctx> {
|
||||||
SingleElement(Instance<'ctx, Int<Int32>>), // TODO: To be SizeT
|
SingleElement(Instance<'ctx, Int<Int32>>), // TODO: To be SizeT
|
||||||
Slice(RustSlice<'ctx>),
|
Slice(RustSlice<'ctx, Int32>),
|
||||||
NewAxis,
|
NewAxis,
|
||||||
Ellipsis,
|
Ellipsis,
|
||||||
}
|
}
|
||||||
|
@ -143,7 +145,7 @@ impl<'ctx> RustNDIndex<'ctx> {
|
||||||
.store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte)));
|
.store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte)));
|
||||||
}
|
}
|
||||||
RustNDIndex::Slice(in_rust_slice) => {
|
RustNDIndex::Slice(in_rust_slice) => {
|
||||||
let user_slice_ptr = Struct(Slice).alloca(generator, ctx);
|
let user_slice_ptr = Struct(Slice { int_kind: Int32 }).alloca(generator, ctx);
|
||||||
in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr);
|
in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr);
|
||||||
|
|
||||||
dst_ndindex_ptr
|
dst_ndindex_ptr
|
||||||
|
|
Loading…
Reference in New Issue