From bbc68b8b1a71562b171bba72d8c89fed1507ee7b Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 27 Nov 2024 17:29:40 +0800 Subject: [PATCH] [core] codegen: implement ndarray iterator NDIter Based on 50f960ab: core/ndstrides: implement ndarray iterator NDIter A necessary utility to iterate through all elements in a possibly strided ndarray. --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/ndarray/iter.hpp | 146 ++++++++++ nac3core/src/codegen/irrt/ndarray/iter.rs | 67 +++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + nac3core/src/codegen/types/ndarray/mod.rs | 1 + nac3core/src/codegen/types/ndarray/nditer.rs | 256 ++++++++++++++++++ nac3core/src/codegen/values/array.rs | 2 +- nac3core/src/codegen/values/ndarray/mod.rs | 2 + nac3core/src/codegen/values/ndarray/nditer.rs | 176 ++++++++++++ 9 files changed, 652 insertions(+), 1 deletion(-) create mode 100644 nac3core/irrt/irrt/ndarray/iter.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/iter.rs create mode 100644 nac3core/src/codegen/types/ndarray/nditer.rs create mode 100644 nac3core/src/codegen/values/ndarray/nditer.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 088b84fb..1093e8e0 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -5,3 +5,4 @@ #include "irrt/slice.hpp" #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" +#include "irrt/ndarray/iter.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/iter.hpp b/nac3core/irrt/irrt/ndarray/iter.hpp new file mode 100644 index 00000000..69aaaa42 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/iter.hpp @@ -0,0 +1,146 @@ +#pragma once + +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +/** + * @brief Helper struct to enumerate through an ndarray *efficiently*. + * + * Example usage (in pseudo-code): + * ``` + * // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double` + * NDIter nditer; + * nditer.initialize(my_ndarray); + * while (nditer.has_element()) { + * // This body is run 6 (= my_ndarray.size) times. + * + * // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end + * print(nditer.indices); + * + * // 0 -> 1 -> 2 -> 3 -> 4 -> 5 + * print(nditer.nth); + * + * // <1st element> -> <2nd element> -> ... -> <6th element> -> end + * print(*((double *) nditer.element)) + * + * nditer.next(); // Go to next element. + * } + * ``` + * + * Interesting cases: + * - If `my_ndarray.ndims` == 0, there is one iteration. + * - If `my_ndarray.shape` contains zeroes, there are no iterations. + */ +template +struct NDIter { + // Information about the ndarray being iterated over. + SizeT ndims; + SizeT* shape; + SizeT* strides; + + /** + * @brief The current indices. + * + * Must be allocated by the caller. + */ + SizeT* indices; + + /** + * @brief The nth (0-based) index of the current indices. + * + * Initially this is 0. + */ + SizeT nth; + + /** + * @brief Pointer to the current element. + * + * Initially this points to first element of the ndarray. + */ + void* element; + + /** + * @brief Cache for the product of shape. + * + * Could be 0 if `shape` has 0s in it. + */ + SizeT size; + + void initialize(SizeT ndims, SizeT* shape, SizeT* strides, void* element, SizeT* indices) { + this->ndims = ndims; + this->shape = shape; + this->strides = strides; + + this->indices = indices; + this->element = element; + + // Compute size + this->size = 1; + for (SizeT i = 0; i < ndims; i++) { + this->size *= shape[i]; + } + + // `indices` starts on all 0s. + for (SizeT axis = 0; axis < ndims; axis++) + indices[axis] = 0; + nth = 0; + } + + void initialize_by_ndarray(NDArray* ndarray, SizeT* indices) { + // NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first + // element as well. + this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices); + } + + // Is the current iteration valid? + // If true, then `element`, `indices` and `nth` contain details about the current element. + bool has_element() { return nth < size; } + + // Go to the next element. + void next() { + for (SizeT i = 0; i < ndims; i++) { + SizeT axis = ndims - i - 1; + indices[axis]++; + if (indices[axis] >= shape[axis]) { + indices[axis] = 0; + + // TODO: There is something called backstrides to speedup iteration. + // See https://ajcr.net/stride-guide-part-1/, and + // https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides. + element = static_cast(reinterpret_cast(element) - strides[axis] * (shape[axis] - 1)); + } else { + element = static_cast(reinterpret_cast(element) + strides[axis]); + break; + } + } + nth++; + } +}; +} // namespace + +extern "C" { +void __nac3_nditer_initialize(NDIter* iter, NDArray* ndarray, int32_t* indices) { + iter->initialize_by_ndarray(ndarray, indices); +} + +void __nac3_nditer_initialize64(NDIter* iter, NDArray* ndarray, int64_t* indices) { + iter->initialize_by_ndarray(ndarray, indices); +} + +bool __nac3_nditer_has_element(NDIter* iter) { + return iter->has_element(); +} + +bool __nac3_nditer_has_element64(NDIter* iter) { + return iter->has_element(); +} + +void __nac3_nditer_next(NDIter* iter) { + iter->next(); +} + +void __nac3_nditer_next64(NDIter* iter) { + iter->next(); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs new file mode 100644 index 00000000..f62a1989 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -0,0 +1,67 @@ +use inkwell::{ + values::{BasicValueEnum, IntValue}, + AddressSpace, +}; + +use crate::codegen::{ + expr::{create_and_call_function, infer_and_call_function}, + irrt::get_usize_dependent_function_name, + types::ProxyType, + values::{nditer::NDIterValue, ArrayLikeValue, ArraySliceValue, NDArrayValue, ProxyValue}, + CodeGenContext, CodeGenerator, +}; + +pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + iter: NDIterValue<'ctx>, + ndarray: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); + + create_and_call_function( + ctx, + &name, + None, + &[ + (iter.get_type().as_base_type().into(), iter.as_base_value().into()), + (ndarray.get_type().as_base_type().into(), ndarray.as_base_value().into()), + (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), + ], + None, + None, + ); +} + +pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + iter: NDIterValue<'ctx>, +) -> IntValue<'ctx> { + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_has_element"); + + infer_and_call_function( + ctx, + &name, + Some(ctx.ctx.bool_type().into()), + &[iter.as_base_value().into()], + None, + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() +} + +pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + iter: NDIterValue<'ctx>, +) { + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_next"); + + infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index b9d02d12..dca5979c 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -16,8 +16,10 @@ use crate::codegen::{ CodeGenContext, CodeGenerator, }; pub use basic::*; +pub use iter::*; mod basic; +mod iter; /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the /// calculated total size. diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index c53430f5..4da53122 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -23,6 +23,7 @@ use crate::{ pub use contiguous::*; mod contiguous; +pub mod nditer; /// Proxy type for a `ndarray` type in LLVM. #[derive(Debug, PartialEq, Eq, Clone, Copy)] diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs new file mode 100644 index 00000000..46a8064c --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -0,0 +1,256 @@ +use inkwell::{ + context::{AsContextRef, Context}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use super::ProxyType; +use crate::codegen::{ + irrt, + types::structure::{StructField, StructFields}, + values::{nditer::NDIterValue, ArraySliceValue, NDArrayValue, ProxyValue}, + CodeGenContext, CodeGenerator, +}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct NDIterType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct NDIterStructFields<'ctx> { + #[value_type(usize)] + pub ndims: StructField<'ctx, IntValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + pub shape: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + pub strides: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + pub indices: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(usize)] + pub nth: StructField<'ctx, IntValue<'ctx>>, + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub element: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(usize)] + pub size: StructField<'ctx, IntValue<'ctx>>, +} + +impl<'ctx> NDIterType<'ctx> { + /// Checks whether `llvm_ty` represents a `nditer` type, returning [Err] if it does not. + pub fn is_representable( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let ctx = llvm_ty.get_context(); + + let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec(); + + let llvm_ndarray_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); + }; + if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() { + return Err(format!( + "Expected {} fields in `NDArray`, got {}", + llvm_expected_ty.len(), + llvm_ndarray_ty.count_fields() + )); + } + + llvm_expected_ty + .iter() + .enumerate() + .map(|(i, expected_ty)| { + (expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap()) + }) + .try_for_each(|(expected_ty, actual_ty)| { + if expected_ty == actual_ty { + Ok(()) + } else { + Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}")) + } + })?; + + Ok(()) + } + + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> NDIterStructFields<'ctx> { + NDIterStructFields::new(ctx, llvm_usize) + } + + /// See [`NDIterType::fields`]. + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDIterStructFields<'ctx> { + Self::fields(ctx, self.llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of an `NDIter`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + /// Creates an instance of [`NDIter`]. + #[must_use] + pub fn new(generator: &G, ctx: &'ctx Context) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_nditer = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_nditer, llvm_usize } + } + + /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + /// Returns the type of the `size` field of this `nditer` type. + #[must_use] + pub fn size_type(&self) -> IntType<'ctx> { + self.llvm_usize + } + + #[must_use] + pub fn alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(generator, ctx, name), + parent, + indices, + self.llvm_usize, + name, + ) + } + + /// Allocate an [`NDIter`] that iterates through the given `ndarray`. + #[must_use] + pub fn construct( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + ) -> >::Value { + let nditer = self.raw_alloca(generator, ctx, None); + let ndims = ndarray.load_ndims(ctx); + + // The caller has the responsibility to allocate 'indices' for `NDIter`. + let indices = + generator.gen_array_var_alloc(ctx, self.llvm_usize.into(), ndims, None).unwrap(); + + let nditer = >::Value::from_pointer_value( + nditer, + ndarray, + indices, + self.llvm_usize, + None, + ); + + irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, indices); + + nditer + } + + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + value, + parent, + indices, + self.llvm_usize, + name, + ) + } +} + +impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { + type Base = PointerType<'ctx>; + type Value = NDIterValue<'ctx>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn is_representable( + generator: &G, + ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + } + + fn raw_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Base { + generator + .gen_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + name, + ) + .unwrap() + } + + fn array_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + size, + name, + ) + .unwrap() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: NDIterType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 8d14fe8a..78975f06 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -207,7 +207,7 @@ pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: /// Type alias for a function that casts a [`BasicValueEnum`] into a `T`. type ValueDowncastFn<'ctx, T> = - Box, BasicValueEnum<'ctx>) -> T>; + Box, BasicValueEnum<'ctx>) -> T + 'ctx>; /// Type alias for a function that casts a `T` into a [`BasicValueEnum`]. type ValueUpcastFn<'ctx, T> = Box, T) -> BasicValueEnum<'ctx>>; diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index b6ed150a..65e3313f 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -20,6 +20,8 @@ pub use contiguous::*; mod contiguous; +pub mod nditer; + /// Proxy type for accessing an `NDArray` value in LLVM. #[derive(Copy, Clone)] pub struct NDArrayValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs new file mode 100644 index 00000000..37df01bf --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -0,0 +1,176 @@ +use inkwell::{ + types::{BasicType, IntType}, + values::{BasicValueEnum, IntValue, PointerValue}, + AddressSpace, +}; + +use super::{NDArrayValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator}; +use crate::codegen::{ + irrt, + stmt::{gen_for_callback, BreakContinueHooks}, + types::{nditer::NDIterType, structure::StructField}, + values::{ArraySliceValue, TypedArrayLikeAdapter}, + CodeGenContext, CodeGenerator, +}; + +#[derive(Copy, Clone)] +pub struct NDIterValue<'ctx> { + value: PointerValue<'ctx>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> NDIterValue<'ctx> { + /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an + /// instance. + pub fn is_representable( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + ::Type::is_representable(value.get_type(), llvm_usize) + } + + /// Creates an [`NDArrayValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + + Self { value: ptr, parent, indices, llvm_usize, name } + } + + /// Is the current iteration valid? + /// + /// If true, then `element`, `indices` and `nth` contain details about the current element. + /// + /// If `ndarray` is unsized, this returns true only for the first iteration. + /// If `ndarray` is 0-sized, this always returns false. + #[must_use] + pub fn has_element( + &self, + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_nditer_has_element(generator, ctx, *self) + } + + /// Go to the next element. If `has_element()` is false, then this has undefined behavior. + /// + /// If `ndarray` is unsized, this can only be called once. + /// If `ndarray` is 0-sized, this can never be called. + pub fn next(&self, generator: &G, ctx: &CodeGenContext<'ctx, '_>) { + irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self); + } + + fn element(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields(ctx.ctx).element + } + + /// Get pointer to the current element. + #[must_use] + pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let elem_ty = self.parent.dtype; + + let p = self.element(ctx).get(ctx, self.as_base_value(), None); + ctx.builder + .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") + .unwrap() + } + + /// Get the value of the current element. + #[must_use] + pub fn get_scalar(&self, ctx: &CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> { + let p = self.get_pointer(ctx); + ctx.builder.build_load(p, "value").unwrap() + } + + fn nth(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields(ctx.ctx).nth + } + + /// Get the index of the current element if this ndarray were a flat ndarray. + #[must_use] + pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + self.nth(ctx).get(ctx, self.as_base_value(), None) + } + + /// Get the indices of the current element. + #[must_use] + pub fn get_indices( + &'ctx self, + ) -> impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, IntValue<'ctx>> + { + TypedArrayLikeAdapter::from( + self.indices, + Box::new(|ctx, val| { + ctx.builder + .build_int_z_extend_or_bit_cast(val.into_int_value(), self.llvm_usize, "") + .unwrap() + }), + Box::new(|_, val| val.into()), + ) + } +} + +impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { + type Base = PointerValue<'ctx>; + type Type = NDIterType<'ctx>; + + fn get_type(&self) -> Self::Type { + NDIterType::from_type(self.as_base_value().get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: NDIterValue<'ctx>) -> Self { + value.as_base_value() + } +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Iterate through every element in the ndarray. + /// + /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterHandle`] to + /// get properties of the current iteration (e.g., the current element, indices, etc.) + pub fn foreach<'a, G, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + body: F, + ) -> Result<(), String> + where + G: CodeGenerator + ?Sized, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + NDIterValue<'ctx>, + ) -> Result<(), String>, + { + gen_for_callback( + generator, + ctx, + Some("ndarray_foreach"), + |generator, ctx| { + Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self)) + }, + |generator, ctx, nditer| Ok(nditer.has_element(generator, ctx)), + |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), + |generator, ctx, nditer| { + nditer.next(generator, ctx); + Ok(()) + }, + ) + } +}