core/ndstrides: add NDArrayObject::atleast_nd
This commit is contained in:
parent
b6980c3a39
commit
9cfa2622ca
@ -25,6 +25,7 @@ pub mod factory;
|
|||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
pub mod nditer;
|
pub mod nditer;
|
||||||
pub mod shape_util;
|
pub mod shape_util;
|
||||||
|
pub mod view;
|
||||||
|
|
||||||
/// Fields of [`NDArray`]
|
/// Fields of [`NDArray`]
|
||||||
pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> {
|
pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> {
|
||||||
|
28
nac3core/src/codegen/object/ndarray/view.rs
Normal file
28
nac3core/src/codegen/object/ndarray/view.rs
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
use super::{indexing::RustNDIndex, NDArrayObject};
|
||||||
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Make sure the ndarray is at least `ndmin`-dimensional.
|
||||||
|
///
|
||||||
|
/// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape.
|
||||||
|
/// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray.
|
||||||
|
#[must_use]
|
||||||
|
pub fn atleast_nd<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndmin: u64,
|
||||||
|
) -> Self {
|
||||||
|
if self.ndims < ndmin {
|
||||||
|
// Extend the dimensions with np.newaxis.
|
||||||
|
let mut indices = vec![];
|
||||||
|
for _ in self.ndims..ndmin {
|
||||||
|
indices.push(RustNDIndex::NewAxis);
|
||||||
|
}
|
||||||
|
indices.push(RustNDIndex::Ellipsis);
|
||||||
|
self.index(generator, ctx, &indices)
|
||||||
|
} else {
|
||||||
|
*self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user