forked from M-Labs/nac3
core/ndstrides: add ndarray iterator (NDIter)
This commit is contained in:
parent
4777909543
commit
48fb3ff5ad
@ -3,4 +3,5 @@
|
||||
#include <irrt/math_util.hpp>
|
||||
#include <irrt/ndarray/basic.hpp>
|
||||
#include <irrt/ndarray/def.hpp>
|
||||
#include <irrt/ndarray/iter.hpp>
|
||||
#include <irrt/original.hpp>
|
142
nac3core/irrt/irrt/ndarray/iter.hpp
Normal file
142
nac3core/irrt/irrt/ndarray/iter.hpp
Normal file
@ -0,0 +1,142 @@
|
||||
#pragma once
|
||||
|
||||
#include <irrt/int_types.hpp>
|
||||
#include <irrt/ndarray/def.hpp>
|
||||
|
||||
namespace
|
||||
{
|
||||
/**
|
||||
* @brief Helper struct to enumerate through all indices under a shape.
|
||||
*
|
||||
* i.e., If `shape` is `[3, 2]`, by repeating `next()`, then you get:
|
||||
* - `[0, 0]`
|
||||
* - `[0, 1]`
|
||||
* - `[1, 0]`
|
||||
* - `[1, 1]`
|
||||
* - `[2, 0]`
|
||||
* - `[2, 1]`
|
||||
* - end.
|
||||
*
|
||||
* Interesting cases:
|
||||
* - If ndims == 0, there is one enumeration.
|
||||
* - If shape contains zeroes, there are no enumerations.
|
||||
*/
|
||||
template <typename SizeT> struct NDIter
|
||||
{
|
||||
SizeT ndims;
|
||||
SizeT *shape;
|
||||
SizeT *strides;
|
||||
|
||||
/**
|
||||
* @brief The current indices.
|
||||
*
|
||||
* Must be allocated by the caller.
|
||||
*/
|
||||
SizeT *indices;
|
||||
|
||||
/**
|
||||
* @brief The nth (0-based) index of the current indices.
|
||||
*/
|
||||
SizeT nth;
|
||||
|
||||
/**
|
||||
* @brief Pointer to the current element.
|
||||
*/
|
||||
uint8_t *element;
|
||||
|
||||
/**
|
||||
* @brief The product of shape.
|
||||
*/
|
||||
SizeT size;
|
||||
|
||||
// TODO:: There is something called backstrides to speedup iteration.
|
||||
// See https://ajcr.net/stride-guide-part-1/, and https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
|
||||
// Maybe LLVM is clever and knows how to optimize.
|
||||
|
||||
void initialize(SizeT ndims, SizeT *shape, SizeT *strides, uint8_t *element, SizeT *indices)
|
||||
{
|
||||
this->ndims = ndims;
|
||||
this->shape = shape;
|
||||
this->strides = strides;
|
||||
|
||||
this->indices = indices;
|
||||
this->element = element;
|
||||
|
||||
// Compute size and backstrides
|
||||
this->size = 1;
|
||||
for (SizeT i = 0; i < ndims; i++)
|
||||
{
|
||||
this->size *= shape[i];
|
||||
}
|
||||
|
||||
for (SizeT axis = 0; axis < ndims; axis++)
|
||||
indices[axis] = 0;
|
||||
nth = 0;
|
||||
}
|
||||
|
||||
void initialize_by_ndarray(NDArray<SizeT> *ndarray, SizeT *indices)
|
||||
{
|
||||
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
|
||||
}
|
||||
|
||||
bool has_next()
|
||||
{
|
||||
return nth < size;
|
||||
}
|
||||
|
||||
void next()
|
||||
{
|
||||
for (SizeT i = 0; i < ndims; i++)
|
||||
{
|
||||
SizeT axis = ndims - i - 1;
|
||||
indices[axis]++;
|
||||
if (indices[axis] >= shape[axis])
|
||||
{
|
||||
indices[axis] = 0;
|
||||
|
||||
// TODO: Can be optimized with backstrides.
|
||||
element -= strides[axis] * (shape[axis] - 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
element += strides[axis];
|
||||
break;
|
||||
}
|
||||
}
|
||||
nth++;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
extern "C"
|
||||
{
|
||||
void __nac3_nditer_initialize(NDIter<int32_t> *iter, NDArray<int32_t> *ndarray, int32_t *indices)
|
||||
{
|
||||
iter->initialize_by_ndarray(ndarray, indices);
|
||||
}
|
||||
|
||||
void __nac3_nditer_initialize64(NDIter<int64_t> *iter, NDArray<int64_t> *ndarray, int64_t *indices)
|
||||
{
|
||||
iter->initialize_by_ndarray(ndarray, indices);
|
||||
}
|
||||
|
||||
bool __nac3_nditer_has_next(NDIter<int32_t> *iter)
|
||||
{
|
||||
return iter->has_next();
|
||||
}
|
||||
|
||||
bool __nac3_nditer_has_next64(NDIter<int64_t> *iter)
|
||||
{
|
||||
return iter->has_next();
|
||||
}
|
||||
|
||||
void __nac3_nditer_next(NDIter<int32_t> *iter)
|
||||
{
|
||||
iter->next();
|
||||
}
|
||||
|
||||
void __nac3_nditer_next64(NDIter<int64_t> *iter)
|
||||
{
|
||||
iter->next();
|
||||
}
|
||||
}
|
@ -7,7 +7,7 @@ use super::{
|
||||
},
|
||||
llvm_intrinsics,
|
||||
model::*,
|
||||
object::ndarray::NDArray,
|
||||
object::ndarray::{nditer::NDIter, NDArray},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::codegen::classes::TypedArrayLikeAccessor;
|
||||
@ -1090,3 +1090,32 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
|
||||
CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
|
||||
}
|
||||
|
||||
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||
) {
|
||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
|
||||
CallFunction::begin(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
|
||||
}
|
||||
|
||||
pub fn call_nac3_nditer_has_next<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
||||
) -> Instance<'ctx, Int<Bool>> {
|
||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_next");
|
||||
CallFunction::begin(generator, ctx, &name).arg(iter).returning_auto("has_next")
|
||||
}
|
||||
|
||||
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
||||
) {
|
||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
|
||||
CallFunction::begin(generator, ctx, &name).arg(iter).returning_void();
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
pub mod nditer;
|
||||
|
||||
use inkwell::{context::Context, types::BasicType, values::PointerValue, AddressSpace};
|
||||
|
||||
use crate::{
|
||||
|
168
nac3core/src/codegen/object/ndarray/nditer.rs
Normal file
168
nac3core/src/codegen/object/ndarray/nditer.rs
Normal file
@ -0,0 +1,168 @@
|
||||
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
|
||||
|
||||
use crate::codegen::{
|
||||
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
|
||||
model::*,
|
||||
object::any::AnyObject,
|
||||
stmt::{gen_for_callback, BreakContinueHooks},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
use super::NDArrayObject;
|
||||
|
||||
/// Fields of [`NDIter`]
|
||||
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
|
||||
pub ndims: F::Out<Int<SizeT>>,
|
||||
pub shape: F::Out<Ptr<Int<SizeT>>>,
|
||||
pub strides: F::Out<Ptr<Int<SizeT>>>,
|
||||
|
||||
pub indices: F::Out<Ptr<Int<SizeT>>>,
|
||||
pub nth: F::Out<Int<SizeT>>,
|
||||
pub element: F::Out<Ptr<Int<Byte>>>,
|
||||
|
||||
pub size: F::Out<Int<SizeT>>,
|
||||
}
|
||||
|
||||
/// An IRRT helper structure used to iterate through an ndarray.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct NDIter;
|
||||
|
||||
impl<'ctx> StructKind<'ctx> for NDIter {
|
||||
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
|
||||
|
||||
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
||||
Self::Fields {
|
||||
ndims: traversal.add_auto("ndims"),
|
||||
shape: traversal.add_auto("shape"),
|
||||
strides: traversal.add_auto("strides"),
|
||||
|
||||
indices: traversal.add_auto("indices"),
|
||||
nth: traversal.add_auto("nth"),
|
||||
element: traversal.add_auto("element"),
|
||||
|
||||
size: traversal.add_auto("size"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A helper structure containing extra details of an [`NDIter`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NDIterHandle<'ctx> {
|
||||
instance: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
||||
/// The ndarray this [`NDIter`] to iterating over.
|
||||
ndarray: NDArrayObject<'ctx>,
|
||||
/// The current indices of [`NDIter`].
|
||||
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||
}
|
||||
|
||||
impl<'ctx> NDIterHandle<'ctx> {
|
||||
/// Allocate an [`NDIter`] that iterates through an ndarray.
|
||||
pub fn new<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayObject<'ctx>,
|
||||
) -> Self {
|
||||
let nditer = Struct(NDIter).alloca(generator, ctx);
|
||||
let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
|
||||
|
||||
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
||||
let indices = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
||||
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
|
||||
|
||||
NDIterHandle { ndarray, instance: nditer, indices }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn has_next<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Instance<'ctx, Int<Bool>> {
|
||||
call_nac3_nditer_has_next(generator, ctx, self.instance)
|
||||
}
|
||||
|
||||
pub fn next<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) {
|
||||
call_nac3_nditer_next(generator, ctx, self.instance);
|
||||
}
|
||||
|
||||
/// Get pointer to the current element.
|
||||
#[must_use]
|
||||
pub fn get_pointer<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
|
||||
|
||||
let p = self.instance.get(generator, ctx, |f| f.element);
|
||||
ctx.builder
|
||||
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Get the value of the current element.
|
||||
#[must_use]
|
||||
pub fn get_scalar<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> AnyObject<'ctx> {
|
||||
let p = self.get_pointer(generator, ctx);
|
||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||
AnyObject { ty: self.ndarray.dtype, value }
|
||||
}
|
||||
|
||||
/// Get the index of the current element.
|
||||
#[must_use]
|
||||
pub fn get_index<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Instance<'ctx, Int<SizeT>> {
|
||||
self.instance.get(generator, ctx, |f| f.nth)
|
||||
}
|
||||
|
||||
/// Get the indices of the current element.
|
||||
#[must_use]
|
||||
pub fn get_indices(&self) -> Instance<'ctx, Ptr<Int<SizeT>>> {
|
||||
self.indices
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Iterate through every element in the ndarray.
|
||||
///
|
||||
/// `body` also access to [`BreakContinueHooks`] to short-circuit.
|
||||
pub fn foreach<'a, G, F>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
body: F,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
F: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
NDIterHandle<'ctx>,
|
||||
) -> Result<(), String>,
|
||||
{
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
Some("ndarray_foreach"),
|
||||
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|
||||
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|
||||
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|
||||
|generator, ctx, nditer| {
|
||||
nditer.next(generator, ctx);
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user