From 635542a36dc3f3ef380ea7dc29c5b5265f555cd8 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 11 Jul 2024 17:32:57 +0800 Subject: [PATCH] WIP: core: properly allocate dst_ndarray subscript --- nac3core/src/codegen/expr.rs | 4 +-- nac3core/src/codegen/irrt/mod.rs | 45 +-------------------------- nac3core/src/codegen/numpy.rs | 52 ++++++++++++++++++++++++++++++-- 3 files changed, 53 insertions(+), 48 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 25493590..c91d8bba 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -17,7 +17,7 @@ use crate::{ call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, call_memcpy_generic, }, - need_sret, numpy, + need_sret, numpy::{self, call_ndarray_subscript_impl}, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, @@ -2186,7 +2186,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( } // Finally, perform the actual subscript logic - let subndarray = call_nac3_ndarray_subscript_and_alloc_dst(generator, ctx, ndarray, &ndslices.iter().collect_vec()); + let subndarray = call_ndarray_subscript_impl(generator, ctx, ndarray, &ndslices.iter().collect_vec())?; // ...and return the result let result = ValueEnum::Dynamic(subndarray.ptr.into()); diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 0df3bc50..2f8f6ae7 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1325,8 +1325,7 @@ pub fn call_nac3_ndarray_deduce_ndims_after_slicing_raw<'ctx>( .into_int_value() } -// TODO: RENAME ME AND MY FRIENDS -pub fn call_nac3_ndarray_subscript_raw<'ctx>( +pub fn call_nac3_ndarray_subscript<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NpArrayValue<'ctx>, num_slices: IntValue<'ctx>, @@ -1366,45 +1365,3 @@ pub fn call_nac3_ndarray_subscript_raw<'ctx>( ) .unwrap(); } - -pub fn call_nac3_ndarray_subscript_and_alloc_dst<'ctx, G>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NpArrayValue<'ctx>, - ndslices: &Vec<&NDSlice<'ctx>>, -) -> NpArrayValue<'ctx> -where - G: CodeGenerator + ?Sized, -{ - // First we will calculate the correct ndims of the dst_ndarray - // Then allocate for dst_ndarray (A known `ndims` value is required for this) - // Finally do call the IRRT function that actually does subscript - - let size_type = ndarray.ty.size_type; - - // Prepare the argument `ndims` - let ndims = ndarray.load_ndims(ctx); - - // Prepare the argument `num_slices` in LLVM - which conveniently is simply `ndslices.len()` - let num_slices = size_type.const_int(ndslices.len() as u64, false); - - // Prepare the argument `slices` - let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices); - - // Deduce the ndims - let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing_raw( - ctx, - ndarray.ty.size_type, - ndims, - num_slices, - ndslices_ptr, - ); - - // Allocate `dst_ndarray` - let dst_ndarray = - ndarray.ty.var_alloc(generator, ctx, dst_ndims, Some("subscript_dst_ndarray")); - - call_nac3_ndarray_subscript_raw(ctx, ndarray, num_slices, ndslices_ptr, dst_ndarray); - - dst_ndarray -} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index c4c80b35..317cfcff 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -36,8 +36,9 @@ use nac3parser::ast::{Operator, StrRef}; use super::{ classes::NpArrayValue, irrt::{ - call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, get_irrt_ndarray_ptr_type, - get_opaque_uint8_ptr_type, + call_nac3_ndarray_deduce_ndims_after_slicing_raw, call_nac3_ndarray_set_strides_by_shape, + call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type, + get_opaque_uint8_ptr_type, IrrtNDSlice, NDSlice, }, stmt::gen_return, }; @@ -2319,6 +2320,53 @@ where } } +pub fn call_ndarray_subscript_impl<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NpArrayValue<'ctx>, + ndslices: &Vec<&NDSlice<'ctx>>, +) -> Result, String> +where + G: CodeGenerator + ?Sized, +{ + // First we will calculate the correct ndims of the dst_ndarray + // Then allocate for dst_ndarray (A known `ndims` value is required for this) + // Finally do call the IRRT function that actually does subscript + + let size_type = ndarray.ty.size_type; + + // Prepare the argument `ndims` + let ndims = ndarray.load_ndims(ctx); + + // Prepare the argument `num_slices` in LLVM - which conveniently is simply `ndslices.len()` + let num_slices = size_type.const_int(ndslices.len() as u64, false); + + // Prepare the argument `slices` + let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices); + + // Deduce the ndims + let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing_raw( + ctx, + ndarray.ty.size_type, + ndims, + num_slices, + ndslices_ptr, + ); + + // Allocate `dst_ndarray` + let dst_ndarray = alloca_ndarray_and_init( + generator, + ctx, + ndarray.ty.elem_type, + NDArrayInitMode::SetNDim { ndim: dst_ndims }, + Some("subndarray"), + )?; + + call_nac3_ndarray_subscript(ctx, ndarray, num_slices, ndslices_ptr, dst_ndarray); + + Ok(dst_ndarray) +} + /// LLVM-typed implementation for generating the implementation for constructing an empty `NDArray`. fn call_ndarray_empty_impl<'ctx, G>( generator: &mut G,