[core] codegen: Implement NDArrayValue::atleast_nd

Based on 9cfa2622: core/ndstrides: add NDArrayObject::atleast_nd.
This commit is contained in:
David Mak 2024-12-12 11:14:48 +08:00
parent dc91d9e35a
commit 061747c67b
2 changed files with 38 additions and 0 deletions

View File

@ -19,10 +19,12 @@ use crate::codegen::{
pub use contiguous::*; pub use contiguous::*;
pub use indexing::*; pub use indexing::*;
pub use nditer::*; pub use nditer::*;
pub use view::*;
mod contiguous; mod contiguous;
mod indexing; mod indexing;
mod nditer; mod nditer;
mod view;
/// Proxy type for accessing an `NDArray` value in LLVM. /// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]

View File

@ -0,0 +1,36 @@
use std::iter::{once, repeat_n};
use itertools::Itertools;
use crate::codegen::{
values::ndarray::{NDArrayValue, RustNDIndex},
CodeGenContext, CodeGenerator,
};
impl<'ctx> NDArrayValue<'ctx> {
/// Make sure the ndarray is at least `ndmin`-dimensional.
///
/// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended
/// to the shape. Otherwise, this function does nothing and return this ndarray.
#[must_use]
pub fn atleast_nd<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndmin: u64,
) -> Self {
assert!(self.ndims.is_some(), "NDArrayValue::atleast_nd is only supported for instances with compile-time known ndims (self.ndims = Some(...))");
let ndims = self.ndims.unwrap();
if ndims < ndmin {
// Extend the dimensions with np.newaxis.
let indices = repeat_n(RustNDIndex::NewAxis, (ndmin - ndims) as usize)
.chain(once(RustNDIndex::Ellipsis))
.collect_vec();
self.index(generator, ctx, &indices)
} else {
*self
}
}
}