diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index d91f734d6..6c8c9aabb 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -12,13 +12,16 @@ use super::{ TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; -use crate::codegen::{ - irrt, - llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, - stmt::gen_for_callback_incrementing, - type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField, TupleType}, - CodeGenContext, CodeGenerator, +use crate::{ + codegen::{ + irrt, + llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, + stmt::gen_for_callback_incrementing, + type_aligned_alloca, + types::{ndarray::NDArrayType, structure::StructField, TupleType}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, }; pub use broadcast::*; pub use contiguous::*; @@ -501,22 +504,38 @@ impl<'ctx> NDArrayValue<'ctx> { self.ndims.map(|ndims| ndims == 0) } - /// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`]. - /// Otherwise, do nothing and return the ndarray itself. - // TODO: Rename to get_unsized_element - pub fn split_unsized( + /// Returns the element present in this `ndarray` if this is unsized. + pub fn get_unsized_element( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - ) -> ScalarOrNDArray<'ctx> { - let Some(is_unsized) = self.is_unsized() else { todo!() }; + ) -> Option> { + let Some(is_unsized) = self.is_unsized() else { + panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + }; if is_unsized { // NOTE: `np.size(self) == 0` here is never possible. let zero = generator.get_size_type(ctx.ctx).const_zero(); let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; - ScalarOrNDArray::Scalar(value) + Some(value) + } else { + None + } + } + + /// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`]. + /// Otherwise, do nothing and return the ndarray itself. + pub fn split_unsized( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> ScalarOrNDArray<'ctx> { + assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + + if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) { + ScalarOrNDArray::Scalar(unsized_elem) } else { ScalarOrNDArray::NDArray(*self) } @@ -978,7 +997,52 @@ pub enum ScalarOrNDArray<'ctx> { NDArray(NDArrayValue<'ctx>), } +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), + ScalarOrNDArray::NDArray(_) => Err(()), + } + } +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayValue<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(_) => Err(()), + ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray), + } + } +} + impl<'ctx> ScalarOrNDArray<'ctx> { + /// Split on `object` either into a scalar or an ndarray. + /// + /// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`]. + /// + /// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`]. + pub fn from_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (object_ty, object): (Type, BasicValueEnum<'ctx>), + ) -> ScalarOrNDArray<'ctx> { + match &*ctx.unifier.get_ty(object_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + ScalarOrNDArray::NDArray(ndarray) + } + + _ => ScalarOrNDArray::Scalar(object), + } + } + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. #[must_use] pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { @@ -987,4 +1051,33 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(), } } + + /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. + /// + /// - If this is an ndarray, the ndarray is returned. + /// - If this is a scalar, this function returns new ndarray created with + /// [`NDArrayType::construct_unsized`]. + pub fn to_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> NDArrayValue<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => *ndarray, + ScalarOrNDArray::Scalar(scalar) => { + NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type()) + .construct_unsized(generator, ctx, scalar, None) + } + } + } + + /// Get the dtype of the ndarray created if this were called with + /// [`ScalarOrNDArray::to_ndarray`]. + #[must_use] + pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.get_type(), + } + } }