From dc91d9e35a52bff3b88789d14916ba037fd1fbd7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 10 Dec 2024 16:43:57 +0800 Subject: [PATCH] [core] codegen: Implement ScalarOrNDArray and use it in indexing Based on 8f9d2d82: core/ndstrides: implement ndarray indexing. --- nac3core/src/codegen/expr.rs | 25 ++++++------ nac3core/src/codegen/values/ndarray/mod.rs | 45 ++++++++++++++++++++++ 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 158dfe96..0118ca43 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -34,7 +34,8 @@ use super::{ }, types::{ndarray::NDArrayType, ListType}, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, + ndarray::{NDArrayValue, RustNDIndex}, + ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenTask, CodeGenerator, @@ -3486,19 +3487,21 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (ty, ndims) = - unpack_ndarray_var_tys(&mut ctx.unifier, value.custom.unwrap()); - - let v = if let Some(v) = generator.gen_expr(ctx, value)? { - v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? - .into_pointer_value() - } else { + let Some(ndarray) = generator.gen_expr(ctx, value)? else { return Ok(None); }; - let v = NDArrayType::from_unifier_type(generator, ctx, value.custom.unwrap()) - .map_value(v, None); - return gen_ndarray_subscript_expr(generator, ctx, ty, ndims, v, slice); + let ndarray_ty = value.custom.unwrap(); + let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?; + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_value(ndarray.into_pointer_value(), None); + + let indices = RustNDIndex::from_subscript_expr(generator, ctx, slice)?; + let result = ndarray + .index(generator, ctx, &indices) + .split_unsized(generator, ctx) + .to_basic_value_enum(); + return Ok(Some(ValueEnum::Dynamic(result))); } TypeEnum::TTuple { .. } => { let index: u32 = diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index fdf11dd2..b6d86de5 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -403,6 +403,33 @@ impl<'ctx> NDArrayValue<'ctx> { assert_eq!(self.dtype, src.dtype, "self and src dtype should match"); irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self); } + + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. + #[must_use] + pub fn is_unsized(&self) -> Option { + self.ndims.map(|ndims| ndims == 0) + } + + /// If this ndarray is unsized, return its sole value as an [`AnyObject`]. + /// Otherwise, do nothing and return the ndarray itself. + // TODO: Rename to get_unsized_element + pub fn split_unsized( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> ScalarOrNDArray<'ctx> { + let Some(is_unsized) = self.is_unsized() else { todo!() }; + + if is_unsized { + // NOTE: `np.size(self) == 0` here is never possible. + let zero = generator.get_size_type(ctx.ctx).const_zero(); + let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; + + ScalarOrNDArray::Scalar(value) + } else { + ScalarOrNDArray::NDArray(*self) + } + } } impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { @@ -884,3 +911,21 @@ pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec< } strides } + +/// A convenience enum for implementing functions that acts on scalars or ndarrays or both. +#[derive(Clone, Copy)] +pub enum ScalarOrNDArray<'ctx> { + Scalar(BasicValueEnum<'ctx>), + NDArray(NDArrayValue<'ctx>), +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. + #[must_use] + pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { + match self { + ScalarOrNDArray::Scalar(scalar) => scalar, + ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(), + } + } +}