From 9d8358c0e76696d69800bb0a31b8f86e43c97ed8 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 12:00:31 +0800 Subject: [PATCH] core/ndstrides: implement ndarray iterator NDIter A necessary utility to iterate through all elements in a possibly strided ndarray. --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/ndarray/iter.hpp | 146 ++++++++++++++ nac3core/src/codegen/irrt/mod.rs | 31 ++- nac3core/src/codegen/object/ndarray/mod.rs | 2 + nac3core/src/codegen/object/ndarray/nditer.rs | 178 ++++++++++++++++++ 5 files changed, 357 insertions(+), 1 deletion(-) create mode 100644 nac3core/irrt/irrt/ndarray/iter.hpp create mode 100644 nac3core/src/codegen/object/ndarray/nditer.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 2b3760d6..58d18f8a 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -6,3 +6,4 @@ #include "irrt/slice.hpp" #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" +#include "irrt/ndarray/iter.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/iter.hpp b/nac3core/irrt/irrt/ndarray/iter.hpp new file mode 100644 index 00000000..d419b7e8 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/iter.hpp @@ -0,0 +1,146 @@ +#pragma once + +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +/** + * @brief Helper struct to enumerate through an ndarray *efficiently*. + * + * Example usage (in pseudo-code): + * ``` + * // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double` + * NDIter nditer; + * nditer.initialize(my_ndarray); + * while (nditer.has_element()) { + * // This body is run 6 (= my_ndarray.size) times. + * + * // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end + * print(nditer.indices); + * + * // 0 -> 1 -> 2 -> 3 -> 4 -> 5 + * print(nditer.nth); + * + * // <1st element> -> <2nd element> -> ... -> <6th element> -> end + * print(*((double *) nditer.element)) + * + * nditer.next(); // Go to next element. + * } + * ``` + * + * Interesting cases: + * - If `my_ndarray.ndims` == 0, there is one iteration. + * - If `my_ndarray.shape` contains zeroes, there are no iterations. + */ +template +struct NDIter { + // Information about the ndarray being iterated over. + SizeT ndims; + SizeT* shape; + SizeT* strides; + + /** + * @brief The current indices. + * + * Must be allocated by the caller. + */ + SizeT* indices; + + /** + * @brief The nth (0-based) index of the current indices. + * + * Initially this is 0. + */ + SizeT nth; + + /** + * @brief Pointer to the current element. + * + * Initially this points to first element of the ndarray. + */ + uint8_t* element; + + /** + * @brief Cache for the product of shape. + * + * Could be 0 if `shape` has 0s in it. + */ + SizeT size; + + void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element, SizeT* indices) { + this->ndims = ndims; + this->shape = shape; + this->strides = strides; + + this->indices = indices; + this->element = element; + + // Compute size + this->size = 1; + for (SizeT i = 0; i < ndims; i++) { + this->size *= shape[i]; + } + + // `indices` starts on all 0s. + for (SizeT axis = 0; axis < ndims; axis++) + indices[axis] = 0; + nth = 0; + } + + void initialize_by_ndarray(NDArray* ndarray, SizeT* indices) { + // NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first + // element as well. + this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices); + } + + // Is the current iteration valid? + // If true, then `element`, `indices` and `nth` contain details about the current element. + bool has_element() { return nth < size; } + + // Go to the next element. + void next() { + for (SizeT i = 0; i < ndims; i++) { + SizeT axis = ndims - i - 1; + indices[axis]++; + if (indices[axis] >= shape[axis]) { + indices[axis] = 0; + + // TODO: There is something called backstrides to speedup iteration. + // See https://ajcr.net/stride-guide-part-1/, and + // https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides. + element -= strides[axis] * (shape[axis] - 1); + } else { + element += strides[axis]; + break; + } + } + nth++; + } +}; +} // namespace + +extern "C" { +void __nac3_nditer_initialize(NDIter* iter, NDArray* ndarray, int32_t* indices) { + iter->initialize_by_ndarray(ndarray, indices); +} + +void __nac3_nditer_initialize64(NDIter* iter, NDArray* ndarray, int64_t* indices) { + iter->initialize_by_ndarray(ndarray, indices); +} + +bool __nac3_nditer_has_element(NDIter* iter) { + return iter->has_element(); +} + +bool __nac3_nditer_has_element64(NDIter* iter) { + return iter->has_element(); +} + +void __nac3_nditer_next(NDIter* iter) { + iter->next(); +} + +void __nac3_nditer_next64(NDIter* iter) { + iter->next(); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 641b174d..472aaaa1 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -19,7 +19,7 @@ use super::{ llvm_intrinsics, macros::codegen_unreachable, model::*, - object::ndarray::NDArray, + object::ndarray::{nditer::NDIter, NDArray}, stmt::gen_for_callback_incrementing, CodeGenContext, CodeGenerator, }; @@ -1084,3 +1084,32 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); } + +pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + iter: Instance<'ctx, Ptr>>, + ndarray: Instance<'ctx, Ptr>>, + indices: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); + FnCall::builder(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void(); +} + +pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + iter: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_element"); + FnCall::builder(generator, ctx, &name).arg(iter).returning_auto("has_element") +} + +pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + iter: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next"); + FnCall::builder(generator, ctx, &name).arg(iter).returning_void(); +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 9178b88e..d1de92d1 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -17,6 +17,8 @@ use crate::{ use super::any::AnyObject; +pub mod nditer; + /// Fields of [`NDArray`] pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { pub data: F::Output>>, diff --git a/nac3core/src/codegen/object/ndarray/nditer.rs b/nac3core/src/codegen/object/ndarray/nditer.rs new file mode 100644 index 00000000..b0d29b7a --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/nditer.rs @@ -0,0 +1,178 @@ +use inkwell::{types::BasicType, values::PointerValue, AddressSpace}; + +use super::NDArrayObject; +use crate::codegen::{ + irrt::{call_nac3_nditer_has_element, call_nac3_nditer_initialize, call_nac3_nditer_next}, + model::*, + object::any::AnyObject, + stmt::{gen_for_callback, BreakContinueHooks}, + CodeGenContext, CodeGenerator, +}; + +/// Fields of [`NDIter`] +pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> { + pub ndims: F::Output>, + pub shape: F::Output>>, + pub strides: F::Output>>, + + pub indices: F::Output>>, + pub nth: F::Output>, + pub element: F::Output>>, + + pub size: F::Output>, +} + +/// An IRRT helper structure used to iterate through an ndarray. +#[derive(Debug, Clone, Copy, Default)] +pub struct NDIter; + +impl<'ctx> StructKind<'ctx> for NDIter { + type Fields> = NDIterFields<'ctx, F>; + + fn iter_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + ndims: traversal.add_auto("ndims"), + shape: traversal.add_auto("shape"), + strides: traversal.add_auto("strides"), + + indices: traversal.add_auto("indices"), + nth: traversal.add_auto("nth"), + element: traversal.add_auto("element"), + + size: traversal.add_auto("size"), + } + } +} + +/// A helper structure with a convenient interface to interact with [`NDIter`]. +#[derive(Debug, Clone)] +pub struct NDIterHandle<'ctx> { + instance: Instance<'ctx, Ptr>>, + /// The ndarray this [`NDIter`] to iterating over. + ndarray: NDArrayObject<'ctx>, + /// The current indices of [`NDIter`]. + indices: Instance<'ctx, Ptr>>, +} + +impl<'ctx> NDIterHandle<'ctx> { + /// Allocate an [`NDIter`] that iterates through an ndarray. + pub fn new( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayObject<'ctx>, + ) -> Self { + let nditer = Struct(NDIter).alloca(generator, ctx); + let ndims = ndarray.ndims_llvm(generator, ctx.ctx); + + // The caller has the responsibility to allocate 'indices' for `NDIter`. + let indices = Int(SizeT).array_alloca(generator, ctx, ndims.value); + call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices); + + NDIterHandle { ndarray, instance: nditer, indices } + } + + /// Is the current iteration valid? + /// + /// If true, then `element`, `indices` and `nth` contain details about the current element. + /// + /// If `ndarray` is unsized, this returns true only for the first iteration. + /// If `ndarray` is 0-sized, this always returns false. + #[must_use] + pub fn has_element( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_nditer_has_element(generator, ctx, self.instance) + } + + /// Go to the next element. If `has_element()` is false, then this has undefined behavior. + /// + /// If `ndarray` is unsized, this can only be called once. + /// If `ndarray` is 0-sized, this can never be called. + pub fn next( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + call_nac3_nditer_next(generator, ctx, self.instance); + } + + /// Get pointer to the current element. + #[must_use] + pub fn get_pointer( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype); + + let p = self.instance.get(generator, ctx, |f| f.element); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element") + .unwrap() + } + + /// Get the value of the current element. + #[must_use] + pub fn get_scalar( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + let p = self.get_pointer(generator, ctx); + let value = ctx.builder.build_load(p, "value").unwrap(); + AnyObject { ty: self.ndarray.dtype, value } + } + + /// Get the index of the current element if this ndarray were a flat ndarray. + #[must_use] + pub fn get_index( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + self.instance.get(generator, ctx, |f| f.nth) + } + + /// Get the indices of the current element. + #[must_use] + pub fn get_indices(&self) -> Instance<'ctx, Ptr>> { + self.indices + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Iterate through every element in the ndarray. + /// + /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterHandle`] to + /// get properties of the current iteration (e.g., the current element, indices, etc.) + pub fn foreach<'a, G, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + body: F, + ) -> Result<(), String> + where + G: CodeGenerator + ?Sized, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + NDIterHandle<'ctx>, + ) -> Result<(), String>, + { + gen_for_callback( + generator, + ctx, + Some("ndarray_foreach"), + |generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)), + |generator, ctx, nditer| Ok(nditer.has_element(generator, ctx).value), + |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), + |generator, ctx, nditer| { + nditer.next(generator, ctx); + Ok(()) + }, + ) + } +}