From 2ab7b299b8ed32956a22eced6a9cfd6a03a4eae0 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 30 Jul 2024 17:52:28 +0800 Subject: [PATCH] core/ndstrides: refactor numpy indexing --- nac3core/src/codegen/expr.rs | 103 +++++------------- nac3core/src/codegen/irrt/ndarray/indexing.rs | 10 +- nac3core/src/codegen/numpy_new/indexing.rs | 76 +++++++++++++ nac3core/src/codegen/numpy_new/mod.rs | 1 + nac3core/src/codegen/numpy_new/object.rs | 1 + nac3core/src/codegen/numpy_new/util.rs | 37 ++++++- 6 files changed, 144 insertions(+), 84 deletions(-) create mode 100644 nac3core/src/codegen/numpy_new/indexing.rs diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 990ac83a..01d24a01 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use super::{ irrt::slice::{RustUserSlice, SliceIndex}, + numpy_new::object::{NDArrayObject, ScalarOrNDArray}, structure::ndarray::NpArray, }; use crate::{ @@ -18,7 +19,6 @@ use crate::{ call_memcpy_generic, }, need_sret, numpy, - numpy_new::util::alloca_ndarray, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, @@ -35,7 +35,7 @@ use crate::{ use inkwell::{ attributes::{Attribute, AttributeLoc}, types::{AnyType, BasicType, BasicTypeEnum}, - values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}, + values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::{chain, izip, Either, Itertools}; @@ -44,7 +44,7 @@ use nac3parser::ast::{ StrRef, Unaryop, }; -use ndarray::indexing::{call_nac3_ndarray_index, RustNDIndex}; +use ndarray::indexing::RustNDIndex; use super::{ model::*, @@ -2130,23 +2130,13 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ) } -/// Generates code for a subscript expression on an `ndarray`. -/// -/// * `elem_ty` - The `Type` of the `NDArray` elements. -/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. -/// * `src_ndarray` - The `NDArray` value. -/// * `subscript` - The subscript expression used to index into the `ndarray`. -fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( +pub fn gen_ndarray_subscript_ndindexes<'ctx, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - ndims: Type, - src_ndarray: Ptr<'ctx, StructModel>, subscript: &Expr>, -) -> Result>, String> { +) -> Result>, String> { // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools let tyctx = generator.type_context(ctx.ctx); - let sizet_model = IntModel(SizeT); let slice_index_model = IntModel(SliceIndex::default()); // Annoying notes about `slice` @@ -2215,66 +2205,23 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( }; rust_ndindexes.push(ndindex); } + Ok(rust_ndindexes) +} - // Extract the `ndims` from a `Type` to `i128` - // We *HAVE* to know this statically, this is used to determine - // whether this subscript expression returns a scalar or an ndarray - let TypeEnum::TLiteral { values: ndims_values, .. } = &*ctx.unifier.get_ty_immutable(ndims) - else { - unreachable!() - }; - assert_eq!(ndims_values.len(), 1); - let src_ndims = i128::try_from(ndims_values[0].clone()).unwrap(); - - // Check for "too many indices for array: array is ..." error - if src_ndims < rust_ndindexes.len() as i128 { - ctx.make_assert( - generator, - ctx.ctx.bool_type().const_int(1, false), - "0:IndexError", - "too many indices for array: array is {0}-dimensional, but {1} were indexed", - [None, None, None], - ctx.current_loc, - ); - } - - let dst_ndims = RustNDIndex::deduce_ndims_after_slicing(&rust_ndindexes, src_ndims as i32); - let dst_ndarray = alloca_ndarray( - generator, - ctx, - sizet_model.constant(tyctx, ctx.ctx, dst_ndims as u64), - "subndarray", - ); - - // Prepare the subscripts - let (num_ndindexes, ndindexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, &rust_ndindexes); - - // NOTE: IRRT does check for indexing errors - call_nac3_ndarray_index(generator, ctx, num_ndindexes, ndindexes, src_ndarray, dst_ndarray); - - // ...and return the result, with two cases - let result_llvm_value: BasicValueEnum<'_> = if dst_ndims == 0 { - // 1) ndims == 0 (this happens when you do `np.zerps((3, 4))[1, 1]`), return the element - - let pelement = dst_ndarray.gep(ctx, |f| f.data).load(tyctx, ctx, "pelement"); // `*data` points to the first element by definition - - // Cast the opaque `pelement` ptr to `elem_ty` - let elem_ty = ctx.get_llvm_type(generator, elem_ty); - let pelement = ctx - .builder - .build_pointer_cast( - pelement.value, - elem_ty.ptr_type(AddressSpace::default()), - "pelement_casted", - ) - .unwrap(); - - ctx.builder.build_load(pelement, "element").unwrap().as_basic_value_enum() - } else { - // 2) ndims > 0 (other cases), return subndarray - dst_ndarray.value.as_basic_value_enum() - }; - Ok(Some(ValueEnum::Dynamic(result_llvm_value))) +/// Generates code for a subscript expression on an `ndarray`. +/// +/// * `elem_ty` - The `Type` of the `NDArray` elements. +/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. +/// * `src_ndarray` - The `NDArray` value. +/// * `subscript` - The subscript expression used to index into the `ndarray`. +pub fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayObject<'ctx>, + subscript: &Expr>, +) -> Result, String> { + let indexes = gen_ndarray_subscript_ndindexes(generator, ctx, subscript)?; + Ok(ndarray.index(generator, ctx, &indexes, "subndarray")) } /// See [`CodeGenerator::gen_expr`]. @@ -2920,7 +2867,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let tyctx = generator.type_context(ctx.ctx); let pndarray_model = PtrModel(StructModel(NpArray)); - let (dtype, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); + let (&dtype, &ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); let Some(ndarray) = generator.gen_expr(ctx, value)? else { return Ok(None); @@ -2929,10 +2876,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let ndarray = ndarray.to_basic_value_enum(ctx, generator, value.custom.unwrap())?; let ndarray = pndarray_model.check_value(tyctx, ctx.ctx, ndarray).unwrap(); + let ndarray = NDArrayObject { dtype, ndims, instance: ndarray }; - return gen_ndarray_subscript_expr( - generator, ctx, *dtype, *ndims, ndarray, slice, - ); + let result = gen_ndarray_subscript_expr(generator, ctx, ndarray, slice)?; + return Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum()))); } TypeEnum::TTuple { .. } => { let index: u32 = diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index 285d726e..8e526251 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -76,7 +76,7 @@ impl<'ctx> RustNDIndex<'ctx> { dst_ndindex_ptr.gep(ctx, |f| f.data).store(ctx, data); } - /// Allocate an array of `NDIndex`es onto the stack and return its stack pointer. + /// Allocate an array of `NDIndex`es on the stack and return its stack pointer. pub fn alloca_ndindexes( tyctx: TypeContext<'ctx>, ctx: &CodeGenContext<'ctx, '_>, @@ -97,10 +97,10 @@ impl<'ctx> RustNDIndex<'ctx> { } #[must_use] - pub fn deduce_ndims_after_slicing(slices: &[RustNDIndex], original_ndims: i32) -> i32 { - let mut final_ndims: i32 = original_ndims; - for slice in slices { - match slice { + pub fn deduce_ndims_after_indexing(indices: &[RustNDIndex], original_ndims: u64) -> u64 { + let mut final_ndims = original_ndims; + for index in indices { + match index { RustNDIndex::SingleElement(_) => { final_ndims -= 1; } diff --git a/nac3core/src/codegen/numpy_new/indexing.rs b/nac3core/src/codegen/numpy_new/indexing.rs new file mode 100644 index 00000000..5b99d1ad --- /dev/null +++ b/nac3core/src/codegen/numpy_new/indexing.rs @@ -0,0 +1,76 @@ +use crate::{ + codegen::{ + irrt::ndarray::indexing::{call_nac3_ndarray_index, RustNDIndex}, + model::*, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, Unifier}, +}; + +use super::{ + object::{NDArrayObject, ScalarObject, ScalarOrNDArray}, + util::{create_ndims, extract_ndims}, +}; + +impl<'ctx> NDArrayObject<'ctx> { + pub fn deduce_ndims_after_indexing_with( + &self, + unifier: &mut Unifier, + indexes: &[RustNDIndex<'ctx>], + ) -> Type { + let ndims = extract_ndims(unifier, self.ndims); + let new_ndims = RustNDIndex::deduce_ndims_after_indexing(indexes, ndims); + create_ndims(unifier, new_ndims) + } + + #[must_use] + pub fn index_always_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indexes: &[RustNDIndex<'ctx>], + name: &str, + ) -> Self { + let tyctx = generator.type_context(ctx.ctx); + + let dst_ndims = self.deduce_ndims_after_indexing_with(&mut ctx.unifier, indexes); + let dst_ndarray = NDArrayObject::alloca(generator, ctx, dst_ndims, self.dtype, name); + + let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, indexes); + call_nac3_ndarray_index( + generator, + ctx, + num_indexes, + indexes, + self.instance, + dst_ndarray.instance, + ); + + dst_ndarray + } + + pub fn index( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indexes: &[RustNDIndex<'ctx>], + name: &str, + ) -> ScalarOrNDArray<'ctx> { + let tyctx = generator.type_context(ctx.ctx); + let sizet_model = IntModel(SizeT); + + let subndarray = self.index_always_ndarray(generator, ctx, indexes, name); + if subndarray.is_unsized(&ctx.unifier) { + // TODO: This actually never fails, don't use the `checked_` version. + let value = subndarray.checked_get_nth_element( + generator, + ctx, + sizet_model.const_0(tyctx, ctx.ctx), + name, + ); + ScalarOrNDArray::Scalar(ScalarObject { dtype: self.dtype, value }) + } else { + ScalarOrNDArray::NDArray(subndarray) + } + } +} diff --git a/nac3core/src/codegen/numpy_new/mod.rs b/nac3core/src/codegen/numpy_new/mod.rs index f6f8a317..42936115 100644 --- a/nac3core/src/codegen/numpy_new/mod.rs +++ b/nac3core/src/codegen/numpy_new/mod.rs @@ -1,5 +1,6 @@ pub mod broadcast; pub mod factory; +pub mod indexing; pub mod object; pub mod util; pub mod view; diff --git a/nac3core/src/codegen/numpy_new/object.rs b/nac3core/src/codegen/numpy_new/object.rs index 7c944148..ef156088 100644 --- a/nac3core/src/codegen/numpy_new/object.rs +++ b/nac3core/src/codegen/numpy_new/object.rs @@ -29,6 +29,7 @@ pub enum ScalarOrNDArray<'ctx> { impl<'ctx> ScalarOrNDArray<'ctx> { /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. + #[must_use] pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { match self { ScalarOrNDArray::Scalar(scalar) => scalar.value, diff --git a/nac3core/src/codegen/numpy_new/util.rs b/nac3core/src/codegen/numpy_new/util.rs index 43f6c8cd..58e9c1b0 100644 --- a/nac3core/src/codegen/numpy_new/util.rs +++ b/nac3core/src/codegen/numpy_new/util.rs @@ -1,4 +1,8 @@ -use inkwell::types::BasicType; +use inkwell::{ + types::BasicType, + values::{BasicValueEnum, PointerValue}, + AddressSpace, +}; use util::gen_model_memcpy; use crate::{ @@ -249,6 +253,37 @@ impl<'ctx> NDArrayObject<'ctx> { call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance); } + pub fn checked_get_nth_pelement( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + i: Int<'ctx, SizeT>, + name: &str, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.dtype); + + let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, i); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name) + .unwrap() + } + + pub fn checked_get_nth_element( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + i: Int<'ctx, SizeT>, + name: &str, + ) -> BasicValueEnum<'ctx> { + let pelement = self.checked_get_nth_pelement(generator, ctx, i, "pelement"); + ctx.builder.build_load(pelement, name).unwrap() + } + + #[must_use] + pub fn is_unsized(&self, unifier: &Unifier) -> bool { + extract_ndims(unifier, self.ndims) == 0 + } + pub fn size( &self, generator: &mut G,