From f731e604c78076b27c1e823b10f894043053f0e8 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 16:47:57 +0800 Subject: [PATCH 1/2] core/ndstrides: add more ScalarOrNDArray and NDArrayObject utils --- nac3core/src/codegen/object/ndarray/mod.rs | 84 +++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 5a22a610..655f8e76 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -18,7 +18,7 @@ use crate::{ CodeGenContext, CodeGenerator, }, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, - typecheck::typedef::Type, + typecheck::typedef::{Type, TypeEnum}, }; pub mod array; @@ -483,6 +483,22 @@ impl<'ctx> NDArrayObject<'ctx> { TupleObject::from_objects(generator, ctx, objects) } + + /// Create an unsized ndarray to contain `object`. + pub fn make_unsized( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> NDArrayObject<'ctx> { + // We have to put the value on the stack to get a data pointer. + let data = ctx.builder.build_alloca(object.value.get_type(), "make_unsized").unwrap(); + ctx.builder.build_store(data, object.value).unwrap(); + let data = Ptr(Int(Byte)).pointer_cast(generator, ctx, data); + + let ndarray = NDArrayObject::alloca(generator, ctx, object.ty, 0); + ndarray.instance.set(ctx, |f| f.data, data); + ndarray + } } /// A convenience enum for implementing functions that acts on scalars or ndarrays or both. @@ -492,7 +508,50 @@ pub enum ScalarOrNDArray<'ctx> { NDArray(NDArrayObject<'ctx>), } +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), + ScalarOrNDArray::NDArray(_ndarray) => Err(()), + } + } +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(_scalar) => Err(()), + ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray), + } + } +} + impl<'ctx> ScalarOrNDArray<'ctx> { + /// Split on `object` either into a scalar or an ndarray. + /// + /// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`]. + /// + /// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`]. + pub fn split_object( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> ScalarOrNDArray<'ctx> { + match &*ctx.unifier.get_ty(object.ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayObject::from_object(generator, ctx, object); + ScalarOrNDArray::NDArray(ndarray) + } + _ => ScalarOrNDArray::Scalar(object), + } + } + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. #[must_use] pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { @@ -501,4 +560,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(), } } + + /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. + /// - If this is an ndarray, the ndarray is returned. + /// - If this is a scalar, this function returns new ndarray created with [`NDArrayObject::make_unsized`]. + pub fn to_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> NDArrayObject<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => *ndarray, + ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar), + } + } + + /// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`]. + #[must_use] + pub fn get_dtype(&self) -> Type { + match self { + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.ty, + } + } } -- 2.44.2 From 5bed394ef7bee4622e8520f2340397ee07c11b3c Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 21 Aug 2024 12:35:22 +0800 Subject: [PATCH 2/2] core/ndstrides: implement subscript assignment Overlapping is not handled. Currently it has undefined behavior. --- nac3core/src/codegen/stmt.rs | 48 +++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index cfc188c8..8427cab0 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -17,6 +17,12 @@ use super::{ gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, + object::{ + any::AnyObject, + ndarray::{ + indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject, ScalarOrNDArray, + }, + }, CodeGenContext, CodeGenerator, }; use crate::{ @@ -411,7 +417,47 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { // Handle NDArray item assignment - todo!("ndarray subscript assignment is not yet implemented"); + // Process target + let target = generator + .gen_expr(ctx, target)? + .unwrap() + .to_basic_value_enum(ctx, generator, target_ty)?; + let target = AnyObject { value: target, ty: target_ty }; + + // Process key + let key = gen_ndarray_subscript_ndindices(generator, ctx, key)?; + + // Process value + let value = value.to_basic_value_enum(ctx, generator, value_ty)?; + let value = AnyObject { value, ty: value_ty }; + + /* + Reference code: + ```python + target = target[key] + value = np.asarray(value) + + shape = np.broadcast_shape((target, value)) + + target = np.broadcast_to(target, shape) + value = np.broadcast_to(value, shape) + + ...and finally copy 1-1 from value to target. + ``` + */ + + let target = NDArrayObject::from_object(generator, ctx, target); + let target = target.index(generator, ctx, &key); + + let value = + ScalarOrNDArray::split_object(generator, ctx, value).to_ndarray(generator, ctx); + + let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]); + + let target = broadcast_result.ndarrays[0]; + let value = broadcast_result.ndarrays[1]; + + target.copy_data_from(generator, ctx, value); } _ => { panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); -- 2.44.2