From 9c5273ae09d21948fe43ac136743d815e0597133 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 15:02:42 +0800 Subject: [PATCH] core/ndstrides: add NDArrayObject::atleast_nd --- nac3core/src/codegen/object/ndarray/mod.rs | 1 + nac3core/src/codegen/object/ndarray/view.rs | 29 +++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 nac3core/src/codegen/object/ndarray/view.rs diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 87fccae..ae44b07 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -2,6 +2,7 @@ pub mod factory; pub mod indexing; pub mod nditer; pub mod shape_util; +pub mod view; use inkwell::{ context::Context, diff --git a/nac3core/src/codegen/object/ndarray/view.rs b/nac3core/src/codegen/object/ndarray/view.rs new file mode 100644 index 0000000..8776d94 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/view.rs @@ -0,0 +1,29 @@ +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::{indexing::RustNDIndex, NDArrayObject}; + +impl<'ctx> NDArrayObject<'ctx> { + /// Make sure the ndarray is at least `ndmin`-dimensional. + /// + /// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape. + /// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray. + #[must_use] + pub fn atleast_nd( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndmin: u64, + ) -> Self { + if self.ndims < ndmin { + // Extend the dimensions with np.newaxis. + let mut indices = vec![]; + for _ in self.ndims..ndmin { + indices.push(RustNDIndex::NewAxis); + } + indices.push(RustNDIndex::Ellipsis); + self.index(generator, ctx, &indices) + } else { + *self + } + } +}