diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 3595528..edebb4f 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -16,7 +16,11 @@ use super::{ gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, - values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, + types::ndarray::NDArrayType, + values::{ + ndarray::{RustNDIndex, ScalarOrNDArray}, + ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, RangeValue, + }, CodeGenContext, CodeGenerator, }; use crate::{ @@ -411,7 +415,54 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { // Handle NDArray item assignment - todo!("ndarray subscript assignment is not yet implemented"); + // Process target + let target = generator + .gen_expr(ctx, target)? + .unwrap() + .to_basic_value_enum(ctx, generator, target_ty)?; + + // Process key + let key = RustNDIndex::from_subscript_expr(generator, ctx, key)?; + + // Process value + let value = value.to_basic_value_enum(ctx, generator, value_ty)?; + + // Reference code: + // ```python + // target = target[key] + // value = np.asarray(value) + // + // shape = np.broadcast_shape((target, value)) + // + // target = np.broadcast_to(target, shape) + // value = np.broadcast_to(value, shape) + // + // # ...and finally copy 1-1 from value to target. + // ``` + + let target = NDArrayType::from_unifier_type(generator, ctx, target_ty) + .map_value(target.into_pointer_value(), None); + let target = target.index(generator, ctx, &key); + + let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value)) + .to_ndarray(generator, ctx); + + let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()] + .iter() + .filter_map(|ndims| *ndims) + .max(); + let broadcast_result = NDArrayType::new( + generator, + ctx.ctx, + value.get_type().element_type(), + broadcast_ndims, + ) + .broadcast(generator, ctx, &[target, value]); + + let target = broadcast_result.ndarrays[0]; + let value = broadcast_result.ndarrays[1]; + + target.copy_data_from(generator, ctx, value); } _ => { panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 170ac14..b668860 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -276,6 +276,38 @@ def test_ndarray_broadcast_to(): output_int32(np_shape(zs)[2]) output_ndarray_float_3(zs) +def test_ndarray_subscript_assignment(): + xs = np_array([[11.0, 22.0, 33.0, 44.0], [55.0, 66.0, 77.0, 88.0]]) + + xs[0, 0] = 99.0 + output_ndarray_float_2(xs) + + xs[0] = 100.0 + output_ndarray_float_2(xs) + + xs[:, ::2] = 101.0 + output_ndarray_float_2(xs) + + xs[1:, 0] = 102.0 + output_ndarray_float_2(xs) + + xs[0] = np_array([-1.0, -2.0, -3.0, -4.0]) + output_ndarray_float_2(xs) + + xs[:] = np_array([-5.0, -6.0, -7.0, -8.0]) + output_ndarray_float_2(xs) + + # Test assignment with memory sharing + ys1 = np_reshape(xs, (2, 4)) + ys2 = np_transpose(ys1) + ys3 = ys2[::-1, 0] + ys3[0] = -999.0 + + output_ndarray_float_2(xs) + output_ndarray_float_2(ys1) + output_ndarray_float_2(ys2) + output_ndarray_float_1(ys3) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1653,6 +1685,7 @@ def run() -> int32: test_ndarray_transpose() test_ndarray_reshape() test_ndarray_broadcast_to() + test_ndarray_subscript_assignment() test_ndarray_add() test_ndarray_add_broadcast()