diff --git a/nac3core/src/codegen/values/ndarray/fold.rs b/nac3core/src/codegen/values/ndarray/fold.rs new file mode 100644 index 00000000..7c8aebd4 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/fold.rs @@ -0,0 +1,101 @@ +use inkwell::values::{BasicValue, BasicValueEnum}; + +use super::{NDArrayValue, NDIterValue, ScalarOrNDArray}; +use crate::codegen::{ + stmt::{gen_for_callback, BreakContinueHooks}, + types::ndarray::NDIterType, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayValue<'ctx> { + /// Folds the elements of this ndarray into an accumulator value by applying `f`, returning the + /// final value. + /// + /// `f` has access to [`BreakContinueHooks`] to short-circuit the `fold` operation, an instance + /// of `V` representing the current accumulated value, and an [`NDIterValue`] to get the + /// properties of the current iterated element. + pub fn fold<'a, G, V, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: V, + f: F, + ) -> Result + where + G: CodeGenerator + ?Sized, + V: BasicValue<'ctx> + TryFrom>, + >>::Error: std::fmt::Debug, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + V, + NDIterValue<'ctx>, + ) -> Result, + { + let acc_ptr = + generator.gen_var_alloc(ctx, init.as_basic_value_enum().get_type(), None).unwrap(); + ctx.builder.build_store(acc_ptr, init).unwrap(); + + gen_for_callback( + generator, + ctx, + Some("ndarray_fold"), + |generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)), + |_, ctx, nditer| Ok(nditer.has_element(ctx)), + |generator, ctx, hooks, nditer| { + let acc = V::try_from(ctx.builder.build_load(acc_ptr, "").unwrap()).unwrap(); + let acc = f(generator, ctx, hooks, acc, nditer)?; + ctx.builder.build_store(acc_ptr, acc).unwrap(); + Ok(()) + }, + |_, ctx, nditer| { + nditer.next(ctx); + Ok(()) + }, + )?; + + let acc = ctx.builder.build_load(acc_ptr, "").unwrap(); + Ok(V::try_from(acc).unwrap()) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// See [`NDArrayValue::fold`]. + /// + /// The primary differences between this function and `NDArrayValue::fold` are: + /// + /// - The 3rd parameter of `f` is an `Option` of hooks, since `break`/`continue` hooks are not + /// available if this instance represents a scalar value. + /// - The 5th parameter of `f` is a [`BasicValueEnum`], since no [iterator][`NDIterValue`] will + /// be created if this instance represents a scalar value. + pub fn fold<'a, G, V, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: V, + f: F, + ) -> Result + where + G: CodeGenerator + ?Sized, + V: BasicValue<'ctx> + TryFrom>, + >>::Error: std::fmt::Debug, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + Option<&BreakContinueHooks<'ctx>>, + V, + BasicValueEnum<'ctx>, + ) -> Result, + { + match self { + ScalarOrNDArray::Scalar(v) => f(generator, ctx, None, init, *v), + ScalarOrNDArray::NDArray(v) => { + v.fold(generator, ctx, init, |generator, ctx, hooks, acc, nditer| { + let elem = nditer.get_scalar(ctx); + f(generator, ctx, Some(&hooks), acc, elem) + }) + } + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 705412e0..1bf5db31 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -30,6 +30,7 @@ pub use nditer::*; mod broadcast; mod contiguous; +mod fold; mod indexing; mod map; mod matmul;