forked from M-Labs/nac3
core/ndstrides: implement ndarray iterator NDIter
This commit is contained in:
parent
a531d0127a
commit
f5827dae24
|
@ -3,4 +3,5 @@
|
||||||
#include <irrt/math_util.hpp>
|
#include <irrt/math_util.hpp>
|
||||||
#include <irrt/ndarray/basic.hpp>
|
#include <irrt/ndarray/basic.hpp>
|
||||||
#include <irrt/ndarray/def.hpp>
|
#include <irrt/ndarray/def.hpp>
|
||||||
|
#include <irrt/ndarray/iter.hpp>
|
||||||
#include <irrt/original.hpp>
|
#include <irrt/original.hpp>
|
|
@ -0,0 +1,139 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
#include <irrt/ndarray/def.hpp>
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @brief Helper struct to enumerate through an ndarray *efficiently*.
|
||||||
|
*
|
||||||
|
* Interesting cases:
|
||||||
|
* - If ndims == 0, there is one iteration.
|
||||||
|
* - If shape contains zeroes, there are no iterations.
|
||||||
|
*/
|
||||||
|
template <typename SizeT> struct NDIter
|
||||||
|
{
|
||||||
|
// Information about the ndarray being iterated over.
|
||||||
|
SizeT ndims;
|
||||||
|
SizeT *shape;
|
||||||
|
SizeT *strides;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The current indices.
|
||||||
|
*
|
||||||
|
* Must be allocated by the caller.
|
||||||
|
*/
|
||||||
|
SizeT *indices;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The nth (0-based) index of the current indices.
|
||||||
|
*
|
||||||
|
* Initially this is all 0s.
|
||||||
|
*/
|
||||||
|
SizeT nth;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Pointer to the current element.
|
||||||
|
*
|
||||||
|
* Initially this points to first element of the ndarray.
|
||||||
|
*/
|
||||||
|
uint8_t *element;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Cache for the product of shape.
|
||||||
|
*
|
||||||
|
* Could be 0 if `shape` has 0s in it.
|
||||||
|
*/
|
||||||
|
SizeT size;
|
||||||
|
|
||||||
|
// TODO:: Not implemented: There is something called backstrides to speedup iteration.
|
||||||
|
// See https://ajcr.net/stride-guide-part-1/, and https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
|
||||||
|
|
||||||
|
void initialize(SizeT ndims, SizeT *shape, SizeT *strides, uint8_t *element, SizeT *indices)
|
||||||
|
{
|
||||||
|
this->ndims = ndims;
|
||||||
|
this->shape = shape;
|
||||||
|
this->strides = strides;
|
||||||
|
|
||||||
|
this->indices = indices;
|
||||||
|
this->element = element;
|
||||||
|
|
||||||
|
// Compute size
|
||||||
|
this->size = 1;
|
||||||
|
for (SizeT i = 0; i < ndims; i++)
|
||||||
|
{
|
||||||
|
this->size *= shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++)
|
||||||
|
indices[axis] = 0;
|
||||||
|
nth = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize_by_ndarray(NDArray<SizeT> *ndarray, SizeT *indices)
|
||||||
|
{
|
||||||
|
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_next()
|
||||||
|
{
|
||||||
|
return nth < size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void next()
|
||||||
|
{
|
||||||
|
for (SizeT i = 0; i < ndims; i++)
|
||||||
|
{
|
||||||
|
SizeT axis = ndims - i - 1;
|
||||||
|
indices[axis]++;
|
||||||
|
if (indices[axis] >= shape[axis])
|
||||||
|
{
|
||||||
|
indices[axis] = 0;
|
||||||
|
|
||||||
|
// TODO: Can be optimized with backstrides.
|
||||||
|
element -= strides[axis] * (shape[axis] - 1);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
element += strides[axis];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nth++;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
{
|
||||||
|
void __nac3_nditer_initialize(NDIter<int32_t> *iter, NDArray<int32_t> *ndarray, int32_t *indices)
|
||||||
|
{
|
||||||
|
iter->initialize_by_ndarray(ndarray, indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_nditer_initialize64(NDIter<int64_t> *iter, NDArray<int64_t> *ndarray, int64_t *indices)
|
||||||
|
{
|
||||||
|
iter->initialize_by_ndarray(ndarray, indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool __nac3_nditer_has_next(NDIter<int32_t> *iter)
|
||||||
|
{
|
||||||
|
return iter->has_next();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool __nac3_nditer_has_next64(NDIter<int64_t> *iter)
|
||||||
|
{
|
||||||
|
return iter->has_next();
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_nditer_next(NDIter<int32_t> *iter)
|
||||||
|
{
|
||||||
|
iter->next();
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_nditer_next64(NDIter<int64_t> *iter)
|
||||||
|
{
|
||||||
|
iter->next();
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,7 +7,7 @@ use super::{
|
||||||
},
|
},
|
||||||
llvm_intrinsics,
|
llvm_intrinsics,
|
||||||
model::*,
|
model::*,
|
||||||
object::ndarray::NDArray,
|
object::ndarray::{nditer::NDIter, NDArray},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
use crate::codegen::classes::TypedArrayLikeAccessor;
|
use crate::codegen::classes::TypedArrayLikeAccessor;
|
||||||
|
@ -1090,3 +1090,32 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
|
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
|
||||||
CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
|
CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
||||||
|
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||||
|
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
) {
|
||||||
|
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
|
||||||
|
CallFunction::begin(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_nditer_has_next<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
||||||
|
) -> Instance<'ctx, Int<Bool>> {
|
||||||
|
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_next");
|
||||||
|
CallFunction::begin(generator, ctx, &name).arg(iter).returning_auto("has_next")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
||||||
|
) {
|
||||||
|
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
|
||||||
|
CallFunction::begin(generator, ctx, &name).arg(iter).returning_void();
|
||||||
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
pub mod nditer;
|
||||||
|
|
||||||
use inkwell::{context::Context, types::BasicType, values::PointerValue, AddressSpace};
|
use inkwell::{context::Context, types::BasicType, values::PointerValue, AddressSpace};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
|
|
@ -0,0 +1,177 @@
|
||||||
|
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
|
||||||
|
model::*,
|
||||||
|
object::any::AnyObject,
|
||||||
|
stmt::{gen_for_callback, BreakContinueHooks},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::NDArrayObject;
|
||||||
|
|
||||||
|
/// Fields of [`NDIter`]
|
||||||
|
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
|
||||||
|
pub ndims: F::Out<Int<SizeT>>,
|
||||||
|
pub shape: F::Out<Ptr<Int<SizeT>>>,
|
||||||
|
pub strides: F::Out<Ptr<Int<SizeT>>>,
|
||||||
|
|
||||||
|
pub indices: F::Out<Ptr<Int<SizeT>>>,
|
||||||
|
pub nth: F::Out<Int<SizeT>>,
|
||||||
|
pub element: F::Out<Ptr<Int<Byte>>>,
|
||||||
|
|
||||||
|
pub size: F::Out<Int<SizeT>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An IRRT helper structure used to iterate through an ndarray.
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct NDIter;
|
||||||
|
|
||||||
|
impl<'ctx> StructKind<'ctx> for NDIter {
|
||||||
|
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
|
||||||
|
|
||||||
|
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
||||||
|
Self::Fields {
|
||||||
|
ndims: traversal.add_auto("ndims"),
|
||||||
|
shape: traversal.add_auto("shape"),
|
||||||
|
strides: traversal.add_auto("strides"),
|
||||||
|
|
||||||
|
indices: traversal.add_auto("indices"),
|
||||||
|
nth: traversal.add_auto("nth"),
|
||||||
|
element: traversal.add_auto("element"),
|
||||||
|
|
||||||
|
size: traversal.add_auto("size"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A helper structure with a convenient interface to interact with [`NDIter`].
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct NDIterHandle<'ctx> {
|
||||||
|
instance: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
||||||
|
/// The ndarray this [`NDIter`] to iterating over.
|
||||||
|
ndarray: NDArrayObject<'ctx>,
|
||||||
|
/// The current indices of [`NDIter`].
|
||||||
|
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDIterHandle<'ctx> {
|
||||||
|
/// Allocate an [`NDIter`] that iterates through an ndarray.
|
||||||
|
pub fn new<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayObject<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
let nditer = Struct(NDIter).alloca(generator, ctx);
|
||||||
|
let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
|
||||||
|
|
||||||
|
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
||||||
|
let indices = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
||||||
|
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
|
||||||
|
|
||||||
|
NDIterHandle { ndarray, instance: nditer, indices }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Is there a next element?
|
||||||
|
///
|
||||||
|
/// If `ndarray` is unsized, this returns true only for the first iteration.
|
||||||
|
/// If `ndarray` is 0-sized, this always returns false.
|
||||||
|
#[must_use]
|
||||||
|
pub fn has_next<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Instance<'ctx, Int<Bool>> {
|
||||||
|
call_nac3_nditer_has_next(generator, ctx, self.instance)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Go to the next element. If `has_next()` is false, then this has undefined behavior.
|
||||||
|
///
|
||||||
|
/// If `ndarray` is unsized, this can only be called once.
|
||||||
|
/// If `ndarray` is 0-sized, this can never be called.
|
||||||
|
pub fn next<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) {
|
||||||
|
call_nac3_nditer_next(generator, ctx, self.instance);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get pointer to the current element.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_pointer<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
|
||||||
|
|
||||||
|
let p = self.instance.get(generator, ctx, |f| f.element);
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the value of the current element.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_scalar<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
let p = self.get_pointer(generator, ctx);
|
||||||
|
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||||
|
AnyObject { ty: self.ndarray.dtype, value }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the index of the current element if this ndarray were a flat ndarray.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_index<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Instance<'ctx, Int<SizeT>> {
|
||||||
|
self.instance.get(generator, ctx, |f| f.nth)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the indices of the current element.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_indices(&self) -> Instance<'ctx, Ptr<Int<SizeT>>> {
|
||||||
|
self.indices
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Iterate through every element in the ndarray.
|
||||||
|
///
|
||||||
|
/// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterHandle`] to
|
||||||
|
/// get properties of the current iteration (e.g., the current element, indices, etc.)
|
||||||
|
pub fn foreach<'a, G, F>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
body: F,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
F: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
BreakContinueHooks<'ctx>,
|
||||||
|
NDIterHandle<'ctx>,
|
||||||
|
) -> Result<(), String>,
|
||||||
|
{
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
Some("ndarray_foreach"),
|
||||||
|
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|
||||||
|
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|
||||||
|
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|
||||||
|
|generator, ctx, nditer| {
|
||||||
|
nditer.next(generator, ctx);
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue