From 438943ac6fa87debc5fc1cc026056b231c51a041 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 10 Dec 2024 16:32:44 +0800 Subject: [PATCH] [core] codegen: Implement indexing for NDArray Based on 8f9d2d82: core/ndstrides: implement ndarray indexing The functionality for `...` and `np.newaxis` is there in IRRT, but there is no implementation of them for @kernel Python expressions because of https://git.m-labs.hk/M-Labs/nac3/issues/486. --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/indexing.hpp | 220 +++++++++++++++ nac3core/src/codegen/irrt/ndarray/indexing.rs | 29 ++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + .../src/codegen/types/ndarray/indexing.rs | 215 ++++++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 2 + .../src/codegen/values/ndarray/indexing.rs | 262 ++++++++++++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 2 + 8 files changed, 734 insertions(+), 1 deletion(-) create mode 100644 nac3core/irrt/irrt/ndarray/indexing.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/indexing.rs create mode 100644 nac3core/src/codegen/types/ndarray/indexing.rs create mode 100644 nac3core/src/codegen/values/ndarray/indexing.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 5eafe013..8447fc5a 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -6,4 +6,5 @@ #include "irrt/slice.hpp" #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" -#include "irrt/ndarray/iter.hpp" \ No newline at end of file +#include "irrt/ndarray/iter.hpp" +#include "irrt/ndarray/indexing.hpp" diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp new file mode 100644 index 00000000..9e9e7b6e --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -0,0 +1,220 @@ +#pragma once + +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/range.hpp" +#include "irrt/slice.hpp" + +namespace { +typedef uint8_t NDIndexType; + +/** + * @brief A single element index + * + * `data` points to a `int32_t`. + */ +const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0; + +/** + * @brief A slice index + * + * `data` points to a `Slice`. + */ +const NDIndexType ND_INDEX_TYPE_SLICE = 1; + +/** + * @brief `np.newaxis` / `None` + * + * `data` is unused. + */ +const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2; + +/** + * @brief `Ellipsis` / `...` + * + * `data` is unused. + */ +const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3; + +/** + * @brief An index used in ndarray indexing + * + * That is: + * ``` + * my_ndarray[::-1, 3, ..., np.newaxis] + * ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex. + * ``` + */ +struct NDIndex { + /** + * @brief Enum tag to specify the type of index. + * + * Please see the comment of each enum constant. + */ + NDIndexType type; + + /** + * @brief The accompanying data associated with `type`. + * + * Please see the comment of each enum constant. + */ + uint8_t* data; +}; +} // namespace + +namespace { +namespace ndarray { +namespace indexing { +/** + * @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) + * + * This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python. + * + * This function also does proper assertions on `indices` to check for out of bounds access and more. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after + * indexing `src_ndarray` with `indices`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data`. + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`. + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed. + * - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works. + * + * @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python. + * @param src_ndarray The NDArray to be indexed. + * @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above, + */ +template +void index(SizeT num_indices, const NDIndex* indices, const NDArray* src_ndarray, NDArray* dst_ndarray) { + // Validate `indices`. + + // Expected value of `dst_ndarray->ndims`. + SizeT expected_dst_ndims = src_ndarray->ndims; + // To check for "too many indices for array: array is ?-dimensional, but ? were indexed" + SizeT num_indexed = 0; + // There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis. + SizeT num_ellipsis = 0; + + for (SizeT i = 0; i < num_indices; i++) { + if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) { + expected_dst_ndims--; + num_indexed++; + } else if (indices[i].type == ND_INDEX_TYPE_SLICE) { + num_indexed++; + } else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) { + expected_dst_ndims++; + } else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) { + num_ellipsis++; + if (num_ellipsis > 1) { + raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM, + NO_PARAM, NO_PARAM); + } + } else { + __builtin_unreachable(); + } + } + + debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims); + + if (src_ndarray->ndims - num_indexed < 0) { + raise_exception(SizeT, EXN_INDEX_ERROR, + "too many indices for array: array is {0}-dimensional, " + "but {1} were indexed", + src_ndarray->ndims, num_indices, NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Reference code: + // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 + SizeT src_axis = 0; + SizeT dst_axis = 0; + + for (int32_t i = 0; i < num_indices; i++) { + const NDIndex* index = &indices[i]; + if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) { + SizeT input = (SizeT) * ((int32_t*)index->data); + + SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input); + if (k == -1) { + raise_exception(SizeT, EXN_INDEX_ERROR, + "index {0} is out of bounds for axis {1} " + "with size {2}", + input, src_axis, src_ndarray->shape[src_axis]); + } + + dst_ndarray->data = static_cast(dst_ndarray->data) + k * src_ndarray->strides[src_axis]; + + src_axis++; + } else if (index->type == ND_INDEX_TYPE_SLICE) { + Slice* slice = (Slice*)index->data; + + Range range = slice->indices_checked(src_ndarray->shape[src_axis]); + + dst_ndarray->data = static_cast(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis]; + dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis]; + dst_ndarray->shape[dst_axis] = (SizeT)range.len(); + + dst_axis++; + src_axis++; + } else if (index->type == ND_INDEX_TYPE_NEWAXIS) { + dst_ndarray->strides[dst_axis] = 0; + dst_ndarray->shape[dst_axis] = 1; + + dst_axis++; + } else if (index->type == ND_INDEX_TYPE_ELLIPSIS) { + // The number of ':' entries this '...' implies. + SizeT ellipsis_size = src_ndarray->ndims - num_indexed; + + for (SizeT j = 0; j < ellipsis_size; j++) { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; + + dst_axis++; + src_axis++; + } + } else { + __builtin_unreachable(); + } + } + + for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) { + dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + + debug_assert_eq(SizeT, src_ndarray->ndims, src_axis); + debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis); +} +} // namespace indexing +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::indexing; + +void __nac3_ndarray_index(int32_t num_indices, + NDIndex* indices, + NDArray* src_ndarray, + NDArray* dst_ndarray) { + index(num_indices, indices, src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_index64(int64_t num_indices, + NDIndex* indices, + NDArray* src_ndarray, + NDArray* dst_ndarray) { + index(num_indices, indices, src_ndarray, dst_ndarray); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs new file mode 100644 index 00000000..0821b2cd --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -0,0 +1,29 @@ +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue}, + CodeGenContext, CodeGenerator, +}; + +pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + indices: ArraySliceValue<'ctx>, + src_ndarray: NDArrayValue<'ctx>, + dst_ndarray: NDArrayValue<'ctx>, +) { + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_index"); + infer_and_call_function( + ctx, + &name, + None, + &[ + indices.size(ctx, generator).into(), + indices.base_ptr(ctx, generator).into(), + src_ndarray.as_base_value().into(), + dst_ndarray.as_base_value().into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 4a344410..a05e0ce3 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -16,9 +16,11 @@ use crate::codegen::{ CodeGenContext, CodeGenerator, }; pub use basic::*; +pub use indexing::*; pub use iter::*; mod basic; +mod indexing; mod iter; /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs new file mode 100644 index 00000000..959d4f57 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -0,0 +1,215 @@ +use inkwell::{ + context::{AsContextRef, Context}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use crate::codegen::{ + types::{ + structure::{check_struct_type_matches_fields, StructField, StructFields}, + ProxyType, + }, + values::{ + ndarray::{NDIndexValue, RustNDIndex}, + ArrayLikeIndexer, ArraySliceValue, ProxyValue, + }, + CodeGenContext, CodeGenerator, +}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct NDIndexType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct NDIndexStructFields<'ctx> { + #[value_type(i8_type())] + pub type_: StructField<'ctx, IntValue<'ctx>>, + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub data: StructField<'ctx, PointerValue<'ctx>>, +} + +impl<'ctx> NDIndexType<'ctx> { + /// Checks whether `llvm_ty` represents a `ndindex` type, returning [Err] if it does not. + pub fn is_representable( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let ctx = llvm_ty.get_context(); + + let llvm_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!( + "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" + )); + }; + + let fields = NDIndexStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields(fields, llvm_ty, "NDIndex", &[]) + } + + #[must_use] + fn fields( + ctx: impl AsContextRef<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> NDIndexStructFields<'ctx> { + NDIndexStructFields::new(ctx, llvm_usize) + } + + #[must_use] + pub fn get_fields(&self) -> NDIndexStructFields<'ctx> { + Self::fields(self.ty.get_context(), self.llvm_usize) + } + + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + #[must_use] + pub fn new(generator: &G, ctx: &'ctx Context) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_ndindex = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_ndindex, llvm_usize } + } + + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + #[must_use] + pub fn alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Serialize a list of [`RustNDIndex`] as a newly allocated LLVM array of [`NDIndexValue`]. + #[must_use] + pub fn construct_ndindices( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_ndindices: &[RustNDIndex<'ctx>], + ) -> ArraySliceValue<'ctx> { + // Allocate the LLVM ndindices. + let num_ndindices = self.llvm_usize.const_int(in_ndindices.len() as u64, false); + let ndindices = self.array_alloca(generator, ctx, num_ndindices, None); + + // Initialize all of them. + for (i, in_ndindex) in in_ndindices.iter().enumerate() { + let pndindex = unsafe { + ndindices.ptr_offset_unchecked( + ctx, + generator, + &ctx.ctx.i64_type().const_int(u64::try_from(i).unwrap(), false), + None, + ) + }; + + in_ndindex.write_to_ndindex( + generator, + ctx, + NDIndexValue::from_pointer_value(pndindex, self.llvm_usize, None), + ); + } + + ndindices + } + + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { + type Base = PointerType<'ctx>; + type Value = NDIndexValue<'ctx>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn is_representable( + generator: &G, + ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + } + + fn raw_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Base { + generator + .gen_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + name, + ) + .unwrap() + } + + fn array_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + size, + name, + ) + .unwrap() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: NDIndexType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 05e11659..c65deb6e 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -21,9 +21,11 @@ use crate::{ typecheck::typedef::Type, }; pub use contiguous::*; +pub use indexing::*; pub use nditer::*; mod contiguous; +mod indexing; mod nditer; /// Proxy type for a `ndarray` type in LLVM. diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs new file mode 100644 index 00000000..69c00807 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -0,0 +1,262 @@ +use inkwell::{ + types::IntType, + values::{IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3parser::ast::{Expr, ExprKind}; + +use crate::{ + codegen::{ + irrt, + types::{ + ndarray::{NDArrayType, NDIndexType}, + structure::StructField, + utils::SliceType, + }, + values::{ndarray::NDArrayValue, utils::RustSlice, ProxyValue}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +/// An IRRT representation of an ndarray subscript index. +#[derive(Copy, Clone)] +pub struct NDIndexValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> NDIndexValue<'ctx> { + /// Checks whether `value` is an instance of `ndindex`, returning [Err] if `value` is not an + /// instance. + pub fn is_representable( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + >::Type::is_representable(value.get_type(), llvm_usize) + } + + /// Creates an [`NDIndexValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + + Self { value: ptr, llvm_usize, name } + } + + fn type_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().type_ + } + + pub fn load_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + self.type_field().get(ctx, self.value, self.name) + } + + pub fn store_type(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + self.type_field().set(ctx, self.value, value, self.name); + } + + fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().data + } + + pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.data_field().get(ctx, self.value, self.name) + } + + pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + self.data_field().set(ctx, self.value, value, self.name); + } +} + +impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { + type Base = PointerValue<'ctx>; + type Type = NDIndexType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: NDIndexValue<'ctx>) -> Self { + value.as_base_value() + } +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Get the expected `ndims` after indexing with `indices`. + #[must_use] + fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> Option { + let mut ndims = self.ndims?; + + for index in indices { + match index { + RustNDIndex::SingleElement(_) => { + ndims -= 1; // Single elements decrements ndims + } + RustNDIndex::NewAxis => { + ndims += 1; // `np.newaxis` / `none` adds a new axis + } + RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {} + } + } + + Some(ndims) + } + + /// Index into the ndarray, and return a newly-allocated view on this ndarray. + /// + /// This function behaves like NumPy's ndarray indexing, but if the indices index + /// into a single element, an unsized ndarray is returned. + #[must_use] + pub fn index( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: &[RustNDIndex<'ctx>], + ) -> Self { + assert!(self.ndims.is_some(), "NDArrayValue::index is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); + + let dst_ndims = self.deduce_ndims_after_indexing_with(indices); + let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims) + .construct_uninitialized(generator, ctx, None); + + let indices = + NDIndexType::new(generator, ctx.ctx).construct_ndindices(generator, ctx, indices); + irrt::ndarray::call_nac3_ndarray_index(generator, ctx, indices, *self, dst_ndarray); + + dst_ndarray + } +} + +/// A convenience enum representing a [`NDIndexValue`]. +// TODO: Rename to CTConstNDIndex +#[derive(Debug, Clone)] +pub enum RustNDIndex<'ctx> { + SingleElement(IntValue<'ctx>), + Slice(RustSlice<'ctx>), + NewAxis, + Ellipsis, +} + +impl<'ctx> RustNDIndex<'ctx> { + /// Generate LLVM code to transform an ndarray subscript expression to + /// its list of [`RustNDIndex`] + /// + /// i.e., + /// ```python + /// my_ndarray[::3, 1, :2:] + /// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es + /// ``` + pub fn from_subscript_expr( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + subscript: &Expr>, + ) -> Result>, String> { + // Annoying notes about `slice` + // - `my_array[5]` + // - slice is a `Constant` + // - `my_array[:5]` + // - slice is a `Slice` + // - `my_array[:]` + // - slice is a `Slice`, but lower upper step would all be `Option::None` + // - `my_array[:, :]` + // - slice is now a `Tuple` of two `Slice`-s + // + // In summary: + // - when there is a comma "," within [], `slice` will be a `Tuple` of the entries. + // - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself. + // + // So we first "flatten" out the slice expression + let index_exprs = match &subscript.node { + ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(), + _ => vec![subscript], + }; + + // Process all index expressions + let mut rust_ndindices: Vec = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here. + for index_expr in index_exprs { + // NOTE: Currently nac3core's slices do not have an object representation, + // so the code/implementation looks awkward - we have to do pattern matching on the expression + let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node { + // Handle slices + let slice = RustSlice::from_slice_expr(generator, ctx, lower, upper, step)?; + RustNDIndex::Slice(slice) + } else { + // Treat and handle everything else as a single element index. + let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum( + ctx, + generator, + ctx.primitives.int32, // Must be int32, this checks for illegal values + )?; + let index = index.into_int_value(); + + RustNDIndex::SingleElement(index) + }; + rust_ndindices.push(ndindex); + } + Ok(rust_ndindices) + } + + /// Get the value to set `NDIndex::type` for this variant. + #[must_use] + pub fn get_type_id(&self) -> u64 { + // Defined in IRRT, must be in sync + match self { + RustNDIndex::SingleElement(_) => 0, + RustNDIndex::Slice(_) => 1, + RustNDIndex::NewAxis => 2, + RustNDIndex::Ellipsis => 3, + } + } + + /// Serialize this [`RustNDIndex`] by writing it into an LLVM [`NDIndexValue`]. + pub fn write_to_ndindex( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst_ndindex: NDIndexValue<'ctx>, + ) { + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + + // Set `dst_ndindex.type` + dst_ndindex.store_type(ctx, ctx.ctx.i8_type().const_int(self.get_type_id(), false)); + + // Set `dst_ndindex_ptr->data` + match self { + RustNDIndex::SingleElement(in_index) => { + let index_ptr = ctx.builder.build_alloca(ctx.ctx.i32_type(), "").unwrap(); + ctx.builder.build_store(index_ptr, *in_index).unwrap(); + + dst_ndindex.store_data( + ctx, + ctx.builder.build_pointer_cast(index_ptr, llvm_pi8, "").unwrap(), + ); + } + RustNDIndex::Slice(in_rust_slice) => { + let user_slice_ptr = + SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx)) + .alloca(generator, ctx, None); + in_rust_slice.write_to_slice(ctx, user_slice_ptr); + + dst_ndindex.store_data( + ctx, + ctx.builder.build_pointer_cast(user_slice_ptr.into(), llvm_pi8, "").unwrap(), + ); + } + RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {} + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 3d70bf93..fdf11dd2 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -17,9 +17,11 @@ use crate::codegen::{ CodeGenContext, CodeGenerator, }; pub use contiguous::*; +pub use indexing::*; pub use nditer::*; mod contiguous; +mod indexing; mod nditer; /// Proxy type for accessing an `NDArray` value in LLVM.