diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 3d0b26f4..a888ea70 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -25,6 +25,7 @@ pub mod factory; pub mod indexing; pub mod nditer; pub mod shape_util; +pub mod view; /// Fields of [`NDArray`] pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { diff --git a/nac3core/src/codegen/object/ndarray/view.rs b/nac3core/src/codegen/object/ndarray/view.rs new file mode 100644 index 00000000..8afd5b4c --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/view.rs @@ -0,0 +1,28 @@ +use super::{indexing::RustNDIndex, NDArrayObject}; +use crate::codegen::{CodeGenContext, CodeGenerator}; + +impl<'ctx> NDArrayObject<'ctx> { + /// Make sure the ndarray is at least `ndmin`-dimensional. + /// + /// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape. + /// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray. + #[must_use] + pub fn atleast_nd( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndmin: u64, + ) -> Self { + if self.ndims < ndmin { + // Extend the dimensions with np.newaxis. + let mut indices = vec![]; + for _ in self.ndims..ndmin { + indices.push(RustNDIndex::NewAxis); + } + indices.push(RustNDIndex::Ellipsis); + self.index(generator, ctx, &indices) + } else { + *self + } + } +}