1
0
forked from M-Labs/nac3

WIP: core: properly allocate dst_ndarray subscript

This commit is contained in:
lyken 2024-07-11 17:32:57 +08:00
parent 39a05d6be6
commit 635542a36d
3 changed files with 53 additions and 48 deletions

View File

@ -17,7 +17,7 @@ use crate::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_memcpy_generic,
},
need_sret, numpy,
need_sret, numpy::{self, call_ndarray_subscript_impl},
stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var,
@ -2186,7 +2186,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
}
// Finally, perform the actual subscript logic
let subndarray = call_nac3_ndarray_subscript_and_alloc_dst(generator, ctx, ndarray, &ndslices.iter().collect_vec());
let subndarray = call_ndarray_subscript_impl(generator, ctx, ndarray, &ndslices.iter().collect_vec())?;
// ...and return the result
let result = ValueEnum::Dynamic(subndarray.ptr.into());

View File

@ -1325,8 +1325,7 @@ pub fn call_nac3_ndarray_deduce_ndims_after_slicing_raw<'ctx>(
.into_int_value()
}
// TODO: RENAME ME AND MY FRIENDS
pub fn call_nac3_ndarray_subscript_raw<'ctx>(
pub fn call_nac3_ndarray_subscript<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
num_slices: IntValue<'ctx>,
@ -1366,45 +1365,3 @@ pub fn call_nac3_ndarray_subscript_raw<'ctx>(
)
.unwrap();
}
pub fn call_nac3_ndarray_subscript_and_alloc_dst<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
ndslices: &Vec<&NDSlice<'ctx>>,
) -> NpArrayValue<'ctx>
where
G: CodeGenerator + ?Sized,
{
// First we will calculate the correct ndims of the dst_ndarray
// Then allocate for dst_ndarray (A known `ndims` value is required for this)
// Finally do call the IRRT function that actually does subscript
let size_type = ndarray.ty.size_type;
// Prepare the argument `ndims`
let ndims = ndarray.load_ndims(ctx);
// Prepare the argument `num_slices` in LLVM - which conveniently is simply `ndslices.len()`
let num_slices = size_type.const_int(ndslices.len() as u64, false);
// Prepare the argument `slices`
let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices);
// Deduce the ndims
let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing_raw(
ctx,
ndarray.ty.size_type,
ndims,
num_slices,
ndslices_ptr,
);
// Allocate `dst_ndarray`
let dst_ndarray =
ndarray.ty.var_alloc(generator, ctx, dst_ndims, Some("subscript_dst_ndarray"));
call_nac3_ndarray_subscript_raw(ctx, ndarray, num_slices, ndslices_ptr, dst_ndarray);
dst_ndarray
}

View File

@ -36,8 +36,9 @@ use nac3parser::ast::{Operator, StrRef};
use super::{
classes::NpArrayValue,
irrt::{
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, get_irrt_ndarray_ptr_type,
get_opaque_uint8_ptr_type,
call_nac3_ndarray_deduce_ndims_after_slicing_raw, call_nac3_ndarray_set_strides_by_shape,
call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type,
get_opaque_uint8_ptr_type, IrrtNDSlice, NDSlice,
},
stmt::gen_return,
};
@ -2319,6 +2320,53 @@ where
}
}
pub fn call_ndarray_subscript_impl<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
ndslices: &Vec<&NDSlice<'ctx>>,
) -> Result<NpArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
{
// First we will calculate the correct ndims of the dst_ndarray
// Then allocate for dst_ndarray (A known `ndims` value is required for this)
// Finally do call the IRRT function that actually does subscript
let size_type = ndarray.ty.size_type;
// Prepare the argument `ndims`
let ndims = ndarray.load_ndims(ctx);
// Prepare the argument `num_slices` in LLVM - which conveniently is simply `ndslices.len()`
let num_slices = size_type.const_int(ndslices.len() as u64, false);
// Prepare the argument `slices`
let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices);
// Deduce the ndims
let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing_raw(
ctx,
ndarray.ty.size_type,
ndims,
num_slices,
ndslices_ptr,
);
// Allocate `dst_ndarray`
let dst_ndarray = alloca_ndarray_and_init(
generator,
ctx,
ndarray.ty.elem_type,
NDArrayInitMode::SetNDim { ndim: dst_ndims },
Some("subndarray"),
)?;
call_nac3_ndarray_subscript(ctx, ndarray, num_slices, ndslices_ptr, dst_ndarray);
Ok(dst_ndarray)
}
/// LLVM-typed implementation for generating the implementation for constructing an empty `NDArray`.
fn call_ndarray_empty_impl<'ctx, G>(
generator: &mut G,