1
0
forked from M-Labs/nac3

core/ndstrides: add NDArrayObject::atleast_nd

This commit is contained in:
lyken 2024-08-20 15:02:42 +08:00
parent 3d734aef17
commit 9c5273ae09
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
2 changed files with 30 additions and 0 deletions

View File

@ -2,6 +2,7 @@ pub mod factory;
pub mod indexing;
pub mod nditer;
pub mod shape_util;
pub mod view;
use inkwell::{
context::Context,

View File

@ -0,0 +1,29 @@
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::{indexing::RustNDIndex, NDArrayObject};
impl<'ctx> NDArrayObject<'ctx> {
/// Make sure the ndarray is at least `ndmin`-dimensional.
///
/// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape.
/// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray.
#[must_use]
pub fn atleast_nd<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndmin: u64,
) -> Self {
if self.ndims < ndmin {
// Extend the dimensions with np.newaxis.
let mut indices = vec![];
for _ in self.ndims..ndmin {
indices.push(RustNDIndex::NewAxis);
}
indices.push(RustNDIndex::Ellipsis);
self.index(generator, ctx, &indices)
} else {
*self
}
}
}