forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: add NDArray with strides definition

This commit is contained in:
lyken 2024-07-28 14:23:31 +08:00
parent 7502b14d55
commit fd3d02bff0
4 changed files with 73 additions and 0 deletions

View File

@ -0,0 +1,44 @@
#pragma once
namespace {
/**
* @brief The NDArray object
*
* The official numpy implementations: https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
*/
template <typename SizeT>
struct NDArray {
/**
* @brief The underlying data this `ndarray` is pointing to.
*
* Must be set to `nullptr` to indicate that this NDArray's `data` is uninitialized.
*/
uint8_t* data;
/**
* @brief The number of bytes of a single element in `data`.
*/
SizeT itemsize;
/**
* @brief The number of dimensions of this shape.
*/
SizeT ndims;
/**
* @brief The NDArray shape, with length equal to `ndims`.
*
* Note that it may contain 0.
*/
SizeT* shape;
/**
* @brief Array strides, with length equal to `ndims`
*
* The stride values are in units of bytes, not number of elements.
*
* Note that `strides` can have negative values.
*/
SizeT* strides;
};
} // namespace

View File

@ -4,4 +4,5 @@
#include <irrt/core.hpp>
#include <irrt/error_context.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/utils.hpp>

View File

@ -1,2 +1,3 @@
pub mod cslice;
pub mod exception;
pub mod ndarray;

View File

@ -0,0 +1,27 @@
use crate::codegen::*;
pub struct NpArrayFields<F: FieldVisitor> {
pub data: F::Field<PtrModel<IntModel<Byte>>>,
pub itemsize: F::Field<IntModel<SizeT>>,
pub ndims: F::Field<IntModel<SizeT>>,
pub shape: F::Field<PtrModel<IntModel<SizeT>>>,
pub strides: F::Field<PtrModel<IntModel<SizeT>>>,
}
// TODO: Rename to `NDArray` when the old NDArray is removed.
#[derive(Debug, Clone, Copy, Default)]
pub struct NpArray;
impl StructKind for NpArray {
type Fields<F: FieldVisitor> = NpArrayFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields {
data: visitor.add("data"),
itemsize: visitor.add("itemsize"),
ndims: visitor.add("ndims"),
shape: visitor.add("shape"),
strides: visitor.add("strides"),
}
}
}