From 01c96396463fdfc0316b10c79bf83bf1488459cb Mon Sep 17 00:00:00 2001 From: lyken Date: Sat, 24 Aug 2024 15:37:45 +0800 Subject: [PATCH 1/3] core/irrt: add Slice and Range Needed for implementing general ndarray indexing. Currently IRRT slice and range have nothing to do with NAC3's slice and range. The IRRT slice and range are currently there to implement ndarray specific features. However, in the future their definitions may be used to replace that of NAC3's. (NAC3's range is a [i32 x 3], IRRT's range is a proper struct. NAC3 does not have a slice struct). --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/range.hpp | 47 +++++++++++ nac3core/irrt/irrt/slice.hpp | 150 ++++++++++++++++++++++++++++++++--- 3 files changed, 187 insertions(+), 11 deletions(-) create mode 100644 nac3core/irrt/irrt/range.hpp diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 58d18f8a..43f15e8f 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -3,6 +3,7 @@ #include "irrt/list.hpp" #include "irrt/math.hpp" #include "irrt/ndarray.hpp" +#include "irrt/range.hpp" #include "irrt/slice.hpp" #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" diff --git a/nac3core/irrt/irrt/range.hpp b/nac3core/irrt/irrt/range.hpp new file mode 100644 index 00000000..e9d4e612 --- /dev/null +++ b/nac3core/irrt/irrt/range.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/int_types.hpp" + +namespace { +namespace range { +template +T len(T start, T stop, T step) { + // Reference: + // https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933 + if (step > 0 && start < stop) + return 1 + (stop - 1 - start) / step; + else if (step < 0 && start > stop) + return 1 + (start - 1 - stop) / (-step); + else + return 0; +} +} // namespace range + +/** + * @brief A Python range. + */ +template +struct Range { + T start; + T stop; + T step; + + /** + * @brief Calculate the `len()` of this range. + */ + template + T len() { + debug_assert(SizeT, step != 0); + return range::len(start, stop, step); + } +}; +} // namespace + +extern "C" { +using namespace range; + +SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) { + return len(start, end, step); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/slice.hpp b/nac3core/irrt/irrt/slice.hpp index a1523ddc..4cf13e05 100644 --- a/nac3core/irrt/irrt/slice.hpp +++ b/nac3core/irrt/irrt/slice.hpp @@ -1,6 +1,145 @@ #pragma once +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" #include "irrt/int_types.hpp" +#include "irrt/math_util.hpp" +#include "irrt/range.hpp" + +namespace { +namespace slice { +/** + * @brief Resolve a possibly negative index in a list of a known length. + * + * Returns -1 if the resolved index is out of the list's bounds. + */ +template +T resolve_index_in_length(T length, T index) { + T resolved = index < 0 ? length + index : index; + if (0 <= resolved && resolved < length) { + return resolved; + } else { + return -1; + } +} + +/** + * @brief Resolve a slice as a range. + * + * This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python. + */ +template +void indices(bool start_defined, + T start, + bool stop_defined, + T stop, + bool step_defined, + T step, + T length, + T* range_start, + T* range_stop, + T* range_step) { + // Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388 + *range_step = step_defined ? step : 1; + bool step_is_negative = *range_step < 0; + + T lower, upper; + if (step_is_negative) { + lower = -1; + upper = length - 1; + } else { + lower = 0; + upper = length; + } + + if (start_defined) { + *range_start = start < 0 ? max(lower, start + length) : min(upper, start); + } else { + *range_start = step_is_negative ? upper : lower; + } + + if (stop_defined) { + *range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop); + } else { + *range_stop = step_is_negative ? lower : upper; + } +} +} // namespace slice + +/** + * @brief A Python-like slice with **unresolved** indices. + */ +template +struct Slice { + bool start_defined; + T start; + + bool stop_defined; + T stop; + + bool step_defined; + T step; + + Slice() { this->reset(); } + + void reset() { + this->start_defined = false; + this->stop_defined = false; + this->step_defined = false; + } + + void set_start(T start) { + this->start_defined = true; + this->start = start; + } + + void set_stop(T stop) { + this->stop_defined = true; + this->stop = stop; + } + + void set_step(T step) { + this->step_defined = true; + this->step = step; + } + + /** + * @brief Resolve this slice as a range. + * + * In Python, this would be `range(*slice(start, stop, step).indices(length))`. + */ + template + Range indices(T length) { + // Reference: + // https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388 + debug_assert(SizeT, length >= 0); + + Range result; + slice::indices(start_defined, start, stop_defined, stop, step_defined, step, length, &result.start, + &result.stop, &result.step); + return result; + } + + /** + * @brief Like `.indices()` but with assertions. + */ + template + Range indices_checked(T length) { + // TODO: Switch to `SizeT length` + + if (length < 0) { + raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM, + NO_PARAM); + } + + if (this->step_defined && this->step == 0) { + raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM); + } + + return this->indices(length); + } +}; +} // namespace extern "C" { SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) { @@ -14,15 +153,4 @@ SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) { } return i; } - -SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) { - SliceIndex diff = end - start; - if (diff > 0 && step > 0) { - return ((diff - 1) / step) + 1; - } else if (diff < 0 && step < 0) { - return ((diff + 1) / step) + 1; - } else { - return 0; - } } -} \ No newline at end of file -- 2.44.2 From 9d0bfd965cf4a9db8af0090076b8fb7c40e044d2 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 15 Aug 2024 22:28:23 +0800 Subject: [PATCH 2/3] core/irrt: rename NDIndex to NDIndexInt Unfortunately the name `NDIndex` is used in later commits. Renaming this typedef to `NDIndexInt` to avoid amending. `NDIndexInt` will be removed anyway when ndarray strides is completed. --- nac3core/irrt/irrt/int_types.hpp | 2 +- nac3core/irrt/irrt/ndarray.hpp | 31 ++++++++++++++++++------------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/nac3core/irrt/irrt/int_types.hpp b/nac3core/irrt/irrt/int_types.hpp index 656a060e..694b5e35 100644 --- a/nac3core/irrt/irrt/int_types.hpp +++ b/nac3core/irrt/irrt/int_types.hpp @@ -17,6 +17,6 @@ using uint64_t = unsigned _ExtInt(64); #endif // NDArray indices are always `uint32_t`. -using NDIndex = uint32_t; +using NDIndexInt = uint32_t; // The type of an index or a value describing the length of a range/slice is always `int32_t`. using SliceIndex = int32_t; diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index cacdd2a2..72ca0b9e 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -19,7 +19,7 @@ SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, Size } template -void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) { +void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) { SizeT stride = 1; for (SizeT dim = 0; dim < num_dims; dim++) { SizeT i = num_dims - dim - 1; @@ -30,7 +30,10 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n } template -SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) { +SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, + SizeT num_dims, + const NDIndexInt* indices, + SizeT num_indices) { SizeT idx = 0; SizeT stride = 1; for (SizeT i = 0; i < num_dims; ++i) { @@ -77,8 +80,8 @@ void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, template void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims, SizeT src_ndims, - const NDIndex* in_idx, - NDIndex* out_idx) { + const NDIndexInt* in_idx, + NDIndexInt* out_idx) { for (SizeT i = 0; i < src_ndims; ++i) { SizeT src_i = src_ndims - i - 1; out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; @@ -96,21 +99,23 @@ __nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_ return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); } -void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) { +void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) { __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); } -void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) { +void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); } uint32_t -__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) { +__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) { return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); } -uint64_t -__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) { +uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, + uint64_t num_dims, + const NDIndexInt* indices, + uint64_t num_indices) { return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); } @@ -132,15 +137,15 @@ void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims, void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims, uint32_t src_ndims, - const NDIndex* in_idx, - NDIndex* out_idx) { + const NDIndexInt* in_idx, + NDIndexInt* out_idx) { __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); } void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, uint64_t src_ndims, - const NDIndex* in_idx, - NDIndex* out_idx) { + const NDIndexInt* in_idx, + NDIndexInt* out_idx) { __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); } } \ No newline at end of file -- 2.44.2 From 8f9d2d82dda593a216be1f6c5ce5b2db681a1068 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 21 Aug 2024 13:43:07 +0800 Subject: [PATCH 3/3] core/ndstrides: implement ndarray indexing The functionality for `...` and `np.newaxis` is there in IRRT, but there is no implementation of them for @kernel Python expressions because of https://git.m-labs.hk/M-Labs/nac3/issues/486. --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/indexing.hpp | 220 +++++++++++ nac3core/src/codegen/expr.rs | 370 ++---------------- nac3core/src/codegen/irrt/mod.rs | 19 +- nac3core/src/codegen/object/mod.rs | 1 + .../src/codegen/object/ndarray/indexing.rs | 226 +++++++++++ nac3core/src/codegen/object/ndarray/mod.rs | 45 ++- nac3core/src/codegen/object/utils/mod.rs | 1 + nac3core/src/codegen/object/utils/slice.rs | 125 ++++++ 9 files changed, 660 insertions(+), 350 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/indexing.hpp create mode 100644 nac3core/src/codegen/object/ndarray/indexing.rs create mode 100644 nac3core/src/codegen/object/utils/mod.rs create mode 100644 nac3core/src/codegen/object/utils/slice.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 43f15e8f..b586ce57 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -7,4 +7,5 @@ #include "irrt/slice.hpp" #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" -#include "irrt/ndarray/iter.hpp" \ No newline at end of file +#include "irrt/ndarray/iter.hpp" +#include "irrt/ndarray/indexing.hpp" diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp new file mode 100644 index 00000000..b2597d0d --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -0,0 +1,220 @@ +#pragma once + +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/range.hpp" +#include "irrt/slice.hpp" + +namespace { +typedef uint8_t NDIndexType; + +/** + * @brief A single element index + * + * `data` points to a `int32_t`. + */ +const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0; + +/** + * @brief A slice index + * + * `data` points to a `Slice`. + */ +const NDIndexType ND_INDEX_TYPE_SLICE = 1; + +/** + * @brief `np.newaxis` / `None` + * + * `data` is unused. + */ +const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2; + +/** + * @brief `Ellipsis` / `...` + * + * `data` is unused. + */ +const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3; + +/** + * @brief An index used in ndarray indexing + * + * That is: + * ``` + * my_ndarray[::-1, 3, ..., np.newaxis] + * ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex. + * ``` + */ +struct NDIndex { + /** + * @brief Enum tag to specify the type of index. + * + * Please see the comment of each enum constant. + */ + NDIndexType type; + + /** + * @brief The accompanying data associated with `type`. + * + * Please see the comment of each enum constant. + */ + uint8_t* data; +}; +} // namespace + +namespace { +namespace ndarray { +namespace indexing { +/** + * @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) + * + * This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python. + * + * This function also does proper assertions on `indices` to check for out of bounds access and more. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after + * indexing `src_ndarray` with `indices`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data`. + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`. + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed. + * - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works. + * + * @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python. + * @param src_ndarray The NDArray to be indexed. + * @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above, + */ +template +void index(SizeT num_indices, const NDIndex* indices, const NDArray* src_ndarray, NDArray* dst_ndarray) { + // Validate `indices`. + + // Expected value of `dst_ndarray->ndims`. + SizeT expected_dst_ndims = src_ndarray->ndims; + // To check for "too many indices for array: array is ?-dimensional, but ? were indexed" + SizeT num_indexed = 0; + // There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis. + SizeT num_ellipsis = 0; + + for (SizeT i = 0; i < num_indices; i++) { + if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) { + expected_dst_ndims--; + num_indexed++; + } else if (indices[i].type == ND_INDEX_TYPE_SLICE) { + num_indexed++; + } else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) { + expected_dst_ndims++; + } else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) { + num_ellipsis++; + if (num_ellipsis > 1) { + raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM, + NO_PARAM, NO_PARAM); + } + } else { + __builtin_unreachable(); + } + } + + debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims); + + if (src_ndarray->ndims - num_indexed < 0) { + raise_exception(SizeT, EXN_INDEX_ERROR, + "too many indices for array: array is {0}-dimensional, " + "but {1} were indexed", + src_ndarray->ndims, num_indices, NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Reference code: + // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 + SizeT src_axis = 0; + SizeT dst_axis = 0; + + for (int32_t i = 0; i < num_indices; i++) { + const NDIndex* index = &indices[i]; + if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) { + SizeT input = (SizeT) * ((int32_t*)index->data); + + SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input); + if (k == -1) { + raise_exception(SizeT, EXN_INDEX_ERROR, + "index {0} is out of bounds for axis {1} " + "with size {2}", + input, src_axis, src_ndarray->shape[src_axis]); + } + + dst_ndarray->data += k * src_ndarray->strides[src_axis]; + + src_axis++; + } else if (index->type == ND_INDEX_TYPE_SLICE) { + Slice* slice = (Slice*)index->data; + + Range range = slice->indices_checked(src_ndarray->shape[src_axis]); + + dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis]; + dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis]; + dst_ndarray->shape[dst_axis] = (SizeT)range.len(); + + dst_axis++; + src_axis++; + } else if (index->type == ND_INDEX_TYPE_NEWAXIS) { + dst_ndarray->strides[dst_axis] = 0; + dst_ndarray->shape[dst_axis] = 1; + + dst_axis++; + } else if (index->type == ND_INDEX_TYPE_ELLIPSIS) { + // The number of ':' entries this '...' implies. + SizeT ellipsis_size = src_ndarray->ndims - num_indexed; + + for (SizeT j = 0; j < ellipsis_size; j++) { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; + + dst_axis++; + src_axis++; + } + } else { + __builtin_unreachable(); + } + } + + for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) { + dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + + debug_assert_eq(SizeT, src_ndarray->ndims, src_axis); + debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis); +} +} // namespace indexing +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::indexing; + +void __nac3_ndarray_index(int32_t num_indices, + NDIndex* indices, + NDArray* src_ndarray, + NDArray* dst_ndarray) { + index(num_indices, indices, src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_index64(int64_t num_indices, + NDIndex* indices, + NDArray* src_ndarray, + NDArray* dst_ndarray) { + index(num_indices, indices, src_ndarray, dst_ndarray); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8a010420..84a4c3f4 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -21,7 +21,7 @@ use nac3parser::ast::{ use super::{ classes::{ ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ProxyValue, - RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + RangeValue, UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name, @@ -32,6 +32,10 @@ use super::{ }, macros::codegen_unreachable, need_sret, numpy, + object::{ + any::AnyObject, + ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject}, + }, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, @@ -40,11 +44,7 @@ use super::{ }; use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{ - helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - DefinitionId, TopLevelDef, - }, + toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -2505,338 +2505,6 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ) } -/// Generates code for a subscript expression on an `ndarray`. -/// -/// * `ty` - The `Type` of the `NDArray` elements. -/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. -/// * `v` - The `NDArray` value. -/// * `slice` - The slice expression used to subscript into the `ndarray`. -fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Type, - ndims: Type, - v: NDArrayValue<'ctx>, - slice: &Expr>, -) -> Result>, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { - codegen_unreachable!(ctx) - }; - - let ndims = values - .iter() - .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) - .collect::, _>>() - .map_err(|val| { - format!( - "Expected non-negative literal for ndarray.ndims, got {}", - i128::try_from(val).unwrap() - ) - })?; - - assert!(!ndims.is_empty()); - - // The number of dimensions subscripted by the index expression. - // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a - // dimension will remove a dimension. - let subscripted_dims = match &slice.node { - ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| { - if let ExprKind::Slice { .. } = &value_subexpr.node { - acc - } else { - acc + 1 - } - }), - - ExprKind::Slice { .. } => 0, - _ => 1, - }; - - let ndarray_ndims_ty = ctx.unifier.get_fresh_literal( - ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(), - None, - ); - let ndarray_ty = - make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty)); - let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); - let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); - let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap(); - - // Check that len is non-zero - let len = v.load_ndims(ctx); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(), - "0:IndexError", - "too many indices for array: array is {0}-dimensional but 1 were indexed", - [Some(len), None, None], - slice.location, - ); - - // Normalizes a possibly-negative index to its corresponding positive index - let normalize_index = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - dim: u64| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "") - .unwrap()) - }, - |_, _| Ok(Some(index)), - |generator, ctx| { - let llvm_i32 = ctx.ctx.i32_type(); - - let len = unsafe { - v.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, true), - None, - ) - }; - - let index = ctx - .builder - .build_int_add( - len, - ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), - "", - ) - .unwrap(); - - Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value)) - }; - - // Converts a slice expression into a slice-range tuple - let expr_to_slice = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - node: &ExprKind>, - dim: u64| { - match node { - ExprKind::Constant { value: Constant::Int(v), .. } => { - let Some(index) = - normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)? - else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - - ExprKind::Slice { lower, upper, step } => { - let dim_sz = unsafe { - v.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, false), - None, - ) - }; - - handle_slice_indices(lower, upper, step, ctx, generator, dim_sz) - } - - _ => { - let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) }; - let index = index - .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, dim)? else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - } - }; - - let make_indices_arr = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>| - -> Result<_, String> { - Ok(if let ExprKind::Tuple { elts, .. } = &slice.node { - let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(elts.len() as u64, false), - None, - )?; - - for (i, elt) in elts.iter().enumerate() { - let Some(index) = generator.gen_expr(ctx, elt)? else { - return Ok(None); - }; - - let index = index - .to_basic_value_enum(ctx, generator, elt.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { - return Ok(None); - }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - None, - ) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - } - - Some(index_addr) - } else if let Some(index) = generator.gen_expr(ctx, slice)? { - let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(1u64, false), - None, - )?; - - let index = - index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - - Some(index_addr) - } else { - None - }) - }; - - Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - v.data().get(ctx, generator, &index_addr, None).into() - } else { - match &slice.node { - ExprKind::Tuple { elts, .. } => { - let slices = elts - .iter() - .enumerate() - .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) - .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) - .collect::, _>>()?; - if slices.len() < elts.len() { - return Ok(None); - } - - let slices = slices.into_iter().map(Option::unwrap).collect_vec(); - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() - } - - ExprKind::Slice { .. } => { - let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { - return Ok(None); - }; - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() - } - - _ => { - // Accessing an element from a multi-dimensional `ndarray` - - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - // Create a new array, remove the top dimension from the dimension-size-list, and copy the - // elements over - let subscripted_ndarray = - generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None); - - let num_dims = v.load_ndims(ctx); - ndarray.store_ndims( - ctx, - generator, - ctx.builder - .build_int_sub(num_dims, llvm_usize.const_int(1, false), "") - .unwrap(), - ); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - - let ndarray_num_dims = ctx - .builder - .build_int_z_extend_or_bit_cast( - ndarray.load_ndims(ctx), - llvm_usize.size_of().get_type(), - "", - ) - .unwrap(); - let v_dims_src_ptr = unsafe { - v.dim_sizes().ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - call_memcpy_generic( - ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), - v_dims_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - let ndarray_num_elems = ctx - .builder - .build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "") - .unwrap(); - ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); - - let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); - call_memcpy_generic( - ctx, - ndarray.data().base_ptr(ctx, generator), - v_data_src_ptr, - ctx.builder - .build_int_mul( - ndarray_num_elems, - llvm_ndarray_data_t.size_of().unwrap(), - "", - ) - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - ndarray.as_base_value().into() - } - } - })) -} - /// See [`CodeGenerator::gen_expr`]. pub fn gen_expr<'ctx, G: CodeGenerator>( generator: &mut G, @@ -3493,18 +3161,26 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v.data().get(ctx, generator, &index, None).into() } } - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { - let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); - - let v = if let Some(v) = generator.gen_expr(ctx, value)? { - v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? - .into_pointer_value() - } else { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let Some(ndarray) = generator.gen_expr(ctx, value)? else { return Ok(None); }; - let v = NDArrayValue::from_ptr_val(v, usize, None); - return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); + let ndarray_ty = value.custom.unwrap(); + let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = NDArrayObject::from_object( + generator, + ctx, + AnyObject { ty: ndarray_ty, value: ndarray }, + ); + + let indices = gen_ndarray_subscript_ndindices(generator, ctx, slice)?; + let result = ndarray + .index(generator, ctx, &indices) + .split_unsized(generator, ctx) + .to_basic_value_enum(); + return Ok(Some(ValueEnum::Dynamic(result))); } TypeEnum::TTuple { .. } => { let index: u32 = diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 7f2cd687..94996cd3 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -19,7 +19,7 @@ use super::{ llvm_intrinsics, macros::codegen_unreachable, model::{function::FnCall, *}, - object::ndarray::{nditer::NDIter, NDArray}, + object::ndarray::{indexing::NDIndex, nditer::NDIter, NDArray}, stmt::gen_for_callback_incrementing, CodeGenContext, CodeGenerator, }; @@ -1112,3 +1112,20 @@ pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next"); FnCall::builder(generator, ctx, &name).arg(iter).returning_void(); } + +pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + num_indices: Instance<'ctx, Int>, + indices: Instance<'ctx, Ptr>>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_index"); + FnCall::builder(generator, ctx, &name) + .arg(num_indices) + .arg(indices) + .arg(src_ndarray) + .arg(dst_ndarray) + .returning_void(); +} diff --git a/nac3core/src/codegen/object/mod.rs b/nac3core/src/codegen/object/mod.rs index 9466d70b..17b0b940 100644 --- a/nac3core/src/codegen/object/mod.rs +++ b/nac3core/src/codegen/object/mod.rs @@ -2,3 +2,4 @@ pub mod any; pub mod list; pub mod ndarray; pub mod tuple; +pub mod utils; diff --git a/nac3core/src/codegen/object/ndarray/indexing.rs b/nac3core/src/codegen/object/ndarray/indexing.rs new file mode 100644 index 00000000..d4fbbb35 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/indexing.rs @@ -0,0 +1,226 @@ +use super::NDArrayObject; +use crate::codegen::{ + irrt::call_nac3_ndarray_index, + model::*, + object::utils::slice::{RustSlice, Slice}, + CodeGenContext, CodeGenerator, +}; + +pub type NDIndexType = Byte; + +/// Fields of [`NDIndex`] +#[derive(Debug, Clone, Copy)] +pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> { + pub type_: F::Output>, + pub data: F::Output>>, +} + +/// An IRRT representation of an ndarray subscript index. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct NDIndex; + +impl<'ctx> StructKind<'ctx> for NDIndex { + type Fields> = NDIndexFields<'ctx, F>; + + fn iter_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") } + } +} + +// A convenience enum representing a [`NDIndex`]. +#[derive(Debug, Clone)] +pub enum RustNDIndex<'ctx> { + SingleElement(Instance<'ctx, Int>), + Slice(RustSlice<'ctx, Int32>), + NewAxis, + Ellipsis, +} + +impl<'ctx> RustNDIndex<'ctx> { + /// Get the value to set `NDIndex::type` for this variant. + fn get_type_id(&self) -> u64 { + // Defined in IRRT, must be in sync + match self { + RustNDIndex::SingleElement(_) => 0, + RustNDIndex::Slice(_) => 1, + RustNDIndex::NewAxis => 2, + RustNDIndex::Ellipsis => 3, + } + } + + /// Serialize this [`RustNDIndex`] by writing it into an LLVM [`NDIndex`]. + fn write_to_ndindex( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + dst_ndindex_ptr: Instance<'ctx, Ptr>>, + ) { + // Set `dst_ndindex_ptr->type` + dst_ndindex_ptr.gep(ctx, |f| f.type_).store( + ctx, + Int(NDIndexType::default()).const_int(generator, ctx.ctx, self.get_type_id(), false), + ); + + // Set `dst_ndindex_ptr->data` + match self { + RustNDIndex::SingleElement(in_index) => { + let index_ptr = Int(Int32).alloca(generator, ctx); + index_ptr.store(ctx, *in_index); + + dst_ndindex_ptr + .gep(ctx, |f| f.data) + .store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte))); + } + RustNDIndex::Slice(in_rust_slice) => { + let user_slice_ptr = Struct(Slice(Int32)).alloca(generator, ctx); + in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr); + + dst_ndindex_ptr + .gep(ctx, |f| f.data) + .store(ctx, user_slice_ptr.pointer_cast(generator, ctx, Int(Byte))); + } + RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {} + } + } + + /// Serialize a list of `RustNDIndex` as a newly allocated LLVM array of `NDIndex`. + pub fn make_ndindices( + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + in_ndindices: &[RustNDIndex<'ctx>], + ) -> (Instance<'ctx, Int>, Instance<'ctx, Ptr>>) { + let ndindex_model = Struct(NDIndex); + + // Allocate the LLVM ndindices. + let num_ndindices = + Int(SizeT).const_int(generator, ctx.ctx, in_ndindices.len() as u64, false); + let ndindices = ndindex_model.array_alloca(generator, ctx, num_ndindices.value); + + // Initialize all of them. + for (i, in_ndindex) in in_ndindices.iter().enumerate() { + let pndindex = ndindices.offset_const(ctx, i64::try_from(i).unwrap()); + in_ndindex.write_to_ndindex(generator, ctx, pndindex); + } + + (num_ndindices, ndindices) + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Get the expected `ndims` after indexing with `indices`. + #[must_use] + fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 { + let mut ndims = self.ndims; + for index in indices { + match index { + RustNDIndex::SingleElement(_) => { + ndims -= 1; // Single elements decrements ndims + } + RustNDIndex::NewAxis => { + ndims += 1; // `np.newaxis` / `none` adds a new axis + } + RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {} + } + } + ndims + } + + /// Index into the ndarray, and return a newly-allocated view on this ndarray. + /// + /// This function behaves like NumPy's ndarray indexing, but if the indices index + /// into a single element, an unsized ndarray is returned. + #[must_use] + pub fn index( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: &[RustNDIndex<'ctx>], + ) -> Self { + let dst_ndims = self.deduce_ndims_after_indexing_with(indices); + let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims); + + let (num_indices, indices) = RustNDIndex::make_ndindices(generator, ctx, indices); + call_nac3_ndarray_index( + generator, + ctx, + num_indices, + indices, + self.instance, + dst_ndarray.instance, + ); + + dst_ndarray + } +} + +pub mod util { + use itertools::Itertools; + use nac3parser::ast::{Expr, ExprKind}; + + use crate::{ + codegen::{model::*, object::utils::slice::util::gen_slice, CodeGenContext, CodeGenerator}, + typecheck::typedef::Type, + }; + + use super::RustNDIndex; + + /// Generate LLVM code to transform an ndarray subscript expression to + /// its list of [`RustNDIndex`] + /// + /// i.e., + /// ```python + /// my_ndarray[::3, 1, :2:] + /// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es + /// ``` + pub fn gen_ndarray_subscript_ndindices<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + subscript: &Expr>, + ) -> Result>, String> { + // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools + + // Annoying notes about `slice` + // - `my_array[5]` + // - slice is a `Constant` + // - `my_array[:5]` + // - slice is a `Slice` + // - `my_array[:]` + // - slice is a `Slice`, but lower upper step would all be `Option::None` + // - `my_array[:, :]` + // - slice is now a `Tuple` of two `Slice`-s + // + // In summary: + // - when there is a comma "," within [], `slice` will be a `Tuple` of the entries. + // - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself. + // + // So we first "flatten" out the slice expression + let index_exprs = match &subscript.node { + ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(), + _ => vec![subscript], + }; + + // Process all index expressions + let mut rust_ndindices: Vec = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here. + for index_expr in index_exprs { + // NOTE: Currently nac3core's slices do not have an object representation, + // so the code/implementation looks awkward - we have to do pattern matching on the expression + let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node { + // Handle slices + let slice = gen_slice(generator, ctx, lower, upper, step)?; + RustNDIndex::Slice(slice) + } else { + // Treat and handle everything else as a single element index. + let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum( + ctx, + generator, + ctx.primitives.int32, // Must be int32, this checks for illegal values + )?; + let index = Int(Int32).check_value(generator, ctx.ctx, index).unwrap(); + + RustNDIndex::SingleElement(index) + }; + rust_ndindices.push(ndindex); + } + Ok(rust_ndindices) + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 8b55c2e1..7fa3365a 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, types::BasicType, - values::{BasicValueEnum, PointerValue}, + values::{BasicValue, BasicValueEnum, PointerValue}, AddressSpace, }; @@ -22,6 +22,7 @@ use crate::{ }; pub mod factory; +pub mod indexing; pub mod nditer; pub mod shape_util; @@ -352,6 +353,30 @@ impl<'ctx> NDArrayObject<'ctx> { call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance); } + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. + #[must_use] + pub fn is_unsized(&self) -> bool { + self.ndims == 0 + } + + /// If this ndarray is unsized, return its sole value as an [`AnyObject`]. + /// Otherwise, do nothing and return the ndarray itself. + pub fn split_unsized( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> ScalarOrNDArray<'ctx> { + if self.is_unsized() { + // NOTE: `np.size(self) == 0` here is never possible. + let zero = Int(SizeT).const_0(generator, ctx.ctx); + let value = self.get_nth_scalar(generator, ctx, zero).value; + + ScalarOrNDArray::Scalar(AnyObject { ty: self.dtype, value }) + } else { + ScalarOrNDArray::NDArray(*self) + } + } + /// Fill the ndarray with a scalar. /// /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. @@ -369,3 +394,21 @@ impl<'ctx> NDArrayObject<'ctx> { .unwrap(); } } + +/// A convenience enum for implementing functions that acts on scalars or ndarrays or both. +#[derive(Debug, Clone, Copy)] +pub enum ScalarOrNDArray<'ctx> { + Scalar(AnyObject<'ctx>), + NDArray(NDArrayObject<'ctx>), +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. + #[must_use] + pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { + match self { + ScalarOrNDArray::Scalar(scalar) => scalar.value, + ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(), + } + } +} diff --git a/nac3core/src/codegen/object/utils/mod.rs b/nac3core/src/codegen/object/utils/mod.rs new file mode 100644 index 00000000..913812d4 --- /dev/null +++ b/nac3core/src/codegen/object/utils/mod.rs @@ -0,0 +1 @@ +pub mod slice; diff --git a/nac3core/src/codegen/object/utils/slice.rs b/nac3core/src/codegen/object/utils/slice.rs new file mode 100644 index 00000000..f4bb3f5c --- /dev/null +++ b/nac3core/src/codegen/object/utils/slice.rs @@ -0,0 +1,125 @@ +use crate::codegen::{model::*, CodeGenContext, CodeGenerator}; + +/// Fields of [`Slice`] +#[derive(Debug, Clone)] +pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> { + pub start_defined: F::Output>, + pub start: F::Output>, + pub stop_defined: F::Output>, + pub stop: F::Output>, + pub step_defined: F::Output>, + pub step: F::Output>, +} + +/// An IRRT representation of an (unresolved) slice. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct Slice(pub N); + +impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice { + type Fields> = SliceFields<'ctx, F, N>; + + fn iter_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + start_defined: traversal.add_auto("start_defined"), + start: traversal.add("start", Int(self.0)), + stop_defined: traversal.add_auto("stop_defined"), + stop: traversal.add("stop", Int(self.0)), + step_defined: traversal.add_auto("step_defined"), + step: traversal.add("step", Int(self.0)), + } + } +} + +/// A Rust structure that has [`Slice`] utilities and looks like a [`Slice`] but +/// `start`, `stop` and `step` are held by LLVM registers only and possibly +/// [`Option::None`] if unspecified. +#[derive(Debug, Clone)] +pub struct RustSlice<'ctx, N: IntKind<'ctx>> { + // It is possible that `start`, `stop`, and `step` are all `None`. + // We need to know the `int_kind` even when that is the case. + pub int_kind: N, + pub start: Option>>, + pub stop: Option>>, + pub step: Option>>, +} + +impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> { + /// Write the contents to an LLVM [`Slice`]. + pub fn write_to_slice( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + dst_slice_ptr: Instance<'ctx, Ptr>>>, + ) { + let false_ = Int(Bool).const_false(generator, ctx.ctx); + let true_ = Int(Bool).const_true(generator, ctx.ctx); + + match self.start { + Some(start) => { + dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_); + dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start); + } + None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_), + } + + match self.stop { + Some(stop) => { + dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_); + dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop); + } + None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_), + } + + match self.step { + Some(step) => { + dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_); + dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step); + } + None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_), + } + } +} + +pub mod util { + use nac3parser::ast::Expr; + + use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::Type, + }; + + use super::RustSlice; + + /// Generate LLVM IR for an [`ExprKind::Slice`] and convert it into a [`RustSlice`]. + #[allow(clippy::type_complexity)] + pub fn gen_slice<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lower: &Option>>>, + upper: &Option>>>, + step: &Option>>>, + ) -> Result, String> { + let mut help = |value_expr: &Option>>>| -> Result<_, String> { + Ok(match value_expr { + None => None, + Some(value_expr) => { + let value_expr = generator + .gen_expr(ctx, value_expr)? + .unwrap() + .to_basic_value_enum(ctx, generator, ctx.primitives.int32)?; + + let value_expr = + Int(Int32).check_value(generator, ctx.ctx, value_expr).unwrap(); + + Some(value_expr) + } + }) + }; + + let start = help(lower)?; + let stop = help(upper)?; + let step = help(step)?; + + Ok(RustSlice { int_kind: Int32, start, stop, step }) + } +} -- 2.44.2