diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index b082d1e..3bd3a3d 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -7,12 +7,11 @@ use crate::{ }, expr::gen_binop_expr_with_values, irrt::{ - calculate_len_for_slice_range, call_ndarray_calc_broadcast, + self, calculate_len_for_slice_range, call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, }, - llvm_intrinsics, - llvm_intrinsics::call_memcpy_generic, + llvm_intrinsics::{self, call_memcpy_generic}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, }, @@ -32,6 +31,8 @@ use inkwell::{ }; use nac3parser::ast::{Operator, StrRef}; +use super::stmt::BreakContinueHooks; + /// Creates an uninitialized `NDArray` instance. fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, @@ -1984,3 +1985,42 @@ pub fn gen_ndarray_fill<'ctx>( Ok(()) } + +/// Generate a construct that iterates through all scalar elements within an `ndarray`. +/// +/// * `ndarray`: The input ndarray [`NDArrayValue`] +/// * `body_fn`: A lambda containing IR statements that acts on every scalar element within `ndarray`. +pub fn gen_ndarray_iter_scalar_callback<'ctx, G, BodyFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + body: BodyFn, +) -> Result<(), String> +where + G: CodeGenerator + ?Sized, + BodyFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, '_>, + BreakContinueHooks, + BasicValueEnum<'ctx>, + ) -> Result<(), String>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + let ndarray_size = + irrt::call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)); + + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_int(0, false), + (ndarray_size, false), + |generator, ctx, hooks, idx| { + let scalar = unsafe { ndarray.data().get_unchecked(ctx, generator, &idx, None) }; + body(generator, ctx, hooks, scalar) + }, + llvm_usize.const_int(1, false), + )?; + + Ok(()) +}