[core] codegen/values/ndarray: Add fold utilities
Needed for np_{any,all}.
This commit is contained in:
parent
357970a793
commit
18e8e5269f
101
nac3core/src/codegen/values/ndarray/fold.rs
Normal file
101
nac3core/src/codegen/values/ndarray/fold.rs
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
use inkwell::values::{BasicValue, BasicValueEnum};
|
||||||
|
|
||||||
|
use super::{NDArrayValue, NDIterValue, ScalarOrNDArray};
|
||||||
|
use crate::codegen::{
|
||||||
|
stmt::{gen_for_callback, BreakContinueHooks},
|
||||||
|
types::ndarray::NDIterType,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
/// Folds the elements of this ndarray into an accumulator value by applying `f`, returning the
|
||||||
|
/// final value.
|
||||||
|
///
|
||||||
|
/// `f` has access to [`BreakContinueHooks`] to short-circuit the `fold` operation, an instance
|
||||||
|
/// of `V` representing the current accumulated value, and an [`NDIterValue`] to get the
|
||||||
|
/// properties of the current iterated element.
|
||||||
|
pub fn fold<'a, G, V, F>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
init: V,
|
||||||
|
f: F,
|
||||||
|
) -> Result<V, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
V: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>,
|
||||||
|
<V as TryFrom<BasicValueEnum<'ctx>>>::Error: std::fmt::Debug,
|
||||||
|
F: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
BreakContinueHooks<'ctx>,
|
||||||
|
V,
|
||||||
|
NDIterValue<'ctx>,
|
||||||
|
) -> Result<V, String>,
|
||||||
|
{
|
||||||
|
let acc_ptr =
|
||||||
|
generator.gen_var_alloc(ctx, init.as_basic_value_enum().get_type(), None).unwrap();
|
||||||
|
ctx.builder.build_store(acc_ptr, init).unwrap();
|
||||||
|
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
Some("ndarray_fold"),
|
||||||
|
|generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)),
|
||||||
|
|_, ctx, nditer| Ok(nditer.has_element(ctx)),
|
||||||
|
|generator, ctx, hooks, nditer| {
|
||||||
|
let acc = V::try_from(ctx.builder.build_load(acc_ptr, "").unwrap()).unwrap();
|
||||||
|
let acc = f(generator, ctx, hooks, acc, nditer)?;
|
||||||
|
ctx.builder.build_store(acc_ptr, acc).unwrap();
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|_, ctx, nditer| {
|
||||||
|
nditer.next(ctx);
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let acc = ctx.builder.build_load(acc_ptr, "").unwrap();
|
||||||
|
Ok(V::try_from(acc).unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
|
/// See [`NDArrayValue::fold`].
|
||||||
|
///
|
||||||
|
/// The primary differences between this function and `NDArrayValue::fold` are:
|
||||||
|
///
|
||||||
|
/// - The 3rd parameter of `f` is an `Option` of hooks, since `break`/`continue` hooks are not
|
||||||
|
/// available if this instance represents a scalar value.
|
||||||
|
/// - The 5th parameter of `f` is a [`BasicValueEnum`], since no [iterator][`NDIterValue`] will
|
||||||
|
/// be created if this instance represents a scalar value.
|
||||||
|
pub fn fold<'a, G, V, F>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
init: V,
|
||||||
|
f: F,
|
||||||
|
) -> Result<V, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
V: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>,
|
||||||
|
<V as TryFrom<BasicValueEnum<'ctx>>>::Error: std::fmt::Debug,
|
||||||
|
F: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
Option<&BreakContinueHooks<'ctx>>,
|
||||||
|
V,
|
||||||
|
BasicValueEnum<'ctx>,
|
||||||
|
) -> Result<V, String>,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
ScalarOrNDArray::Scalar(v) => f(generator, ctx, None, init, *v),
|
||||||
|
ScalarOrNDArray::NDArray(v) => {
|
||||||
|
v.fold(generator, ctx, init, |generator, ctx, hooks, acc, nditer| {
|
||||||
|
let elem = nditer.get_scalar(ctx);
|
||||||
|
f(generator, ctx, Some(&hooks), acc, elem)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -30,6 +30,7 @@ pub use nditer::*;
|
|||||||
|
|
||||||
mod broadcast;
|
mod broadcast;
|
||||||
mod contiguous;
|
mod contiguous;
|
||||||
|
mod fold;
|
||||||
mod indexing;
|
mod indexing;
|
||||||
mod map;
|
mod map;
|
||||||
mod matmul;
|
mod matmul;
|
||||||
|
Loading…
Reference in New Issue
Block a user