From 3782791323ed604e2f236624a182adc91828e75a Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 23 Aug 2024 09:52:08 +0800 Subject: [PATCH] core: make IRRT slice and range work for any int types --- nac3core/irrt/irrt/ndarray/indexing.hpp | 16 +++--- nac3core/irrt/irrt/ndarray/transpose.hpp | 2 +- nac3core/irrt/irrt/slice.hpp | 53 ++++++++----------- .../src/codegen/object/ndarray/indexing.rs | 38 ++++++------- 4 files changed, 52 insertions(+), 57 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp index e6babe66..8d971230 100644 --- a/nac3core/irrt/irrt/ndarray/indexing.hpp +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -13,14 +13,14 @@ typedef uint8_t NDIndexType; /** * @brief A single element index * - * `data` points to a `SliceIndex`. + * `data` points to a `int32_t`. */ const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0; /** * @brief A slice index * - * `data` points to a `Slice`. + * `data` points to a `Slice`. */ const NDIndexType ND_INDEX_TYPE_SLICE = 1; @@ -155,15 +155,15 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray *src_ SizeT src_axis = 0; SizeT dst_axis = 0; - for (SliceIndex i = 0; i < num_indices; i++) + for (int32_t i = 0; i < num_indices; i++) { const NDIndex *index = &indices[i]; if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) { - SliceIndex input = *((SliceIndex *)index->data); - SliceIndex k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input); + SizeT input = (SizeT) * ((int32_t *)index->data); + SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input); - if (k == slice::OUT_OF_BOUNDS) + if (k == -1) { raise_exception(SizeT, EXN_INDEX_ERROR, "index {0} is out of bounds for axis {1} " @@ -177,9 +177,9 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray *src_ } else if (index->type == ND_INDEX_TYPE_SLICE) { - Slice *slice = (Slice *)index->data; + Slice *slice = (Slice *)index->data; - Range range = slice->indices_checked(src_ndarray->shape[src_axis]); + Range range = slice->indices_checked(src_ndarray->shape[src_axis]); dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis]; dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis]; diff --git a/nac3core/irrt/irrt/ndarray/transpose.hpp b/nac3core/irrt/irrt/ndarray/transpose.hpp index 1ac73f4d..ab5fe009 100644 --- a/nac3core/irrt/irrt/ndarray/transpose.hpp +++ b/nac3core/irrt/irrt/ndarray/transpose.hpp @@ -45,7 +45,7 @@ template void assert_transpose_axes(SizeT ndims, SizeT num_axes for (SizeT i = 0; i < ndims; i++) { SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); - if (axis == slice::OUT_OF_BOUNDS) + if (axis == -1) { // TODO: numpy actually throws a `numpy.exceptions.AxisError` raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims, diff --git a/nac3core/irrt/irrt/slice.hpp b/nac3core/irrt/irrt/slice.hpp index b385587b..c064924b 100644 --- a/nac3core/irrt/irrt/slice.hpp +++ b/nac3core/irrt/irrt/slice.hpp @@ -5,26 +5,22 @@ #include #include -// The type of an index or a value describing the length of a -// range/slice is always `int32_t`. -using SliceIndex = int32_t; - namespace { /** * @brief A Python range. */ -struct Range +template struct Range { - SliceIndex start; - SliceIndex stop; - SliceIndex step; + T start; + T stop; + T step; /** * @brief Calculate the `len()` of this range. */ - template SliceIndex len() + template T len() { // Reference: https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933 debug_assert(SizeT, step != 0); @@ -52,34 +48,31 @@ namespace slice * `length - 1` (inclusive). * */ -SliceIndex resolve_index_in_length_clamped(SliceIndex length, SliceIndex index) +template T resolve_index_in_length_clamped(T length, T index) { if (index < 0) { - return max(length + index, 0); + return max(length + index, 0); } else { - return min(length, index); + return min(length, index); } } -const SliceIndex OUT_OF_BOUNDS = -1; - /** - * @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS` - * if `index` is out of bounds. + * @brief Like `resolve_index_in_length_clamped`, but returns `-1` if `index` is out of bounds. */ -SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) +template T resolve_index_in_length(T length, T index) { - SliceIndex resolved = index < 0 ? length + index : index; + T resolved = index < 0 ? length + index : index; if (0 <= resolved && resolved < length) { return resolved; } else { - return OUT_OF_BOUNDS; + return -1; } } } // namespace slice @@ -87,16 +80,16 @@ SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) /** * @brief A Python-like slice with **unresolved** indices. */ -struct Slice +template struct Slice { bool start_defined; - SliceIndex start; + T start; bool stop_defined; - SliceIndex stop; + T stop; bool step_defined; - SliceIndex step; + T step; Slice() { @@ -110,19 +103,19 @@ struct Slice this->step_defined = false; } - void set_start(SliceIndex start) + void set_start(T start) { this->start_defined = true; this->start = start; } - void set_stop(SliceIndex stop) + void set_stop(T stop) { this->stop_defined = true; this->stop = stop; } - void set_step(SliceIndex step) + void set_step(T step) { this->step_defined = true; this->step = step; @@ -133,17 +126,17 @@ struct Slice * * In Python, this would be `range(*slice(start, stop, step).indices(length))`. */ - template Range indices(SliceIndex length) + template Range indices(T length) { // Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388 debug_assert(SizeT, length >= 0); - Range result; + Range result; result.step = step_defined ? step : 1; bool step_is_negative = result.step < 0; - SliceIndex lower, upper; + T lower, upper; if (step_is_negative) { lower = -1; @@ -179,7 +172,7 @@ struct Slice /** * @brief Like `.indices()` but with assertions. */ - template Range indices_checked(SliceIndex length) + template Range indices_checked(T length) { // TODO: Switch to `SizeT length` diff --git a/nac3core/src/codegen/object/ndarray/indexing.rs b/nac3core/src/codegen/object/ndarray/indexing.rs index ae8b2d66..71300d37 100644 --- a/nac3core/src/codegen/object/ndarray/indexing.rs +++ b/nac3core/src/codegen/object/ndarray/indexing.rs @@ -25,49 +25,51 @@ impl<'ctx> StructKind<'ctx> for NDIndex { /// Fields of [`Slice`] #[derive(Debug, Clone)] -pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>> { +pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> { pub start_defined: F::Out>, - pub start: F::Out>, + pub start: F::Out>, pub stop_defined: F::Out>, - pub stop: F::Out>, + pub stop: F::Out>, pub step_defined: F::Out>, - pub step: F::Out>, + pub step: F::Out>, } /// An IRRT representation of an (unresolved) slice. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub struct Slice; +pub struct Slice { + int_kind: N, +} -impl<'ctx> StructKind<'ctx> for Slice { - type Fields> = SliceFields<'ctx, F>; +impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice { + type Fields> = SliceFields<'ctx, F, N>; fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { Self::Fields { start_defined: traversal.add_auto("start_defined"), - start: traversal.add_auto("start"), + start: traversal.add("start", Int(self.int_kind)), stop_defined: traversal.add_auto("stop_defined"), - stop: traversal.add_auto("stop"), + stop: traversal.add("stop", Int(self.int_kind)), step_defined: traversal.add_auto("step_defined"), - step: traversal.add_auto("step"), + step: traversal.add("step", Int(self.int_kind)), } } } /// A convenience structure to prepare a [`Slice`]. #[derive(Debug, Clone)] -pub struct RustSlice<'ctx> { - pub start: Option>>, - pub stop: Option>>, - pub step: Option>>, +pub struct RustSlice<'ctx, N: IntKind<'ctx>> { + pub start: Option>>, + pub stop: Option>>, + pub step: Option>>, } -impl<'ctx> RustSlice<'ctx> { +impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> { /// Write the contents to an LLVM [`Slice`]. pub fn write_to_slice( &self, generator: &mut G, ctx: &CodeGenContext<'ctx, '_>, - dst_slice_ptr: Instance<'ctx, Ptr>>, + dst_slice_ptr: Instance<'ctx, Ptr>>>, ) { let false_ = Int(Bool).const_false(generator, ctx.ctx); let true_ = Int(Bool).const_true(generator, ctx.ctx); @@ -102,7 +104,7 @@ impl<'ctx> RustSlice<'ctx> { #[derive(Debug, Clone)] pub enum RustNDIndex<'ctx> { SingleElement(Instance<'ctx, Int>), // TODO: To be SizeT - Slice(RustSlice<'ctx>), + Slice(RustSlice<'ctx, Int32>), NewAxis, Ellipsis, } @@ -143,7 +145,7 @@ impl<'ctx> RustNDIndex<'ctx> { .store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte))); } RustNDIndex::Slice(in_rust_slice) => { - let user_slice_ptr = Struct(Slice).alloca(generator, ctx); + let user_slice_ptr = Struct(Slice { int_kind: Int32 }).alloca(generator, ctx); in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr); dst_ndindex_ptr