From c6cab4f35dd0cb3d5baffa1b2eaae1a5010a0128 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 21 Aug 2024 12:35:22 +0800 Subject: [PATCH] core/ndstrides: implement ndarray subscript assignment --- nac3core/src/codegen/stmt.rs | 48 +++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 081a5cef..7e053c06 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -2,6 +2,12 @@ use super::{ super::symbol_resolver::ValueEnum, expr::destructure_range, irrt::{handle_slice_indices, list_slice_assignment}, + object::{ + any::AnyObject, + ndarray::{ + indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject, ScalarOrNDArray, + }, + }, CodeGenContext, CodeGenerator, }; use crate::{ @@ -401,7 +407,47 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { // Handle NDArray item assignment - todo!("ndarray subscript assignment is not yet implemented"); + // Process target + let target = generator + .gen_expr(ctx, target)? + .unwrap() + .to_basic_value_enum(ctx, generator, target_ty)?; + let target = AnyObject { value: target, ty: target_ty }; + + // Process key + let key = gen_ndarray_subscript_ndindices(generator, ctx, key)?; + + // Process value + let value = value.to_basic_value_enum(ctx, generator, value_ty)?; + let value = AnyObject { value, ty: value_ty }; + + /* + Reference code: + ```python + target = target[key] + value = np.asarray(value) + + shape = np.broadcast_shape((target, value)) + + target = np.broadcast_to(target, shape) + value = np.broadcast_to(value, shape) + + ...and finally copy 1-1 from value to target. + ``` + */ + + let target = NDArrayObject::from_object(generator, ctx, target); + let target = target.index(generator, ctx, &key); + + let value = + ScalarOrNDArray::split_object(generator, ctx, value).to_ndarray(generator, ctx); + + let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]); + + let target = broadcast_result.ndarrays[0]; + let value = broadcast_result.ndarrays[1]; + + target.copy_data_from(generator, ctx, value); } _ => { panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));