diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 449b1a61..a4d4229b 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -12,13 +12,16 @@ use super::{ TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; -use crate::codegen::{ - irrt, - llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, - stmt::gen_for_callback_incrementing, - type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField, TupleType}, - CodeGenContext, CodeGenerator, +use crate::{ + codegen::{ + irrt, + llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, + stmt::gen_for_callback_incrementing, + type_aligned_alloca, + types::{ndarray::NDArrayType, structure::StructField, TupleType}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, }; pub use broadcast::*; pub use contiguous::*; @@ -976,7 +979,52 @@ pub enum ScalarOrNDArray<'ctx> { NDArray(NDArrayValue<'ctx>), } +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), + ScalarOrNDArray::NDArray(_) => Err(()), + } + } +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayValue<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(_) => Err(()), + ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray), + } + } +} + impl<'ctx> ScalarOrNDArray<'ctx> { + /// Split on `object` either into a scalar or an ndarray. + /// + /// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`]. + /// + /// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`]. + pub fn from_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (object_ty, object): (Type, BasicValueEnum<'ctx>), + ) -> ScalarOrNDArray<'ctx> { + match &*ctx.unifier.get_ty(object_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + ScalarOrNDArray::NDArray(ndarray) + } + + _ => ScalarOrNDArray::Scalar(object), + } + } + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. #[must_use] pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { @@ -985,4 +1033,31 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(), } } + + /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. + /// + /// - If this is an ndarray, the ndarray is returned. + /// - If this is a scalar, this function returns new ndarray created with [`NDArrayObject::make_unsized`]. + pub fn to_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> NDArrayValue<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => *ndarray, + ScalarOrNDArray::Scalar(scalar) => { + NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type()) + .construct_unsized(generator, ctx, scalar, None) + } + } + } + + /// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`]. + #[must_use] + pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.get_type(), + } + } }