From 48ce2d6c8a860deafe757ada61d2991fe9d97e30 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 16:47:57 +0800 Subject: [PATCH] core/ndstrides: add more ScalarOrNDArray and NDArrayObject utils --- nac3core/src/codegen/object/ndarray/mod.rs | 84 +++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 1310540c..f88828b0 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -25,7 +25,7 @@ use crate::{ CodeGenContext, CodeGenerator, }, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, - typecheck::typedef::Type, + typecheck::typedef::{Type, TypeEnum}, }; use super::{any::AnyObject, tuple::TupleObject}; @@ -482,6 +482,22 @@ impl<'ctx> NDArrayObject<'ctx> { TupleObject::from_objects(generator, ctx, objects) } + + /// Create an unsized ndarray to contain `object`. + pub fn make_unsized( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> NDArrayObject<'ctx> { + // We have to put the value on the stack to get a data pointer. + let data = ctx.builder.build_alloca(object.value.get_type(), "make_unsized").unwrap(); + ctx.builder.build_store(data, object.value).unwrap(); + let data = Ptr(Int(Byte)).pointer_cast(generator, ctx, data); + + let ndarray = NDArrayObject::alloca(generator, ctx, object.ty, 0); + ndarray.instance.set(ctx, |f| f.data, data); + ndarray + } } /// A convenience enum for implementing functions that acts on scalars or ndarrays or both. @@ -491,7 +507,50 @@ pub enum ScalarOrNDArray<'ctx> { NDArray(NDArrayObject<'ctx>), } +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), + ScalarOrNDArray::NDArray(_ndarray) => Err(()), + } + } +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(_scalar) => Err(()), + ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray), + } + } +} + impl<'ctx> ScalarOrNDArray<'ctx> { + /// Split on `object` either into a scalar or an ndarray. + /// + /// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`]. + /// + /// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`]. + pub fn split_object( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> ScalarOrNDArray<'ctx> { + match &*ctx.unifier.get_ty(object.ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayObject::from_object(generator, ctx, object); + ScalarOrNDArray::NDArray(ndarray) + } + _ => ScalarOrNDArray::Scalar(object), + } + } + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. #[must_use] pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { @@ -500,4 +559,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(), } } + + /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. + /// - If this is an ndarray, the ndarray is returned. + /// - If this is a scalar, this function returns new ndarray created with [`NDArrayObject::make_unsized`]. + pub fn to_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> NDArrayObject<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => *ndarray, + ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar), + } + } + + /// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`]. + #[must_use] + pub fn get_dtype(&self) -> Type { + match self { + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.ty, + } + } }