forked from M-Labs/nac3
core/ndstrides: add NDArray definition with strides
This commit is contained in:
parent
9c3a10377f
commit
92b97a9f4f
|
@ -0,0 +1,44 @@
|
|||
#pragma once
|
||||
|
||||
namespace {
|
||||
/**
|
||||
* @brief The NDArray object
|
||||
*
|
||||
* The official numpy implementations: https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
||||
*/
|
||||
template <typename SizeT>
|
||||
struct NDArray {
|
||||
/**
|
||||
* @brief The underlying data this `ndarray` is pointing to.
|
||||
*
|
||||
* Must be set to `nullptr` to indicate that this NDArray's `data` is uninitialized.
|
||||
*/
|
||||
uint8_t* data;
|
||||
|
||||
/**
|
||||
* @brief The number of bytes of a single element in `data`.
|
||||
*/
|
||||
SizeT itemsize;
|
||||
|
||||
/**
|
||||
* @brief The number of dimensions of this shape.
|
||||
*/
|
||||
SizeT ndims;
|
||||
|
||||
/**
|
||||
* @brief The NDArray shape, with length equal to `ndims`.
|
||||
*
|
||||
* Note that it may contain 0.
|
||||
*/
|
||||
SizeT* shape;
|
||||
|
||||
/**
|
||||
* @brief Array strides, with length equal to `ndims`
|
||||
*
|
||||
* The stride values are in units of bytes, not number of elements.
|
||||
*
|
||||
* Note that `strides` can have negative values.
|
||||
*/
|
||||
SizeT* strides;
|
||||
};
|
||||
} // namespace
|
|
@ -4,4 +4,5 @@
|
|||
#include <irrt/core.hpp>
|
||||
#include <irrt/error_context.hpp>
|
||||
#include <irrt/int_defs.hpp>
|
||||
#include <irrt/ndarray/def.hpp>
|
||||
#include <irrt/utils.hpp>
|
|
@ -1,2 +1,3 @@
|
|||
pub mod cslice;
|
||||
pub mod exception;
|
||||
pub mod ndarray;
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
use crate::codegen::*;
|
||||
|
||||
pub struct NpArrayFields<'ctx> {
|
||||
pub data: Field<PointerModel<ByteModel>>,
|
||||
pub itemsize: Field<SizeTModel<'ctx>>,
|
||||
pub ndims: Field<SizeTModel<'ctx>>,
|
||||
pub shape: Field<PointerModel<SizeTModel<'ctx>>>,
|
||||
pub strides: Field<PointerModel<SizeTModel<'ctx>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct NpArray<'ctx> {
|
||||
pub sizet: SizeTModel<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> StructKind<'ctx> for NpArray<'ctx> {
|
||||
type Fields = NpArrayFields<'ctx>;
|
||||
|
||||
fn struct_name(&self) -> &'static str {
|
||||
"NDArray"
|
||||
}
|
||||
|
||||
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
|
||||
NpArrayFields {
|
||||
data: builder.add_field_auto("data"),
|
||||
itemsize: builder.add_field("itemsize", self.sizet),
|
||||
ndims: builder.add_field("ndims", self.sizet),
|
||||
shape: builder.add_field("shape", PointerModel(self.sizet)),
|
||||
strides: builder.add_field("strides", PointerModel(self.sizet)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> Pointer<'ctx, StructModel<NpArray<'ctx>>> {
|
||||
/// Get an [`ArraySlice`] of [`NpArrayFields::shape`] with [`NpArrayFields::ndims`] as its length.
|
||||
pub fn shape_slice(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) -> ArraySlice<'ctx, SizeTModel<'ctx>, SizeTModel<'ctx>> {
|
||||
let ndims = self.gep(ctx, |f| f.ndims).load(ctx, "ndims");
|
||||
let shape_base_ptr = self.gep(ctx, |f| f.shape).load(ctx, "shape");
|
||||
ArraySlice { num_elements: ndims, pointer: shape_base_ptr }
|
||||
}
|
||||
|
||||
/// Get an [`ArraySlice`] of [`NpArrayFields::strides`] with [`NpArrayFields::ndims`] as its length.
|
||||
pub fn strides_slice(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) -> ArraySlice<'ctx, SizeTModel<'ctx>, SizeTModel<'ctx>> {
|
||||
let ndims = self.gep(ctx, |f| f.ndims).load(ctx, "ndims");
|
||||
let strides_base_ptr = self.gep(ctx, |f| f.strides).load(ctx, "strides");
|
||||
ArraySlice { num_elements: ndims, pointer: strides_base_ptr }
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue