From 061747c67b5e032f31006798447b003ed816b914 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 12 Dec 2024 11:14:48 +0800 Subject: [PATCH] [core] codegen: Implement NDArrayValue::atleast_nd Based on 9cfa2622: core/ndstrides: add NDArrayObject::atleast_nd. --- nac3core/src/codegen/values/ndarray/mod.rs | 2 ++ nac3core/src/codegen/values/ndarray/view.rs | 36 +++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 nac3core/src/codegen/values/ndarray/view.rs diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index b6d86de5..12fd8634 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -19,10 +19,12 @@ use crate::codegen::{ pub use contiguous::*; pub use indexing::*; pub use nditer::*; +pub use view::*; mod contiguous; mod indexing; mod nditer; +mod view; /// Proxy type for accessing an `NDArray` value in LLVM. #[derive(Copy, Clone)] diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs new file mode 100644 index 00000000..70a9d659 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -0,0 +1,36 @@ +use std::iter::{once, repeat_n}; + +use itertools::Itertools; + +use crate::codegen::{ + values::ndarray::{NDArrayValue, RustNDIndex}, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayValue<'ctx> { + /// Make sure the ndarray is at least `ndmin`-dimensional. + /// + /// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended + /// to the shape. Otherwise, this function does nothing and return this ndarray. + #[must_use] + pub fn atleast_nd( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndmin: u64, + ) -> Self { + assert!(self.ndims.is_some(), "NDArrayValue::atleast_nd is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); + + let ndims = self.ndims.unwrap(); + + if ndims < ndmin { + // Extend the dimensions with np.newaxis. + let indices = repeat_n(RustNDIndex::NewAxis, (ndmin - ndims) as usize) + .chain(once(RustNDIndex::Ellipsis)) + .collect_vec(); + self.index(generator, ctx, &indices) + } else { + *self + } + } +}