core: add gen_ndarray_iter_scalar_callback

This commit is contained in:
lyken 2024-06-20 15:59:33 +08:00
parent 18b4b150d8
commit e52c1c4423
1 changed files with 43 additions and 3 deletions

View File

@ -7,12 +7,11 @@ use crate::{
},
expr::gen_binop_expr_with_values,
irrt::{
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
self, calculate_len_for_slice_range, call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
},
llvm_intrinsics,
llvm_intrinsics::call_memcpy_generic,
llvm_intrinsics::{self, call_memcpy_generic},
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator,
},
@ -32,6 +31,8 @@ use inkwell::{
};
use nac3parser::ast::{Operator, StrRef};
use super::stmt::BreakContinueHooks;
/// Creates an uninitialized `NDArray` instance.
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
@ -1984,3 +1985,42 @@ pub fn gen_ndarray_fill<'ctx>(
Ok(())
}
/// Generate a construct that iterates through all scalar elements within an `ndarray`.
///
/// * `ndarray`: The input ndarray [`NDArrayValue`]
/// * `body_fn`: A lambda containing IR statements that acts on every scalar element within `ndarray`.
pub fn gen_ndarray_iter_scalar_callback<'ctx, G, BodyFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
body: BodyFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
BodyFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, '_>,
BreakContinueHooks,
BasicValueEnum<'ctx>,
) -> Result<(), String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_size =
irrt::call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(0, false),
(ndarray_size, false),
|generator, ctx, hooks, idx| {
let scalar = unsafe { ndarray.data().get_unchecked(ctx, generator, &idx, None) };
body(generator, ctx, hooks, scalar)
},
llvm_usize.const_int(1, false),
)?;
Ok(())
}