1
0
forked from M-Labs/nac3

core/ndstrides: add NDArrayObject::atleast_nd

This commit is contained in:
lyken 2024-08-20 15:02:42 +08:00 committed by David Mak
parent b6980c3a39
commit 9cfa2622ca
2 changed files with 29 additions and 0 deletions

View File

@ -25,6 +25,7 @@ pub mod factory;
pub mod indexing; pub mod indexing;
pub mod nditer; pub mod nditer;
pub mod shape_util; pub mod shape_util;
pub mod view;
/// Fields of [`NDArray`] /// Fields of [`NDArray`]
pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> {

View File

@ -0,0 +1,28 @@
use super::{indexing::RustNDIndex, NDArrayObject};
use crate::codegen::{CodeGenContext, CodeGenerator};
impl<'ctx> NDArrayObject<'ctx> {
/// Make sure the ndarray is at least `ndmin`-dimensional.
///
/// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape.
/// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray.
#[must_use]
pub fn atleast_nd<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndmin: u64,
) -> Self {
if self.ndims < ndmin {
// Extend the dimensions with np.newaxis.
let mut indices = vec![];
for _ in self.ndims..ndmin {
indices.push(RustNDIndex::NewAxis);
}
indices.push(RustNDIndex::Ellipsis);
self.index(generator, ctx, &indices)
} else {
*self
}
}
}