1
0
forked from M-Labs/nac3

core/ndstrides: implement ndarray iterator NDIter

A necessary utility to iterate through all elements in a possibly strided ndarray.
This commit is contained in:
lyken 2024-08-20 12:00:31 +08:00 committed by DSLstandard
parent 92e7103ec7
commit 40c42b571a
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
5 changed files with 358 additions and 1 deletions

View File

@ -6,3 +6,4 @@
#include "irrt/slice.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/ndarray/iter.hpp"

View File

@ -0,0 +1,146 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
/**
* @brief Helper struct to enumerate through an ndarray *efficiently*.
*
* Example usage (in pseudo-code):
* ```
* // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double`
* NDIter nditer;
* nditer.initialize(my_ndarray);
* while (nditer.has_element()) {
* // This body is run 6 (= my_ndarray.size) times.
*
* // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end
* print(nditer.indices);
*
* // 0 -> 1 -> 2 -> 3 -> 4 -> 5
* print(nditer.nth);
*
* // <1st element> -> <2nd element> -> ... -> <6th element> -> end
* print(*((double *) nditer.element))
*
* nditer.next(); // Go to next element.
* }
* ```
*
* Interesting cases:
* - If `my_ndarray.ndims` == 0, there is one iteration.
* - If `my_ndarray.shape` contains zeroes, there are no iterations.
*/
template<typename SizeT>
struct NDIter {
// Information about the ndarray being iterated over.
SizeT ndims;
SizeT* shape;
SizeT* strides;
/**
* @brief The current indices.
*
* Must be allocated by the caller.
*/
SizeT* indices;
/**
* @brief The nth (0-based) index of the current indices.
*
* Initially this is 0.
*/
SizeT nth;
/**
* @brief Pointer to the current element.
*
* Initially this points to first element of the ndarray.
*/
uint8_t* element;
/**
* @brief Cache for the product of shape.
*
* Could be 0 if `shape` has 0s in it.
*/
SizeT size;
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element, SizeT* indices) {
this->ndims = ndims;
this->shape = shape;
this->strides = strides;
this->indices = indices;
this->element = element;
// Compute size
this->size = 1;
for (SizeT i = 0; i < ndims; i++) {
this->size *= shape[i];
}
// `indices` starts on all 0s.
for (SizeT axis = 0; axis < ndims; axis++)
indices[axis] = 0;
nth = 0;
}
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
// NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first
// element as well.
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
}
// Is the current iteration valid?
// If true, then `element`, `indices` and `nth` contain details about the current element.
bool has_element() { return nth < size; }
// Go to the next element.
void next() {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
indices[axis]++;
if (indices[axis] >= shape[axis]) {
indices[axis] = 0;
// TODO: There is something called backstrides to speedup iteration.
// See https://ajcr.net/stride-guide-part-1/, and
// https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
element -= strides[axis] * (shape[axis] - 1);
} else {
element += strides[axis];
break;
}
}
nth++;
}
};
} // namespace
extern "C" {
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray, int32_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
void __nac3_nditer_initialize64(NDIter<int64_t>* iter, NDArray<int64_t>* ndarray, int64_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
bool __nac3_nditer_has_element(NDIter<int32_t>* iter) {
return iter->has_element();
}
bool __nac3_nditer_has_element64(NDIter<int64_t>* iter) {
return iter->has_element();
}
void __nac3_nditer_next(NDIter<int32_t>* iter) {
iter->next();
}
void __nac3_nditer_next64(NDIter<int64_t>* iter) {
iter->next();
}
}

View File

@ -8,7 +8,7 @@ use super::{
llvm_intrinsics,
macros::codegen_unreachable,
model::*,
object::ndarray::NDArray,
object::ndarray::{nditer::NDIter, NDArray},
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
@ -1083,3 +1083,32 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
}
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
FnCall::builder(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
}
pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
) -> Instance<'ctx, Int<Bool>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_element");
FnCall::builder(generator, ctx, &name).arg(iter).returning_auto("has_element")
}
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
FnCall::builder(generator, ctx, &name).arg(iter).returning_void();
}

View File

@ -1,3 +1,5 @@
pub mod nditer;
use inkwell::{context::Context, types::BasicType, values::PointerValue, AddressSpace};
use crate::{

View File

@ -0,0 +1,179 @@
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
use crate::codegen::{
irrt::{call_nac3_nditer_has_element, call_nac3_nditer_initialize, call_nac3_nditer_next},
model::*,
object::any::AnyObject,
stmt::{gen_for_callback, BreakContinueHooks},
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
/// Fields of [`NDIter`]
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Output<Int<SizeT>>,
pub shape: F::Output<Ptr<Int<SizeT>>>,
pub strides: F::Output<Ptr<Int<SizeT>>>,
pub indices: F::Output<Ptr<Int<SizeT>>>,
pub nth: F::Output<Int<SizeT>>,
pub element: F::Output<Ptr<Int<Byte>>>,
pub size: F::Output<Int<SizeT>>,
}
/// An IRRT helper structure used to iterate through an ndarray.
#[derive(Debug, Clone, Copy, Default)]
pub struct NDIter;
impl<'ctx> StructKind<'ctx> for NDIter {
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
ndims: traversal.add_auto("ndims"),
shape: traversal.add_auto("shape"),
strides: traversal.add_auto("strides"),
indices: traversal.add_auto("indices"),
nth: traversal.add_auto("nth"),
element: traversal.add_auto("element"),
size: traversal.add_auto("size"),
}
}
}
/// A helper structure with a convenient interface to interact with [`NDIter`].
#[derive(Debug, Clone)]
pub struct NDIterHandle<'ctx> {
instance: Instance<'ctx, Ptr<Struct<NDIter>>>,
/// The ndarray this [`NDIter`] to iterating over.
ndarray: NDArrayObject<'ctx>,
/// The current indices of [`NDIter`].
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
}
impl<'ctx> NDIterHandle<'ctx> {
/// Allocate an [`NDIter`] that iterates through an ndarray.
pub fn new<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
) -> Self {
let nditer = Struct(NDIter).alloca(generator, ctx);
let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
// The caller has the responsibility to allocate 'indices' for `NDIter`.
let indices = Int(SizeT).array_alloca(generator, ctx, ndims.value);
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
NDIterHandle { ndarray, instance: nditer, indices }
}
/// Is the current iteration valid?
///
/// If true, then `element`, `indices` and `nth` contain details about the current element.
///
/// If `ndarray` is unsized, this returns true only for the first iteration.
/// If `ndarray` is 0-sized, this always returns false.
#[must_use]
pub fn has_element<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<Bool>> {
call_nac3_nditer_has_element(generator, ctx, self.instance)
}
/// Go to the next element. If `has_element()` is false, then this has undefined behavior.
///
/// If `ndarray` is unsized, this can only be called once.
/// If `ndarray` is 0-sized, this can never be called.
pub fn next<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_nditer_next(generator, ctx, self.instance);
}
/// Get pointer to the current element.
#[must_use]
pub fn get_pointer<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
let p = self.instance.get(generator, ctx, |f| f.element);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
.unwrap()
}
/// Get the value of the current element.
#[must_use]
pub fn get_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
let p = self.get_pointer(generator, ctx);
let value = ctx.builder.build_load(p, "value").unwrap();
AnyObject { ty: self.ndarray.dtype, value }
}
/// Get the index of the current element if this ndarray were a flat ndarray.
#[must_use]
pub fn get_index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<SizeT>> {
self.instance.get(generator, ctx, |f| f.nth)
}
/// Get the indices of the current element.
#[must_use]
pub fn get_indices(&self) -> Instance<'ctx, Ptr<Int<SizeT>>> {
self.indices
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Iterate through every element in the ndarray.
///
/// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterHandle`] to
/// get properties of the current iteration (e.g., the current element, indices, etc.)
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
NDIterHandle<'ctx>,
) -> Result<(), String>,
{
gen_for_callback(
generator,
ctx,
Some("ndarray_foreach"),
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|generator, ctx, nditer| Ok(nditer.has_element(generator, ctx).value),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| {
nditer.next(generator, ctx);
Ok(())
},
)
}
}